mirror of
https://github.com/bellingcat/reddit-post-scraping-tool.git
synced 2026-06-07 19:18:29 +03:00
Scraping more than 100 posts. Code refactor and optimisation. Improved the file writer and update checker functions
This commit is contained in:
@@ -7,27 +7,27 @@ packages = ["rpst"]
|
||||
|
||||
[project]
|
||||
name = "reddit-post-scraping-tool"
|
||||
version = "1.9.1.1"
|
||||
version = "2.0.0.0"
|
||||
description = "Retrieve Reddit posts that contain the specified keyword from a specified subreddit."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
license = {file = "LICENSE"}
|
||||
license = { file = "LICENSE" }
|
||||
keywords = ["reddit-crawler", "reddit-scraping", "reddit", "reddit-api"]
|
||||
authors = [{name = "Richard Mwewa", email = "rly0nheart@duck.com"}]
|
||||
authors = [{ name = "Richard Mwewa", email = "rly0nheart@duck.com" }]
|
||||
classifiers = [
|
||||
"Development Status :: 5 - Production/Stable",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Visual Basic",
|
||||
"Intended Audience :: End Users/Desktop",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
"Natural Language :: English"
|
||||
"Development Status :: 5 - Production/Stable",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Visual Basic",
|
||||
"Intended Audience :: End Users/Desktop",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
"Natural Language :: English"
|
||||
]
|
||||
|
||||
dependencies = [
|
||||
"rich",
|
||||
"glyphoji",
|
||||
"requests",
|
||||
"aiohttp",
|
||||
"rich-argparse"
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
@@ -36,4 +36,5 @@ documentation = "https://github.com/bellingcat/reddit-post-scraping-tool/wiki"
|
||||
repository = "https://github.com/bellingcat/reddit-post-scraping-tool.git"
|
||||
|
||||
[project.scripts]
|
||||
rpst = "rpst.main:run"
|
||||
rpst = "rpst.scraper:run"
|
||||
reddit_post_scraping_tool = "rpst.scraper:run"
|
||||
|
||||
@@ -1 +1,42 @@
|
||||
import os
|
||||
|
||||
__author__: str = "Richard Mwewa"
|
||||
__about_author__: str = "https://about/me/rly0nheart"
|
||||
__version__: str = "2.0.0.0"
|
||||
|
||||
__description__: str = f"""
|
||||
# RPST (Reddit Post Scraping Tool) {__version__}
|
||||
> Retrieve Reddit posts that contain the specified keyword from a specified subreddit.
|
||||
"""
|
||||
__epilog__: str = f"""
|
||||
# by [{__author__}]({__about_author__})
|
||||
|
||||
```
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 {__author__}
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
```
|
||||
"""
|
||||
|
||||
# Construct path to the program's directory
|
||||
PROGRAM_DIRECTORY: str = os.path.expanduser(
|
||||
os.path.join("~", "reddit_post_scraping_tool")
|
||||
)
|
||||
|
||||
164
rpst/api.py
Normal file
164
rpst/api.py
Normal file
@@ -0,0 +1,164 @@
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
from typing import Union
|
||||
|
||||
import aiohttp
|
||||
|
||||
from .coreutils import log
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
|
||||
REDDIT_ENDPOINT: str = "https://www.reddit.com"
|
||||
PYPI_PROJECT_ENDPOINT: str = "https://pypi.org/pypi/reddit-post-scraping-tool/json"
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
|
||||
async def get_data(session: aiohttp.ClientSession, endpoint: str) -> Union[dict, list]:
|
||||
"""
|
||||
Fetches JSON data from a given API endpoint.
|
||||
|
||||
:param session: aiohttp session to use for the request.
|
||||
:param endpoint: The API endpoint to fetch data from.
|
||||
:return: Returns JSON data as a dictionary or list. Returns an empty dict if fetching fails.
|
||||
"""
|
||||
|
||||
try:
|
||||
async with session.get(
|
||||
endpoint,
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
error_message = await response.json()
|
||||
log.error(f"An API error occurred: {error_message}")
|
||||
return {}
|
||||
|
||||
except aiohttp.ClientConnectionError as error:
|
||||
log.error(f"An HTTP error occurred: {error}")
|
||||
return {}
|
||||
except Exception as error:
|
||||
log.critical(f"An unknown error occurred: {error}")
|
||||
return {}
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
|
||||
async def get_updates(session: aiohttp.ClientSession):
|
||||
"""
|
||||
Gets and compares the current program version with the remote version.
|
||||
|
||||
Assumes version format: major.minor.patch.prefix
|
||||
|
||||
:param session: aiohttp session to use for the request.
|
||||
"""
|
||||
from . import __version__
|
||||
|
||||
# Make a GET request to PyPI to get the project's latest release.
|
||||
response: dict = await get_data(endpoint=PYPI_PROJECT_ENDPOINT, session=session)
|
||||
|
||||
if response.get("info"):
|
||||
release: dict = response.get("info")
|
||||
remote_version: str = release.get("version")
|
||||
# Splitting the version strings into components
|
||||
remote_parts: list = remote_version.split(".")
|
||||
local_parts: list = __version__.split(".")
|
||||
|
||||
update_message: str = ""
|
||||
|
||||
# Check for differences in version parts
|
||||
if remote_parts[0] != local_parts[0]:
|
||||
update_message = (
|
||||
f"MAJOR update ({remote_version}) available."
|
||||
f" It might introduce significant changes."
|
||||
)
|
||||
|
||||
elif remote_parts[1] != local_parts[1]:
|
||||
update_message = (
|
||||
f"MINOR update ({remote_version}) available."
|
||||
f" Includes small feature changes/improvements."
|
||||
)
|
||||
|
||||
elif remote_parts[2] != local_parts[2]:
|
||||
update_message = (
|
||||
f"PATCH update ({remote_version}) available."
|
||||
f" Generally for bug fixes and small tweaks."
|
||||
)
|
||||
|
||||
elif (
|
||||
len(remote_parts) > 3
|
||||
and len(local_parts) > 3
|
||||
and remote_parts[3] != local_parts[3]
|
||||
):
|
||||
update_message = (
|
||||
f"BUILD update ({remote_version}) available."
|
||||
f" Might be for specific builds or special versions."
|
||||
)
|
||||
|
||||
if update_message:
|
||||
log.info(update_message)
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
async def get_posts(
|
||||
subreddit: str,
|
||||
listing: str,
|
||||
timeframe: str,
|
||||
limit: int,
|
||||
session: aiohttp.ClientSession,
|
||||
) -> list:
|
||||
all_posts = await paginated_posts(
|
||||
posts_endpoint=f"{REDDIT_ENDPOINT}/r/{subreddit}/{listing}.json?limit={limit}&t={timeframe}",
|
||||
limit=limit,
|
||||
session=session,
|
||||
)
|
||||
|
||||
return all_posts[:limit]
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
|
||||
async def paginated_posts(
|
||||
posts_endpoint: str, limit: int, session: aiohttp.ClientSession
|
||||
) -> list:
|
||||
"""
|
||||
Paginates and retrieves posts until the specified limit is reached.
|
||||
|
||||
:param posts_endpoint: API endpoint for retrieving posts.
|
||||
:param limit: Limit of the number of posts to retrieve.
|
||||
:param session: aiohttp session to use for the request.
|
||||
:return: A list of all posts.
|
||||
"""
|
||||
all_posts: list = []
|
||||
last_post_id: str = ""
|
||||
|
||||
# Determine whether to use the 'after' parameter
|
||||
use_after: bool = limit > 100
|
||||
|
||||
while len(all_posts) < limit:
|
||||
# Make the API request with the 'after' parameter if it's provided and the limit is more than 100
|
||||
if use_after and last_post_id:
|
||||
endpoint_with_after: str = f"{posts_endpoint}&after={last_post_id}"
|
||||
else:
|
||||
endpoint_with_after: str = posts_endpoint
|
||||
|
||||
posts_data: dict = await get_data(endpoint=endpoint_with_after, session=session)
|
||||
posts_children: list = posts_data.get("data", {}).get("children", [])
|
||||
|
||||
# If there are no more posts, break out of the loop
|
||||
if not posts_children:
|
||||
break
|
||||
|
||||
all_posts.extend(posts_children)
|
||||
|
||||
# We use the id of the last post in the list to paginate to the next posts
|
||||
last_post_id: str = all_posts[-1].get("data").get("id")
|
||||
|
||||
return all_posts
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
103
rpst/base.py
Normal file
103
rpst/base.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
import aiohttp
|
||||
|
||||
from .api import get_posts, get_updates
|
||||
from .coreutils import timestamp_to_utc
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
|
||||
@dataclass
|
||||
class Post:
|
||||
id: str
|
||||
thumbnail: str
|
||||
title: str
|
||||
text: str
|
||||
author: str
|
||||
subreddit: str
|
||||
subreddit_id: str
|
||||
subreddit_type: str
|
||||
upvotes: int
|
||||
upvote_ratio: float
|
||||
downvotes: int
|
||||
gilded: int
|
||||
is_nsfw: bool
|
||||
is_shareable: bool
|
||||
is_edited: bool
|
||||
comments: int
|
||||
hide_from_bots: bool
|
||||
score: float
|
||||
domain: str
|
||||
permalink: str
|
||||
is_locked: bool
|
||||
is_archived: bool
|
||||
created_at: str
|
||||
raw_post: dict
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
|
||||
async def find_posts(
|
||||
keyword: str,
|
||||
subreddit: str,
|
||||
listing: str,
|
||||
timeframe: str,
|
||||
limit: int,
|
||||
) -> List[Post]:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
found_posts_count: int = 0
|
||||
found_posts_list: list = []
|
||||
|
||||
await get_updates(session=session)
|
||||
raw_posts: list = await get_posts(
|
||||
subreddit=subreddit,
|
||||
listing=listing,
|
||||
timeframe=timeframe,
|
||||
limit=limit,
|
||||
session=session,
|
||||
)
|
||||
for raw_post in raw_posts:
|
||||
post_data: dict = raw_post.get("data")
|
||||
|
||||
if keyword.lower() in post_data.get(
|
||||
"selftext"
|
||||
) or keyword.lower() in post_data.get("title"):
|
||||
found_posts_count += 1
|
||||
post = Post(
|
||||
id=post_data.get("id"),
|
||||
thumbnail=post_data.get("thumbnail"),
|
||||
title=post_data.get("title"),
|
||||
text=post_data.get("selftext"),
|
||||
author=post_data.get("author"),
|
||||
subreddit=post_data.get("subreddit"),
|
||||
subreddit_id=post_data.get("subreddit_id"),
|
||||
subreddit_type=post_data.get("subreddit_type"),
|
||||
upvotes=post_data.get("ups"),
|
||||
upvote_ratio=post_data.get("upvote_ratio"),
|
||||
downvotes=post_data.get("downs"),
|
||||
gilded=post_data.get("gilded"),
|
||||
is_nsfw=post_data.get("over_18"),
|
||||
is_shareable=post_data.get("is_reddit_media_domain"),
|
||||
is_edited=post_data.get("edited"),
|
||||
comments=post_data.get("num_comments"),
|
||||
hide_from_bots=post_data.get("is_robot_indexable"),
|
||||
score=post_data.get("score"),
|
||||
domain=post_data.get("domain"),
|
||||
permalink=post_data.get("permalink"),
|
||||
is_locked=post_data.get("locked"),
|
||||
is_archived=post_data.get("archived"),
|
||||
created_at=timestamp_to_utc(timestamp=post_data.get("created_utc")),
|
||||
raw_post=post_data,
|
||||
)
|
||||
found_posts_list.append(post)
|
||||
|
||||
return found_posts_list
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
170
rpst/coreutils.py
Normal file
170
rpst/coreutils.py
Normal file
@@ -0,0 +1,170 @@
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from rich.logging import RichHandler
|
||||
from rich.markdown import Markdown
|
||||
from rich_argparse import RichHelpFormatter
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
|
||||
def timestamp_to_utc(timestamp: int) -> str:
|
||||
"""
|
||||
Converts a Unix timestamp to a formatted datetime string.
|
||||
|
||||
:param timestamp: The Unix timestamp to be converted.
|
||||
:return: A formatted datetime string in the format "dd MMMM yyyy, hh:mm:ssAM/PM".
|
||||
"""
|
||||
utc_from_timestamp: datetime = datetime.utcfromtimestamp(timestamp)
|
||||
datetime_string: str = utc_from_timestamp.strftime("%d %B %Y, %I:%M:%S%p")
|
||||
return datetime_string
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
|
||||
def pathfinder(directories: list[str]):
|
||||
for directory in directories:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
|
||||
def save_posts(
|
||||
filename: str,
|
||||
save_to_dir: str,
|
||||
posts: list,
|
||||
save_json: bool = False,
|
||||
save_csv: bool = False,
|
||||
):
|
||||
posts_data: list = [post.__dict__ for post in posts]
|
||||
|
||||
if save_json:
|
||||
json_path = os.path.join(os.path.join(save_to_dir, "json"), f"{filename}.json")
|
||||
with open(json_path, "w", encoding="utf-8") as json_file:
|
||||
json.dump(posts_data, json_file, indent=4)
|
||||
log.info(
|
||||
f"{os.path.getsize(json_file.name)} bytes written to [link file://{json_file.name}]{json_file.name}"
|
||||
)
|
||||
|
||||
if save_csv:
|
||||
csv_path = os.path.join(os.path.join(save_to_dir, "csv"), f"{filename}.csv")
|
||||
with open(csv_path, "w", newline="", encoding="utf-8") as csv_file:
|
||||
writer = csv.writer(csv_file)
|
||||
if posts:
|
||||
writer.writerow(
|
||||
posts_data[0].keys()
|
||||
) # header from keys of the first item
|
||||
for post in posts:
|
||||
writer.writerow(post.__dict__.values())
|
||||
log.info(
|
||||
f"{os.path.getsize(csv_file.name)} bytes written to [link file://{csv_file.name}]{csv_file.name}"
|
||||
)
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
|
||||
def create_parser():
|
||||
"""
|
||||
Creates and configures an argument parser for the command line arguments.
|
||||
|
||||
:return: A configured argparse.ArgumentParser object ready to parse the command line arguments.
|
||||
"""
|
||||
from . import __version__, __description__, __epilog__
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=Markdown(__description__, style="argparse.text"),
|
||||
epilog=Markdown(__epilog__, style="argparse.text"),
|
||||
formatter_class=RichHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"keyword",
|
||||
help="keyword to search for, in posts",
|
||||
)
|
||||
parser.add_argument("subreddit", help="subreddit to scrape")
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--limit",
|
||||
help="maximum number of posts to scrape (default: %(default)s)",
|
||||
default=200,
|
||||
type=int,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-ls",
|
||||
"--listing",
|
||||
default="top",
|
||||
const="top",
|
||||
nargs="?",
|
||||
choices=["best", "controversial", "hot", "new", "rising", "top"],
|
||||
help="listing of posts to scrape (default: %(default)s)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--timeframe",
|
||||
default="all",
|
||||
const="all",
|
||||
nargs="?",
|
||||
choices=["hour", "day", "week", "month", "year", "all"],
|
||||
help="timeframe from which to scrape posts (default: %(default)s)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-j",
|
||||
"--json",
|
||||
help="write found posts to a json file",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--csv",
|
||||
help="write found posts to a csv file",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--debug",
|
||||
help="(dev) run rpst in debug mode",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument("-v", "--version", action="version", version=__version__)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
|
||||
def set_loglevel(debug_mode: bool) -> logging.getLogger:
|
||||
"""
|
||||
Configure and return a logging object with the specified log level.
|
||||
|
||||
:param debug_mode: If True, the log level is set to "NOTSET". Otherwise, it is set to "INFO".
|
||||
:return: A logging object configured with the specified log level.
|
||||
"""
|
||||
logging.basicConfig(
|
||||
level="DEBUG" if debug_mode else "INFO",
|
||||
format="%(message)s",
|
||||
handlers=[
|
||||
RichHandler(
|
||||
markup=True, log_time_format="%I:%M:%S%p", show_level=debug_mode
|
||||
)
|
||||
],
|
||||
)
|
||||
return logging.getLogger("RPST (Reddit Post Scraping Tool)")
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
args: argparse = create_parser().parse_args()
|
||||
log: logging.getLogger = set_loglevel(debug_mode=args.debug)
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
33
rpst/main.py
33
rpst/main.py
@@ -1,33 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
from .rpst import get_posts
|
||||
from .utils import create_parser, set_loglevel, check_updates
|
||||
|
||||
|
||||
def run():
|
||||
"""
|
||||
Main entry point for the program. It creates a parser, parses the command line arguments,
|
||||
checks for updates, gets posts, and handles any exceptions that occur during the execution.
|
||||
"""
|
||||
|
||||
# Create a parser and parse the command line arguments
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
log = set_loglevel(debug_mode=args.debug)
|
||||
|
||||
# Record the start time
|
||||
start_time = datetime.now()
|
||||
|
||||
try:
|
||||
# Check for updates
|
||||
check_updates(version_tag="1.9.1.1")
|
||||
|
||||
# Get posts with the provided/parsed arguments
|
||||
get_posts(args=args)
|
||||
except KeyboardInterrupt:
|
||||
log.warning("User interruption detected ([yellow]Ctrl+C[/]).")
|
||||
except Exception as e:
|
||||
log.error(f"An error occurred: [red]{e}[/]")
|
||||
finally:
|
||||
log.info(f"Finished in {datetime.now() - start_time} seconds.")
|
||||
131
rpst/rpst.py
131
rpst/rpst.py
@@ -1,131 +0,0 @@
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
|
||||
import requests
|
||||
from glyphoji import glyph
|
||||
from rich import print
|
||||
from rich.tree import Tree
|
||||
|
||||
from .utils import convert_timestamp_to_datetime, write_post_data
|
||||
|
||||
|
||||
def create_post_branch(post: dict, keyword: str, tree: Tree, args: argparse) -> Tree:
|
||||
"""
|
||||
This function extracts relevant data from a Reddit post and adds it in a tree branch structure,
|
||||
followed by the post's selftext.
|
||||
|
||||
:param post: A dictionary containing the data of a Reddit post.
|
||||
:param keyword: The keyword that is used to find posts, in his case gets uses as the filename.
|
||||
:param tree: Tree where the post branch will be added.
|
||||
:param args: A namespace object from argparse.
|
||||
:returns: The main tree with added post branches.
|
||||
"""
|
||||
# Define the data to extract from the post.
|
||||
post_data = {
|
||||
# "Author": post["data"]["author"],
|
||||
f"{glyph.id_button} ID": post["data"]["id"],
|
||||
f"{glyph.people_hugging} Subreddit": post["data"]["subreddit_name_prefixed"],
|
||||
f"{glyph.face_with_peeking_eye} Visibility": post["data"]["subreddit_type"],
|
||||
f"{glyph.framed_picture} Thumbnail": post["data"]["thumbnail"],
|
||||
f"{glyph.white_question_mark} Gilded": post["data"]["gilded"],
|
||||
f"{glyph.up_arrow} Upvotes": post["data"]["ups"],
|
||||
f"{glyph.chart_increasing} Upvote ratio": post["data"]["upvote_ratio"],
|
||||
f"{glyph.down_arrow} Downvotes": post["data"]["downs"],
|
||||
f"{glyph.trophy} Awards": post["data"]["total_awards_received"],
|
||||
f"{glyph.trophy} Top award": post["data"]["top_awarded_type"],
|
||||
f"{glyph.no_one_under_eighteen} Is NSFW?": post["data"]["over_18"],
|
||||
f"{glyph.left_arrow_curving_right} Is crosspostable?": post["data"][
|
||||
"is_crosspostable"
|
||||
],
|
||||
f"{glyph.bar_chart} Score": post["data"]["score"],
|
||||
f"{glyph.card_file_box} Category": post["data"]["category"],
|
||||
f"{glyph.globe_with_meridians} Domain": post["data"]["domain"],
|
||||
f"{glyph.calendar} Posted on": convert_timestamp_to_datetime(
|
||||
post["data"]["created"]
|
||||
),
|
||||
f"{glyph.calendar} Approved at": post["data"]["approved_at_utc"],
|
||||
f"{glyph.bust_in_silhouette} Approved by": post["data"]["approved_by"],
|
||||
}
|
||||
|
||||
# Add the post's branch to the main tree.
|
||||
post_branch = tree.add(f"{glyph.page_with_curl} {post['data']['title']}")
|
||||
|
||||
# Add each piece of extracted data as a branch of the post_branch.
|
||||
for post_key, post_value in post_data.items():
|
||||
post_branch.add(f"{post_key}: {post_value}", style="dim")
|
||||
|
||||
# This ensures that the post's selftext is also added to the written json/csv file.
|
||||
post_data[f"{glyph.clipboard} Text"] = post["data"]["selftext"]
|
||||
write_post_data(
|
||||
filename=keyword, post_data=post_data, tree_branch=post_branch, args=args
|
||||
)
|
||||
post_branch.add(post["data"]["selftext"], style="italic")
|
||||
|
||||
return tree
|
||||
|
||||
|
||||
def get_posts(args: argparse):
|
||||
"""
|
||||
Scrapes a given subreddit for posts that contain a specified keyword.
|
||||
The search is limited by the number of posts and timeframe specified.
|
||||
|
||||
:param args: Namespace object from argparse.
|
||||
|
||||
Expected Object Attributes
|
||||
--------------------------
|
||||
- keyword: The keyword to search for in the posts.
|
||||
- subreddit: The subreddit to scrape.
|
||||
- listing: The type of posts to scrape. This could be 'hot', 'new', etc.
|
||||
- timeframe: The timeframe from which to scrape posts. This could be 'day', 'week', etc.
|
||||
- limit: The maximum number of posts to scrape.
|
||||
- json: If specified, all found posts will be written to a json file.
|
||||
"""
|
||||
keyword = args.keyword
|
||||
subreddit = args.subreddit
|
||||
listing = args.listing
|
||||
timeframe = args.timeframe
|
||||
limit = args.limit
|
||||
|
||||
# Create main result tree.
|
||||
main_tree = Tree(
|
||||
f"[bold]{glyph.calendar} {datetime.now()}[/]", guide_style="bold bright_blue"
|
||||
)
|
||||
|
||||
# Start a new session
|
||||
session = requests.session()
|
||||
# Set the User-Agent to mimic a Safari browser on a Mac.
|
||||
session.headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, "
|
||||
"like Gecko) Version/14.1.1 Safari/605.1.15"
|
||||
}
|
||||
|
||||
# Send a GET request to the specified subreddit and listing,
|
||||
# limiting the response by the specified limit and timeframe.
|
||||
response = session.get(
|
||||
f"https://reddit.com/r/{subreddit}/{listing}"
|
||||
f".json?limit={limit}&t={timeframe}"
|
||||
).json()
|
||||
|
||||
# Initialize a counter for the number of posts found that contain the keyword.
|
||||
found_posts = 0
|
||||
|
||||
# Loop through each post in the response
|
||||
for post_index, post in enumerate(response["data"]["children"], start=1):
|
||||
# If the keyword is found in the post's selftext or title, increment the counter and process the post.
|
||||
if (
|
||||
keyword.lower() in post["data"]["selftext"]
|
||||
or keyword.lower() in post["data"]["title"]
|
||||
):
|
||||
# Create a branch for found post(s) and show post index and post author as the title
|
||||
found_tree = main_tree.add(
|
||||
f"{glyph.bust_in_silhouette} #{post_index} by [bold]@{post['data']['author']}[/]"
|
||||
)
|
||||
found_posts += 1
|
||||
create_post_branch(post=post, keyword=keyword, tree=found_tree, args=args)
|
||||
|
||||
# Log the number of posts in which the keyword was found
|
||||
main_tree.add(
|
||||
f"{glyph.check_mark_button} Keyword ('{keyword}') was found in "
|
||||
f"{found_posts}/{len(response['data']['children'])} {listing} posts from r/{subreddit}."
|
||||
)
|
||||
print(main_tree)
|
||||
94
rpst/scraper.py
Normal file
94
rpst/scraper.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from rich.pretty import pprint
|
||||
|
||||
from . import __version__, PROGRAM_DIRECTORY
|
||||
from .base import find_posts
|
||||
from .coreutils import args, log, save_posts, pathfinder
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
|
||||
def run():
|
||||
"""Main entry point for rpst or rpst."""
|
||||
# ------------------------------------- #
|
||||
|
||||
keyword: str = args.keyword
|
||||
subreddit: str = args.subreddit
|
||||
listing: str = args.listing
|
||||
limit: int = args.limit
|
||||
|
||||
# ------------------------------------- #
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
# ------------------------------------- #
|
||||
|
||||
print(
|
||||
"""
|
||||
┳┓┏┓┏┓┏┳┓
|
||||
┣┫┃┃┗┓ ┃
|
||||
┛┗┣┛┗┛ ┻ """
|
||||
)
|
||||
|
||||
# ------------------------------------- #
|
||||
|
||||
try:
|
||||
log.info(
|
||||
f"[bold]RPST[/] {__version__} started at {start_time.strftime('%a %b %d %Y, %I:%M:%S%p')}..."
|
||||
)
|
||||
|
||||
found_posts = asyncio.run(
|
||||
find_posts(
|
||||
keyword=keyword,
|
||||
subreddit=subreddit,
|
||||
listing=listing,
|
||||
timeframe=args.timeframe,
|
||||
limit=limit,
|
||||
),
|
||||
)
|
||||
|
||||
if found_posts:
|
||||
pprint(
|
||||
found_posts,
|
||||
expand_all=True,
|
||||
)
|
||||
log.info(
|
||||
f"'{subreddit}': Found {len(found_posts)}/{limit} {listing} posts containing the keyword ('{keyword}')"
|
||||
)
|
||||
if args.json or args.csv:
|
||||
target_dir: str = os.path.join(PROGRAM_DIRECTORY, subreddit)
|
||||
pathfinder(
|
||||
directories=[
|
||||
os.path.join(target_dir, "csv"),
|
||||
os.path.join(target_dir, "json"),
|
||||
]
|
||||
)
|
||||
save_posts(
|
||||
filename=keyword,
|
||||
save_to_dir=target_dir,
|
||||
posts=found_posts,
|
||||
save_json=args.json,
|
||||
save_csv=args.csv,
|
||||
)
|
||||
else:
|
||||
log.info(
|
||||
f"'r/{subreddit}': No {listing} posts found that contain the keyword ('{keyword}')"
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
log.warning("User interruption detected ([yellow]Ctrl+C[/])")
|
||||
except Exception as error:
|
||||
log.error(f"An error occurred: [red]{error}[/]")
|
||||
finally:
|
||||
log.info(f"Finished in {datetime.now() - start_time} seconds")
|
||||
|
||||
# ------------------------------------- #
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
182
rpst/utils.py
182
rpst/utils.py
@@ -1,182 +0,0 @@
|
||||
import os
|
||||
import csv
|
||||
import json
|
||||
import logging
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
|
||||
import requests
|
||||
from glyphoji import glyph
|
||||
from rich import print
|
||||
from rich.tree import Tree
|
||||
|
||||
from rich.markdown import Markdown
|
||||
from rich.logging import RichHandler
|
||||
|
||||
|
||||
def convert_timestamp_to_datetime(timestamp: int) -> str:
|
||||
"""
|
||||
Converts a Unix timestamp to a formatted datetime string.
|
||||
|
||||
:param timestamp: The Unix timestamp to be converted.
|
||||
:return: A formatted datetime string in the format "dd MMMM yyyy, hh:mm:ssAM/PM".
|
||||
"""
|
||||
utc_from_timestamp = datetime.utcfromtimestamp(timestamp)
|
||||
datetime_object = utc_from_timestamp.strftime("%d %B %Y, %I:%M:%S%p")
|
||||
return datetime_object
|
||||
|
||||
|
||||
def create_parser():
|
||||
"""
|
||||
Creates and configures an argument parser for the command line arguments.
|
||||
|
||||
:return: A configured argparse.ArgumentParser object ready to parse the command line arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="RPST (Reddit Post Scraping Tool) —by Richard Mwewa | https://about.me/rly0nheart",
|
||||
epilog="Retrieve Reddit posts that contain the specified keyword from a specified subreddit."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-k", "--keyword", help="The keyword to search for in the posts.", required=True
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s", "--subreddit", help="The subreddit to scrape.", required=True
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--limit",
|
||||
help="The maximum number of posts to scrape (1-100). (default: %(default)s)",
|
||||
default=10,
|
||||
type=int,
|
||||
choices=range(
|
||||
1, 101
|
||||
), # This enforces that the limit must be between 1 and 100 inclusive.
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--listing",
|
||||
default="top",
|
||||
const="top",
|
||||
nargs="?",
|
||||
choices=["controversial", "hot", "best", "new", "rising"],
|
||||
help="The type of posts to scrape (default: %(default)s)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--timeframe",
|
||||
default="all",
|
||||
const="all",
|
||||
nargs="?",
|
||||
choices=["hour", "day", "week", "month", "year", "all"],
|
||||
help="The timeframe from which to scrape posts (default: %(default)s)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--json",
|
||||
help="Write all found posts to a json file.",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--csv",
|
||||
help="Write all found posts to a csv file.",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--debug",
|
||||
help="run rpst in debug mode",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def check_updates(version_tag: str):
|
||||
"""
|
||||
This function checks if there's a new release of a project on GitHub. If there is, it logs an
|
||||
information message and prints the release notes.
|
||||
|
||||
:param version_tag: A string representing the current version of the project.
|
||||
"""
|
||||
|
||||
# Make a GET request to the GitHub API to get the latest release of the project.
|
||||
response = requests.get(
|
||||
"https://api.github.com/repos/bellingcat/reddit-post-scraping-tool/releases/latest"
|
||||
).json()
|
||||
|
||||
# Check if the latest release's tag matches the current version tag.
|
||||
if response["tag_name"] != version_tag:
|
||||
# If not, convert the release notes from Markdown to HTML.
|
||||
raw_release_notes = response["body"]
|
||||
|
||||
# Log an info message about the new release.
|
||||
print(
|
||||
f"{glyph.up_arrow} A new release of RPST is available ({response['tag_name']}). "
|
||||
f"Run 'pip install --upgrade reddit-post-scraping-tool' to get the updates."
|
||||
)
|
||||
|
||||
# Print the release notes.
|
||||
print(Markdown(raw_release_notes))
|
||||
|
||||
|
||||
def set_loglevel(debug_mode: bool) -> logging.getLogger:
|
||||
"""
|
||||
Configure and return a logging object with the specified log level.
|
||||
|
||||
:param debug_mode: If True, the log level is set to "NOTSET". Otherwise, it is set to "INFO".
|
||||
:return: A logging object configured with the specified log level.
|
||||
"""
|
||||
logging.basicConfig(
|
||||
level="NOTSET" if debug_mode else "INFO",
|
||||
format="%(message)s",
|
||||
handlers=[
|
||||
RichHandler(markup=True, log_time_format="[%I:%M:%S %p]", show_level=False)
|
||||
],
|
||||
)
|
||||
return logging.getLogger("RPST")
|
||||
|
||||
|
||||
def write_post_data(post_data: dict, filename: str, args, tree_branch: Tree):
|
||||
"""
|
||||
Writes post data to a specified JSON or CSV file based on the args provided, and updates
|
||||
the provided tree with the status.
|
||||
|
||||
:param post_data: A dictionary containing post data.
|
||||
:param filename: The name of the file to which post data will be written.
|
||||
:param args: A namespace object from argparse containing the output format options (args.json or args.csv).
|
||||
:param tree_branch: A rich Tree object to which status information will be added.
|
||||
"""
|
||||
home_directory = os.path.expanduser("~")
|
||||
|
||||
if args.json:
|
||||
json_file_path = os.path.join(home_directory, f"{filename}.json")
|
||||
with open(json_file_path, "a", encoding="utf-8") as file:
|
||||
file.write(json.dumps(post_data, ensure_ascii=False))
|
||||
file.write("\n") # Separate posts with newline
|
||||
tree_branch.add(
|
||||
f"{glyph.page_facing_up} JSON data successfully written/appended to file: "
|
||||
f"[italic][link file://{json_file_path}]{json_file_path}[/]"
|
||||
)
|
||||
else:
|
||||
tree_branch.add(
|
||||
f"{glyph.cross_mark_button} JSON data writing operation was skipped. No changes made."
|
||||
)
|
||||
|
||||
if args.csv:
|
||||
csv_file_path = os.path.join(home_directory, f"{filename}.csv")
|
||||
with open(csv_file_path, "a", newline="", encoding="utf-8") as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=post_data.keys())
|
||||
|
||||
# Write headers if file is empty
|
||||
if csvfile.tell() == 0:
|
||||
writer.writeheader()
|
||||
|
||||
writer.writerow(post_data)
|
||||
tree_branch.add(
|
||||
f"{glyph.page_facing_up} CSV data successfully written/appended to file: "
|
||||
f"[italic][link file://{csv_file_path}]{csv_file_path}[/]"
|
||||
)
|
||||
else:
|
||||
tree_branch.add(
|
||||
f"{glyph.cross_mark_button} CSV data writing operation was skipped. No changes made."
|
||||
)
|
||||
Reference in New Issue
Block a user