added type hints for auth, incorporated auth into base module

This commit is contained in:
Tristan Lee
2023-09-04 10:40:30 -05:00
parent cf575e6cf6
commit 0f8e865bf3
3 changed files with 38 additions and 16 deletions

View File

@@ -6,6 +6,8 @@ from .base import TikTokDownloader, load_hashtags_from_file
def create_parser():
"""Create parser tp parse input command-line arguments."""
parser = argparse.ArgumentParser(
description="Analyze hashtags within posts scraped from TikTok."
)
@@ -51,12 +53,20 @@ def create_parser():
help="Directory to save scraped data and visualizations to",
default=Path(".").resolve().parent / "data",
)
parser.add_argument(
"--config",
type=str,
help="File name of configuration file to store TikTok credentials to",
default=None,
)
parser.add_argument("--log", type=str, help="File to write logs to", default=None)
return parser
def main():
"""Parse and process command-line arguments, scrape specified hashtags, and perform specified analyses."""
parser = create_parser()
args = parser.parse_args()
@@ -79,7 +89,9 @@ def main():
else:
hashtags = args.hashtags
downloader = TikTokDownloader(hashtags=hashtags, data_dir=args.output_dir)
downloader = TikTokDownloader(
hashtags=hashtags, data_dir=args.output_dir, config_file=args.config
)
downloader.run(
download=args.download, plot=args.plot, table=args.table, number=args.number

View File

@@ -2,17 +2,22 @@ import os
import configparser
from pathlib import Path
import logging
from typing import Optional
class Authorization:
"""Handle authorization for TikTok, using the `msToken`."""
def __init__(self):
self.config_file = Path.home() / ".tiktok"
self.section = "TikTok"
self.ms_token = None
def __init__(self, config_file: Optional[str] = None):
if config_file:
self.config_file = Path(config_file)
else:
self.config_file = Path.home() / ".tiktok"
def get_token(self):
self.section = "TikTok"
self.get_token()
def get_token(self) -> str:
"""Load the "msToken" cookie taken from TikTok, which the scraper requires."""
# Step 1: check if MS_TOKEN is defined as environment variable
@@ -37,14 +42,14 @@ class Authorization:
return self.ms_token
def load_token(self):
def load_token(self) -> Optional[str]:
"""Parse a config file and extract the token."""
config = configparser.ConfigParser()
config.read(self.config_file)
return config.get(section=self.section, option="MS_TOKEN", fallback=None)
def dump_token(self, ms_token):
def dump_token(self, ms_token: str):
"""Write the token to a config file."""
config = configparser.ConfigParser()
@@ -52,10 +57,10 @@ class Authorization:
config.add_section(self.section)
config.set(section=self.section, option="MS_TOKEN", value=ms_token)
with open(self.config_file, "w") as f:
with open(self.config_file, "w", encoding="utf-8") as f:
config.write(f)
def input_token(self):
def input_token(self) -> str:
"""Allow user to manually enter the token in the terminal."""
print(

View File

@@ -17,6 +17,8 @@ import seaborn as sns
from TikTokApi import TikTokApi
from .auth import Authorization
warnings.filterwarnings("ignore", message="Glyph (.*) missing from current font")
sns.set_theme(style="darkgrid")
@@ -38,13 +40,11 @@ def load_hashtags_from_file(file: str) -> List[str]:
return process_hashtag_list(hashtags=hashtags)
async def _fetch_hashtag_data(hashtag: str) -> List[Dict]:
async def _fetch_hashtag_data(hashtag: str, ms_token: str) -> List[Dict]:
"""Fetch data for videos containing a specified hashtag, asynchronously."""
data = []
async with TikTokApi() as api:
await api.create_sessions(
ms_tokens=[os.environ["MS_TOKEN"]], num_sessions=1, sleep_after=3
)
await api.create_sessions(ms_tokens=[ms_token], num_sessions=1, sleep_after=3)
async for video in api.hashtag(name=hashtag).videos(count=1000):
data.append(video.as_dict)
return data
@@ -101,13 +101,16 @@ def aggregate_cooccurring_hashtags(hashtag_file: Path) -> Counter:
class TikTokDownloader:
"""Main class for scraping data from TikTok."""
def __init__(self, hashtags: List[str], data_dir: str):
def __init__(self, hashtags: List[str], data_dir: str, config_file: str = None):
self.hashtags = process_hashtag_list(hashtags)
logging.info(f"Hashtags to scrape: {hashtags}")
self.data_dir = Path(data_dir)
os.makedirs(self.data_dir, exist_ok=True)
self.auth = Authorization(config_file=config_file)
self.ms_token = self.auth.ms_token
def get_hashtag_posts(self, hashtag: str):
"""Fetch data about posts that used a specified hashtag and merge with
existing data, if it exists."""
@@ -125,7 +128,9 @@ class TikTokDownloader:
already_fetched_data = []
# Scrape posts that use the specified hashtag
fetched_data = asyncio.run(_fetch_hashtag_data(hashtag=hashtag))
fetched_data = asyncio.run(
_fetch_hashtag_data(hashtag=hashtag, ms_token=self.ms_token)
)
if len(fetched_data) == 0:
logging.warning(f"No posts were found for the hashtag: {hashtag}")