diff --git a/README.md b/README.md index 01cd57a..166550c 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ The scraper this tool uses requires an `msToken` taken from the TikTok website o ## About the tool ### Command-line arguments ``` -usage: tiktok-hashtag-analysis [-h] [--file FILE] [-d] [--number NUMBER] [-p] [-t] [--output-dir OUTPUT_DIR] [--config CONFIG] [--log LOG] [hashtags ...] +usage: tiktok-hashtag-analysis [-h] [--file FILE] [-d] [--number NUMBER] [-p] [-t] [--output-dir OUTPUT_DIR] [--config CONFIG] [--log LOG] [--limit LIMIT] [-v] [hashtags ...] Analyze hashtags within posts scraped from TikTok. @@ -34,6 +34,8 @@ optional arguments: Directory to save scraped data and visualizations to --config CONFIG File name of configuration file to store TikTok credentials to --log LOG File to write logs to + --limit LIMIT Maximum number of videos to download for each hashtag + -v, --verbose Increase output verbosity ``` ### Structure of output data @@ -138,7 +140,7 @@ Assume we want to analyze the 20 most frequently co-occurring hashtags in the do To run the build-in tests in the `tests/` directory, first install the test dependency packages: ``` -pip install .[test] +pip install .[dev] ``` and then run the tests using the following command: @@ -147,4 +149,7 @@ and then run the tests using the following command: pytest ``` -This repo uses [black](https://github.com/psf/black) to format source code, please run the `black` command before submitting a PR. \ No newline at end of file +This repo uses [black](https://github.com/psf/black) to format source code and [mypy](https://mypy.readthedocs.io/en/stable/) for static type checking. Before submitting a pull request, please run both tools on the source code. + +- yt-dlp warning: (unable to find video in feed) +https://www.tiktok.com/@sa_diya_34/video/7261180335763754242 diff --git a/setup.py b/setup.py index c53600d..13d1599 100644 --- a/setup.py +++ b/setup.py @@ -21,8 +21,25 @@ setup( long_description_content_type="text/markdown", url="https://github.com/bellingcat/tiktok-hashtag-analysis", license="MIT License", - install_requires=["seaborn", "matplotlib", "TikTokApi", "requests", "yt_dlp", "tenacity"], - extras_require={"test": ["pytest", "pytest-cov", "pytest-html", "pytest-metadata"]}, + install_requires=[ + "seaborn", + "matplotlib", + "TikTokApi", + "requests", + "yt_dlp", + "tenacity", + "msvc-runtime; os_name=='nt'", + ], + extras_require={ + "dev": [ + "pytest", + "pytest-cov", + "pytest-html", + "pytest-metadata", + "black", + "mypy", + ] + }, classifiers=[ "Development Status :: 5 - Production/Stable", "Intended Audience :: Information Technology", diff --git a/tests/base.py b/tests/base.py index c0d2a07..dbc139d 100644 --- a/tests/base.py +++ b/tests/base.py @@ -3,7 +3,7 @@ from tiktok_hashtag_analysis.base import TikTokDownloader, load_hashtags_from_fi def test_scrape(tmp_path, hashtags): downloader = TikTokDownloader(hashtags=hashtags[:1], data_dir=tmp_path) - downloader.run(download=True, plot=True, table=True, number=20) + downloader.run(limit=1000, download=True, plot=True, table=True, number=20) def test_load_hashtags_from_file(tmp_path, hashtags): diff --git a/tests/cli.py b/tests/cli.py index ea7f399..58999cb 100644 --- a/tests/cli.py +++ b/tests/cli.py @@ -13,11 +13,14 @@ PARSER_ARGUMENTS = [ ("file", "hashtags.txt", "--file"), ("download", True, "--download"), ("download", True, "-d"), + ("limit", 1000, "--limit"), ("number", 20, "--number"), ("plot", True, "--plot"), ("plot", True, "-p"), ("table", True, "--table"), ("table", True, "-t"), + ("verbose", True, "--verbose"), + ("verbose", True, "-v"), ("output_dir", "/tmp/tiktok_download", "--output-dir"), ("config", "~/.tiktok", "--config"), ("log", "../logfile.log", "--log"), diff --git a/tiktok_hashtag_analysis/auth.py b/tiktok_hashtag_analysis/auth.py index 3255ad9..16252ab 100644 --- a/tiktok_hashtag_analysis/auth.py +++ b/tiktok_hashtag_analysis/auth.py @@ -22,20 +22,20 @@ class Authorization: # Step 1: check if MS_TOKEN is defined as environment variable if ms_token := os.environ.get("MS_TOKEN"): self.ms_token = ms_token - logging.info("Loaded token from environment variable") + logging.debug("Loaded token from environment variable") # Step 2: check if MS_TOKEN is defined in config file elif self.config_file.is_file(): if ms_token := self.load_token(): self.ms_token = ms_token - logging.info(f"Loaded token from config file: {self.config_file}") + logging.debug(f"Loaded token from config file: {self.config_file}") # Step 3: have user enter MS_TOKEN via terminal else: ms_token = self.input_token() self.dump_token(ms_token=ms_token) self.ms_token = ms_token - logging.info( + logging.debug( f"Loaded token from user input and saved to config file: {self.config_file}" ) diff --git a/tiktok_hashtag_analysis/base.py b/tiktok_hashtag_analysis/base.py index 71884d5..92d6d1e 100644 --- a/tiktok_hashtag_analysis/base.py +++ b/tiktok_hashtag_analysis/base.py @@ -7,14 +7,22 @@ import warnings import asyncio import logging import re +from urllib.error import HTTPError from typing import List, Dict, Optional import yt_dlp +from yt_dlp.utils import ExtractorError, DownloadError import requests import matplotlib.pyplot as plt import matplotlib.ticker as mtick import seaborn as sns -from tenacity import retry, retry_if_exception_type, stop_after_attempt +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + TryAgain, + wait_exponential, +) from playwright._impl._api_types import Error from TikTokApi import TikTokApi @@ -43,12 +51,12 @@ def load_hashtags_from_file(file: str) -> List[str]: # Retry upon encountering transient playwright errors @retry(retry=retry_if_exception_type(Error), stop=stop_after_attempt(3)) -async def _fetch_hashtag_data(hashtag: str, ms_token: str) -> List[Dict]: +async def _fetch_hashtag_data(hashtag: str, ms_token: str, limit: int) -> List[Dict]: """Fetch data for videos containing a specified hashtag, asynchronously.""" data = [] async with TikTokApi() as api: - await api.create_sessions(ms_tokens=[ms_token], num_sessions=1, sleep_after=3) - async for video in api.hashtag(name=hashtag).videos(count=1000): + await api.create_sessions(ms_tokens=[], num_sessions=1, sleep_after=3) + async for video in api.hashtag(name=hashtag).videos(count=limit): data.append(video.as_dict) return data @@ -66,22 +74,44 @@ def json_dump(file_path: Path, data: List): json.dump(obj=data, fp=f) +@retry(wait=wait_exponential(multiplier=1, max=10)) +def _get(url: str) -> requests.Response: + """Safe version of requests.get that can handle timeouts and retries""" + + r = requests.get(url=url, timeout=30) + if r.status_code not in {200, 403}: + raise TryAgain + else: + return r + + +def download_file_and_save(url: str, filepath: Path): + """Download a file from a specified URL and write its contents to a file""" + + r = _get(url=url) + if r.status_code == 403: + return + ext = r.headers["Content-Type"].split("/")[-1] + path_with_ext = filepath.with_suffix(f".{ext}") + with open(path_with_ext, "wb") as f: + f.write(r.content) + logging.debug(f"Saved file to: {path_with_ext}") + + def download_gallery(video_data: Dict, video_dir: Path): - """yt-dlp doesn't seem to support downloading images from an image gallery, - so this is a quick fix that likely will fail on edge cases.""" + """yt-dlp doesn't support downloading images from an image gallery, + so this downloads all images and audio files from image galleries.""" video_id = video_data["id"] + # A small percentage of image galleries don't have an associated audio file if play_url := video_data["music"]["playUrl"]: - r = requests.get(play_url) - with open(video_dir / f"{video_id}.mp3", "wb") as f: - f.write(r.content) + filepath = video_dir / f"{video_id}" + download_file_and_save(url=play_url, filepath=filepath) for i, image in enumerate(video_data["imagePost"]["images"]): image_url = image["imageURL"]["urlList"][0] - r = requests.get(image_url) - ext = r.headers["Content-Type"].split("/")[-1] - with open(video_dir / f"{video_id}_{i:02d}.{ext}", "wb") as f: - f.write(r.content) + filepath = video_dir / f"{video_id}_{i:02d}" + download_file_and_save(url=image_url, filepath=filepath) def aggregate_cooccurring_hashtags(hashtag_file: Path) -> Counter: @@ -120,17 +150,16 @@ class TikTokDownloader: self.ms_token = self.auth.get_token() def prioritize_hashtags(self): - """Order hashtags basd on whether they've been scraped before, and + """Order hashtags based on whether they've been scraped before, and the time they were most recently scraped""" - previously_scraped_hashtags = set(os.listdir(self.data_dir)) last_edited = { - hashtag: (self.data_dir / hashtag / "posts.json").lstat().st_mtime - for hashtag in previously_scraped_hashtags + file.parts[-2]: file.lstat().st_mtime + for file in self.data_dir.glob("*/posts.json") } self.hashtags.sort(key=lambda h: last_edited.get(h, 0)) - def get_hashtag_posts(self, hashtag: str): + def get_hashtag_posts(self, hashtag: str, limit: int): """Fetch data about posts that used a specified hashtag and merge with existing data, if it exists.""" @@ -141,31 +170,32 @@ class TikTokDownloader: # If there are previously scraped posts, load them if hashtag_file.is_file(): already_fetched_data = json_load(file_path=hashtag_file) - already_fetched_ids = set(video["id"] for video in already_fetched_data) else: - already_fetched_ids = set() already_fetched_data = [] + already_fetched_ids = set(video["id"] for video in already_fetched_data) # Scrape posts that use the specified hashtag fetched_data = asyncio.run( - _fetch_hashtag_data(hashtag=hashtag, ms_token=self.ms_token) + _fetch_hashtag_data(hashtag=hashtag, ms_token=self.ms_token, limit=limit) ) + fetched_ids = set(video["id"] for video in fetched_data) + if len(fetched_data) == 0: logging.warning(f"No posts were found for the hashtag: {hashtag}") # Determine which newly scraped posts haven't been scraped before - new_fetched_data = [ - video for video in fetched_data if video["id"] not in already_fetched_ids + old_fetched_data = [ + video for video in already_fetched_data if video["id"] not in fetched_ids ] - if len(new_fetched_data) == 0: - logging.warning(f"No new posts were found for the hashtag: {hashtag}") + new_post_count = len(fetched_ids - already_fetched_ids) + old_post_count = len(already_fetched_ids) # Merge new and old data and write to file - all_fetched_data = already_fetched_data + new_fetched_data + all_fetched_data = old_fetched_data + fetched_data json_dump(file_path=hashtag_file, data=all_fetched_data) logging.info( - f"Scraped {len(new_fetched_data)} new posts containing the hashtag " - f"'{hashtag}', with {len(already_fetched_data)} posts previously scraped" + f"Scraped {new_post_count} new posts containing the hashtag " + f"'{hashtag}', with {old_post_count} posts previously scraped" ) def get_hashtag_videos(self, hashtag: str): @@ -186,10 +216,6 @@ class TikTokDownloader: new_video_list = [ video for video in video_list if video["id"] not in already_downloaded_ids ] - if len(new_video_list) == 0: - logging.warning( - f"No new videos to be downloaded for the hashtag: {hashtag}" - ) # Populate list of URLs to download using yt-dlp, and list of image # galleries to download using the `download_gallery` function @@ -197,6 +223,8 @@ class TikTokDownloader: galleries_to_download = [] for video in new_video_list: if video.get("imagePost") is None: + if video.get("author") is None: + continue url = f"https://www.tiktok.com/@{video['author']['uniqueId']}/video/{video['id']}" urls_to_download.append(url) else: @@ -206,6 +234,7 @@ class TikTokDownloader: if len(galleries_to_download) > 0: logging.info(f"Downloading image galleries for hashtag {hashtag}") for video in galleries_to_download: + logging.debug(f"Downloading image gallery for video: {video['id']}") download_gallery(video_data=video, video_dir=video_dir) # Download video files for all video posts @@ -216,7 +245,14 @@ class TikTokDownloader: "ignore_errors": True, } with yt_dlp.YoutubeDL(ydl_opts) as ydl: - ydl.download(urls_to_download) + for url in urls_to_download: + try: + ydl.download([url]) + except (HTTPError, TypeError, ExtractorError, DownloadError) as e: + # catch urllib and yt-dlp errors when video not found + logging.warning( + f"Encountered error {e} when attempting to download url: {url}" + ) def frequency_table(self, hashtag: str, number: int): """Print `number`-most commonly co-occurring hashtags for a specified @@ -269,19 +305,16 @@ class TikTokDownloader: plt.savefig(plot_file, bbox_inches="tight", facecolor="white", dpi=300) logging.info(f"Plot saved to file: {plot_file}") - def run(self, download: bool, plot: bool, table: bool, number: int): + def run(self, limit: int, download: bool, plot: bool, table: bool, number: int): """Execute the specified operations on all specified hashtags.""" # Scrape all specified hashtags and perform analyses, depending on if - # `--table` and `--plot` flags are used in the command + # `--table`, `--plot`, and `--download` flags are used in the command for hashtag in self.hashtags: - self.get_hashtag_posts(hashtag=hashtag) + self.get_hashtag_posts(hashtag=hashtag, limit=limit) if plot: self.plot(hashtag=hashtag, number=number) if table: self.frequency_table(hashtag=hashtag, number=number) - - # Download media for all hashtags if `--download` flag is used in the command - for hashtag in self.hashtags: if download: self.get_hashtag_videos(hashtag=hashtag) diff --git a/tiktok_hashtag_analysis/cli.py b/tiktok_hashtag_analysis/cli.py index 333ed49..10818fb 100644 --- a/tiktok_hashtag_analysis/cli.py +++ b/tiktok_hashtag_analysis/cli.py @@ -63,6 +63,18 @@ def create_parser(): default=None, ) parser.add_argument("--log", type=str, help="File to write logs to", default=None) + parser.add_argument( + "--limit", + type=int, + help="Maximum number of videos to download for each hashtag", + default=1000, + ) + parser.add_argument( + "-v", + "--verbose", + help="Increase output verbosity", + action="store_true", + ) return parser @@ -97,7 +109,7 @@ def main(): args = parser.parse_args() logging.basicConfig( - level=logging.INFO, + level=logging.DEBUG if args.verbose else logging.INFO, filename=args.log, format="%(asctime)s %(levelname)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", @@ -122,7 +134,11 @@ def main(): ) downloader.run( - download=args.download, plot=args.plot, table=args.table, number=args.number + limit=args.limit, + download=args.download, + plot=args.plot, + table=args.table, + number=args.number, )