mirror of
https://github.com/bellingcat/reddit-post-scraping-tool.git
synced 2026-06-07 19:18:29 +03:00
Add files via upload
This commit is contained in:
@@ -1,42 +1,42 @@
|
||||
import os
|
||||
|
||||
__author__: str = "Richard Mwewa"
|
||||
__about_author__: str = "https://about/me/rly0nheart"
|
||||
__version__: str = "2.0.0.0"
|
||||
|
||||
__description__: str = f"""
|
||||
# RPST (Reddit Post Scraping Tool) {__version__}
|
||||
> Retrieve Reddit posts that contain the specified keyword from a specified subreddit.
|
||||
"""
|
||||
__epilog__: str = f"""
|
||||
# by [{__author__}]({__about_author__})
|
||||
|
||||
```
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 {__author__}
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
```
|
||||
"""
|
||||
|
||||
# Construct path to the program's directory
|
||||
PROGRAM_DIRECTORY: str = os.path.expanduser(
|
||||
os.path.join("~", "reddit_post_scraping_tool")
|
||||
)
|
||||
import os
|
||||
|
||||
__author__: str = "Richard Mwewa"
|
||||
__about_author__: str = "https://about/me/rly0nheart"
|
||||
__version__: str = "2.0.0.0"
|
||||
|
||||
__description__: str = f"""
|
||||
# RPST (Reddit Post Scraping Tool) {__version__}
|
||||
> Retrieve Reddit posts that contain the specified keyword from a specified subreddit.
|
||||
"""
|
||||
__epilog__: str = f"""
|
||||
# by [{__author__}]({__about_author__})
|
||||
|
||||
```
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 {__author__}
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
```
|
||||
"""
|
||||
|
||||
# Construct path to the program's directory
|
||||
PROGRAM_DIRECTORY: str = os.path.expanduser(
|
||||
os.path.join("~", "reddit_post_scraping_tool")
|
||||
)
|
||||
|
||||
328
rpst/api.py
328
rpst/api.py
@@ -1,164 +1,164 @@
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
from typing import Union
|
||||
|
||||
import aiohttp
|
||||
|
||||
from .coreutils import log
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
|
||||
REDDIT_ENDPOINT: str = "https://www.reddit.com"
|
||||
PYPI_PROJECT_ENDPOINT: str = "https://pypi.org/pypi/reddit-post-scraping-tool/json"
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
|
||||
async def get_data(session: aiohttp.ClientSession, endpoint: str) -> Union[dict, list]:
|
||||
"""
|
||||
Fetches JSON data from a given API endpoint.
|
||||
|
||||
:param session: aiohttp session to use for the request.
|
||||
:param endpoint: The API endpoint to fetch data from.
|
||||
:return: Returns JSON data as a dictionary or list. Returns an empty dict if fetching fails.
|
||||
"""
|
||||
|
||||
try:
|
||||
async with session.get(
|
||||
endpoint,
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
error_message = await response.json()
|
||||
log.error(f"An API error occurred: {error_message}")
|
||||
return {}
|
||||
|
||||
except aiohttp.ClientConnectionError as error:
|
||||
log.error(f"An HTTP error occurred: {error}")
|
||||
return {}
|
||||
except Exception as error:
|
||||
log.critical(f"An unknown error occurred: {error}")
|
||||
return {}
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
|
||||
async def get_updates(session: aiohttp.ClientSession):
|
||||
"""
|
||||
Gets and compares the current program version with the remote version.
|
||||
|
||||
Assumes version format: major.minor.patch.prefix
|
||||
|
||||
:param session: aiohttp session to use for the request.
|
||||
"""
|
||||
from . import __version__
|
||||
|
||||
# Make a GET request to PyPI to get the project's latest release.
|
||||
response: dict = await get_data(endpoint=PYPI_PROJECT_ENDPOINT, session=session)
|
||||
|
||||
if response.get("info"):
|
||||
release: dict = response.get("info")
|
||||
remote_version: str = release.get("version")
|
||||
# Splitting the version strings into components
|
||||
remote_parts: list = remote_version.split(".")
|
||||
local_parts: list = __version__.split(".")
|
||||
|
||||
update_message: str = ""
|
||||
|
||||
# Check for differences in version parts
|
||||
if remote_parts[0] != local_parts[0]:
|
||||
update_message = (
|
||||
f"MAJOR update ({remote_version}) available."
|
||||
f" It might introduce significant changes."
|
||||
)
|
||||
|
||||
elif remote_parts[1] != local_parts[1]:
|
||||
update_message = (
|
||||
f"MINOR update ({remote_version}) available."
|
||||
f" Includes small feature changes/improvements."
|
||||
)
|
||||
|
||||
elif remote_parts[2] != local_parts[2]:
|
||||
update_message = (
|
||||
f"PATCH update ({remote_version}) available."
|
||||
f" Generally for bug fixes and small tweaks."
|
||||
)
|
||||
|
||||
elif (
|
||||
len(remote_parts) > 3
|
||||
and len(local_parts) > 3
|
||||
and remote_parts[3] != local_parts[3]
|
||||
):
|
||||
update_message = (
|
||||
f"BUILD update ({remote_version}) available."
|
||||
f" Might be for specific builds or special versions."
|
||||
)
|
||||
|
||||
if update_message:
|
||||
log.info(update_message)
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
async def get_posts(
|
||||
subreddit: str,
|
||||
listing: str,
|
||||
timeframe: str,
|
||||
limit: int,
|
||||
session: aiohttp.ClientSession,
|
||||
) -> list:
|
||||
all_posts = await paginated_posts(
|
||||
posts_endpoint=f"{REDDIT_ENDPOINT}/r/{subreddit}/{listing}.json?limit={limit}&t={timeframe}",
|
||||
limit=limit,
|
||||
session=session,
|
||||
)
|
||||
|
||||
return all_posts[:limit]
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
|
||||
async def paginated_posts(
|
||||
posts_endpoint: str, limit: int, session: aiohttp.ClientSession
|
||||
) -> list:
|
||||
"""
|
||||
Paginates and retrieves posts until the specified limit is reached.
|
||||
|
||||
:param posts_endpoint: API endpoint for retrieving posts.
|
||||
:param limit: Limit of the number of posts to retrieve.
|
||||
:param session: aiohttp session to use for the request.
|
||||
:return: A list of all posts.
|
||||
"""
|
||||
all_posts: list = []
|
||||
last_post_id: str = ""
|
||||
|
||||
# Determine whether to use the 'after' parameter
|
||||
use_after: bool = limit > 100
|
||||
|
||||
while len(all_posts) < limit:
|
||||
# Make the API request with the 'after' parameter if it's provided and the limit is more than 100
|
||||
if use_after and last_post_id:
|
||||
endpoint_with_after: str = f"{posts_endpoint}&after={last_post_id}"
|
||||
else:
|
||||
endpoint_with_after: str = posts_endpoint
|
||||
|
||||
posts_data: dict = await get_data(endpoint=endpoint_with_after, session=session)
|
||||
posts_children: list = posts_data.get("data", {}).get("children", [])
|
||||
|
||||
# If there are no more posts, break out of the loop
|
||||
if not posts_children:
|
||||
break
|
||||
|
||||
all_posts.extend(posts_children)
|
||||
|
||||
# We use the id of the last post in the list to paginate to the next posts
|
||||
last_post_id: str = all_posts[-1].get("data").get("id")
|
||||
|
||||
return all_posts
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
from typing import Union
|
||||
|
||||
import aiohttp
|
||||
|
||||
from .coreutils import log
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
|
||||
REDDIT_ENDPOINT: str = "https://www.reddit.com"
|
||||
PYPI_PROJECT_ENDPOINT: str = "https://pypi.org/pypi/reddit-post-scraping-tool/json"
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
|
||||
async def get_data(session: aiohttp.ClientSession, endpoint: str) -> Union[dict, list]:
|
||||
"""
|
||||
Fetches JSON data from a given API endpoint.
|
||||
|
||||
:param session: aiohttp session to use for the request.
|
||||
:param endpoint: The API endpoint to fetch data from.
|
||||
:return: Returns JSON data as a dictionary or list. Returns an empty dict if fetching fails.
|
||||
"""
|
||||
|
||||
try:
|
||||
async with session.get(
|
||||
endpoint,
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
error_message = await response.json()
|
||||
log.error(f"An API error occurred: {error_message}")
|
||||
return {}
|
||||
|
||||
except aiohttp.ClientConnectionError as error:
|
||||
log.error(f"An HTTP error occurred: {error}")
|
||||
return {}
|
||||
except Exception as error:
|
||||
log.critical(f"An unknown error occurred: {error}")
|
||||
return {}
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
|
||||
async def get_updates(session: aiohttp.ClientSession):
|
||||
"""
|
||||
Gets and compares the current program version with the remote version.
|
||||
|
||||
Assumes version format: major.minor.patch.prefix
|
||||
|
||||
:param session: aiohttp session to use for the request.
|
||||
"""
|
||||
from . import __version__
|
||||
|
||||
# Make a GET request to PyPI to get the project's latest release.
|
||||
response: dict = await get_data(endpoint=PYPI_PROJECT_ENDPOINT, session=session)
|
||||
|
||||
if response.get("info"):
|
||||
release: dict = response.get("info")
|
||||
remote_version: str = release.get("version")
|
||||
# Splitting the version strings into components
|
||||
remote_parts: list = remote_version.split(".")
|
||||
local_parts: list = __version__.split(".")
|
||||
|
||||
update_message: str = ""
|
||||
|
||||
# Check for differences in version parts
|
||||
if remote_parts[0] != local_parts[0]:
|
||||
update_message = (
|
||||
f"MAJOR update ({remote_version}) available."
|
||||
f" It might introduce significant changes."
|
||||
)
|
||||
|
||||
elif remote_parts[1] != local_parts[1]:
|
||||
update_message = (
|
||||
f"MINOR update ({remote_version}) available."
|
||||
f" Includes small feature changes/improvements."
|
||||
)
|
||||
|
||||
elif remote_parts[2] != local_parts[2]:
|
||||
update_message = (
|
||||
f"PATCH update ({remote_version}) available."
|
||||
f" Generally for bug fixes and small tweaks."
|
||||
)
|
||||
|
||||
elif (
|
||||
len(remote_parts) > 3
|
||||
and len(local_parts) > 3
|
||||
and remote_parts[3] != local_parts[3]
|
||||
):
|
||||
update_message = (
|
||||
f"BUILD update ({remote_version}) available."
|
||||
f" Might be for specific builds or special versions."
|
||||
)
|
||||
|
||||
if update_message:
|
||||
log.info(update_message)
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
async def get_posts(
|
||||
subreddit: str,
|
||||
listing: str,
|
||||
timeframe: str,
|
||||
limit: int,
|
||||
session: aiohttp.ClientSession,
|
||||
) -> list:
|
||||
all_posts = await paginated_posts(
|
||||
posts_endpoint=f"{REDDIT_ENDPOINT}/r/{subreddit}/{listing}.json?limit={limit}&t={timeframe}",
|
||||
limit=limit,
|
||||
session=session,
|
||||
)
|
||||
|
||||
return all_posts[:limit]
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
|
||||
async def paginated_posts(
|
||||
posts_endpoint: str, limit: int, session: aiohttp.ClientSession
|
||||
) -> list:
|
||||
"""
|
||||
Paginates and retrieves posts until the specified limit is reached.
|
||||
|
||||
:param posts_endpoint: API endpoint for retrieving posts.
|
||||
:param limit: Limit of the number of posts to retrieve.
|
||||
:param session: aiohttp session to use for the request.
|
||||
:return: A list of all posts.
|
||||
"""
|
||||
all_posts: list = []
|
||||
last_post_id: str = ""
|
||||
|
||||
# Determine whether to use the 'after' parameter
|
||||
use_after: bool = limit > 100
|
||||
|
||||
while len(all_posts) < limit:
|
||||
# Make the API request with the 'after' parameter if it's provided and the limit is more than 100
|
||||
if use_after and last_post_id:
|
||||
endpoint_with_after: str = f"{posts_endpoint}&after={last_post_id}"
|
||||
else:
|
||||
endpoint_with_after: str = posts_endpoint
|
||||
|
||||
posts_data: dict = await get_data(endpoint=endpoint_with_after, session=session)
|
||||
posts_children: list = posts_data.get("data", {}).get("children", [])
|
||||
|
||||
# If there are no more posts, break out of the loop
|
||||
if not posts_children:
|
||||
break
|
||||
|
||||
all_posts.extend(posts_children)
|
||||
|
||||
# We use the id of the last post in the list to paginate to the next posts
|
||||
last_post_id: str = all_posts[-1].get("data").get("id")
|
||||
|
||||
return all_posts
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
@@ -1,170 +1,170 @@
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from rich.logging import RichHandler
|
||||
from rich.markdown import Markdown
|
||||
from rich_argparse import RichHelpFormatter
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
|
||||
def timestamp_to_utc(timestamp: int) -> str:
|
||||
"""
|
||||
Converts a Unix timestamp to a formatted datetime string.
|
||||
|
||||
:param timestamp: The Unix timestamp to be converted.
|
||||
:return: A formatted datetime string in the format "dd MMMM yyyy, hh:mm:ssAM/PM".
|
||||
"""
|
||||
utc_from_timestamp: datetime = datetime.utcfromtimestamp(timestamp)
|
||||
datetime_string: str = utc_from_timestamp.strftime("%d %B %Y, %I:%M:%S%p")
|
||||
return datetime_string
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
|
||||
def pathfinder(directories: list[str]):
|
||||
for directory in directories:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
|
||||
def save_posts(
|
||||
filename: str,
|
||||
save_to_dir: str,
|
||||
posts: list,
|
||||
save_json: bool = False,
|
||||
save_csv: bool = False,
|
||||
):
|
||||
posts_data: list = [post.__dict__ for post in posts]
|
||||
|
||||
if save_json:
|
||||
json_path = os.path.join(os.path.join(save_to_dir, "json"), f"{filename}.json")
|
||||
with open(json_path, "w", encoding="utf-8") as json_file:
|
||||
json.dump(posts_data, json_file, indent=4)
|
||||
log.info(
|
||||
f"{os.path.getsize(json_file.name)} bytes written to [link file://{json_file.name}]{json_file.name}"
|
||||
)
|
||||
|
||||
if save_csv:
|
||||
csv_path = os.path.join(os.path.join(save_to_dir, "csv"), f"{filename}.csv")
|
||||
with open(csv_path, "w", newline="", encoding="utf-8") as csv_file:
|
||||
writer = csv.writer(csv_file)
|
||||
if posts:
|
||||
writer.writerow(
|
||||
posts_data[0].keys()
|
||||
) # header from keys of the first item
|
||||
for post in posts:
|
||||
writer.writerow(post.__dict__.values())
|
||||
log.info(
|
||||
f"{os.path.getsize(csv_file.name)} bytes written to [link file://{csv_file.name}]{csv_file.name}"
|
||||
)
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
|
||||
def create_parser():
|
||||
"""
|
||||
Creates and configures an argument parser for the command line arguments.
|
||||
|
||||
:return: A configured argparse.ArgumentParser object ready to parse the command line arguments.
|
||||
"""
|
||||
from . import __version__, __description__, __epilog__
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=Markdown(__description__, style="argparse.text"),
|
||||
epilog=Markdown(__epilog__, style="argparse.text"),
|
||||
formatter_class=RichHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"keyword",
|
||||
help="keyword to search for, in posts",
|
||||
)
|
||||
parser.add_argument("subreddit", help="subreddit to scrape")
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--limit",
|
||||
help="maximum number of posts to scrape (default: %(default)s)",
|
||||
default=200,
|
||||
type=int,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-ls",
|
||||
"--listing",
|
||||
default="top",
|
||||
const="top",
|
||||
nargs="?",
|
||||
choices=["best", "controversial", "hot", "new", "rising", "top"],
|
||||
help="listing of posts to scrape (default: %(default)s)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--timeframe",
|
||||
default="all",
|
||||
const="all",
|
||||
nargs="?",
|
||||
choices=["hour", "day", "week", "month", "year", "all"],
|
||||
help="timeframe from which to scrape posts (default: %(default)s)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-j",
|
||||
"--json",
|
||||
help="write found posts to a json file",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--csv",
|
||||
help="write found posts to a csv file",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--debug",
|
||||
help="(dev) run rpst in debug mode",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument("-v", "--version", action="version", version=__version__)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
|
||||
def set_loglevel(debug_mode: bool) -> logging.getLogger:
|
||||
"""
|
||||
Configure and return a logging object with the specified log level.
|
||||
|
||||
:param debug_mode: If True, the log level is set to "NOTSET". Otherwise, it is set to "INFO".
|
||||
:return: A logging object configured with the specified log level.
|
||||
"""
|
||||
logging.basicConfig(
|
||||
level="DEBUG" if debug_mode else "INFO",
|
||||
format="%(message)s",
|
||||
handlers=[
|
||||
RichHandler(
|
||||
markup=True, log_time_format="%I:%M:%S%p", show_level=debug_mode
|
||||
)
|
||||
],
|
||||
)
|
||||
return logging.getLogger("RPST (Reddit Post Scraping Tool)")
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
args: argparse = create_parser().parse_args()
|
||||
log: logging.getLogger = set_loglevel(debug_mode=args.debug)
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from rich.logging import RichHandler
|
||||
from rich.markdown import Markdown
|
||||
from rich_argparse import RichHelpFormatter
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
|
||||
def timestamp_to_utc(timestamp: int) -> str:
|
||||
"""
|
||||
Converts a Unix timestamp to a formatted datetime string.
|
||||
|
||||
:param timestamp: The Unix timestamp to be converted.
|
||||
:return: A formatted datetime string in the format "dd MMMM yyyy, hh:mm:ssAM/PM".
|
||||
"""
|
||||
utc_from_timestamp: datetime = datetime.utcfromtimestamp(timestamp)
|
||||
datetime_string: str = utc_from_timestamp.strftime("%d %B %Y, %I:%M:%S%p")
|
||||
return datetime_string
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
|
||||
def pathfinder(directories: list[str]):
|
||||
for directory in directories:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
|
||||
def save_posts(
|
||||
filename: str,
|
||||
save_to_dir: str,
|
||||
posts: list,
|
||||
save_json: bool = False,
|
||||
save_csv: bool = False,
|
||||
):
|
||||
posts_data: list = [post.__dict__ for post in posts]
|
||||
|
||||
if save_json:
|
||||
json_path = os.path.join(os.path.join(save_to_dir, "json"), f"{filename}.json")
|
||||
with open(json_path, "w", encoding="utf-8") as json_file:
|
||||
json.dump(posts_data, json_file, indent=4)
|
||||
log.info(
|
||||
f"{os.path.getsize(json_file.name)} bytes written to [link file://{json_file.name}]{json_file.name}"
|
||||
)
|
||||
|
||||
if save_csv:
|
||||
csv_path = os.path.join(os.path.join(save_to_dir, "csv"), f"{filename}.csv")
|
||||
with open(csv_path, "w", newline="", encoding="utf-8") as csv_file:
|
||||
writer = csv.writer(csv_file)
|
||||
if posts:
|
||||
writer.writerow(
|
||||
posts_data[0].keys()
|
||||
) # header from keys of the first item
|
||||
for post in posts:
|
||||
writer.writerow(post.__dict__.values())
|
||||
log.info(
|
||||
f"{os.path.getsize(csv_file.name)} bytes written to [link file://{csv_file.name}]{csv_file.name}"
|
||||
)
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
|
||||
def create_parser():
|
||||
"""
|
||||
Creates and configures an argument parser for the command line arguments.
|
||||
|
||||
:return: A configured argparse.ArgumentParser object ready to parse the command line arguments.
|
||||
"""
|
||||
from . import __version__, __description__, __epilog__
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=Markdown(__description__, style="argparse.text"),
|
||||
epilog=Markdown(__epilog__, style="argparse.text"),
|
||||
formatter_class=RichHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"keyword",
|
||||
help="keyword to search for, in posts",
|
||||
)
|
||||
parser.add_argument("subreddit", help="subreddit to scrape")
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--limit",
|
||||
help="maximum number of posts to scrape (default: %(default)s)",
|
||||
default=200,
|
||||
type=int,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-ls",
|
||||
"--listing",
|
||||
default="top",
|
||||
const="top",
|
||||
nargs="?",
|
||||
choices=["best", "controversial", "hot", "new", "rising", "top"],
|
||||
help="listing of posts to scrape (default: %(default)s)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--timeframe",
|
||||
default="all",
|
||||
const="all",
|
||||
nargs="?",
|
||||
choices=["hour", "day", "week", "month", "year", "all"],
|
||||
help="timeframe from which to scrape posts (default: %(default)s)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-j",
|
||||
"--json",
|
||||
help="write found posts to a json file",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--csv",
|
||||
help="write found posts to a csv file",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--debug",
|
||||
help="(dev) run rpst in debug mode",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument("-v", "--version", action="version", version=__version__)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
|
||||
def set_loglevel(debug_mode: bool) -> logging.getLogger:
|
||||
"""
|
||||
Configure and return a logging object with the specified log level.
|
||||
|
||||
:param debug_mode: If True, the log level is set to "NOTSET". Otherwise, it is set to "INFO".
|
||||
:return: A logging object configured with the specified log level.
|
||||
"""
|
||||
logging.basicConfig(
|
||||
level="DEBUG" if debug_mode else "INFO",
|
||||
format="%(message)s",
|
||||
handlers=[
|
||||
RichHandler(
|
||||
markup=True, log_time_format="%I:%M:%S%p", show_level=debug_mode
|
||||
)
|
||||
],
|
||||
)
|
||||
return logging.getLogger("RPST (Reddit Post Scraping Tool)")
|
||||
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
args: argparse = create_parser().parse_args()
|
||||
log: logging.getLogger = set_loglevel(debug_mode=args.debug)
|
||||
|
||||
# +++++++++++++++++++++++++++++++++++++++++++++++++ #
|
||||
|
||||
Reference in New Issue
Block a user