Scraping more than 100 posts. Code refactor and optimisation. Improved the file writer and update checker functions

This commit is contained in:
rly0nheart
2023-12-03 17:39:32 +02:00
parent 6dd075a44e
commit ff905764cf
9 changed files with 586 additions and 359 deletions

View File

@@ -7,27 +7,27 @@ packages = ["rpst"]
[project]
name = "reddit-post-scraping-tool"
version = "1.9.1.1"
version = "2.0.0.0"
description = "Retrieve Reddit posts that contain the specified keyword from a specified subreddit."
readme = "README.md"
requires-python = ">=3.8"
license = {file = "LICENSE"}
license = { file = "LICENSE" }
keywords = ["reddit-crawler", "reddit-scraping", "reddit", "reddit-api"]
authors = [{name = "Richard Mwewa", email = "rly0nheart@duck.com"}]
authors = [{ name = "Richard Mwewa", email = "rly0nheart@duck.com" }]
classifiers = [
"Development Status :: 5 - Production/Stable",
"Programming Language :: Python :: 3",
"Programming Language :: Visual Basic",
"Intended Audience :: End Users/Desktop",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Natural Language :: English"
"Development Status :: 5 - Production/Stable",
"Programming Language :: Python :: 3",
"Programming Language :: Visual Basic",
"Intended Audience :: End Users/Desktop",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Natural Language :: English"
]
dependencies = [
"rich",
"glyphoji",
"requests",
"aiohttp",
"rich-argparse"
]
[project.urls]
@@ -36,4 +36,5 @@ documentation = "https://github.com/bellingcat/reddit-post-scraping-tool/wiki"
repository = "https://github.com/bellingcat/reddit-post-scraping-tool.git"
[project.scripts]
rpst = "rpst.main:run"
rpst = "rpst.scraper:run"
reddit_post_scraping_tool = "rpst.scraper:run"

View File

@@ -1 +1,42 @@
import os
__author__: str = "Richard Mwewa"
__about_author__: str = "https://about/me/rly0nheart"
__version__: str = "2.0.0.0"
__description__: str = f"""
# RPST (Reddit Post Scraping Tool) {__version__}
> Retrieve Reddit posts that contain the specified keyword from a specified subreddit.
"""
__epilog__: str = f"""
# by [{__author__}]({__about_author__})
```
MIT License
Copyright (c) 2023 {__author__}
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
```
"""
# Construct path to the program's directory
PROGRAM_DIRECTORY: str = os.path.expanduser(
os.path.join("~", "reddit_post_scraping_tool")
)

164
rpst/api.py Normal file
View File

@@ -0,0 +1,164 @@
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
from typing import Union
import aiohttp
from .coreutils import log
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
REDDIT_ENDPOINT: str = "https://www.reddit.com"
PYPI_PROJECT_ENDPOINT: str = "https://pypi.org/pypi/reddit-post-scraping-tool/json"
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
async def get_data(session: aiohttp.ClientSession, endpoint: str) -> Union[dict, list]:
"""
Fetches JSON data from a given API endpoint.
:param session: aiohttp session to use for the request.
:param endpoint: The API endpoint to fetch data from.
:return: Returns JSON data as a dictionary or list. Returns an empty dict if fetching fails.
"""
try:
async with session.get(
endpoint,
) as response:
if response.status == 200:
return await response.json()
else:
error_message = await response.json()
log.error(f"An API error occurred: {error_message}")
return {}
except aiohttp.ClientConnectionError as error:
log.error(f"An HTTP error occurred: {error}")
return {}
except Exception as error:
log.critical(f"An unknown error occurred: {error}")
return {}
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
async def get_updates(session: aiohttp.ClientSession):
"""
Gets and compares the current program version with the remote version.
Assumes version format: major.minor.patch.prefix
:param session: aiohttp session to use for the request.
"""
from . import __version__
# Make a GET request to PyPI to get the project's latest release.
response: dict = await get_data(endpoint=PYPI_PROJECT_ENDPOINT, session=session)
if response.get("info"):
release: dict = response.get("info")
remote_version: str = release.get("version")
# Splitting the version strings into components
remote_parts: list = remote_version.split(".")
local_parts: list = __version__.split(".")
update_message: str = ""
# Check for differences in version parts
if remote_parts[0] != local_parts[0]:
update_message = (
f"MAJOR update ({remote_version}) available."
f" It might introduce significant changes."
)
elif remote_parts[1] != local_parts[1]:
update_message = (
f"MINOR update ({remote_version}) available."
f" Includes small feature changes/improvements."
)
elif remote_parts[2] != local_parts[2]:
update_message = (
f"PATCH update ({remote_version}) available."
f" Generally for bug fixes and small tweaks."
)
elif (
len(remote_parts) > 3
and len(local_parts) > 3
and remote_parts[3] != local_parts[3]
):
update_message = (
f"BUILD update ({remote_version}) available."
f" Might be for specific builds or special versions."
)
if update_message:
log.info(update_message)
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
async def get_posts(
subreddit: str,
listing: str,
timeframe: str,
limit: int,
session: aiohttp.ClientSession,
) -> list:
all_posts = await paginated_posts(
posts_endpoint=f"{REDDIT_ENDPOINT}/r/{subreddit}/{listing}.json?limit={limit}&t={timeframe}",
limit=limit,
session=session,
)
return all_posts[:limit]
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
async def paginated_posts(
posts_endpoint: str, limit: int, session: aiohttp.ClientSession
) -> list:
"""
Paginates and retrieves posts until the specified limit is reached.
:param posts_endpoint: API endpoint for retrieving posts.
:param limit: Limit of the number of posts to retrieve.
:param session: aiohttp session to use for the request.
:return: A list of all posts.
"""
all_posts: list = []
last_post_id: str = ""
# Determine whether to use the 'after' parameter
use_after: bool = limit > 100
while len(all_posts) < limit:
# Make the API request with the 'after' parameter if it's provided and the limit is more than 100
if use_after and last_post_id:
endpoint_with_after: str = f"{posts_endpoint}&after={last_post_id}"
else:
endpoint_with_after: str = posts_endpoint
posts_data: dict = await get_data(endpoint=endpoint_with_after, session=session)
posts_children: list = posts_data.get("data", {}).get("children", [])
# If there are no more posts, break out of the loop
if not posts_children:
break
all_posts.extend(posts_children)
# We use the id of the last post in the list to paginate to the next posts
last_post_id: str = all_posts[-1].get("data").get("id")
return all_posts
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #

103
rpst/base.py Normal file
View File

@@ -0,0 +1,103 @@
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
from dataclasses import dataclass
from typing import List
import aiohttp
from .api import get_posts, get_updates
from .coreutils import timestamp_to_utc
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
@dataclass
class Post:
id: str
thumbnail: str
title: str
text: str
author: str
subreddit: str
subreddit_id: str
subreddit_type: str
upvotes: int
upvote_ratio: float
downvotes: int
gilded: int
is_nsfw: bool
is_shareable: bool
is_edited: bool
comments: int
hide_from_bots: bool
score: float
domain: str
permalink: str
is_locked: bool
is_archived: bool
created_at: str
raw_post: dict
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
async def find_posts(
keyword: str,
subreddit: str,
listing: str,
timeframe: str,
limit: int,
) -> List[Post]:
async with aiohttp.ClientSession() as session:
found_posts_count: int = 0
found_posts_list: list = []
await get_updates(session=session)
raw_posts: list = await get_posts(
subreddit=subreddit,
listing=listing,
timeframe=timeframe,
limit=limit,
session=session,
)
for raw_post in raw_posts:
post_data: dict = raw_post.get("data")
if keyword.lower() in post_data.get(
"selftext"
) or keyword.lower() in post_data.get("title"):
found_posts_count += 1
post = Post(
id=post_data.get("id"),
thumbnail=post_data.get("thumbnail"),
title=post_data.get("title"),
text=post_data.get("selftext"),
author=post_data.get("author"),
subreddit=post_data.get("subreddit"),
subreddit_id=post_data.get("subreddit_id"),
subreddit_type=post_data.get("subreddit_type"),
upvotes=post_data.get("ups"),
upvote_ratio=post_data.get("upvote_ratio"),
downvotes=post_data.get("downs"),
gilded=post_data.get("gilded"),
is_nsfw=post_data.get("over_18"),
is_shareable=post_data.get("is_reddit_media_domain"),
is_edited=post_data.get("edited"),
comments=post_data.get("num_comments"),
hide_from_bots=post_data.get("is_robot_indexable"),
score=post_data.get("score"),
domain=post_data.get("domain"),
permalink=post_data.get("permalink"),
is_locked=post_data.get("locked"),
is_archived=post_data.get("archived"),
created_at=timestamp_to_utc(timestamp=post_data.get("created_utc")),
raw_post=post_data,
)
found_posts_list.append(post)
return found_posts_list
# +++++++++++++++++++++++++++++++++++++++++++++++++ #

170
rpst/coreutils.py Normal file
View File

@@ -0,0 +1,170 @@
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
import argparse
import csv
import json
import logging
import os
from datetime import datetime
from rich.logging import RichHandler
from rich.markdown import Markdown
from rich_argparse import RichHelpFormatter
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
def timestamp_to_utc(timestamp: int) -> str:
"""
Converts a Unix timestamp to a formatted datetime string.
:param timestamp: The Unix timestamp to be converted.
:return: A formatted datetime string in the format "dd MMMM yyyy, hh:mm:ssAM/PM".
"""
utc_from_timestamp: datetime = datetime.utcfromtimestamp(timestamp)
datetime_string: str = utc_from_timestamp.strftime("%d %B %Y, %I:%M:%S%p")
return datetime_string
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
def pathfinder(directories: list[str]):
for directory in directories:
os.makedirs(directory, exist_ok=True)
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
def save_posts(
filename: str,
save_to_dir: str,
posts: list,
save_json: bool = False,
save_csv: bool = False,
):
posts_data: list = [post.__dict__ for post in posts]
if save_json:
json_path = os.path.join(os.path.join(save_to_dir, "json"), f"{filename}.json")
with open(json_path, "w", encoding="utf-8") as json_file:
json.dump(posts_data, json_file, indent=4)
log.info(
f"{os.path.getsize(json_file.name)} bytes written to [link file://{json_file.name}]{json_file.name}"
)
if save_csv:
csv_path = os.path.join(os.path.join(save_to_dir, "csv"), f"{filename}.csv")
with open(csv_path, "w", newline="", encoding="utf-8") as csv_file:
writer = csv.writer(csv_file)
if posts:
writer.writerow(
posts_data[0].keys()
) # header from keys of the first item
for post in posts:
writer.writerow(post.__dict__.values())
log.info(
f"{os.path.getsize(csv_file.name)} bytes written to [link file://{csv_file.name}]{csv_file.name}"
)
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
def create_parser():
"""
Creates and configures an argument parser for the command line arguments.
:return: A configured argparse.ArgumentParser object ready to parse the command line arguments.
"""
from . import __version__, __description__, __epilog__
parser = argparse.ArgumentParser(
description=Markdown(__description__, style="argparse.text"),
epilog=Markdown(__epilog__, style="argparse.text"),
formatter_class=RichHelpFormatter,
)
parser.add_argument(
"keyword",
help="keyword to search for, in posts",
)
parser.add_argument("subreddit", help="subreddit to scrape")
parser.add_argument(
"-l",
"--limit",
help="maximum number of posts to scrape (default: %(default)s)",
default=200,
type=int,
)
parser.add_argument(
"-ls",
"--listing",
default="top",
const="top",
nargs="?",
choices=["best", "controversial", "hot", "new", "rising", "top"],
help="listing of posts to scrape (default: %(default)s)",
)
parser.add_argument(
"-t",
"--timeframe",
default="all",
const="all",
nargs="?",
choices=["hour", "day", "week", "month", "year", "all"],
help="timeframe from which to scrape posts (default: %(default)s)",
)
parser.add_argument(
"-j",
"--json",
help="write found posts to a json file",
action="store_true",
)
parser.add_argument(
"-c",
"--csv",
help="write found posts to a csv file",
action="store_true",
)
parser.add_argument(
"-d",
"--debug",
help="(dev) run rpst in debug mode",
action="store_true",
)
parser.add_argument("-v", "--version", action="version", version=__version__)
return parser
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
def set_loglevel(debug_mode: bool) -> logging.getLogger:
"""
Configure and return a logging object with the specified log level.
:param debug_mode: If True, the log level is set to "NOTSET". Otherwise, it is set to "INFO".
:return: A logging object configured with the specified log level.
"""
logging.basicConfig(
level="DEBUG" if debug_mode else "INFO",
format="%(message)s",
handlers=[
RichHandler(
markup=True, log_time_format="%I:%M:%S%p", show_level=debug_mode
)
],
)
return logging.getLogger("RPST (Reddit Post Scraping Tool)")
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
args: argparse = create_parser().parse_args()
log: logging.getLogger = set_loglevel(debug_mode=args.debug)
# +++++++++++++++++++++++++++++++++++++++++++++++++ #

View File

@@ -1,33 +0,0 @@
from datetime import datetime
from .rpst import get_posts
from .utils import create_parser, set_loglevel, check_updates
def run():
"""
Main entry point for the program. It creates a parser, parses the command line arguments,
checks for updates, gets posts, and handles any exceptions that occur during the execution.
"""
# Create a parser and parse the command line arguments
parser = create_parser()
args = parser.parse_args()
log = set_loglevel(debug_mode=args.debug)
# Record the start time
start_time = datetime.now()
try:
# Check for updates
check_updates(version_tag="1.9.1.1")
# Get posts with the provided/parsed arguments
get_posts(args=args)
except KeyboardInterrupt:
log.warning("User interruption detected ([yellow]Ctrl+C[/]).")
except Exception as e:
log.error(f"An error occurred: [red]{e}[/]")
finally:
log.info(f"Finished in {datetime.now() - start_time} seconds.")

View File

@@ -1,131 +0,0 @@
import argparse
from datetime import datetime
import requests
from glyphoji import glyph
from rich import print
from rich.tree import Tree
from .utils import convert_timestamp_to_datetime, write_post_data
def create_post_branch(post: dict, keyword: str, tree: Tree, args: argparse) -> Tree:
"""
This function extracts relevant data from a Reddit post and adds it in a tree branch structure,
followed by the post's selftext.
:param post: A dictionary containing the data of a Reddit post.
:param keyword: The keyword that is used to find posts, in his case gets uses as the filename.
:param tree: Tree where the post branch will be added.
:param args: A namespace object from argparse.
:returns: The main tree with added post branches.
"""
# Define the data to extract from the post.
post_data = {
# "Author": post["data"]["author"],
f"{glyph.id_button} ID": post["data"]["id"],
f"{glyph.people_hugging} Subreddit": post["data"]["subreddit_name_prefixed"],
f"{glyph.face_with_peeking_eye} Visibility": post["data"]["subreddit_type"],
f"{glyph.framed_picture} Thumbnail": post["data"]["thumbnail"],
f"{glyph.white_question_mark} Gilded": post["data"]["gilded"],
f"{glyph.up_arrow} Upvotes": post["data"]["ups"],
f"{glyph.chart_increasing} Upvote ratio": post["data"]["upvote_ratio"],
f"{glyph.down_arrow} Downvotes": post["data"]["downs"],
f"{glyph.trophy} Awards": post["data"]["total_awards_received"],
f"{glyph.trophy} Top award": post["data"]["top_awarded_type"],
f"{glyph.no_one_under_eighteen} Is NSFW?": post["data"]["over_18"],
f"{glyph.left_arrow_curving_right} Is crosspostable?": post["data"][
"is_crosspostable"
],
f"{glyph.bar_chart} Score": post["data"]["score"],
f"{glyph.card_file_box} Category": post["data"]["category"],
f"{glyph.globe_with_meridians} Domain": post["data"]["domain"],
f"{glyph.calendar} Posted on": convert_timestamp_to_datetime(
post["data"]["created"]
),
f"{glyph.calendar} Approved at": post["data"]["approved_at_utc"],
f"{glyph.bust_in_silhouette} Approved by": post["data"]["approved_by"],
}
# Add the post's branch to the main tree.
post_branch = tree.add(f"{glyph.page_with_curl} {post['data']['title']}")
# Add each piece of extracted data as a branch of the post_branch.
for post_key, post_value in post_data.items():
post_branch.add(f"{post_key}: {post_value}", style="dim")
# This ensures that the post's selftext is also added to the written json/csv file.
post_data[f"{glyph.clipboard} Text"] = post["data"]["selftext"]
write_post_data(
filename=keyword, post_data=post_data, tree_branch=post_branch, args=args
)
post_branch.add(post["data"]["selftext"], style="italic")
return tree
def get_posts(args: argparse):
"""
Scrapes a given subreddit for posts that contain a specified keyword.
The search is limited by the number of posts and timeframe specified.
:param args: Namespace object from argparse.
Expected Object Attributes
--------------------------
- keyword: The keyword to search for in the posts.
- subreddit: The subreddit to scrape.
- listing: The type of posts to scrape. This could be 'hot', 'new', etc.
- timeframe: The timeframe from which to scrape posts. This could be 'day', 'week', etc.
- limit: The maximum number of posts to scrape.
- json: If specified, all found posts will be written to a json file.
"""
keyword = args.keyword
subreddit = args.subreddit
listing = args.listing
timeframe = args.timeframe
limit = args.limit
# Create main result tree.
main_tree = Tree(
f"[bold]{glyph.calendar} {datetime.now()}[/]", guide_style="bold bright_blue"
)
# Start a new session
session = requests.session()
# Set the User-Agent to mimic a Safari browser on a Mac.
session.headers = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, "
"like Gecko) Version/14.1.1 Safari/605.1.15"
}
# Send a GET request to the specified subreddit and listing,
# limiting the response by the specified limit and timeframe.
response = session.get(
f"https://reddit.com/r/{subreddit}/{listing}"
f".json?limit={limit}&t={timeframe}"
).json()
# Initialize a counter for the number of posts found that contain the keyword.
found_posts = 0
# Loop through each post in the response
for post_index, post in enumerate(response["data"]["children"], start=1):
# If the keyword is found in the post's selftext or title, increment the counter and process the post.
if (
keyword.lower() in post["data"]["selftext"]
or keyword.lower() in post["data"]["title"]
):
# Create a branch for found post(s) and show post index and post author as the title
found_tree = main_tree.add(
f"{glyph.bust_in_silhouette} #{post_index} by [bold]@{post['data']['author']}[/]"
)
found_posts += 1
create_post_branch(post=post, keyword=keyword, tree=found_tree, args=args)
# Log the number of posts in which the keyword was found
main_tree.add(
f"{glyph.check_mark_button} Keyword ('{keyword}') was found in "
f"{found_posts}/{len(response['data']['children'])} {listing} posts from r/{subreddit}."
)
print(main_tree)

94
rpst/scraper.py Normal file
View File

@@ -0,0 +1,94 @@
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
import asyncio
import os
from datetime import datetime
from rich.pretty import pprint
from . import __version__, PROGRAM_DIRECTORY
from .base import find_posts
from .coreutils import args, log, save_posts, pathfinder
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
def run():
"""Main entry point for rpst or rpst."""
# ------------------------------------- #
keyword: str = args.keyword
subreddit: str = args.subreddit
listing: str = args.listing
limit: int = args.limit
# ------------------------------------- #
start_time = datetime.now()
# ------------------------------------- #
print(
"""
┳┓┏┓┏┓┏┳┓
┣┫┃┃┗┓ ┃
┛┗┣┛┗┛ ┻ """
)
# ------------------------------------- #
try:
log.info(
f"[bold]RPST[/] {__version__} started at {start_time.strftime('%a %b %d %Y, %I:%M:%S%p')}..."
)
found_posts = asyncio.run(
find_posts(
keyword=keyword,
subreddit=subreddit,
listing=listing,
timeframe=args.timeframe,
limit=limit,
),
)
if found_posts:
pprint(
found_posts,
expand_all=True,
)
log.info(
f"'{subreddit}': Found {len(found_posts)}/{limit} {listing} posts containing the keyword ('{keyword}')"
)
if args.json or args.csv:
target_dir: str = os.path.join(PROGRAM_DIRECTORY, subreddit)
pathfinder(
directories=[
os.path.join(target_dir, "csv"),
os.path.join(target_dir, "json"),
]
)
save_posts(
filename=keyword,
save_to_dir=target_dir,
posts=found_posts,
save_json=args.json,
save_csv=args.csv,
)
else:
log.info(
f"'r/{subreddit}': No {listing} posts found that contain the keyword ('{keyword}')"
)
except KeyboardInterrupt:
log.warning("User interruption detected ([yellow]Ctrl+C[/])")
except Exception as error:
log.error(f"An error occurred: [red]{error}[/]")
finally:
log.info(f"Finished in {datetime.now() - start_time} seconds")
# ------------------------------------- #
# +++++++++++++++++++++++++++++++++++++++++++++++++ #

View File

@@ -1,182 +0,0 @@
import os
import csv
import json
import logging
import argparse
from datetime import datetime
import requests
from glyphoji import glyph
from rich import print
from rich.tree import Tree
from rich.markdown import Markdown
from rich.logging import RichHandler
def convert_timestamp_to_datetime(timestamp: int) -> str:
"""
Converts a Unix timestamp to a formatted datetime string.
:param timestamp: The Unix timestamp to be converted.
:return: A formatted datetime string in the format "dd MMMM yyyy, hh:mm:ssAM/PM".
"""
utc_from_timestamp = datetime.utcfromtimestamp(timestamp)
datetime_object = utc_from_timestamp.strftime("%d %B %Y, %I:%M:%S%p")
return datetime_object
def create_parser():
"""
Creates and configures an argument parser for the command line arguments.
:return: A configured argparse.ArgumentParser object ready to parse the command line arguments.
"""
parser = argparse.ArgumentParser(
description="RPST (Reddit Post Scraping Tool) —by Richard Mwewa | https://about.me/rly0nheart",
epilog="Retrieve Reddit posts that contain the specified keyword from a specified subreddit."
)
parser.add_argument(
"-k", "--keyword", help="The keyword to search for in the posts.", required=True
)
parser.add_argument(
"-s", "--subreddit", help="The subreddit to scrape.", required=True
)
parser.add_argument(
"-c",
"--limit",
help="The maximum number of posts to scrape (1-100). (default: %(default)s)",
default=10,
type=int,
choices=range(
1, 101
), # This enforces that the limit must be between 1 and 100 inclusive.
)
parser.add_argument(
"-l",
"--listing",
default="top",
const="top",
nargs="?",
choices=["controversial", "hot", "best", "new", "rising"],
help="The type of posts to scrape (default: %(default)s)",
)
parser.add_argument(
"-t",
"--timeframe",
default="all",
const="all",
nargs="?",
choices=["hour", "day", "week", "month", "year", "all"],
help="The timeframe from which to scrape posts (default: %(default)s)",
)
parser.add_argument(
"--json",
help="Write all found posts to a json file.",
action="store_true",
)
parser.add_argument(
"--csv",
help="Write all found posts to a csv file.",
action="store_true",
)
parser.add_argument(
"-d",
"--debug",
help="run rpst in debug mode",
action="store_true",
)
return parser
def check_updates(version_tag: str):
"""
This function checks if there's a new release of a project on GitHub. If there is, it logs an
information message and prints the release notes.
:param version_tag: A string representing the current version of the project.
"""
# Make a GET request to the GitHub API to get the latest release of the project.
response = requests.get(
"https://api.github.com/repos/bellingcat/reddit-post-scraping-tool/releases/latest"
).json()
# Check if the latest release's tag matches the current version tag.
if response["tag_name"] != version_tag:
# If not, convert the release notes from Markdown to HTML.
raw_release_notes = response["body"]
# Log an info message about the new release.
print(
f"{glyph.up_arrow} A new release of RPST is available ({response['tag_name']}). "
f"Run 'pip install --upgrade reddit-post-scraping-tool' to get the updates."
)
# Print the release notes.
print(Markdown(raw_release_notes))
def set_loglevel(debug_mode: bool) -> logging.getLogger:
"""
Configure and return a logging object with the specified log level.
:param debug_mode: If True, the log level is set to "NOTSET". Otherwise, it is set to "INFO".
:return: A logging object configured with the specified log level.
"""
logging.basicConfig(
level="NOTSET" if debug_mode else "INFO",
format="%(message)s",
handlers=[
RichHandler(markup=True, log_time_format="[%I:%M:%S %p]", show_level=False)
],
)
return logging.getLogger("RPST")
def write_post_data(post_data: dict, filename: str, args, tree_branch: Tree):
"""
Writes post data to a specified JSON or CSV file based on the args provided, and updates
the provided tree with the status.
:param post_data: A dictionary containing post data.
:param filename: The name of the file to which post data will be written.
:param args: A namespace object from argparse containing the output format options (args.json or args.csv).
:param tree_branch: A rich Tree object to which status information will be added.
"""
home_directory = os.path.expanduser("~")
if args.json:
json_file_path = os.path.join(home_directory, f"{filename}.json")
with open(json_file_path, "a", encoding="utf-8") as file:
file.write(json.dumps(post_data, ensure_ascii=False))
file.write("\n") # Separate posts with newline
tree_branch.add(
f"{glyph.page_facing_up} JSON data successfully written/appended to file: "
f"[italic][link file://{json_file_path}]{json_file_path}[/]"
)
else:
tree_branch.add(
f"{glyph.cross_mark_button} JSON data writing operation was skipped. No changes made."
)
if args.csv:
csv_file_path = os.path.join(home_directory, f"{filename}.csv")
with open(csv_file_path, "a", newline="", encoding="utf-8") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=post_data.keys())
# Write headers if file is empty
if csvfile.tell() == 0:
writer.writeheader()
writer.writerow(post_data)
tree_branch.add(
f"{glyph.page_facing_up} CSV data successfully written/appended to file: "
f"[italic][link file://{csv_file_path}]{csv_file_path}[/]"
)
else:
tree_branch.add(
f"{glyph.cross_mark_button} CSV data writing operation was skipped. No changes made."
)