diff --git a/tiktok_hashtag_analysis/__main__.py b/tiktok_hashtag_analysis/__main__.py index 8a9e5ee..3c3bbfd 100644 --- a/tiktok_hashtag_analysis/__main__.py +++ b/tiktok_hashtag_analysis/__main__.py @@ -6,6 +6,8 @@ from .base import TikTokDownloader, load_hashtags_from_file def create_parser(): + """Create parser tp parse input command-line arguments.""" + parser = argparse.ArgumentParser( description="Analyze hashtags within posts scraped from TikTok." ) @@ -51,12 +53,20 @@ def create_parser(): help="Directory to save scraped data and visualizations to", default=Path(".").resolve().parent / "data", ) + parser.add_argument( + "--config", + type=str, + help="File name of configuration file to store TikTok credentials to", + default=None, + ) parser.add_argument("--log", type=str, help="File to write logs to", default=None) return parser def main(): + """Parse and process command-line arguments, scrape specified hashtags, and perform specified analyses.""" + parser = create_parser() args = parser.parse_args() @@ -79,7 +89,9 @@ def main(): else: hashtags = args.hashtags - downloader = TikTokDownloader(hashtags=hashtags, data_dir=args.output_dir) + downloader = TikTokDownloader( + hashtags=hashtags, data_dir=args.output_dir, config_file=args.config + ) downloader.run( download=args.download, plot=args.plot, table=args.table, number=args.number diff --git a/tiktok_hashtag_analysis/auth.py b/tiktok_hashtag_analysis/auth.py index 17b8f3c..5d5ac16 100644 --- a/tiktok_hashtag_analysis/auth.py +++ b/tiktok_hashtag_analysis/auth.py @@ -2,17 +2,22 @@ import os import configparser from pathlib import Path import logging +from typing import Optional class Authorization: """Handle authorization for TikTok, using the `msToken`.""" - def __init__(self): - self.config_file = Path.home() / ".tiktok" - self.section = "TikTok" - self.ms_token = None + def __init__(self, config_file: Optional[str] = None): + if config_file: + self.config_file = Path(config_file) + else: + self.config_file = Path.home() / ".tiktok" - def get_token(self): + self.section = "TikTok" + self.get_token() + + def get_token(self) -> str: """Load the "msToken" cookie taken from TikTok, which the scraper requires.""" # Step 1: check if MS_TOKEN is defined as environment variable @@ -37,14 +42,14 @@ class Authorization: return self.ms_token - def load_token(self): + def load_token(self) -> Optional[str]: """Parse a config file and extract the token.""" config = configparser.ConfigParser() config.read(self.config_file) return config.get(section=self.section, option="MS_TOKEN", fallback=None) - def dump_token(self, ms_token): + def dump_token(self, ms_token: str): """Write the token to a config file.""" config = configparser.ConfigParser() @@ -52,10 +57,10 @@ class Authorization: config.add_section(self.section) config.set(section=self.section, option="MS_TOKEN", value=ms_token) - with open(self.config_file, "w") as f: + with open(self.config_file, "w", encoding="utf-8") as f: config.write(f) - def input_token(self): + def input_token(self) -> str: """Allow user to manually enter the token in the terminal.""" print( diff --git a/tiktok_hashtag_analysis/base.py b/tiktok_hashtag_analysis/base.py index 74df81b..77623a3 100644 --- a/tiktok_hashtag_analysis/base.py +++ b/tiktok_hashtag_analysis/base.py @@ -17,6 +17,8 @@ import seaborn as sns from TikTokApi import TikTokApi +from .auth import Authorization + warnings.filterwarnings("ignore", message="Glyph (.*) missing from current font") sns.set_theme(style="darkgrid") @@ -38,13 +40,11 @@ def load_hashtags_from_file(file: str) -> List[str]: return process_hashtag_list(hashtags=hashtags) -async def _fetch_hashtag_data(hashtag: str) -> List[Dict]: +async def _fetch_hashtag_data(hashtag: str, ms_token: str) -> List[Dict]: """Fetch data for videos containing a specified hashtag, asynchronously.""" data = [] async with TikTokApi() as api: - await api.create_sessions( - ms_tokens=[os.environ["MS_TOKEN"]], num_sessions=1, sleep_after=3 - ) + await api.create_sessions(ms_tokens=[ms_token], num_sessions=1, sleep_after=3) async for video in api.hashtag(name=hashtag).videos(count=1000): data.append(video.as_dict) return data @@ -101,13 +101,16 @@ def aggregate_cooccurring_hashtags(hashtag_file: Path) -> Counter: class TikTokDownloader: """Main class for scraping data from TikTok.""" - def __init__(self, hashtags: List[str], data_dir: str): + def __init__(self, hashtags: List[str], data_dir: str, config_file: str = None): self.hashtags = process_hashtag_list(hashtags) logging.info(f"Hashtags to scrape: {hashtags}") self.data_dir = Path(data_dir) os.makedirs(self.data_dir, exist_ok=True) + self.auth = Authorization(config_file=config_file) + self.ms_token = self.auth.ms_token + def get_hashtag_posts(self, hashtag: str): """Fetch data about posts that used a specified hashtag and merge with existing data, if it exists.""" @@ -125,7 +128,9 @@ class TikTokDownloader: already_fetched_data = [] # Scrape posts that use the specified hashtag - fetched_data = asyncio.run(_fetch_hashtag_data(hashtag=hashtag)) + fetched_data = asyncio.run( + _fetch_hashtag_data(hashtag=hashtag, ms_token=self.ms_token) + ) if len(fetched_data) == 0: logging.warning(f"No posts were found for the hashtag: {hashtag}")