diff --git a/pyproject.toml b/pyproject.toml index b881263..fe4579a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,27 +7,27 @@ packages = ["rpst"] [project] name = "reddit-post-scraping-tool" -version = "1.9.1.1" +version = "2.0.0.0" description = "Retrieve Reddit posts that contain the specified keyword from a specified subreddit." readme = "README.md" requires-python = ">=3.8" -license = {file = "LICENSE"} +license = { file = "LICENSE" } keywords = ["reddit-crawler", "reddit-scraping", "reddit", "reddit-api"] -authors = [{name = "Richard Mwewa", email = "rly0nheart@duck.com"}] +authors = [{ name = "Richard Mwewa", email = "rly0nheart@duck.com" }] classifiers = [ - "Development Status :: 5 - Production/Stable", - "Programming Language :: Python :: 3", - "Programming Language :: Visual Basic", - "Intended Audience :: End Users/Desktop", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", - "Natural Language :: English" + "Development Status :: 5 - Production/Stable", + "Programming Language :: Python :: 3", + "Programming Language :: Visual Basic", + "Intended Audience :: End Users/Desktop", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Natural Language :: English" ] dependencies = [ "rich", - "glyphoji", - "requests", + "aiohttp", + "rich-argparse" ] [project.urls] @@ -36,4 +36,5 @@ documentation = "https://github.com/bellingcat/reddit-post-scraping-tool/wiki" repository = "https://github.com/bellingcat/reddit-post-scraping-tool.git" [project.scripts] -rpst = "rpst.main:run" +rpst = "rpst.scraper:run" +reddit_post_scraping_tool = "rpst.scraper:run" diff --git a/rpst/__init__.py b/rpst/__init__.py index 8b13789..f636336 100644 --- a/rpst/__init__.py +++ b/rpst/__init__.py @@ -1 +1,42 @@ +import os +__author__: str = "Richard Mwewa" +__about_author__: str = "https://about/me/rly0nheart" +__version__: str = "2.0.0.0" + +__description__: str = f""" +# RPST (Reddit Post Scraping Tool) {__version__} +> Retrieve Reddit posts that contain the specified keyword from a specified subreddit. +""" +__epilog__: str = f""" +# by [{__author__}]({__about_author__}) + +``` +MIT License + +Copyright (c) 2023 {__author__} + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +``` +""" + +# Construct path to the program's directory +PROGRAM_DIRECTORY: str = os.path.expanduser( + os.path.join("~", "reddit_post_scraping_tool") +) diff --git a/rpst/api.py b/rpst/api.py new file mode 100644 index 0000000..2f40f85 --- /dev/null +++ b/rpst/api.py @@ -0,0 +1,164 @@ +# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # + +from typing import Union + +import aiohttp + +from .coreutils import log + +# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # + + +REDDIT_ENDPOINT: str = "https://www.reddit.com" +PYPI_PROJECT_ENDPOINT: str = "https://pypi.org/pypi/reddit-post-scraping-tool/json" + + +# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # + + +async def get_data(session: aiohttp.ClientSession, endpoint: str) -> Union[dict, list]: + """ + Fetches JSON data from a given API endpoint. + + :param session: aiohttp session to use for the request. + :param endpoint: The API endpoint to fetch data from. + :return: Returns JSON data as a dictionary or list. Returns an empty dict if fetching fails. + """ + + try: + async with session.get( + endpoint, + ) as response: + if response.status == 200: + return await response.json() + else: + error_message = await response.json() + log.error(f"An API error occurred: {error_message}") + return {} + + except aiohttp.ClientConnectionError as error: + log.error(f"An HTTP error occurred: {error}") + return {} + except Exception as error: + log.critical(f"An unknown error occurred: {error}") + return {} + + +# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # + + +async def get_updates(session: aiohttp.ClientSession): + """ + Gets and compares the current program version with the remote version. + + Assumes version format: major.minor.patch.prefix + + :param session: aiohttp session to use for the request. + """ + from . import __version__ + + # Make a GET request to PyPI to get the project's latest release. + response: dict = await get_data(endpoint=PYPI_PROJECT_ENDPOINT, session=session) + + if response.get("info"): + release: dict = response.get("info") + remote_version: str = release.get("version") + # Splitting the version strings into components + remote_parts: list = remote_version.split(".") + local_parts: list = __version__.split(".") + + update_message: str = "" + + # Check for differences in version parts + if remote_parts[0] != local_parts[0]: + update_message = ( + f"MAJOR update ({remote_version}) available." + f" It might introduce significant changes." + ) + + elif remote_parts[1] != local_parts[1]: + update_message = ( + f"MINOR update ({remote_version}) available." + f" Includes small feature changes/improvements." + ) + + elif remote_parts[2] != local_parts[2]: + update_message = ( + f"PATCH update ({remote_version}) available." + f" Generally for bug fixes and small tweaks." + ) + + elif ( + len(remote_parts) > 3 + and len(local_parts) > 3 + and remote_parts[3] != local_parts[3] + ): + update_message = ( + f"BUILD update ({remote_version}) available." + f" Might be for specific builds or special versions." + ) + + if update_message: + log.info(update_message) + + +# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # +async def get_posts( + subreddit: str, + listing: str, + timeframe: str, + limit: int, + session: aiohttp.ClientSession, +) -> list: + all_posts = await paginated_posts( + posts_endpoint=f"{REDDIT_ENDPOINT}/r/{subreddit}/{listing}.json?limit={limit}&t={timeframe}", + limit=limit, + session=session, + ) + + return all_posts[:limit] + + +# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # + + +async def paginated_posts( + posts_endpoint: str, limit: int, session: aiohttp.ClientSession +) -> list: + """ + Paginates and retrieves posts until the specified limit is reached. + + :param posts_endpoint: API endpoint for retrieving posts. + :param limit: Limit of the number of posts to retrieve. + :param session: aiohttp session to use for the request. + :return: A list of all posts. + """ + all_posts: list = [] + last_post_id: str = "" + + # Determine whether to use the 'after' parameter + use_after: bool = limit > 100 + + while len(all_posts) < limit: + # Make the API request with the 'after' parameter if it's provided and the limit is more than 100 + if use_after and last_post_id: + endpoint_with_after: str = f"{posts_endpoint}&after={last_post_id}" + else: + endpoint_with_after: str = posts_endpoint + + posts_data: dict = await get_data(endpoint=endpoint_with_after, session=session) + posts_children: list = posts_data.get("data", {}).get("children", []) + + # If there are no more posts, break out of the loop + if not posts_children: + break + + all_posts.extend(posts_children) + + # We use the id of the last post in the list to paginate to the next posts + last_post_id: str = all_posts[-1].get("data").get("id") + + return all_posts + + +# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # diff --git a/rpst/base.py b/rpst/base.py new file mode 100644 index 0000000..012ab1b --- /dev/null +++ b/rpst/base.py @@ -0,0 +1,103 @@ +# +++++++++++++++++++++++++++++++++++++++++++++++++ # + +from dataclasses import dataclass +from typing import List + +import aiohttp + +from .api import get_posts, get_updates +from .coreutils import timestamp_to_utc + + +# +++++++++++++++++++++++++++++++++++++++++++++++++ # + + +@dataclass +class Post: + id: str + thumbnail: str + title: str + text: str + author: str + subreddit: str + subreddit_id: str + subreddit_type: str + upvotes: int + upvote_ratio: float + downvotes: int + gilded: int + is_nsfw: bool + is_shareable: bool + is_edited: bool + comments: int + hide_from_bots: bool + score: float + domain: str + permalink: str + is_locked: bool + is_archived: bool + created_at: str + raw_post: dict + + +# +++++++++++++++++++++++++++++++++++++++++++++++++ # + + +async def find_posts( + keyword: str, + subreddit: str, + listing: str, + timeframe: str, + limit: int, +) -> List[Post]: + async with aiohttp.ClientSession() as session: + found_posts_count: int = 0 + found_posts_list: list = [] + + await get_updates(session=session) + raw_posts: list = await get_posts( + subreddit=subreddit, + listing=listing, + timeframe=timeframe, + limit=limit, + session=session, + ) + for raw_post in raw_posts: + post_data: dict = raw_post.get("data") + + if keyword.lower() in post_data.get( + "selftext" + ) or keyword.lower() in post_data.get("title"): + found_posts_count += 1 + post = Post( + id=post_data.get("id"), + thumbnail=post_data.get("thumbnail"), + title=post_data.get("title"), + text=post_data.get("selftext"), + author=post_data.get("author"), + subreddit=post_data.get("subreddit"), + subreddit_id=post_data.get("subreddit_id"), + subreddit_type=post_data.get("subreddit_type"), + upvotes=post_data.get("ups"), + upvote_ratio=post_data.get("upvote_ratio"), + downvotes=post_data.get("downs"), + gilded=post_data.get("gilded"), + is_nsfw=post_data.get("over_18"), + is_shareable=post_data.get("is_reddit_media_domain"), + is_edited=post_data.get("edited"), + comments=post_data.get("num_comments"), + hide_from_bots=post_data.get("is_robot_indexable"), + score=post_data.get("score"), + domain=post_data.get("domain"), + permalink=post_data.get("permalink"), + is_locked=post_data.get("locked"), + is_archived=post_data.get("archived"), + created_at=timestamp_to_utc(timestamp=post_data.get("created_utc")), + raw_post=post_data, + ) + found_posts_list.append(post) + + return found_posts_list + + +# +++++++++++++++++++++++++++++++++++++++++++++++++ # diff --git a/rpst/coreutils.py b/rpst/coreutils.py new file mode 100644 index 0000000..b2c8e12 --- /dev/null +++ b/rpst/coreutils.py @@ -0,0 +1,170 @@ +# +++++++++++++++++++++++++++++++++++++++++++++++++ # + +import argparse +import csv +import json +import logging +import os +from datetime import datetime + +from rich.logging import RichHandler +from rich.markdown import Markdown +from rich_argparse import RichHelpFormatter + + +# +++++++++++++++++++++++++++++++++++++++++++++++++ # + + +def timestamp_to_utc(timestamp: int) -> str: + """ + Converts a Unix timestamp to a formatted datetime string. + + :param timestamp: The Unix timestamp to be converted. + :return: A formatted datetime string in the format "dd MMMM yyyy, hh:mm:ssAM/PM". + """ + utc_from_timestamp: datetime = datetime.utcfromtimestamp(timestamp) + datetime_string: str = utc_from_timestamp.strftime("%d %B %Y, %I:%M:%S%p") + return datetime_string + + +# +++++++++++++++++++++++++++++++++++++++++++++++++ # + + +def pathfinder(directories: list[str]): + for directory in directories: + os.makedirs(directory, exist_ok=True) + + +# +++++++++++++++++++++++++++++++++++++++++++++++++ # + + +def save_posts( + filename: str, + save_to_dir: str, + posts: list, + save_json: bool = False, + save_csv: bool = False, +): + posts_data: list = [post.__dict__ for post in posts] + + if save_json: + json_path = os.path.join(os.path.join(save_to_dir, "json"), f"{filename}.json") + with open(json_path, "w", encoding="utf-8") as json_file: + json.dump(posts_data, json_file, indent=4) + log.info( + f"{os.path.getsize(json_file.name)} bytes written to [link file://{json_file.name}]{json_file.name}" + ) + + if save_csv: + csv_path = os.path.join(os.path.join(save_to_dir, "csv"), f"{filename}.csv") + with open(csv_path, "w", newline="", encoding="utf-8") as csv_file: + writer = csv.writer(csv_file) + if posts: + writer.writerow( + posts_data[0].keys() + ) # header from keys of the first item + for post in posts: + writer.writerow(post.__dict__.values()) + log.info( + f"{os.path.getsize(csv_file.name)} bytes written to [link file://{csv_file.name}]{csv_file.name}" + ) + + +# +++++++++++++++++++++++++++++++++++++++++++++++++ # + + +def create_parser(): + """ + Creates and configures an argument parser for the command line arguments. + + :return: A configured argparse.ArgumentParser object ready to parse the command line arguments. + """ + from . import __version__, __description__, __epilog__ + + parser = argparse.ArgumentParser( + description=Markdown(__description__, style="argparse.text"), + epilog=Markdown(__epilog__, style="argparse.text"), + formatter_class=RichHelpFormatter, + ) + + parser.add_argument( + "keyword", + help="keyword to search for, in posts", + ) + parser.add_argument("subreddit", help="subreddit to scrape") + parser.add_argument( + "-l", + "--limit", + help="maximum number of posts to scrape (default: %(default)s)", + default=200, + type=int, + ) + parser.add_argument( + "-ls", + "--listing", + default="top", + const="top", + nargs="?", + choices=["best", "controversial", "hot", "new", "rising", "top"], + help="listing of posts to scrape (default: %(default)s)", + ) + parser.add_argument( + "-t", + "--timeframe", + default="all", + const="all", + nargs="?", + choices=["hour", "day", "week", "month", "year", "all"], + help="timeframe from which to scrape posts (default: %(default)s)", + ) + parser.add_argument( + "-j", + "--json", + help="write found posts to a json file", + action="store_true", + ) + parser.add_argument( + "-c", + "--csv", + help="write found posts to a csv file", + action="store_true", + ) + parser.add_argument( + "-d", + "--debug", + help="(dev) run rpst in debug mode", + action="store_true", + ) + parser.add_argument("-v", "--version", action="version", version=__version__) + + return parser + + +# +++++++++++++++++++++++++++++++++++++++++++++++++ # + + +def set_loglevel(debug_mode: bool) -> logging.getLogger: + """ + Configure and return a logging object with the specified log level. + + :param debug_mode: If True, the log level is set to "NOTSET". Otherwise, it is set to "INFO". + :return: A logging object configured with the specified log level. + """ + logging.basicConfig( + level="DEBUG" if debug_mode else "INFO", + format="%(message)s", + handlers=[ + RichHandler( + markup=True, log_time_format="%I:%M:%S%p", show_level=debug_mode + ) + ], + ) + return logging.getLogger("RPST (Reddit Post Scraping Tool)") + + +# +++++++++++++++++++++++++++++++++++++++++++++++++ # + +args: argparse = create_parser().parse_args() +log: logging.getLogger = set_loglevel(debug_mode=args.debug) + +# +++++++++++++++++++++++++++++++++++++++++++++++++ # diff --git a/rpst/main.py b/rpst/main.py deleted file mode 100644 index 13771d6..0000000 --- a/rpst/main.py +++ /dev/null @@ -1,33 +0,0 @@ -from datetime import datetime - -from .rpst import get_posts -from .utils import create_parser, set_loglevel, check_updates - - -def run(): - """ - Main entry point for the program. It creates a parser, parses the command line arguments, - checks for updates, gets posts, and handles any exceptions that occur during the execution. - """ - - # Create a parser and parse the command line arguments - parser = create_parser() - args = parser.parse_args() - - log = set_loglevel(debug_mode=args.debug) - - # Record the start time - start_time = datetime.now() - - try: - # Check for updates - check_updates(version_tag="1.9.1.1") - - # Get posts with the provided/parsed arguments - get_posts(args=args) - except KeyboardInterrupt: - log.warning("User interruption detected ([yellow]Ctrl+C[/]).") - except Exception as e: - log.error(f"An error occurred: [red]{e}[/]") - finally: - log.info(f"Finished in {datetime.now() - start_time} seconds.") diff --git a/rpst/rpst.py b/rpst/rpst.py deleted file mode 100644 index d30f8b9..0000000 --- a/rpst/rpst.py +++ /dev/null @@ -1,131 +0,0 @@ -import argparse -from datetime import datetime - -import requests -from glyphoji import glyph -from rich import print -from rich.tree import Tree - -from .utils import convert_timestamp_to_datetime, write_post_data - - -def create_post_branch(post: dict, keyword: str, tree: Tree, args: argparse) -> Tree: - """ - This function extracts relevant data from a Reddit post and adds it in a tree branch structure, - followed by the post's selftext. - - :param post: A dictionary containing the data of a Reddit post. - :param keyword: The keyword that is used to find posts, in his case gets uses as the filename. - :param tree: Tree where the post branch will be added. - :param args: A namespace object from argparse. - :returns: The main tree with added post branches. - """ - # Define the data to extract from the post. - post_data = { - # "Author": post["data"]["author"], - f"{glyph.id_button} ID": post["data"]["id"], - f"{glyph.people_hugging} Subreddit": post["data"]["subreddit_name_prefixed"], - f"{glyph.face_with_peeking_eye} Visibility": post["data"]["subreddit_type"], - f"{glyph.framed_picture} Thumbnail": post["data"]["thumbnail"], - f"{glyph.white_question_mark} Gilded": post["data"]["gilded"], - f"{glyph.up_arrow} Upvotes": post["data"]["ups"], - f"{glyph.chart_increasing} Upvote ratio": post["data"]["upvote_ratio"], - f"{glyph.down_arrow} Downvotes": post["data"]["downs"], - f"{glyph.trophy} Awards": post["data"]["total_awards_received"], - f"{glyph.trophy} Top award": post["data"]["top_awarded_type"], - f"{glyph.no_one_under_eighteen} Is NSFW?": post["data"]["over_18"], - f"{glyph.left_arrow_curving_right} Is crosspostable?": post["data"][ - "is_crosspostable" - ], - f"{glyph.bar_chart} Score": post["data"]["score"], - f"{glyph.card_file_box} Category": post["data"]["category"], - f"{glyph.globe_with_meridians} Domain": post["data"]["domain"], - f"{glyph.calendar} Posted on": convert_timestamp_to_datetime( - post["data"]["created"] - ), - f"{glyph.calendar} Approved at": post["data"]["approved_at_utc"], - f"{glyph.bust_in_silhouette} Approved by": post["data"]["approved_by"], - } - - # Add the post's branch to the main tree. - post_branch = tree.add(f"{glyph.page_with_curl} {post['data']['title']}") - - # Add each piece of extracted data as a branch of the post_branch. - for post_key, post_value in post_data.items(): - post_branch.add(f"{post_key}: {post_value}", style="dim") - - # This ensures that the post's selftext is also added to the written json/csv file. - post_data[f"{glyph.clipboard} Text"] = post["data"]["selftext"] - write_post_data( - filename=keyword, post_data=post_data, tree_branch=post_branch, args=args - ) - post_branch.add(post["data"]["selftext"], style="italic") - - return tree - - -def get_posts(args: argparse): - """ - Scrapes a given subreddit for posts that contain a specified keyword. - The search is limited by the number of posts and timeframe specified. - - :param args: Namespace object from argparse. - - Expected Object Attributes - -------------------------- - - keyword: The keyword to search for in the posts. - - subreddit: The subreddit to scrape. - - listing: The type of posts to scrape. This could be 'hot', 'new', etc. - - timeframe: The timeframe from which to scrape posts. This could be 'day', 'week', etc. - - limit: The maximum number of posts to scrape. - - json: If specified, all found posts will be written to a json file. - """ - keyword = args.keyword - subreddit = args.subreddit - listing = args.listing - timeframe = args.timeframe - limit = args.limit - - # Create main result tree. - main_tree = Tree( - f"[bold]{glyph.calendar} {datetime.now()}[/]", guide_style="bold bright_blue" - ) - - # Start a new session - session = requests.session() - # Set the User-Agent to mimic a Safari browser on a Mac. - session.headers = { - "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, " - "like Gecko) Version/14.1.1 Safari/605.1.15" - } - - # Send a GET request to the specified subreddit and listing, - # limiting the response by the specified limit and timeframe. - response = session.get( - f"https://reddit.com/r/{subreddit}/{listing}" - f".json?limit={limit}&t={timeframe}" - ).json() - - # Initialize a counter for the number of posts found that contain the keyword. - found_posts = 0 - - # Loop through each post in the response - for post_index, post in enumerate(response["data"]["children"], start=1): - # If the keyword is found in the post's selftext or title, increment the counter and process the post. - if ( - keyword.lower() in post["data"]["selftext"] - or keyword.lower() in post["data"]["title"] - ): - # Create a branch for found post(s) and show post index and post author as the title - found_tree = main_tree.add( - f"{glyph.bust_in_silhouette} #{post_index} by [bold]@{post['data']['author']}[/]" - ) - found_posts += 1 - create_post_branch(post=post, keyword=keyword, tree=found_tree, args=args) - - # Log the number of posts in which the keyword was found - main_tree.add( - f"{glyph.check_mark_button} Keyword ('{keyword}') was found in " - f"{found_posts}/{len(response['data']['children'])} {listing} posts from r/{subreddit}." - ) - print(main_tree) diff --git a/rpst/scraper.py b/rpst/scraper.py new file mode 100644 index 0000000..a9abd16 --- /dev/null +++ b/rpst/scraper.py @@ -0,0 +1,94 @@ +# +++++++++++++++++++++++++++++++++++++++++++++++++ # + +import asyncio +import os +from datetime import datetime + +from rich.pretty import pprint + +from . import __version__, PROGRAM_DIRECTORY +from .base import find_posts +from .coreutils import args, log, save_posts, pathfinder + + +# +++++++++++++++++++++++++++++++++++++++++++++++++ # + + +def run(): + """Main entry point for rpst or rpst.""" + # ------------------------------------- # + + keyword: str = args.keyword + subreddit: str = args.subreddit + listing: str = args.listing + limit: int = args.limit + + # ------------------------------------- # + + start_time = datetime.now() + + # ------------------------------------- # + + print( + """ +┳┓┏┓┏┓┏┳┓ +┣┫┃┃┗┓ ┃ +┛┗┣┛┗┛ ┻ """ + ) + + # ------------------------------------- # + + try: + log.info( + f"[bold]RPST[/] {__version__} started at {start_time.strftime('%a %b %d %Y, %I:%M:%S%p')}..." + ) + + found_posts = asyncio.run( + find_posts( + keyword=keyword, + subreddit=subreddit, + listing=listing, + timeframe=args.timeframe, + limit=limit, + ), + ) + + if found_posts: + pprint( + found_posts, + expand_all=True, + ) + log.info( + f"'{subreddit}': Found {len(found_posts)}/{limit} {listing} posts containing the keyword ('{keyword}')" + ) + if args.json or args.csv: + target_dir: str = os.path.join(PROGRAM_DIRECTORY, subreddit) + pathfinder( + directories=[ + os.path.join(target_dir, "csv"), + os.path.join(target_dir, "json"), + ] + ) + save_posts( + filename=keyword, + save_to_dir=target_dir, + posts=found_posts, + save_json=args.json, + save_csv=args.csv, + ) + else: + log.info( + f"'r/{subreddit}': No {listing} posts found that contain the keyword ('{keyword}')" + ) + + except KeyboardInterrupt: + log.warning("User interruption detected ([yellow]Ctrl+C[/])") + except Exception as error: + log.error(f"An error occurred: [red]{error}[/]") + finally: + log.info(f"Finished in {datetime.now() - start_time} seconds") + + # ------------------------------------- # + + +# +++++++++++++++++++++++++++++++++++++++++++++++++ # diff --git a/rpst/utils.py b/rpst/utils.py deleted file mode 100644 index 8e5bc94..0000000 --- a/rpst/utils.py +++ /dev/null @@ -1,182 +0,0 @@ -import os -import csv -import json -import logging -import argparse -from datetime import datetime - -import requests -from glyphoji import glyph -from rich import print -from rich.tree import Tree - -from rich.markdown import Markdown -from rich.logging import RichHandler - - -def convert_timestamp_to_datetime(timestamp: int) -> str: - """ - Converts a Unix timestamp to a formatted datetime string. - - :param timestamp: The Unix timestamp to be converted. - :return: A formatted datetime string in the format "dd MMMM yyyy, hh:mm:ssAM/PM". - """ - utc_from_timestamp = datetime.utcfromtimestamp(timestamp) - datetime_object = utc_from_timestamp.strftime("%d %B %Y, %I:%M:%S%p") - return datetime_object - - -def create_parser(): - """ - Creates and configures an argument parser for the command line arguments. - - :return: A configured argparse.ArgumentParser object ready to parse the command line arguments. - """ - parser = argparse.ArgumentParser( - description="RPST (Reddit Post Scraping Tool) —by Richard Mwewa | https://about.me/rly0nheart", - epilog="Retrieve Reddit posts that contain the specified keyword from a specified subreddit." - ) - - parser.add_argument( - "-k", "--keyword", help="The keyword to search for in the posts.", required=True - ) - parser.add_argument( - "-s", "--subreddit", help="The subreddit to scrape.", required=True - ) - parser.add_argument( - "-c", - "--limit", - help="The maximum number of posts to scrape (1-100). (default: %(default)s)", - default=10, - type=int, - choices=range( - 1, 101 - ), # This enforces that the limit must be between 1 and 100 inclusive. - ) - parser.add_argument( - "-l", - "--listing", - default="top", - const="top", - nargs="?", - choices=["controversial", "hot", "best", "new", "rising"], - help="The type of posts to scrape (default: %(default)s)", - ) - parser.add_argument( - "-t", - "--timeframe", - default="all", - const="all", - nargs="?", - choices=["hour", "day", "week", "month", "year", "all"], - help="The timeframe from which to scrape posts (default: %(default)s)", - ) - parser.add_argument( - "--json", - help="Write all found posts to a json file.", - action="store_true", - ) - parser.add_argument( - "--csv", - help="Write all found posts to a csv file.", - action="store_true", - ) - parser.add_argument( - "-d", - "--debug", - help="run rpst in debug mode", - action="store_true", - ) - - return parser - - -def check_updates(version_tag: str): - """ - This function checks if there's a new release of a project on GitHub. If there is, it logs an - information message and prints the release notes. - - :param version_tag: A string representing the current version of the project. - """ - - # Make a GET request to the GitHub API to get the latest release of the project. - response = requests.get( - "https://api.github.com/repos/bellingcat/reddit-post-scraping-tool/releases/latest" - ).json() - - # Check if the latest release's tag matches the current version tag. - if response["tag_name"] != version_tag: - # If not, convert the release notes from Markdown to HTML. - raw_release_notes = response["body"] - - # Log an info message about the new release. - print( - f"{glyph.up_arrow} A new release of RPST is available ({response['tag_name']}). " - f"Run 'pip install --upgrade reddit-post-scraping-tool' to get the updates." - ) - - # Print the release notes. - print(Markdown(raw_release_notes)) - - -def set_loglevel(debug_mode: bool) -> logging.getLogger: - """ - Configure and return a logging object with the specified log level. - - :param debug_mode: If True, the log level is set to "NOTSET". Otherwise, it is set to "INFO". - :return: A logging object configured with the specified log level. - """ - logging.basicConfig( - level="NOTSET" if debug_mode else "INFO", - format="%(message)s", - handlers=[ - RichHandler(markup=True, log_time_format="[%I:%M:%S %p]", show_level=False) - ], - ) - return logging.getLogger("RPST") - - -def write_post_data(post_data: dict, filename: str, args, tree_branch: Tree): - """ - Writes post data to a specified JSON or CSV file based on the args provided, and updates - the provided tree with the status. - - :param post_data: A dictionary containing post data. - :param filename: The name of the file to which post data will be written. - :param args: A namespace object from argparse containing the output format options (args.json or args.csv). - :param tree_branch: A rich Tree object to which status information will be added. - """ - home_directory = os.path.expanduser("~") - - if args.json: - json_file_path = os.path.join(home_directory, f"{filename}.json") - with open(json_file_path, "a", encoding="utf-8") as file: - file.write(json.dumps(post_data, ensure_ascii=False)) - file.write("\n") # Separate posts with newline - tree_branch.add( - f"{glyph.page_facing_up} JSON data successfully written/appended to file: " - f"[italic][link file://{json_file_path}]{json_file_path}[/]" - ) - else: - tree_branch.add( - f"{glyph.cross_mark_button} JSON data writing operation was skipped. No changes made." - ) - - if args.csv: - csv_file_path = os.path.join(home_directory, f"{filename}.csv") - with open(csv_file_path, "a", newline="", encoding="utf-8") as csvfile: - writer = csv.DictWriter(csvfile, fieldnames=post_data.keys()) - - # Write headers if file is empty - if csvfile.tell() == 0: - writer.writeheader() - - writer.writerow(post_data) - tree_branch.add( - f"{glyph.page_facing_up} CSV data successfully written/appended to file: " - f"[italic][link file://{csv_file_path}]{csv_file_path}[/]" - ) - else: - tree_branch.add( - f"{glyph.cross_mark_button} CSV data writing operation was skipped. No changes made." - )