diff --git a/rpst/__init__.py b/rpst/__init__.py index f636336..0542a85 100644 --- a/rpst/__init__.py +++ b/rpst/__init__.py @@ -1,42 +1,42 @@ -import os - -__author__: str = "Richard Mwewa" -__about_author__: str = "https://about/me/rly0nheart" -__version__: str = "2.0.0.0" - -__description__: str = f""" -# RPST (Reddit Post Scraping Tool) {__version__} -> Retrieve Reddit posts that contain the specified keyword from a specified subreddit. -""" -__epilog__: str = f""" -# by [{__author__}]({__about_author__}) - -``` -MIT License - -Copyright (c) 2023 {__author__} - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. -``` -""" - -# Construct path to the program's directory -PROGRAM_DIRECTORY: str = os.path.expanduser( - os.path.join("~", "reddit_post_scraping_tool") -) +import os + +__author__: str = "Richard Mwewa" +__about_author__: str = "https://about/me/rly0nheart" +__version__: str = "2.0.0.0" + +__description__: str = f""" +# RPST (Reddit Post Scraping Tool) {__version__} +> Retrieve Reddit posts that contain the specified keyword from a specified subreddit. +""" +__epilog__: str = f""" +# by [{__author__}]({__about_author__}) + +``` +MIT License + +Copyright (c) 2023 {__author__} + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +``` +""" + +# Construct path to the program's directory +PROGRAM_DIRECTORY: str = os.path.expanduser( + os.path.join("~", "reddit_post_scraping_tool") +) diff --git a/rpst/api.py b/rpst/api.py index 2f40f85..10133e6 100644 --- a/rpst/api.py +++ b/rpst/api.py @@ -1,164 +1,164 @@ -# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # - -from typing import Union - -import aiohttp - -from .coreutils import log - -# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # - - -REDDIT_ENDPOINT: str = "https://www.reddit.com" -PYPI_PROJECT_ENDPOINT: str = "https://pypi.org/pypi/reddit-post-scraping-tool/json" - - -# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # - - -async def get_data(session: aiohttp.ClientSession, endpoint: str) -> Union[dict, list]: - """ - Fetches JSON data from a given API endpoint. - - :param session: aiohttp session to use for the request. - :param endpoint: The API endpoint to fetch data from. - :return: Returns JSON data as a dictionary or list. Returns an empty dict if fetching fails. - """ - - try: - async with session.get( - endpoint, - ) as response: - if response.status == 200: - return await response.json() - else: - error_message = await response.json() - log.error(f"An API error occurred: {error_message}") - return {} - - except aiohttp.ClientConnectionError as error: - log.error(f"An HTTP error occurred: {error}") - return {} - except Exception as error: - log.critical(f"An unknown error occurred: {error}") - return {} - - -# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # - - -async def get_updates(session: aiohttp.ClientSession): - """ - Gets and compares the current program version with the remote version. - - Assumes version format: major.minor.patch.prefix - - :param session: aiohttp session to use for the request. - """ - from . import __version__ - - # Make a GET request to PyPI to get the project's latest release. - response: dict = await get_data(endpoint=PYPI_PROJECT_ENDPOINT, session=session) - - if response.get("info"): - release: dict = response.get("info") - remote_version: str = release.get("version") - # Splitting the version strings into components - remote_parts: list = remote_version.split(".") - local_parts: list = __version__.split(".") - - update_message: str = "" - - # Check for differences in version parts - if remote_parts[0] != local_parts[0]: - update_message = ( - f"MAJOR update ({remote_version}) available." - f" It might introduce significant changes." - ) - - elif remote_parts[1] != local_parts[1]: - update_message = ( - f"MINOR update ({remote_version}) available." - f" Includes small feature changes/improvements." - ) - - elif remote_parts[2] != local_parts[2]: - update_message = ( - f"PATCH update ({remote_version}) available." - f" Generally for bug fixes and small tweaks." - ) - - elif ( - len(remote_parts) > 3 - and len(local_parts) > 3 - and remote_parts[3] != local_parts[3] - ): - update_message = ( - f"BUILD update ({remote_version}) available." - f" Might be for specific builds or special versions." - ) - - if update_message: - log.info(update_message) - - -# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # -async def get_posts( - subreddit: str, - listing: str, - timeframe: str, - limit: int, - session: aiohttp.ClientSession, -) -> list: - all_posts = await paginated_posts( - posts_endpoint=f"{REDDIT_ENDPOINT}/r/{subreddit}/{listing}.json?limit={limit}&t={timeframe}", - limit=limit, - session=session, - ) - - return all_posts[:limit] - - -# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # - - -async def paginated_posts( - posts_endpoint: str, limit: int, session: aiohttp.ClientSession -) -> list: - """ - Paginates and retrieves posts until the specified limit is reached. - - :param posts_endpoint: API endpoint for retrieving posts. - :param limit: Limit of the number of posts to retrieve. - :param session: aiohttp session to use for the request. - :return: A list of all posts. - """ - all_posts: list = [] - last_post_id: str = "" - - # Determine whether to use the 'after' parameter - use_after: bool = limit > 100 - - while len(all_posts) < limit: - # Make the API request with the 'after' parameter if it's provided and the limit is more than 100 - if use_after and last_post_id: - endpoint_with_after: str = f"{posts_endpoint}&after={last_post_id}" - else: - endpoint_with_after: str = posts_endpoint - - posts_data: dict = await get_data(endpoint=endpoint_with_after, session=session) - posts_children: list = posts_data.get("data", {}).get("children", []) - - # If there are no more posts, break out of the loop - if not posts_children: - break - - all_posts.extend(posts_children) - - # We use the id of the last post in the list to paginate to the next posts - last_post_id: str = all_posts[-1].get("data").get("id") - - return all_posts - - -# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # +# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # + +from typing import Union + +import aiohttp + +from .coreutils import log + +# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # + + +REDDIT_ENDPOINT: str = "https://www.reddit.com" +PYPI_PROJECT_ENDPOINT: str = "https://pypi.org/pypi/reddit-post-scraping-tool/json" + + +# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # + + +async def get_data(session: aiohttp.ClientSession, endpoint: str) -> Union[dict, list]: + """ + Fetches JSON data from a given API endpoint. + + :param session: aiohttp session to use for the request. + :param endpoint: The API endpoint to fetch data from. + :return: Returns JSON data as a dictionary or list. Returns an empty dict if fetching fails. + """ + + try: + async with session.get( + endpoint, + ) as response: + if response.status == 200: + return await response.json() + else: + error_message = await response.json() + log.error(f"An API error occurred: {error_message}") + return {} + + except aiohttp.ClientConnectionError as error: + log.error(f"An HTTP error occurred: {error}") + return {} + except Exception as error: + log.critical(f"An unknown error occurred: {error}") + return {} + + +# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # + + +async def get_updates(session: aiohttp.ClientSession): + """ + Gets and compares the current program version with the remote version. + + Assumes version format: major.minor.patch.prefix + + :param session: aiohttp session to use for the request. + """ + from . import __version__ + + # Make a GET request to PyPI to get the project's latest release. + response: dict = await get_data(endpoint=PYPI_PROJECT_ENDPOINT, session=session) + + if response.get("info"): + release: dict = response.get("info") + remote_version: str = release.get("version") + # Splitting the version strings into components + remote_parts: list = remote_version.split(".") + local_parts: list = __version__.split(".") + + update_message: str = "" + + # Check for differences in version parts + if remote_parts[0] != local_parts[0]: + update_message = ( + f"MAJOR update ({remote_version}) available." + f" It might introduce significant changes." + ) + + elif remote_parts[1] != local_parts[1]: + update_message = ( + f"MINOR update ({remote_version}) available." + f" Includes small feature changes/improvements." + ) + + elif remote_parts[2] != local_parts[2]: + update_message = ( + f"PATCH update ({remote_version}) available." + f" Generally for bug fixes and small tweaks." + ) + + elif ( + len(remote_parts) > 3 + and len(local_parts) > 3 + and remote_parts[3] != local_parts[3] + ): + update_message = ( + f"BUILD update ({remote_version}) available." + f" Might be for specific builds or special versions." + ) + + if update_message: + log.info(update_message) + + +# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # +async def get_posts( + subreddit: str, + listing: str, + timeframe: str, + limit: int, + session: aiohttp.ClientSession, +) -> list: + all_posts = await paginated_posts( + posts_endpoint=f"{REDDIT_ENDPOINT}/r/{subreddit}/{listing}.json?limit={limit}&t={timeframe}", + limit=limit, + session=session, + ) + + return all_posts[:limit] + + +# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # + + +async def paginated_posts( + posts_endpoint: str, limit: int, session: aiohttp.ClientSession +) -> list: + """ + Paginates and retrieves posts until the specified limit is reached. + + :param posts_endpoint: API endpoint for retrieving posts. + :param limit: Limit of the number of posts to retrieve. + :param session: aiohttp session to use for the request. + :return: A list of all posts. + """ + all_posts: list = [] + last_post_id: str = "" + + # Determine whether to use the 'after' parameter + use_after: bool = limit > 100 + + while len(all_posts) < limit: + # Make the API request with the 'after' parameter if it's provided and the limit is more than 100 + if use_after and last_post_id: + endpoint_with_after: str = f"{posts_endpoint}&after={last_post_id}" + else: + endpoint_with_after: str = posts_endpoint + + posts_data: dict = await get_data(endpoint=endpoint_with_after, session=session) + posts_children: list = posts_data.get("data", {}).get("children", []) + + # If there are no more posts, break out of the loop + if not posts_children: + break + + all_posts.extend(posts_children) + + # We use the id of the last post in the list to paginate to the next posts + last_post_id: str = all_posts[-1].get("data").get("id") + + return all_posts + + +# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # diff --git a/rpst/coreutils.py b/rpst/coreutils.py index b2c8e12..e9146b1 100644 --- a/rpst/coreutils.py +++ b/rpst/coreutils.py @@ -1,170 +1,170 @@ -# +++++++++++++++++++++++++++++++++++++++++++++++++ # - -import argparse -import csv -import json -import logging -import os -from datetime import datetime - -from rich.logging import RichHandler -from rich.markdown import Markdown -from rich_argparse import RichHelpFormatter - - -# +++++++++++++++++++++++++++++++++++++++++++++++++ # - - -def timestamp_to_utc(timestamp: int) -> str: - """ - Converts a Unix timestamp to a formatted datetime string. - - :param timestamp: The Unix timestamp to be converted. - :return: A formatted datetime string in the format "dd MMMM yyyy, hh:mm:ssAM/PM". - """ - utc_from_timestamp: datetime = datetime.utcfromtimestamp(timestamp) - datetime_string: str = utc_from_timestamp.strftime("%d %B %Y, %I:%M:%S%p") - return datetime_string - - -# +++++++++++++++++++++++++++++++++++++++++++++++++ # - - -def pathfinder(directories: list[str]): - for directory in directories: - os.makedirs(directory, exist_ok=True) - - -# +++++++++++++++++++++++++++++++++++++++++++++++++ # - - -def save_posts( - filename: str, - save_to_dir: str, - posts: list, - save_json: bool = False, - save_csv: bool = False, -): - posts_data: list = [post.__dict__ for post in posts] - - if save_json: - json_path = os.path.join(os.path.join(save_to_dir, "json"), f"{filename}.json") - with open(json_path, "w", encoding="utf-8") as json_file: - json.dump(posts_data, json_file, indent=4) - log.info( - f"{os.path.getsize(json_file.name)} bytes written to [link file://{json_file.name}]{json_file.name}" - ) - - if save_csv: - csv_path = os.path.join(os.path.join(save_to_dir, "csv"), f"{filename}.csv") - with open(csv_path, "w", newline="", encoding="utf-8") as csv_file: - writer = csv.writer(csv_file) - if posts: - writer.writerow( - posts_data[0].keys() - ) # header from keys of the first item - for post in posts: - writer.writerow(post.__dict__.values()) - log.info( - f"{os.path.getsize(csv_file.name)} bytes written to [link file://{csv_file.name}]{csv_file.name}" - ) - - -# +++++++++++++++++++++++++++++++++++++++++++++++++ # - - -def create_parser(): - """ - Creates and configures an argument parser for the command line arguments. - - :return: A configured argparse.ArgumentParser object ready to parse the command line arguments. - """ - from . import __version__, __description__, __epilog__ - - parser = argparse.ArgumentParser( - description=Markdown(__description__, style="argparse.text"), - epilog=Markdown(__epilog__, style="argparse.text"), - formatter_class=RichHelpFormatter, - ) - - parser.add_argument( - "keyword", - help="keyword to search for, in posts", - ) - parser.add_argument("subreddit", help="subreddit to scrape") - parser.add_argument( - "-l", - "--limit", - help="maximum number of posts to scrape (default: %(default)s)", - default=200, - type=int, - ) - parser.add_argument( - "-ls", - "--listing", - default="top", - const="top", - nargs="?", - choices=["best", "controversial", "hot", "new", "rising", "top"], - help="listing of posts to scrape (default: %(default)s)", - ) - parser.add_argument( - "-t", - "--timeframe", - default="all", - const="all", - nargs="?", - choices=["hour", "day", "week", "month", "year", "all"], - help="timeframe from which to scrape posts (default: %(default)s)", - ) - parser.add_argument( - "-j", - "--json", - help="write found posts to a json file", - action="store_true", - ) - parser.add_argument( - "-c", - "--csv", - help="write found posts to a csv file", - action="store_true", - ) - parser.add_argument( - "-d", - "--debug", - help="(dev) run rpst in debug mode", - action="store_true", - ) - parser.add_argument("-v", "--version", action="version", version=__version__) - - return parser - - -# +++++++++++++++++++++++++++++++++++++++++++++++++ # - - -def set_loglevel(debug_mode: bool) -> logging.getLogger: - """ - Configure and return a logging object with the specified log level. - - :param debug_mode: If True, the log level is set to "NOTSET". Otherwise, it is set to "INFO". - :return: A logging object configured with the specified log level. - """ - logging.basicConfig( - level="DEBUG" if debug_mode else "INFO", - format="%(message)s", - handlers=[ - RichHandler( - markup=True, log_time_format="%I:%M:%S%p", show_level=debug_mode - ) - ], - ) - return logging.getLogger("RPST (Reddit Post Scraping Tool)") - - -# +++++++++++++++++++++++++++++++++++++++++++++++++ # - -args: argparse = create_parser().parse_args() -log: logging.getLogger = set_loglevel(debug_mode=args.debug) - -# +++++++++++++++++++++++++++++++++++++++++++++++++ # +# +++++++++++++++++++++++++++++++++++++++++++++++++ # + +import argparse +import csv +import json +import logging +import os +from datetime import datetime + +from rich.logging import RichHandler +from rich.markdown import Markdown +from rich_argparse import RichHelpFormatter + + +# +++++++++++++++++++++++++++++++++++++++++++++++++ # + + +def timestamp_to_utc(timestamp: int) -> str: + """ + Converts a Unix timestamp to a formatted datetime string. + + :param timestamp: The Unix timestamp to be converted. + :return: A formatted datetime string in the format "dd MMMM yyyy, hh:mm:ssAM/PM". + """ + utc_from_timestamp: datetime = datetime.utcfromtimestamp(timestamp) + datetime_string: str = utc_from_timestamp.strftime("%d %B %Y, %I:%M:%S%p") + return datetime_string + + +# +++++++++++++++++++++++++++++++++++++++++++++++++ # + + +def pathfinder(directories: list[str]): + for directory in directories: + os.makedirs(directory, exist_ok=True) + + +# +++++++++++++++++++++++++++++++++++++++++++++++++ # + + +def save_posts( + filename: str, + save_to_dir: str, + posts: list, + save_json: bool = False, + save_csv: bool = False, +): + posts_data: list = [post.__dict__ for post in posts] + + if save_json: + json_path = os.path.join(os.path.join(save_to_dir, "json"), f"{filename}.json") + with open(json_path, "w", encoding="utf-8") as json_file: + json.dump(posts_data, json_file, indent=4) + log.info( + f"{os.path.getsize(json_file.name)} bytes written to [link file://{json_file.name}]{json_file.name}" + ) + + if save_csv: + csv_path = os.path.join(os.path.join(save_to_dir, "csv"), f"{filename}.csv") + with open(csv_path, "w", newline="", encoding="utf-8") as csv_file: + writer = csv.writer(csv_file) + if posts: + writer.writerow( + posts_data[0].keys() + ) # header from keys of the first item + for post in posts: + writer.writerow(post.__dict__.values()) + log.info( + f"{os.path.getsize(csv_file.name)} bytes written to [link file://{csv_file.name}]{csv_file.name}" + ) + + +# +++++++++++++++++++++++++++++++++++++++++++++++++ # + + +def create_parser(): + """ + Creates and configures an argument parser for the command line arguments. + + :return: A configured argparse.ArgumentParser object ready to parse the command line arguments. + """ + from . import __version__, __description__, __epilog__ + + parser = argparse.ArgumentParser( + description=Markdown(__description__, style="argparse.text"), + epilog=Markdown(__epilog__, style="argparse.text"), + formatter_class=RichHelpFormatter, + ) + + parser.add_argument( + "keyword", + help="keyword to search for, in posts", + ) + parser.add_argument("subreddit", help="subreddit to scrape") + parser.add_argument( + "-l", + "--limit", + help="maximum number of posts to scrape (default: %(default)s)", + default=200, + type=int, + ) + parser.add_argument( + "-ls", + "--listing", + default="top", + const="top", + nargs="?", + choices=["best", "controversial", "hot", "new", "rising", "top"], + help="listing of posts to scrape (default: %(default)s)", + ) + parser.add_argument( + "-t", + "--timeframe", + default="all", + const="all", + nargs="?", + choices=["hour", "day", "week", "month", "year", "all"], + help="timeframe from which to scrape posts (default: %(default)s)", + ) + parser.add_argument( + "-j", + "--json", + help="write found posts to a json file", + action="store_true", + ) + parser.add_argument( + "-c", + "--csv", + help="write found posts to a csv file", + action="store_true", + ) + parser.add_argument( + "-d", + "--debug", + help="(dev) run rpst in debug mode", + action="store_true", + ) + parser.add_argument("-v", "--version", action="version", version=__version__) + + return parser + + +# +++++++++++++++++++++++++++++++++++++++++++++++++ # + + +def set_loglevel(debug_mode: bool) -> logging.getLogger: + """ + Configure and return a logging object with the specified log level. + + :param debug_mode: If True, the log level is set to "NOTSET". Otherwise, it is set to "INFO". + :return: A logging object configured with the specified log level. + """ + logging.basicConfig( + level="DEBUG" if debug_mode else "INFO", + format="%(message)s", + handlers=[ + RichHandler( + markup=True, log_time_format="%I:%M:%S%p", show_level=debug_mode + ) + ], + ) + return logging.getLogger("RPST (Reddit Post Scraping Tool)") + + +# +++++++++++++++++++++++++++++++++++++++++++++++++ # + +args: argparse = create_parser().parse_args() +log: logging.getLogger = set_loglevel(debug_mode=args.debug) + +# +++++++++++++++++++++++++++++++++++++++++++++++++ #