diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml new file mode 100644 index 0000000..c26c897 --- /dev/null +++ b/.github/workflows/pytest.yml @@ -0,0 +1,18 @@ +name: Test + +on: [push] + +jobs: + pytest: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v3 + with: + python-version: '3.9' + - run: pip install -e .[dev] + - run: python -m playwright install + - run: | + export DISPLAY=:99 + sudo Xvfb -ac :99 -screen 0 1280x1024x24 > /dev/null 2>&1 & + pytest \ No newline at end of file diff --git a/.github/workflows/pytest_windows.yml b/.github/workflows/pytest_windows.yml new file mode 100644 index 0000000..baf3eb4 --- /dev/null +++ b/.github/workflows/pytest_windows.yml @@ -0,0 +1,15 @@ +name: Test on Windows + +on: [push] + +jobs: + pytest: + runs-on: windows-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v3 + with: + python-version: '3.9' + - run: pip install -e .[dev] + - run: python -m playwright install + - run: pytest \ No newline at end of file diff --git a/README.md b/README.md index e0524fe..121c149 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ You should now be ready to start using it. ## About the tool ### Command-line arguments ``` -usage: tiktok-hashtag-analysis [-h] [--file FILE] [-d] [--number NUMBER] [-p] [-t] [--output-dir OUTPUT_DIR] [--config CONFIG] [--log LOG] [--limit LIMIT] [-v] [hashtags ...] +usage: tiktok-hashtag-analysis [-h] [--file FILE] [-d] [--number NUMBER] [-p] [-t] [--output-dir OUTPUT_DIR] [--config CONFIG] [--log LOG] [--limit LIMIT] [-v] [--headed] [hashtags ...] Analyze hashtags within posts scraped from TikTok. @@ -35,6 +35,7 @@ optional arguments: --log LOG File to write logs to --limit LIMIT Maximum number of videos to download for each hashtag -v, --verbose Increase output verbosity + --headed Don't use headless version of TikTok scraper ``` ### Structure of output data diff --git a/tests/base.py b/tests/base.py index dbc139d..df9e62b 100644 --- a/tests/base.py +++ b/tests/base.py @@ -3,7 +3,16 @@ from tiktok_hashtag_analysis.base import TikTokDownloader, load_hashtags_from_fi def test_scrape(tmp_path, hashtags): downloader = TikTokDownloader(hashtags=hashtags[:1], data_dir=tmp_path) - downloader.run(limit=1000, download=True, plot=True, table=True, number=20) + downloader.run( + limit=10, download=True, plot=True, table=True, number=5, headed=True + ) + + +def test_scrape_headless(tmp_path, hashtags): + downloader = TikTokDownloader(hashtags=hashtags[:1], data_dir=tmp_path) + downloader.run( + limit=10, download=True, plot=True, table=True, number=5, headed=False + ) def test_load_hashtags_from_file(tmp_path, hashtags): diff --git a/tests/cli.py b/tests/cli.py index cf40f5a..d4fe319 100644 --- a/tests/cli.py +++ b/tests/cli.py @@ -20,6 +20,7 @@ PARSER_ARGUMENTS = [ ("table", True, "--table"), ("table", True, "-t"), ("verbose", True, "--verbose"), + ("headed", True, "--headed"), ("verbose", True, "-v"), ("output_dir", "/tmp/tiktok_download", "--output-dir"), ("config", "~/.tiktok", "--config"), @@ -41,44 +42,34 @@ def test_parser(hashtags, attribute, value, flag): assert args.get("hashtags") == hashtags -def test_process_output_dir(monkeypatch, tmp_path): - home_dir = Path.home().resolve() - - # Specified nonexistent output directory without write permissions - parser = create_parser() - specified_output_dir = home_dir.parent / "test" - with pytest.raises(SystemExit) as system_exit: - result = process_output_dir( - specified_output_dir=specified_output_dir, parser=parser - ) - assert system_exit.type == SystemExit - - # Specified existing output directory without write permissions - parser = create_parser() - specified_output_dir = home_dir.parent - with pytest.raises(SystemExit) as system_exit: - result = process_output_dir( - specified_output_dir=specified_output_dir, parser=parser - ) - assert system_exit.type == SystemExit - +def test_output_dir_unspec_nowrite(monkeypatch, tmp_path): # Unspecified, in current directory without write permissions + parser = create_parser() cwd = os.getcwd() + specified_output_dir = tmp_path monkeypatch.chdir(specified_output_dir) + os.chmod(tmp_path, 0o444) result = process_output_dir(specified_output_dir=None, parser=parser) monkeypatch.chdir(cwd) assert result == DEFAULT_OUTPUT_DIR + os.chmod(tmp_path, 0o666) + +def test_output_dir_spec_noexist_write(tmp_path): # Specified nonexisting output directory with write permissions parser = create_parser() - specified_output_dir = tmp_path / "test" / "tiktok" + specified_output_dir = tmp_path / "test" result = process_output_dir( specified_output_dir=specified_output_dir, parser=parser ) assert result == specified_output_dir + +def test_output_dir_unspec_write(monkeypatch, tmp_path): # Unspecified, in current directory with write permissions + parser = create_parser() cwd = os.getcwd() + specified_output_dir = tmp_path monkeypatch.chdir(specified_output_dir) result = process_output_dir(specified_output_dir=None, parser=parser) monkeypatch.chdir(cwd) diff --git a/tiktok_hashtag_analysis/base.py b/tiktok_hashtag_analysis/base.py index 71d7e63..93616f3 100644 --- a/tiktok_hashtag_analysis/base.py +++ b/tiktok_hashtag_analysis/base.py @@ -52,11 +52,15 @@ def load_hashtags_from_file(file: str) -> List[str]: # Retry upon encountering transient playwright errors @retry(retry=retry_if_exception_type(Error), stop=stop_after_attempt(3)) -async def _fetch_hashtag_data(hashtag: str, limit: int) -> List[Dict]: +async def _fetch_hashtag_data( + hashtag: str, limit: int, headed: bool = False +) -> List[Dict]: """Fetch data for videos containing a specified hashtag, asynchronously.""" data = [] async with TikTokApi() as api: - await api.create_sessions(ms_tokens=[], num_sessions=1, sleep_after=3) + await api.create_sessions( + ms_tokens=[], num_sessions=1, sleep_after=3, headless=not headed + ) async for video in api.hashtag(name=hashtag).videos(count=limit): data.append(video.as_dict) return data @@ -157,7 +161,7 @@ class TikTokDownloader: } self.hashtags.sort(key=lambda h: last_edited.get(h, 0)) - def get_hashtag_posts(self, hashtag: str, limit: int): + def get_hashtag_posts(self, hashtag: str, limit: int, headed: bool): """Fetch data about posts that used a specified hashtag and merge with existing data, if it exists.""" @@ -172,8 +176,20 @@ class TikTokDownloader: already_fetched_data = [] already_fetched_ids = set(video["id"] for video in already_fetched_data) - # Scrape posts that use the specified hashtag - fetched_data = asyncio.run(_fetch_hashtag_data(hashtag=hashtag, limit=limit)) + # Scrape posts that use the specified hashag + # Attempt to be robust against TikTok's countermeasures for headless browsing + try: + fetched_data = asyncio.run( + _fetch_hashtag_data(hashtag=hashtag, limit=limit, headed=headed) + ) + except Exception as e: + logger.warning( + f"Encountered error {e} when fetching data, retrying in headed mode" + ) + fetched_data = asyncio.run( + _fetch_hashtag_data(hashtag=hashtag, limit=limit, headed=True) + ) + fetched_ids = set(video["id"] for video in fetched_data) if len(fetched_data) == 0: @@ -303,13 +319,21 @@ class TikTokDownloader: plt.savefig(plot_file, bbox_inches="tight", facecolor="white", dpi=300) logger.info(f"Plot saved to file: {plot_file}") - def run(self, limit: int, download: bool, plot: bool, table: bool, number: int): + def run( + self, + limit: int, + download: bool, + plot: bool, + table: bool, + number: int, + headed: bool, + ): """Execute the specified operations on all specified hashtags.""" # Scrape all specified hashtags and perform analyses, depending on if # `--table`, `--plot`, and `--download` flags are used in the command for hashtag in self.hashtags: - self.get_hashtag_posts(hashtag=hashtag, limit=limit) + self.get_hashtag_posts(hashtag=hashtag, limit=limit, headed=headed) if plot: self.plot(hashtag=hashtag, number=number) if table: diff --git a/tiktok_hashtag_analysis/cli.py b/tiktok_hashtag_analysis/cli.py index 37141d5..63342e5 100644 --- a/tiktok_hashtag_analysis/cli.py +++ b/tiktok_hashtag_analysis/cli.py @@ -77,7 +77,11 @@ def create_parser(): help="Increase output verbosity", action="store_true", ) - + parser.add_argument( + "--headed", + help="Don't use headless version of TikTok scraper", + action="store_true", + ) return parser @@ -99,6 +103,11 @@ def process_output_dir( if not os.access(path=_output_dir, mode=os.W_OK): parser.error(error_message(_output_dir)) else: + # On Windows, os.access is unreliable + temp_file = _output_dir / "test.txt" + with open(temp_file, "w") as f: + f.write("test") + os.remove(temp_file) return _output_dir except PermissionError: parser.error(error_message(_output_dir)) @@ -141,6 +150,7 @@ def main(): plot=args.plot, table=args.table, number=args.number, + headed=args.headed, )