mirror of
https://github.com/bellingcat/tiktok-hashtag-analysis.git
synced 2026-06-13 13:58:31 +03:00
added type hints for auth, incorporated auth into base module
This commit is contained in:
@@ -6,6 +6,8 @@ from .base import TikTokDownloader, load_hashtags_from_file
|
||||
|
||||
|
||||
def create_parser():
|
||||
"""Create parser tp parse input command-line arguments."""
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Analyze hashtags within posts scraped from TikTok."
|
||||
)
|
||||
@@ -51,12 +53,20 @@ def create_parser():
|
||||
help="Directory to save scraped data and visualizations to",
|
||||
default=Path(".").resolve().parent / "data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
help="File name of configuration file to store TikTok credentials to",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument("--log", type=str, help="File to write logs to", default=None)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
"""Parse and process command-line arguments, scrape specified hashtags, and perform specified analyses."""
|
||||
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -79,7 +89,9 @@ def main():
|
||||
else:
|
||||
hashtags = args.hashtags
|
||||
|
||||
downloader = TikTokDownloader(hashtags=hashtags, data_dir=args.output_dir)
|
||||
downloader = TikTokDownloader(
|
||||
hashtags=hashtags, data_dir=args.output_dir, config_file=args.config
|
||||
)
|
||||
|
||||
downloader.run(
|
||||
download=args.download, plot=args.plot, table=args.table, number=args.number
|
||||
|
||||
@@ -2,17 +2,22 @@ import os
|
||||
import configparser
|
||||
from pathlib import Path
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class Authorization:
|
||||
"""Handle authorization for TikTok, using the `msToken`."""
|
||||
|
||||
def __init__(self):
|
||||
self.config_file = Path.home() / ".tiktok"
|
||||
self.section = "TikTok"
|
||||
self.ms_token = None
|
||||
def __init__(self, config_file: Optional[str] = None):
|
||||
if config_file:
|
||||
self.config_file = Path(config_file)
|
||||
else:
|
||||
self.config_file = Path.home() / ".tiktok"
|
||||
|
||||
def get_token(self):
|
||||
self.section = "TikTok"
|
||||
self.get_token()
|
||||
|
||||
def get_token(self) -> str:
|
||||
"""Load the "msToken" cookie taken from TikTok, which the scraper requires."""
|
||||
|
||||
# Step 1: check if MS_TOKEN is defined as environment variable
|
||||
@@ -37,14 +42,14 @@ class Authorization:
|
||||
|
||||
return self.ms_token
|
||||
|
||||
def load_token(self):
|
||||
def load_token(self) -> Optional[str]:
|
||||
"""Parse a config file and extract the token."""
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read(self.config_file)
|
||||
return config.get(section=self.section, option="MS_TOKEN", fallback=None)
|
||||
|
||||
def dump_token(self, ms_token):
|
||||
def dump_token(self, ms_token: str):
|
||||
"""Write the token to a config file."""
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
@@ -52,10 +57,10 @@ class Authorization:
|
||||
config.add_section(self.section)
|
||||
config.set(section=self.section, option="MS_TOKEN", value=ms_token)
|
||||
|
||||
with open(self.config_file, "w") as f:
|
||||
with open(self.config_file, "w", encoding="utf-8") as f:
|
||||
config.write(f)
|
||||
|
||||
def input_token(self):
|
||||
def input_token(self) -> str:
|
||||
"""Allow user to manually enter the token in the terminal."""
|
||||
|
||||
print(
|
||||
|
||||
@@ -17,6 +17,8 @@ import seaborn as sns
|
||||
|
||||
from TikTokApi import TikTokApi
|
||||
|
||||
from .auth import Authorization
|
||||
|
||||
warnings.filterwarnings("ignore", message="Glyph (.*) missing from current font")
|
||||
sns.set_theme(style="darkgrid")
|
||||
|
||||
@@ -38,13 +40,11 @@ def load_hashtags_from_file(file: str) -> List[str]:
|
||||
return process_hashtag_list(hashtags=hashtags)
|
||||
|
||||
|
||||
async def _fetch_hashtag_data(hashtag: str) -> List[Dict]:
|
||||
async def _fetch_hashtag_data(hashtag: str, ms_token: str) -> List[Dict]:
|
||||
"""Fetch data for videos containing a specified hashtag, asynchronously."""
|
||||
data = []
|
||||
async with TikTokApi() as api:
|
||||
await api.create_sessions(
|
||||
ms_tokens=[os.environ["MS_TOKEN"]], num_sessions=1, sleep_after=3
|
||||
)
|
||||
await api.create_sessions(ms_tokens=[ms_token], num_sessions=1, sleep_after=3)
|
||||
async for video in api.hashtag(name=hashtag).videos(count=1000):
|
||||
data.append(video.as_dict)
|
||||
return data
|
||||
@@ -101,13 +101,16 @@ def aggregate_cooccurring_hashtags(hashtag_file: Path) -> Counter:
|
||||
class TikTokDownloader:
|
||||
"""Main class for scraping data from TikTok."""
|
||||
|
||||
def __init__(self, hashtags: List[str], data_dir: str):
|
||||
def __init__(self, hashtags: List[str], data_dir: str, config_file: str = None):
|
||||
self.hashtags = process_hashtag_list(hashtags)
|
||||
logging.info(f"Hashtags to scrape: {hashtags}")
|
||||
|
||||
self.data_dir = Path(data_dir)
|
||||
os.makedirs(self.data_dir, exist_ok=True)
|
||||
|
||||
self.auth = Authorization(config_file=config_file)
|
||||
self.ms_token = self.auth.ms_token
|
||||
|
||||
def get_hashtag_posts(self, hashtag: str):
|
||||
"""Fetch data about posts that used a specified hashtag and merge with
|
||||
existing data, if it exists."""
|
||||
@@ -125,7 +128,9 @@ class TikTokDownloader:
|
||||
already_fetched_data = []
|
||||
|
||||
# Scrape posts that use the specified hashtag
|
||||
fetched_data = asyncio.run(_fetch_hashtag_data(hashtag=hashtag))
|
||||
fetched_data = asyncio.run(
|
||||
_fetch_hashtag_data(hashtag=hashtag, ms_token=self.ms_token)
|
||||
)
|
||||
if len(fetched_data) == 0:
|
||||
logging.warning(f"No posts were found for the hashtag: {hashtag}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user