Add files via upload

This commit is contained in:
Richard Mwewa
2023-12-03 18:55:21 +00:00
committed by GitHub
parent 145d33ef9e
commit 981fbfcac1
3 changed files with 376 additions and 376 deletions

View File

@@ -1,42 +1,42 @@
import os
__author__: str = "Richard Mwewa"
__about_author__: str = "https://about/me/rly0nheart"
__version__: str = "2.0.0.0"
__description__: str = f"""
# RPST (Reddit Post Scraping Tool) {__version__}
> Retrieve Reddit posts that contain the specified keyword from a specified subreddit.
"""
__epilog__: str = f"""
# by [{__author__}]({__about_author__})
```
MIT License
Copyright (c) 2023 {__author__}
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
```
"""
# Construct path to the program's directory
PROGRAM_DIRECTORY: str = os.path.expanduser(
os.path.join("~", "reddit_post_scraping_tool")
)
import os
__author__: str = "Richard Mwewa"
__about_author__: str = "https://about/me/rly0nheart"
__version__: str = "2.0.0.0"
__description__: str = f"""
# RPST (Reddit Post Scraping Tool) {__version__}
> Retrieve Reddit posts that contain the specified keyword from a specified subreddit.
"""
__epilog__: str = f"""
# by [{__author__}]({__about_author__})
```
MIT License
Copyright (c) 2023 {__author__}
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
```
"""
# Construct path to the program's directory
PROGRAM_DIRECTORY: str = os.path.expanduser(
os.path.join("~", "reddit_post_scraping_tool")
)

View File

@@ -1,164 +1,164 @@
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
from typing import Union
import aiohttp
from .coreutils import log
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
REDDIT_ENDPOINT: str = "https://www.reddit.com"
PYPI_PROJECT_ENDPOINT: str = "https://pypi.org/pypi/reddit-post-scraping-tool/json"
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
async def get_data(session: aiohttp.ClientSession, endpoint: str) -> Union[dict, list]:
"""
Fetches JSON data from a given API endpoint.
:param session: aiohttp session to use for the request.
:param endpoint: The API endpoint to fetch data from.
:return: Returns JSON data as a dictionary or list. Returns an empty dict if fetching fails.
"""
try:
async with session.get(
endpoint,
) as response:
if response.status == 200:
return await response.json()
else:
error_message = await response.json()
log.error(f"An API error occurred: {error_message}")
return {}
except aiohttp.ClientConnectionError as error:
log.error(f"An HTTP error occurred: {error}")
return {}
except Exception as error:
log.critical(f"An unknown error occurred: {error}")
return {}
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
async def get_updates(session: aiohttp.ClientSession):
"""
Gets and compares the current program version with the remote version.
Assumes version format: major.minor.patch.prefix
:param session: aiohttp session to use for the request.
"""
from . import __version__
# Make a GET request to PyPI to get the project's latest release.
response: dict = await get_data(endpoint=PYPI_PROJECT_ENDPOINT, session=session)
if response.get("info"):
release: dict = response.get("info")
remote_version: str = release.get("version")
# Splitting the version strings into components
remote_parts: list = remote_version.split(".")
local_parts: list = __version__.split(".")
update_message: str = ""
# Check for differences in version parts
if remote_parts[0] != local_parts[0]:
update_message = (
f"MAJOR update ({remote_version}) available."
f" It might introduce significant changes."
)
elif remote_parts[1] != local_parts[1]:
update_message = (
f"MINOR update ({remote_version}) available."
f" Includes small feature changes/improvements."
)
elif remote_parts[2] != local_parts[2]:
update_message = (
f"PATCH update ({remote_version}) available."
f" Generally for bug fixes and small tweaks."
)
elif (
len(remote_parts) > 3
and len(local_parts) > 3
and remote_parts[3] != local_parts[3]
):
update_message = (
f"BUILD update ({remote_version}) available."
f" Might be for specific builds or special versions."
)
if update_message:
log.info(update_message)
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
async def get_posts(
subreddit: str,
listing: str,
timeframe: str,
limit: int,
session: aiohttp.ClientSession,
) -> list:
all_posts = await paginated_posts(
posts_endpoint=f"{REDDIT_ENDPOINT}/r/{subreddit}/{listing}.json?limit={limit}&t={timeframe}",
limit=limit,
session=session,
)
return all_posts[:limit]
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
async def paginated_posts(
posts_endpoint: str, limit: int, session: aiohttp.ClientSession
) -> list:
"""
Paginates and retrieves posts until the specified limit is reached.
:param posts_endpoint: API endpoint for retrieving posts.
:param limit: Limit of the number of posts to retrieve.
:param session: aiohttp session to use for the request.
:return: A list of all posts.
"""
all_posts: list = []
last_post_id: str = ""
# Determine whether to use the 'after' parameter
use_after: bool = limit > 100
while len(all_posts) < limit:
# Make the API request with the 'after' parameter if it's provided and the limit is more than 100
if use_after and last_post_id:
endpoint_with_after: str = f"{posts_endpoint}&after={last_post_id}"
else:
endpoint_with_after: str = posts_endpoint
posts_data: dict = await get_data(endpoint=endpoint_with_after, session=session)
posts_children: list = posts_data.get("data", {}).get("children", [])
# If there are no more posts, break out of the loop
if not posts_children:
break
all_posts.extend(posts_children)
# We use the id of the last post in the list to paginate to the next posts
last_post_id: str = all_posts[-1].get("data").get("id")
return all_posts
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
from typing import Union
import aiohttp
from .coreutils import log
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
REDDIT_ENDPOINT: str = "https://www.reddit.com"
PYPI_PROJECT_ENDPOINT: str = "https://pypi.org/pypi/reddit-post-scraping-tool/json"
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
async def get_data(session: aiohttp.ClientSession, endpoint: str) -> Union[dict, list]:
"""
Fetches JSON data from a given API endpoint.
:param session: aiohttp session to use for the request.
:param endpoint: The API endpoint to fetch data from.
:return: Returns JSON data as a dictionary or list. Returns an empty dict if fetching fails.
"""
try:
async with session.get(
endpoint,
) as response:
if response.status == 200:
return await response.json()
else:
error_message = await response.json()
log.error(f"An API error occurred: {error_message}")
return {}
except aiohttp.ClientConnectionError as error:
log.error(f"An HTTP error occurred: {error}")
return {}
except Exception as error:
log.critical(f"An unknown error occurred: {error}")
return {}
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
async def get_updates(session: aiohttp.ClientSession):
"""
Gets and compares the current program version with the remote version.
Assumes version format: major.minor.patch.prefix
:param session: aiohttp session to use for the request.
"""
from . import __version__
# Make a GET request to PyPI to get the project's latest release.
response: dict = await get_data(endpoint=PYPI_PROJECT_ENDPOINT, session=session)
if response.get("info"):
release: dict = response.get("info")
remote_version: str = release.get("version")
# Splitting the version strings into components
remote_parts: list = remote_version.split(".")
local_parts: list = __version__.split(".")
update_message: str = ""
# Check for differences in version parts
if remote_parts[0] != local_parts[0]:
update_message = (
f"MAJOR update ({remote_version}) available."
f" It might introduce significant changes."
)
elif remote_parts[1] != local_parts[1]:
update_message = (
f"MINOR update ({remote_version}) available."
f" Includes small feature changes/improvements."
)
elif remote_parts[2] != local_parts[2]:
update_message = (
f"PATCH update ({remote_version}) available."
f" Generally for bug fixes and small tweaks."
)
elif (
len(remote_parts) > 3
and len(local_parts) > 3
and remote_parts[3] != local_parts[3]
):
update_message = (
f"BUILD update ({remote_version}) available."
f" Might be for specific builds or special versions."
)
if update_message:
log.info(update_message)
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
async def get_posts(
subreddit: str,
listing: str,
timeframe: str,
limit: int,
session: aiohttp.ClientSession,
) -> list:
all_posts = await paginated_posts(
posts_endpoint=f"{REDDIT_ENDPOINT}/r/{subreddit}/{listing}.json?limit={limit}&t={timeframe}",
limit=limit,
session=session,
)
return all_posts[:limit]
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
async def paginated_posts(
posts_endpoint: str, limit: int, session: aiohttp.ClientSession
) -> list:
"""
Paginates and retrieves posts until the specified limit is reached.
:param posts_endpoint: API endpoint for retrieving posts.
:param limit: Limit of the number of posts to retrieve.
:param session: aiohttp session to use for the request.
:return: A list of all posts.
"""
all_posts: list = []
last_post_id: str = ""
# Determine whether to use the 'after' parameter
use_after: bool = limit > 100
while len(all_posts) < limit:
# Make the API request with the 'after' parameter if it's provided and the limit is more than 100
if use_after and last_post_id:
endpoint_with_after: str = f"{posts_endpoint}&after={last_post_id}"
else:
endpoint_with_after: str = posts_endpoint
posts_data: dict = await get_data(endpoint=endpoint_with_after, session=session)
posts_children: list = posts_data.get("data", {}).get("children", [])
# If there are no more posts, break out of the loop
if not posts_children:
break
all_posts.extend(posts_children)
# We use the id of the last post in the list to paginate to the next posts
last_post_id: str = all_posts[-1].get("data").get("id")
return all_posts
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #

View File

@@ -1,170 +1,170 @@
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
import argparse
import csv
import json
import logging
import os
from datetime import datetime
from rich.logging import RichHandler
from rich.markdown import Markdown
from rich_argparse import RichHelpFormatter
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
def timestamp_to_utc(timestamp: int) -> str:
"""
Converts a Unix timestamp to a formatted datetime string.
:param timestamp: The Unix timestamp to be converted.
:return: A formatted datetime string in the format "dd MMMM yyyy, hh:mm:ssAM/PM".
"""
utc_from_timestamp: datetime = datetime.utcfromtimestamp(timestamp)
datetime_string: str = utc_from_timestamp.strftime("%d %B %Y, %I:%M:%S%p")
return datetime_string
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
def pathfinder(directories: list[str]):
for directory in directories:
os.makedirs(directory, exist_ok=True)
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
def save_posts(
filename: str,
save_to_dir: str,
posts: list,
save_json: bool = False,
save_csv: bool = False,
):
posts_data: list = [post.__dict__ for post in posts]
if save_json:
json_path = os.path.join(os.path.join(save_to_dir, "json"), f"{filename}.json")
with open(json_path, "w", encoding="utf-8") as json_file:
json.dump(posts_data, json_file, indent=4)
log.info(
f"{os.path.getsize(json_file.name)} bytes written to [link file://{json_file.name}]{json_file.name}"
)
if save_csv:
csv_path = os.path.join(os.path.join(save_to_dir, "csv"), f"{filename}.csv")
with open(csv_path, "w", newline="", encoding="utf-8") as csv_file:
writer = csv.writer(csv_file)
if posts:
writer.writerow(
posts_data[0].keys()
) # header from keys of the first item
for post in posts:
writer.writerow(post.__dict__.values())
log.info(
f"{os.path.getsize(csv_file.name)} bytes written to [link file://{csv_file.name}]{csv_file.name}"
)
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
def create_parser():
"""
Creates and configures an argument parser for the command line arguments.
:return: A configured argparse.ArgumentParser object ready to parse the command line arguments.
"""
from . import __version__, __description__, __epilog__
parser = argparse.ArgumentParser(
description=Markdown(__description__, style="argparse.text"),
epilog=Markdown(__epilog__, style="argparse.text"),
formatter_class=RichHelpFormatter,
)
parser.add_argument(
"keyword",
help="keyword to search for, in posts",
)
parser.add_argument("subreddit", help="subreddit to scrape")
parser.add_argument(
"-l",
"--limit",
help="maximum number of posts to scrape (default: %(default)s)",
default=200,
type=int,
)
parser.add_argument(
"-ls",
"--listing",
default="top",
const="top",
nargs="?",
choices=["best", "controversial", "hot", "new", "rising", "top"],
help="listing of posts to scrape (default: %(default)s)",
)
parser.add_argument(
"-t",
"--timeframe",
default="all",
const="all",
nargs="?",
choices=["hour", "day", "week", "month", "year", "all"],
help="timeframe from which to scrape posts (default: %(default)s)",
)
parser.add_argument(
"-j",
"--json",
help="write found posts to a json file",
action="store_true",
)
parser.add_argument(
"-c",
"--csv",
help="write found posts to a csv file",
action="store_true",
)
parser.add_argument(
"-d",
"--debug",
help="(dev) run rpst in debug mode",
action="store_true",
)
parser.add_argument("-v", "--version", action="version", version=__version__)
return parser
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
def set_loglevel(debug_mode: bool) -> logging.getLogger:
"""
Configure and return a logging object with the specified log level.
:param debug_mode: If True, the log level is set to "NOTSET". Otherwise, it is set to "INFO".
:return: A logging object configured with the specified log level.
"""
logging.basicConfig(
level="DEBUG" if debug_mode else "INFO",
format="%(message)s",
handlers=[
RichHandler(
markup=True, log_time_format="%I:%M:%S%p", show_level=debug_mode
)
],
)
return logging.getLogger("RPST (Reddit Post Scraping Tool)")
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
args: argparse = create_parser().parse_args()
log: logging.getLogger = set_loglevel(debug_mode=args.debug)
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
import argparse
import csv
import json
import logging
import os
from datetime import datetime
from rich.logging import RichHandler
from rich.markdown import Markdown
from rich_argparse import RichHelpFormatter
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
def timestamp_to_utc(timestamp: int) -> str:
"""
Converts a Unix timestamp to a formatted datetime string.
:param timestamp: The Unix timestamp to be converted.
:return: A formatted datetime string in the format "dd MMMM yyyy, hh:mm:ssAM/PM".
"""
utc_from_timestamp: datetime = datetime.utcfromtimestamp(timestamp)
datetime_string: str = utc_from_timestamp.strftime("%d %B %Y, %I:%M:%S%p")
return datetime_string
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
def pathfinder(directories: list[str]):
for directory in directories:
os.makedirs(directory, exist_ok=True)
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
def save_posts(
filename: str,
save_to_dir: str,
posts: list,
save_json: bool = False,
save_csv: bool = False,
):
posts_data: list = [post.__dict__ for post in posts]
if save_json:
json_path = os.path.join(os.path.join(save_to_dir, "json"), f"{filename}.json")
with open(json_path, "w", encoding="utf-8") as json_file:
json.dump(posts_data, json_file, indent=4)
log.info(
f"{os.path.getsize(json_file.name)} bytes written to [link file://{json_file.name}]{json_file.name}"
)
if save_csv:
csv_path = os.path.join(os.path.join(save_to_dir, "csv"), f"{filename}.csv")
with open(csv_path, "w", newline="", encoding="utf-8") as csv_file:
writer = csv.writer(csv_file)
if posts:
writer.writerow(
posts_data[0].keys()
) # header from keys of the first item
for post in posts:
writer.writerow(post.__dict__.values())
log.info(
f"{os.path.getsize(csv_file.name)} bytes written to [link file://{csv_file.name}]{csv_file.name}"
)
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
def create_parser():
"""
Creates and configures an argument parser for the command line arguments.
:return: A configured argparse.ArgumentParser object ready to parse the command line arguments.
"""
from . import __version__, __description__, __epilog__
parser = argparse.ArgumentParser(
description=Markdown(__description__, style="argparse.text"),
epilog=Markdown(__epilog__, style="argparse.text"),
formatter_class=RichHelpFormatter,
)
parser.add_argument(
"keyword",
help="keyword to search for, in posts",
)
parser.add_argument("subreddit", help="subreddit to scrape")
parser.add_argument(
"-l",
"--limit",
help="maximum number of posts to scrape (default: %(default)s)",
default=200,
type=int,
)
parser.add_argument(
"-ls",
"--listing",
default="top",
const="top",
nargs="?",
choices=["best", "controversial", "hot", "new", "rising", "top"],
help="listing of posts to scrape (default: %(default)s)",
)
parser.add_argument(
"-t",
"--timeframe",
default="all",
const="all",
nargs="?",
choices=["hour", "day", "week", "month", "year", "all"],
help="timeframe from which to scrape posts (default: %(default)s)",
)
parser.add_argument(
"-j",
"--json",
help="write found posts to a json file",
action="store_true",
)
parser.add_argument(
"-c",
"--csv",
help="write found posts to a csv file",
action="store_true",
)
parser.add_argument(
"-d",
"--debug",
help="(dev) run rpst in debug mode",
action="store_true",
)
parser.add_argument("-v", "--version", action="version", version=__version__)
return parser
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
def set_loglevel(debug_mode: bool) -> logging.getLogger:
"""
Configure and return a logging object with the specified log level.
:param debug_mode: If True, the log level is set to "NOTSET". Otherwise, it is set to "INFO".
:return: A logging object configured with the specified log level.
"""
logging.basicConfig(
level="DEBUG" if debug_mode else "INFO",
format="%(message)s",
handlers=[
RichHandler(
markup=True, log_time_format="%I:%M:%S%p", show_level=debug_mode
)
],
)
return logging.getLogger("RPST (Reddit Post Scraping Tool)")
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
args: argparse = create_parser().parse_args()
log: logging.getLogger = set_loglevel(debug_mode=args.debug)
# +++++++++++++++++++++++++++++++++++++++++++++++++ #