diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 73765d5..5e1d2a8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ repos: - id: nbqa-ruff args: - --fix - - --target-version=py311 + - --target-version=py310 - --ignore=E721,E722 - --line-length=80 - id: nbqa-black @@ -62,18 +62,17 @@ repos: - --line-length=80 # - repo: https://github.com/astral-sh/ruff-pre-commit -# rev: v0.4.10 +# rev: v0.9.7 # hooks: # - id: ruff # types_or: [python,pyi] # args: # - --fix -# - --target-version=py311 # - --select=B,C,E,F,W,B9 # - --line-length=80 # - --ignore=E203,E402,E501,E261 # - id: ruff-format # types_or: [ python,pyi] # args: -# - --target-version=py311 +# - --target-version=py310 # - --line-length=80 diff --git a/app/tests/conftest.py b/app/tests/conftest.py index f7da39e..997ea62 100644 --- a/app/tests/conftest.py +++ b/app/tests/conftest.py @@ -1,4 +1,6 @@ import os +from datetime import datetime +from http import HTTPStatus from typing import AsyncGenerator from unittest.mock import patch @@ -7,15 +9,31 @@ import pytest_asyncio from fastapi.testclient import TestClient from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession +from app.shared.db import models +from app.shared.db.database import ( + make_async_engine, + make_async_session_local, + make_engine, + make_session_local, +) from app.shared.settings import Settings from app.web.config import ALLOW_ANY_EMAIL +from app.web.db import crud +from app.web.db.crud import get_user_group_names from app.web.db.user_state import UserState +from app.web.main import app_factory +from app.web.security import ( + get_token_or_user_auth, + get_user_auth, + get_user_state, + token_api_key_auth, +) @pytest.fixture(autouse=True) def mock_logger_add(): """Fixture to mock loguru.logger.add for all tests.""" - with patch('loguru.logger.add') as mock_add: + with patch("loguru.logger.add") as mock_add: yield mock_add # This makes the mock available to tests @@ -26,23 +44,22 @@ def get_settings(): @pytest.fixture(autouse=True) def mock_settings(): - with patch('app.shared.settings.Settings', return_value=Settings(_env_file=".env.test")) as mock_settings: + with patch( + "app.shared.settings.Settings", + return_value=Settings(_env_file=".env.test"), + ) as mock_settings: yield mock_settings @pytest.fixture() def test_db(get_settings: Settings): - from app.shared.db import models - from app.shared.db.database import make_engine - from app.web.db.crud import get_user_group_names - get_user_group_names.cache_clear() make_engine.cache_clear() engine = make_engine(get_settings.DATABASE_PATH) fs = get_settings.DATABASE_PATH.replace("sqlite:///", "") if not os.path.exists(fs): - open(fs, 'w').close() + open(fs, "w").close() models.Base.metadata.create_all(engine) @@ -59,7 +76,6 @@ def test_db(get_settings: Settings): @pytest.fixture() def db_session(test_db): - from app.shared.db.database import make_session_local session_local = make_session_local(test_db) with session_local() as session: yield session @@ -67,18 +83,12 @@ def db_session(test_db): @pytest_asyncio.fixture() async def async_test_db(get_settings: Settings): - import asyncio - - from app.shared.db import models - from app.shared.db.database import make_async_engine - from app.web.db.crud import get_user_group_names - get_user_group_names.cache_clear() engine = await make_async_engine(get_settings.ASYNC_DATABASE_PATH) fs = get_settings.ASYNC_DATABASE_PATH.replace("sqlite+aiosqlite:///", "") if not os.path.exists(fs): - open(fs, 'w').close() + open(fs, "w").close() async def create_all(): async with engine.begin() as conn: @@ -102,8 +112,9 @@ async def async_test_db(get_settings: Settings): @pytest_asyncio.fixture() -async def async_db_session(async_test_db: AsyncEngine) -> AsyncGenerator[AsyncSession, None]: - from app.shared.db.database import make_async_session_local +async def async_db_session( + async_test_db: AsyncEngine, +) -> AsyncGenerator[AsyncSession, None]: session_local = await make_async_session_local(async_test_db) async with session_local() as session: yield session @@ -111,8 +122,6 @@ async def async_db_session(async_test_db: AsyncEngine) -> AsyncGenerator[AsyncSe @pytest.fixture() def app(db_session): - from app.web.db import crud - from app.web.main import app_factory app = app_factory() crud.upsert_user_groups(db_session) return app @@ -126,14 +135,13 @@ def client(app): @pytest.fixture() def app_with_auth(app, db_session): - from app.web.security import ( - get_token_or_user_auth, - get_user_auth, - get_user_state, + app.dependency_overrides[get_token_or_user_auth] = ( + lambda: "rick@example.com" ) - app.dependency_overrides[get_token_or_user_auth] = lambda: "rick@example.com" app.dependency_overrides[get_user_auth] = lambda: "morty@example.com" - app.dependency_overrides[get_user_state] = lambda: UserState(db_session, "MORTY@example.com") + app.dependency_overrides[get_user_state] = lambda: UserState( + db_session, "MORTY@example.com" + ) return app @@ -145,7 +153,6 @@ def client_with_auth(app_with_auth): @pytest.fixture() def app_with_token(app): - from app.web.security import get_token_or_user_auth, token_api_key_auth app.dependency_overrides[token_api_key_auth] = lambda: ALLOW_ANY_EMAIL app.dependency_overrides[get_token_or_user_auth] = lambda: ALLOW_ANY_EMAIL return app @@ -162,6 +169,93 @@ def test_no_auth(): # reusable code to ensure a method/endpoint combination is unauthorized def no_auth(http_method, endpoint): response = http_method(endpoint) - assert response.status_code == 403 + assert response.status_code == HTTPStatus.FORBIDDEN assert response.json() == {"detail": "Not authenticated"} + return no_auth + + +@pytest.fixture() +def test_data(db_session): + author_emails = [ + "rick@example.com", + "morty@example.com", + "jerry@example.com", + ] + + # creates 3 users + for email in author_emails: + db_session.add(models.User(email=email)) + db_session.commit() + assert db_session.query(models.User).count() == 3 + + # creates 100 archives for 3 users over 2 months with repeating URLs + for i in range(100): + author = author_emails[i % 3] + archive = models.Archive( + id=f"archive-id-456-{i}", + url=f"https://example-{i % 3}.com", + result={}, + public=author == "jerry@example.com", + author_id=author, + group_id="spaceship" + if author == "morty@example.com" and i % 2 == 0 + else None, + created_at=datetime(2021, (i % 2) + 1, (i % 25) + 1), + ) + if i % 5 == 0: + archive.tags.append(models.Tag(id=f"tag-{i}")) + if i % 10 == 0: + archive.tags.append(models.Tag(id=f"tag-second-{i}")) + if i % 4 == 0: + archive.tags.append(models.Tag(id=f"tag-third-{i}")) + for j in range(10): + archive.urls.append( + models.ArchiveUrl( + url=f"https://example-{i}.com/{j}", key=f"media_{j}" + ) + ) + db_session.add(archive) + + # creates a sheet for each user + for i, email in enumerate( + ["rick@example.com", "morty@example.com", "jerry@example.com"] + ): + db_session.add( + models.Sheet( + id=f"sheet-{i}", + name=f"sheet-{i}", + author_id=email, + group_id=None, + frequency="daily", + ) + ) + if email == "rick@example.com": + db_session.add( + models.Sheet( + id=f"sheet-{i}-2", + name=f"sheet-{i}-2", + author_id=email, + group_id="spaceship", + frequency="hourly", + ) + ) + + db_session.commit() + + assert db_session.query(models.Archive).count() == 100 + assert db_session.query(models.Tag).count() == 20 + 10 + 25 + assert db_session.query(models.ArchiveUrl).count() == 1000 + assert ( + db_session.query(models.ArchiveUrl) + .filter(models.ArchiveUrl.archive_id == "archive-id-456-0") + .count() + == 10 + ) + + # setup groups + assert db_session.query(models.Group).count() == 0 + + crud.upsert_user_groups(db_session) + assert db_session.query(models.Group).count() == 4 + assert db_session.query(models.User).count() == 3 diff --git a/app/tests/shared/db/test_models.py b/app/tests/shared/db/test_models.py index 4da9571..537532b 100644 --- a/app/tests/shared/db/test_models.py +++ b/app/tests/shared/db/test_models.py @@ -1,6 +1,7 @@ -def test_generate_uuid(): - from app.shared.db.models import generate_uuid +from app.shared.db.models import generate_uuid - assert generate_uuid() != generate_uuid() - assert len(generate_uuid()) == 36 - assert generate_uuid().count("-") == 4 + +def test_generate_uuid(): + assert generate_uuid() != generate_uuid() + assert len(generate_uuid()) == 36 + assert generate_uuid().count("-") == 4 diff --git a/app/tests/shared/db/test_worker_crud.py b/app/tests/shared/db/test_worker_crud.py index 09e9a76..2258781 100644 --- a/app/tests/shared/db/test_worker_crud.py +++ b/app/tests/shared/db/test_worker_crud.py @@ -1,11 +1,10 @@ from datetime import datetime +from app.shared import schemas from app.shared.db import models, worker_crud -from app.tests.web.db.test_crud import test_data def test_update_sheet_last_url_archived_at(db_session): - # Create test sheet test_sheet = models.Sheet(id="sheet-123") db_session.add(test_sheet) @@ -14,17 +13,24 @@ def test_update_sheet_last_url_archived_at(db_session): # Test updating existing sheet assert isinstance(test_sheet.last_url_archived_at, datetime) before = test_sheet.last_url_archived_at - assert worker_crud.update_sheet_last_url_archived_at(db_session, "sheet-123") is True + assert ( + worker_crud.update_sheet_last_url_archived_at(db_session, "sheet-123") + is True + ) db_session.refresh(test_sheet) assert isinstance(test_sheet.last_url_archived_at, datetime) assert test_sheet.last_url_archived_at > before # Test non-existent sheet - assert worker_crud.update_sheet_last_url_archived_at(db_session, "non-existent-sheet") is False + assert ( + worker_crud.update_sheet_last_url_archived_at( + db_session, "non-existent-sheet" + ) + is False + ) + def test_get_group(test_data, db_session): - from app.shared.db import worker_crud - assert worker_crud.get_group(db_session, "spaceship") is not None assert worker_crud.get_group(db_session, "interdimensional") is not None assert worker_crud.get_group(db_session, "animated-characters") is not None @@ -32,24 +38,24 @@ def test_get_group(test_data, db_session): def test_create_or_get_user(test_data, db_session): - from app.shared.db import worker_crud - assert db_session.query(models.User).count() == 3 # already exists - assert (u1 := worker_crud.create_or_get_user(db_session, "rick@example.com")) is not None + assert ( + u1 := worker_crud.create_or_get_user(db_session, "rick@example.com") + ) is not None assert u1.email == "rick@example.com" # new user - assert (u2 := worker_crud.create_or_get_user(db_session, "beth@example.com")) is not None + assert ( + u2 := worker_crud.create_or_get_user(db_session, "beth@example.com") + ) is not None assert u2.email == "beth@example.com" assert db_session.query(models.User).count() == 4 def test_create_tag(db_session): - from app.shared.db import worker_crud - assert db_session.query(models.Tag).count() == 0 # create first @@ -57,7 +63,10 @@ def test_create_tag(db_session): assert create_tag is not None assert create_tag.id == "tag-101" assert db_session.query(models.Tag).count() == 1 - assert db_session.query(models.Tag).filter(models.Tag.id == "tag-101").first() == create_tag + assert ( + db_session.query(models.Tag).filter(models.Tag.id == "tag-101").first() + == create_tag + ) # same id does not add new db entry existing_tag = worker_crud.create_tag(db_session, "tag-101") @@ -72,9 +81,6 @@ def test_create_tag(db_session): def test_create_task(db_session): - from app.shared import schemas - from app.shared.db import worker_crud - task = schemas.ArchiveCreate( id="archive-id-456-101", url="https://example-0.com", @@ -83,17 +89,22 @@ def test_create_task(db_session): author_id="rick@example.com", group_id="spaceship", tags=[], - urls=[] + urls=[], ) # with tags and urls - nt = worker_crud.create_archive(db_session, task, [models.Tag(id="tag-101")], [models.ArchiveUrl(url="https://example-0.com/0", key="media_0")]) + nt = worker_crud.create_archive( + db_session, + task, + [models.Tag(id="tag-101")], + [models.ArchiveUrl(url="https://example-0.com/0", key="media_0")], + ) assert nt is not None assert nt.id == "archive-id-456-101" assert nt.url == "https://example-0.com" assert nt.author_id == "rick@example.com" - assert nt.public == False + assert nt.public is False assert nt.group_id == "spaceship" assert len(nt.tags) == 1 assert nt.tags[0].id == "tag-101" @@ -109,7 +120,7 @@ def test_create_task(db_session): assert nt.id == "archive-id-456-102" assert nt.url == "https://example-0.com" assert nt.author_id == "rick@example.com" - assert nt.public == False + assert nt.public is False assert nt.group_id == "spaceship" assert len(nt.tags) == 0 assert len(nt.urls) == 0 diff --git a/app/tests/shared/test_business_logic.py b/app/tests/shared/test_business_logic.py index 830fa7b..e10d402 100644 --- a/app/tests/shared/test_business_logic.py +++ b/app/tests/shared/test_business_logic.py @@ -9,7 +9,7 @@ from app.shared.business_logic import ( ) -class Test_get_store_archive_until: +class TestGetStoreArchiveUntil: GROUP_ID = "test-group" def test_group_not_found(self, db_session): @@ -17,7 +17,10 @@ class Test_get_store_archive_until: get_store_archive_until(db_session, self.GROUP_ID) assert str(exc.value) == f"Group {self.GROUP_ID} not found." - @patch("app.shared.db.worker_crud.get_group", return_value=MagicMock(permissions=None)) + @patch( + "app.shared.db.worker_crud.get_group", + return_value=MagicMock(permissions=None), + ) def test_group_no_permissions(self, db_session): with pytest.raises(AssertionError) as exc: get_store_archive_until(db_session, self.GROUP_ID) @@ -48,14 +51,17 @@ class Test_get_store_archive_until: mock_get_group.assert_called_once_with(db_session, self.GROUP_ID) -class Test_get_store_archive_until_or_never: +class TestGetStoreArchiveUntilOrNever: GROUP_ID = "test-group" def test_group_not_found(self, db_session): result = get_store_archive_until_or_never(db_session, self.GROUP_ID) assert result is None - @patch("app.shared.db.worker_crud.get_group", return_value=MagicMock(permissions=None)) + @patch( + "app.shared.db.worker_crud.get_group", + return_value=MagicMock(permissions=None), + ) def test_group_no_permissions(self, db_session): result = get_store_archive_until_or_never(db_session, self.GROUP_ID) assert result is None diff --git a/app/tests/web/db/test_crud.py b/app/tests/web/db/test_crud.py index 298a087..676aa24 100644 --- a/app/tests/web/db/test_crud.py +++ b/app/tests/web/db/test_crud.py @@ -2,125 +2,379 @@ from datetime import datetime, timedelta from unittest.mock import patch import pytest +import sqlalchemy import yaml +from sqlalchemy import false, true +from sqlalchemy.sql import select from app.shared.db import models from app.shared.settings import Settings +from app.web.config import ALLOW_ANY_EMAIL from app.web.db import crud -authors = ["rick@example.com", "morty@example.com", "jerry@example.com"] - - -@pytest.fixture() -def test_data(db_session): - - # creates 3 users - for email in authors: - db_session.add(models.User(email=email)) - db_session.commit() - assert db_session.query(models.User).count() == 3 - - # creates 100 archives for 3 users over 2 months with repeating URLs - for i in range(100): - author = authors[i % 3] - archive = models.Archive( - id=f"archive-id-456-{i}", - url=f"https://example-{i%3}.com", - result={}, - public=author == "jerry@example.com", - author_id=author, - group_id="spaceship" if author == "morty@example.com" and i % 2 == 0 else None, - created_at=datetime(2021, (i % 2) + 1, (i % 25) + 1) - ) - if i % 5 == 0: - archive.tags.append(models.Tag(id=f"tag-{i}")) - if i % 10 == 0: - archive.tags.append(models.Tag(id=f"tag-second-{i}")) - if i % 4 == 0: - archive.tags.append(models.Tag(id=f"tag-third-{i}")) - for j in range(10): - archive.urls.append(models.ArchiveUrl(url=f"https://example-{i}.com/{j}", key=f"media_{j}")) - db_session.add(archive) - - # creates a sheet for each user - for i, email in enumerate(authors): - db_session.add(models.Sheet(id=f"sheet-{i}", name=f"sheet-{i}", author_id=email, group_id=None, frequency="daily")) - if email == "rick@example.com": - db_session.add(models.Sheet(id=f"sheet-{i}-2", name=f"sheet-{i}-2", author_id=email, group_id="spaceship", frequency="hourly")) - - db_session.commit() - - assert db_session.query(models.Archive).count() == 100 - assert db_session.query(models.Tag).count() == 20 + 10 + 25 - assert db_session.query(models.ArchiveUrl).count() == 1000 - assert db_session.query(models.ArchiveUrl).filter(models.ArchiveUrl.archive_id == "archive-id-456-0").count() == 10 - - # setup groups - assert db_session.query(models.Group).count() == 0 - from app.web.db import crud - crud.upsert_user_groups(db_session) - assert db_session.query(models.Group).count() == 4 - assert db_session.query(models.User).count() == 3 - - def test_search_archives_by_url(test_data, db_session): - from app.web.config import ALLOW_ANY_EMAIL - - # rick's archives are private - assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "rick@example.com", True, False)) == 34 - assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "rick@example.com", [], False)) == 34 - assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "rick@example.com", [], True)) == 34 - assert len(crud.search_archives_by_url(db_session, "https://example-0.com", ALLOW_ANY_EMAIL, [], False)) == 34 - assert len(crud.search_archives_by_url(db_session, "https://example-0.com", ALLOW_ANY_EMAIL, True, False)) == 34 - assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "morty@example.com", [], False)) == 0 - assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "morty@example.com", [], True)) == 0 + # Rick's archives are private + assert ( + len( + crud.search_archives_by_url( + db_session, + "https://example-0.com", + "rick@example.com", + True, + False, + ) + ) + == 34 + ) + assert ( + len( + crud.search_archives_by_url( + db_session, + "https://example-0.com", + "rick@example.com", + [], + False, + ) + ) + == 34 + ) + assert ( + len( + crud.search_archives_by_url( + db_session, + "https://example-0.com", + "rick@example.com", + [], + True, + ) + ) + == 34 + ) + assert ( + len( + crud.search_archives_by_url( + db_session, "https://example-0.com", ALLOW_ANY_EMAIL, [], False + ) + ) + == 34 + ) + assert ( + len( + crud.search_archives_by_url( + db_session, + "https://example-0.com", + ALLOW_ANY_EMAIL, + True, + False, + ) + ) + == 34 + ) + assert ( + len( + crud.search_archives_by_url( + db_session, + "https://example-0.com", + "morty@example.com", + [], + False, + ) + ) + == 0 + ) + assert ( + len( + crud.search_archives_by_url( + db_session, + "https://example-0.com", + "morty@example.com", + [], + True, + ) + ) + == 0 + ) # morty's archives are public but half are in spaceship group - assert len(crud.search_archives_by_url(db_session, "https://example-1.com", "rick@example.com", ["spaceship"], False)) == 16 - assert len(crud.search_archives_by_url(db_session, "https://example-1.com", "rick@example.com", True, False)) == 16 - assert len(crud.search_archives_by_url(db_session, "https://example-1.com", "jerry@example.com", True, True)) == 16 + assert ( + len( + crud.search_archives_by_url( + db_session, + "https://example-1.com", + "rick@example.com", + ["spaceship"], + False, + ) + ) + == 16 + ) + assert ( + len( + crud.search_archives_by_url( + db_session, + "https://example-1.com", + "rick@example.com", + True, + False, + ) + ) + == 16 + ) + assert ( + len( + crud.search_archives_by_url( + db_session, + "https://example-1.com", + "jerry@example.com", + True, + True, + ) + ) + == 16 + ) - # jerry's archives are public - assert len(crud.search_archives_by_url(db_session, "https://example-2.com", "jerry@example.com", [], True)) == 33 - assert len(crud.search_archives_by_url(db_session, "https://example-2.com", "rick@example.com", [], True)) == 33 + # Jerry's archives are public + assert ( + len( + crud.search_archives_by_url( + db_session, + "https://example-2.com", + "jerry@example.com", + [], + True, + ) + ) + == 33 + ) + assert ( + len( + crud.search_archives_by_url( + db_session, + "https://example-2.com", + "rick@example.com", + [], + True, + ) + ) + == 33 + ) # fuzzy search - assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False)) == 100 - assert len(crud.search_archives_by_url(db_session, "https://EXAMPLE", ALLOW_ANY_EMAIL, False, False)) == 100 - assert len(crud.search_archives_by_url(db_session, "2.com", ALLOW_ANY_EMAIL, False, False)) == 33 + assert ( + len( + crud.search_archives_by_url( + db_session, "https://example", ALLOW_ANY_EMAIL, False, False + ) + ) + == 100 + ) + assert ( + len( + crud.search_archives_by_url( + db_session, "https://EXAMPLE", ALLOW_ANY_EMAIL, False, False + ) + ) + == 100 + ) + assert ( + len( + crud.search_archives_by_url( + db_session, "2.com", ALLOW_ANY_EMAIL, False, False + ) + ) + == 33 + ) # absolute search - assert len(crud.search_archives_by_url(db_session, "example-2.com", ALLOW_ANY_EMAIL, [], False, absolute_search=True)) == 0 - assert len(crud.search_archives_by_url(db_session, "https://example-2.com", ALLOW_ANY_EMAIL, [], False, absolute_search=True)) == 33 + assert ( + len( + crud.search_archives_by_url( + db_session, + "example-2.com", + ALLOW_ANY_EMAIL, + [], + False, + absolute_search=True, + ) + ) + == 0 + ) + assert ( + len( + crud.search_archives_by_url( + db_session, + "https://example-2.com", + ALLOW_ANY_EMAIL, + [], + False, + absolute_search=True, + ) + ) + == 33 + ) # archived_after - assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, True, True, archived_after=datetime(2010, 1, 1))) == 100 - assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_after=datetime(2021, 1, 15))) == 70 - assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_after=datetime(2031, 1, 1))) == 0 + assert ( + len( + crud.search_archives_by_url( + db_session, + "https://example", + ALLOW_ANY_EMAIL, + True, + True, + archived_after=datetime(2010, 1, 1), + ) + ) + == 100 + ) + assert ( + len( + crud.search_archives_by_url( + db_session, + "https://example", + ALLOW_ANY_EMAIL, + False, + False, + archived_after=datetime(2021, 1, 15), + ) + ) + == 70 + ) + assert ( + len( + crud.search_archives_by_url( + db_session, + "https://example", + ALLOW_ANY_EMAIL, + False, + False, + archived_after=datetime(2031, 1, 1), + ) + ) + == 0 + ) # archived before - assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_before=datetime(2010, 1, 1))) == 0 - assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_before=datetime(2021, 1, 15))) == 28 - assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_before=datetime(2031, 1, 1))) == 100 + assert ( + len( + crud.search_archives_by_url( + db_session, + "https://example", + ALLOW_ANY_EMAIL, + False, + False, + archived_before=datetime(2010, 1, 1), + ) + ) + == 0 + ) + assert ( + len( + crud.search_archives_by_url( + db_session, + "https://example", + ALLOW_ANY_EMAIL, + False, + False, + archived_before=datetime(2021, 1, 15), + ) + ) + == 28 + ) + assert ( + len( + crud.search_archives_by_url( + db_session, + "https://example", + ALLOW_ANY_EMAIL, + False, + False, + archived_before=datetime(2031, 1, 1), + ) + ) + == 100 + ) # archived before and after - assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_after=datetime(2001, 1, 1), archived_before=datetime(2031, 1, 11))) == 100 - assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_after=datetime(2021, 1, 14), archived_before=datetime(2021, 1, 16))) == 2 + assert ( + len( + crud.search_archives_by_url( + db_session, + "https://example", + ALLOW_ANY_EMAIL, + False, + False, + archived_after=datetime(2001, 1, 1), + archived_before=datetime(2031, 1, 11), + ) + ) + == 100 + ) + assert ( + len( + crud.search_archives_by_url( + db_session, + "https://example", + ALLOW_ANY_EMAIL, + False, + False, + archived_after=datetime(2021, 1, 14), + archived_before=datetime(2021, 1, 16), + ) + ) + == 2 + ) # limit - assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, limit=10)) == 10 - assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, limit=-1)) == 1 + assert ( + len( + crud.search_archives_by_url( + db_session, + "https://example", + ALLOW_ANY_EMAIL, + False, + False, + limit=10, + ) + ) + == 10 + ) + assert ( + len( + crud.search_archives_by_url( + db_session, + "https://example", + ALLOW_ANY_EMAIL, + False, + False, + limit=-1, + ) + ) + == 1 + ) # skip - assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, skip=10)) == 90 + assert ( + len( + crud.search_archives_by_url( + db_session, + "https://example", + ALLOW_ANY_EMAIL, + False, + False, + skip=10, + ) + ) + == 90 + ) def test_search_archives_by_email(test_data, db_session): - from app.web.config import ALLOW_ANY_EMAIL - # lower/upper case - assert len(crud.search_archives_by_email(db_session, "rick@example.com")) == 34 + assert ( + len(crud.search_archives_by_email(db_session, "rick@example.com")) == 34 + ) # ALLOW_ANY_EMAIL is not a user assert len(crud.search_archives_by_email(db_session, ALLOW_ANY_EMAIL)) == 0 @@ -138,45 +392,108 @@ def test_search_archives_by_email(test_data, db_session): @patch("app.web.db.crud.DATABASE_QUERY_LIMIT", new=25) def test_max_query_limit(test_data, db_session): - from app.web.config import ALLOW_ANY_EMAIL + assert ( + len( + crud.search_archives_by_url( + db_session, "https://example", ALLOW_ANY_EMAIL, [], False + ) + ) + == 25 + ) + assert ( + len( + crud.search_archives_by_url( + db_session, + "https://example", + ALLOW_ANY_EMAIL, + True, + True, + limit=1000, + ) + ) + == 25 + ) - assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, [], False)) == 25 - assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, True, True, limit=1000)) == 25 - - assert len(crud.search_archives_by_email(db_session, "rick@example.com")) == 25 - assert len(crud.search_archives_by_email(db_session, "rick@example.com", limit=1000)) == 25 + assert ( + len(crud.search_archives_by_email(db_session, "rick@example.com")) == 25 + ) + assert ( + len( + crud.search_archives_by_email( + db_session, "rick@example.com", limit=1000 + ) + ) + == 25 + ) def test_soft_delete(test_data, db_session): # none deleted yet - db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").first() is not None - assert db_session.query(models.Archive).filter(models.Archive.deleted == True).count() == 0 + assert ( + db_session.query(models.Archive) + .filter(models.Archive.id == "archive-id-456-0") + .first() + is not None + ) + assert ( + db_session.query(models.Archive) + .filter(models.Archive.deleted.is_(true())) + .count() + == 0 + ) # delete - assert crud.soft_delete_archive(db_session, "archive-id-456-0", "rick@example.com") == True + assert ( + crud.soft_delete_archive( + db_session, "archive-id-456-0", "rick@example.com" + ) + is True + ) # ensure soft delete - assert db_session.query(models.Archive).filter(models.Archive.deleted == True).count() == 1 - db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").first() is None + assert ( + db_session.query(models.Archive) + .filter(models.Archive.deleted.is_(true())) + .count() + == 1 + ) + assert ( + db_session.query(models.Archive) + .filter(models.Archive.id == "archive-id-456-0") + .filter(models.Archive.deleted.is_(false())) + .first() + is None + ) # already deleted - assert crud.soft_delete_archive(db_session, "archive-id-456-0", "rick@example.com") == False + assert ( + crud.soft_delete_archive( + db_session, "archive-id-456-0", "rick@example.com" + ) + is False + ) def test_count_archives(test_data, db_session): assert crud.count_archives(db_session) == 100 - db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").delete() + db_session.query(models.Archive).filter( + models.Archive.id == "archive-id-456-0" + ).delete() db_session.commit() assert crud.count_archives(db_session) == 99 def test_count_archive_urls(test_data, db_session): assert crud.count_archive_urls(db_session) == 1000 - db_session.query(models.ArchiveUrl).filter(models.ArchiveUrl.url == "https://example-0.com/0").delete() + db_session.query(models.ArchiveUrl).filter( + models.ArchiveUrl.url == "https://example-0.com/0" + ).delete() db_session.commit() assert crud.count_archive_urls(db_session) == 999 - db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").delete() + db_session.query(models.Archive).filter( + models.Archive.id == "archive-id-456-0" + ).delete() db_session.commit() # no Cascade is enabled assert crud.count_archives(db_session) == 99 @@ -185,16 +502,23 @@ def test_count_archive_urls(test_data, db_session): def test_count_users(test_data, db_session): assert crud.count_users(db_session) == 3 - db_session.query(models.User).filter(models.User.email == "rick@example.com").delete() + db_session.query(models.User).filter( + models.User.email == "rick@example.com" + ).delete() db_session.commit() assert crud.count_users(db_session) == 2 def test_count_by_users_since(test_data, db_session): - from app.web.db import crud - # 100y window - assert len(cu := crud.count_by_user_since(db_session, 60 * 60 * 24 * 31 * 12 * 100)) == 3 + assert ( + len( + cu := crud.count_by_user_since( + db_session, 60 * 60 * 24 * 31 * 12 * 100 + ) + ) + == 3 + ) assert cu[0].total == 34 assert cu[1].total == 33 assert cu[2].total == 33 @@ -203,9 +527,18 @@ def test_count_by_users_since(test_data, db_session): def test_upsert_group(test_data, db_session): assert db_session.query(models.Group).count() == 4 - repeatable_params = ["desc 1", "orch.yaml", "sheet.yaml", "service_account_email@example.com", {"read": ["all"]}, ["example.com"]] + repeatable_params = [ + "desc 1", + "orch.yaml", + "sheet.yaml", + "service_account_email@example.com", + {"read": ["all"]}, + ["example.com"], + ] - assert (g1 := crud.upsert_group(db_session, "spaceship", *repeatable_params)) is not None + assert ( + g1 := crud.upsert_group(db_session, "spaceship", *repeatable_params) + ) is not None assert g1.id == "spaceship" assert g1.description == "desc 1" assert g1.orchestrator == "orch.yaml" @@ -214,14 +547,25 @@ def test_upsert_group(test_data, db_session): assert g1.permissions == {"read": ["all"]} assert g1.domains == ["example.com"] assert len(g1.users) == 2 - assert [u.email for u in g1.users] == ["rick@example.com", "morty@example.com"] + assert [u.email for u in g1.users] == [ + "rick@example.com", + "morty@example.com", + ] - assert (g2 := crud.upsert_group(db_session, "interdimensional", *repeatable_params)) is not None + assert ( + g2 := crud.upsert_group( + db_session, "interdimensional", *repeatable_params + ) + ) is not None assert g2.id == "interdimensional" assert len(g2.users) == 1 assert [u.email for u in g2.users] == ["rick@example.com"] - assert (g3 := crud.upsert_group(db_session, "this-is-a-new-group", *repeatable_params)) is not None + assert ( + g3 := crud.upsert_group( + db_session, "this-is-a-new-group", *repeatable_params + ) + ) is not None assert g3.id == "this-is-a-new-group" assert len(g3.users) == 0 @@ -229,29 +573,38 @@ def test_upsert_group(test_data, db_session): def test_upsert_user_groups(db_session): - @patch('app.web.db.crud.get_settings', new=lambda: bad_setings) + @patch("app.web.db.crud.get_settings", new=lambda: bad_settings) def test_missing_yaml(db_session): with pytest.raises(FileNotFoundError): crud.upsert_user_groups(db_session) - @patch('app.web.db.crud.get_settings', new=lambda: bad_setings) + @patch("app.web.db.crud.get_settings", new=lambda: bad_settings) def test_broken_yaml(db_session): with pytest.raises(yaml.YAMLError): crud.upsert_user_groups(db_session) - bad_setings = Settings(_env_file=".env.test") + bad_settings = Settings(_env_file=".env.test") - bad_setings.USER_GROUPS_FILENAME = "app/tests/user-groups.test.missing.yaml" + bad_settings.USER_GROUPS_FILENAME = ( + "app/tests/user-groups.test.missing.yaml" + ) test_missing_yaml(db_session) - bad_setings.USER_GROUPS_FILENAME = "app/tests/user-groups.test.broken.yaml" + bad_settings.USER_GROUPS_FILENAME = "app/tests/user-groups.test.broken.yaml" test_broken_yaml(db_session) def test_create_sheet(db_session): assert db_session.query(models.Sheet).count() == 0 - s = crud.create_sheet(db_session, "sheet-id-123", "sheet name", "email@example.com", "group-id", "hourly") + s = crud.create_sheet( + db_session, + "sheet-id-123", + "sheet name", + "email@example.com", + "group-id", + "hourly", + ) assert s is not None assert s.id == "sheet-id-123" assert s.name == "sheet name" @@ -261,19 +614,35 @@ def test_create_sheet(db_session): assert db_session.query(models.Sheet).count() == 1 - # duplicate id - import sqlalchemy with pytest.raises(sqlalchemy.exc.IntegrityError): - crud.create_sheet(db_session, "sheet-id-123", "I thought this was another sheet", "email", "group-id", "hourly") + crud.create_sheet( + db_session, + "sheet-id-123", + "I thought this was another sheet", + "email", + "group-id", + "hourly", + ) def test_get_user_sheet(test_data, db_session): assert crud.get_user_sheet(db_session, "", "sheet-0") is None - assert crud.get_user_sheet(db_session, "morty@example.com", "sheet-0") is None + assert ( + crud.get_user_sheet(db_session, "morty@example.com", "sheet-0") is None + ) - assert crud.get_user_sheet(db_session, "rick@example.com", "sheet-0") is not None - assert crud.get_user_sheet(db_session, "rick@example.com", "sheet-0-2") is not None - assert crud.get_user_sheet(db_session, "morty@example.com", "sheet-1") is not None + assert ( + crud.get_user_sheet(db_session, "rick@example.com", "sheet-0") + is not None + ) + assert ( + crud.get_user_sheet(db_session, "rick@example.com", "sheet-0-2") + is not None + ) + assert ( + crud.get_user_sheet(db_session, "morty@example.com", "sheet-1") + is not None + ) def test_get_user_sheets(test_data, db_session): @@ -285,9 +654,9 @@ def test_get_user_sheets(test_data, db_session): def test_delete_sheet(test_data, db_session): - assert crud.delete_sheet(db_session, "sheet-0", "") == False - assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") == True - assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") == False + assert crud.delete_sheet(db_session, "sheet-0", "") is False + assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") is True + assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") is False @pytest.mark.asyncio @@ -299,21 +668,21 @@ async def test_find_by_store_until(async_db_session): url="https://example-expired-1.com", result={}, author_id="rick@example.com", - store_until=now - timedelta(days=1) + store_until=now - timedelta(days=1), ) archive2 = models.Archive( id="archive-expired-2", url="https://example-expired-2.com", result={}, author_id="rick@example.com", - store_until=now - timedelta(hours=1) + store_until=now - timedelta(hours=1), ) archive3 = models.Archive( id="archive-active", url="https://example-active.com", result={}, author_id="rick@example.com", - store_until=now + timedelta(days=1) + store_until=now + timedelta(days=1), ) async_db_session.add_all([archive1, archive2, archive3]) await async_db_session.commit() @@ -323,11 +692,15 @@ async def test_find_by_store_until(async_db_session): assert len(list(expired)) == 2 # Should find 1 archive expired before 2 hours ago - expired = await crud.find_by_store_until(async_db_session, now - timedelta(hours=2)) + expired = await crud.find_by_store_until( + async_db_session, now - timedelta(hours=2) + ) assert len(list(expired)) == 1 # Should find no archives expired before 2 days ago - expired = await crud.find_by_store_until(async_db_session, now - timedelta(days=2)) + expired = await crud.find_by_store_until( + async_db_session, now - timedelta(days=2) + ) assert len(list(expired)) == 0 # Should not find deleted archives @@ -339,45 +712,78 @@ async def test_find_by_store_until(async_db_session): @pytest.mark.asyncio async def test_get_sheets_by_id_hash(async_db_session): + author_emails = ["rick@example.com", "morty@example.com", "jerry@example.com"] + # Add test data - authors = ["rick@example.com", "morty@example.com", "jerry@example.com"] sheets = [ - models.Sheet(id="sheet-0", name="sheet-0", author_id=authors[0], group_id=None, frequency="daily"), - models.Sheet(id="sheet-0-2", name="sheet-0-2", author_id=authors[0], group_id="spaceship", frequency="hourly"), - models.Sheet(id="sheet-1", name="sheet-1", author_id=authors[1], group_id=None, frequency="daily"), - models.Sheet(id="sheet-2", name="sheet-2", author_id=authors[2], group_id=None, frequency="daily") + models.Sheet( + id="sheet-0", + name="sheet-0", + author_id=author_emails[0], + group_id=None, + frequency="daily", + ), + models.Sheet( + id="sheet-0-2", + name="sheet-0-2", + author_id=author_emails[0], + group_id="spaceship", + frequency="hourly", + ), + models.Sheet( + id="sheet-1", + name="sheet-1", + author_id=author_emails[1], + group_id=None, + frequency="daily", + ), + models.Sheet( + id="sheet-2", + name="sheet-2", + author_id=author_emails[2], + group_id=None, + frequency="daily", + ), ] async_db_session.add_all(sheets) await async_db_session.commit() with patch("app.web.db.crud.fnv1a_hash_mod", return_value=1): # Test retrieving hourly sheets - hourly_sheets = await crud.get_sheets_by_id_hash(async_db_session, "hourly", 4, 1) + hourly_sheets = await crud.get_sheets_by_id_hash( + async_db_session, "hourly", 4, 1 + ) assert len(hourly_sheets) == 1 assert hourly_sheets[0].id == "sheet-0-2" assert hourly_sheets[0].frequency == "hourly" # Test retrieving daily sheets - daily_sheets = await crud.get_sheets_by_id_hash(async_db_session, "daily", 4, 1) + daily_sheets = await crud.get_sheets_by_id_hash( + async_db_session, "daily", 4, 1 + ) assert len(daily_sheets) == 3 assert all(sheet.frequency == "daily" for sheet in daily_sheets) - assert {sheet.id for sheet in daily_sheets} == {"sheet-0", "sheet-1", "sheet-2"} + assert {sheet.id for sheet in daily_sheets} == { + "sheet-0", + "sheet-1", + "sheet-2", + } # Test with non-matching hash - no_sheets = await crud.get_sheets_by_id_hash(async_db_session, "daily", 4, 3) + no_sheets = await crud.get_sheets_by_id_hash( + async_db_session, "daily", 4, 3 + ) assert len(no_sheets) == 0 # Test with non-existent frequency - weekly_sheets = await crud.get_sheets_by_id_hash(async_db_session, "weekly", 4, 1) + weekly_sheets = await crud.get_sheets_by_id_hash( + async_db_session, "weekly", 4, 1 + ) assert len(weekly_sheets) == 0 @pytest.mark.asyncio async def test_delete_stale_sheets(async_db_session): - from datetime import datetime, timedelta - - from sqlalchemy.sql import select - now = datetime.now() active_date = now - timedelta(days=5) stale_date = now - timedelta(days=15) @@ -389,29 +795,29 @@ async def test_delete_stale_sheets(async_db_session): name="Active Sheet 1", author_id="rick@example.com", frequency="daily", - last_url_archived_at=active_date + last_url_archived_at=active_date, ), models.Sheet( id="sheet-active-2", name="Active Sheet 2", author_id="morty@example.com", frequency="hourly", - last_url_archived_at=active_date + last_url_archived_at=active_date, ), models.Sheet( id="sheet-stale-1", name="Stale Sheet 1", author_id="rick@example.com", frequency="daily", - last_url_archived_at=stale_date + last_url_archived_at=stale_date, ), models.Sheet( id="sheet-stale-2", name="Stale Sheet 2", author_id="morty@example.com", frequency="daily", - last_url_archived_at=stale_date - ) + last_url_archived_at=stale_date, + ), ] async_db_session.add_all(sheets) await async_db_session.commit() diff --git a/app/tests/web/db/test_user_state.py b/app/tests/web/db/test_user_state.py index 665bf08..5d18cea 100644 --- a/app/tests/web/db/test_user_state.py +++ b/app/tests/web/db/test_user_state.py @@ -1,4 +1,3 @@ - from unittest.mock import MagicMock, PropertyMock, patch import pytest @@ -6,6 +5,7 @@ import pytest from app.shared.db import models from app.shared.user_groups import GroupInfo, GroupPermissions from app.web.db.user_state import UserState +from app.web.utils.misc import convert_priority_to_queue_dict def fresh_user_state(): @@ -21,39 +21,73 @@ def user_state(): def user_state_with_groups(user_state): user_groups = [ models.Group(id="no-permissions", permissions={}), - models.Group(id="group1", description="this is g1", service_account_email="sa1@example.com", permissions={"read": ["group1", "no-permissions"], "read_public": True, "archive_url": True, "archive_sheet": True, "max_archive_lifespan_months": 24, "max_monthly_urls": 100, "max_monthly_mbs": 1000, "priority": "high"}), - models.Group(id="group2", description="this is g2", service_account_email="sa2@example.com", permissions={"read": ["all"], "read_public": True, "archive_url": False, "archive_sheet": False, "max_archive_lifespan_months": -1, "max_monthly_urls": -1, "max_monthly_mbs": -1, "priority": "low", "sheet_frequency": {"daily"}}), + models.Group( + id="group1", + description="this is g1", + service_account_email="sa1@example.com", + permissions={ + "read": ["group1", "no-permissions"], + "read_public": True, + "archive_url": True, + "archive_sheet": True, + "max_archive_lifespan_months": 24, + "max_monthly_urls": 100, + "max_monthly_mbs": 1000, + "priority": "high", + }, + ), + models.Group( + id="group2", + description="this is g2", + service_account_email="sa2@example.com", + permissions={ + "read": ["all"], + "read_public": True, + "archive_url": False, + "archive_sheet": False, + "max_archive_lifespan_months": -1, + "max_monthly_urls": -1, + "max_monthly_mbs": -1, + "priority": "low", + "sheet_frequency": {"daily"}, + }, + ), ] - with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=user_groups): + with patch.object( + UserState, + "user_groups", + new_callable=PropertyMock, + return_value=user_groups, + ): yield user_state def test_permissions(user_state_with_groups): permissions = user_state_with_groups.permissions - assert permissions["all"].read == True - assert permissions["all"].read_public == True - assert permissions["all"].archive_url == True - assert permissions["all"].archive_sheet == True + assert permissions["all"].read is True + assert permissions["all"].read_public is True + assert permissions["all"].archive_url is True + assert permissions["all"].archive_sheet is True assert permissions["all"].max_archive_lifespan_months == -1 assert permissions["all"].max_monthly_urls == -1 assert permissions["all"].max_monthly_mbs == -1 assert permissions["all"].priority == "high" - assert permissions["group1"].read == set(["group1", "no-permissions"]) - assert permissions["group1"].read_public == True - assert permissions["group1"].archive_url == True - assert permissions["group1"].archive_sheet == True + assert permissions["group1"].read == {"group1", "no-permissions"} + assert permissions["group1"].read_public is True + assert permissions["group1"].archive_url is True + assert permissions["group1"].archive_sheet is True assert permissions["group1"].max_archive_lifespan_months == 24 assert permissions["group1"].max_monthly_urls == 100 assert permissions["group1"].max_monthly_mbs == 1000 assert permissions["group1"].priority == "high" - assert permissions["group2"].read == set(["all"]) - assert permissions["group2"].read_public == True - assert permissions["group2"].archive_url == False - assert permissions["group2"].archive_sheet == False + assert permissions["group2"].read == {"all"} + assert permissions["group2"].read_public is True + assert permissions["group2"].archive_url is False + assert permissions["group2"].archive_sheet is False assert permissions["group2"].max_archive_lifespan_months == -1 assert permissions["group2"].max_monthly_urls == -1 assert permissions["group2"].max_monthly_mbs == -1 @@ -63,13 +97,19 @@ def test_permissions(user_state_with_groups): def test_user_groups_names(user_state): - with patch('app.web.db.crud.get_user_group_names', return_value=["group1", "group2"]) as mock: + with patch( + "app.web.db.crud.get_user_group_names", + return_value=["group1", "group2"], + ) as mock: assert user_state.user_groups_names == ["group1", "group2", "default"] mock.assert_called_once_with(None, "test@example.com") def test_user_groups(user_state): - with patch('app.web.db.crud.get_user_groups_by_name', return_value=[MagicMock(), MagicMock()]) as mock: + with patch( + "app.web.db.crud.get_user_groups_by_name", + return_value=[MagicMock(), MagicMock()], + ) as mock: user_state._user_groups_names = ["group1", "group2"] assert len(user_state.user_groups) == 2 mock.assert_called_once_with(None, ["group1", "group2"]) @@ -78,85 +118,166 @@ def test_user_groups(user_state): def test_read(): us = fresh_user_state() - with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock: + with patch.object( + UserState, + "user_groups", + new_callable=PropertyMock, + return_value=[models.Group(id="no-permissions", permissions={})], + ) as mock: assert not hasattr(us, "_read") assert us.read == set() assert us._read == set() mock.assert_called_once() us = fresh_user_state() - with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"read": ["group1", "no-permissions"]})]): - assert us.read == set(["group1", "no-permissions"]) + with patch.object( + UserState, + "user_groups", + new_callable=PropertyMock, + return_value=[ + models.Group( + id="group1", permissions={"read": ["group1", "no-permissions"]} + ) + ], + ): + assert us.read == {"group1", "no-permissions"} us = fresh_user_state() - with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"read": ["all"]})]): - assert us.read == True + with patch.object( + UserState, + "user_groups", + new_callable=PropertyMock, + return_value=[models.Group(id="group1", permissions={"read": ["all"]})], + ): + assert us.read is True def test_read_public(): us = fresh_user_state() - with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock: + with patch.object( + UserState, + "user_groups", + new_callable=PropertyMock, + return_value=[models.Group(id="no-permissions", permissions={})], + ) as mock: assert not hasattr(us, "_read_public") - assert us.read_public == False - assert us._read_public == False + assert us.read_public is False + assert us._read_public is False mock.assert_called_once() # no new calls - assert us.read_public == False + assert us.read_public is False mock.assert_called_once() us = fresh_user_state() - with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"read_public": True})]): - assert us.read_public == True + with patch.object( + UserState, + "user_groups", + new_callable=PropertyMock, + return_value=[ + models.Group(id="group1", permissions={"read_public": True}) + ], + ): + assert us.read_public is True us = fresh_user_state() - with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"read_public": False})]): - assert us.read_public == False + with patch.object( + UserState, + "user_groups", + new_callable=PropertyMock, + return_value=[ + models.Group(id="group1", permissions={"read_public": False}) + ], + ): + assert us.read_public is False def test_archive_url(): us = fresh_user_state() - with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock: + with patch.object( + UserState, + "user_groups", + new_callable=PropertyMock, + return_value=[models.Group(id="no-permissions", permissions={})], + ) as mock: assert not hasattr(us, "_archive_url") - assert us.archive_url == False - assert us._archive_url == False + assert us.archive_url is False + assert us._archive_url is False mock.assert_called_once() # no new calls - assert us.archive_url == False + assert us.archive_url is False mock.assert_called_once() us = fresh_user_state() - with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"archive_url": False})]): - assert us.archive_url == False + with patch.object( + UserState, + "user_groups", + new_callable=PropertyMock, + return_value=[ + models.Group(id="group1", permissions={"archive_url": False}) + ], + ): + assert us.archive_url is False us = fresh_user_state() - with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"archive_url": True})]): - assert us.archive_url == True + with patch.object( + UserState, + "user_groups", + new_callable=PropertyMock, + return_value=[ + models.Group(id="group1", permissions={"archive_url": True}) + ], + ): + assert us.archive_url is True def test_archive_sheet(): us = fresh_user_state() - with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock: + with patch.object( + UserState, + "user_groups", + new_callable=PropertyMock, + return_value=[models.Group(id="no-permissions", permissions={})], + ) as mock: assert not hasattr(us, "_archive_sheet") - assert us.archive_sheet == False - assert us._archive_sheet == False + assert us.archive_sheet is False + assert us._archive_sheet is False mock.assert_called_once() # no new calls - assert us.archive_sheet == False + assert us.archive_sheet is False mock.assert_called_once() us = fresh_user_state() - with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"archive_sheet": False})]): - assert us.archive_sheet == False + with patch.object( + UserState, + "user_groups", + new_callable=PropertyMock, + return_value=[ + models.Group(id="group1", permissions={"archive_sheet": False}) + ], + ): + assert us.archive_sheet is False us = fresh_user_state() - with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"archive_sheet": True})]): - assert us.archive_sheet == True + with patch.object( + UserState, + "user_groups", + new_callable=PropertyMock, + return_value=[ + models.Group(id="group1", permissions={"archive_sheet": True}) + ], + ): + assert us.archive_sheet is True def test_sheet_frequency(): us = fresh_user_state() - with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock: + with patch.object( + UserState, + "user_groups", + new_callable=PropertyMock, + return_value=[models.Group(id="no-permissions", permissions={})], + ) as mock: assert not hasattr(us, "_sheet_frequency") assert us.sheet_frequency == set() assert us._sheet_frequency == set() @@ -166,18 +287,42 @@ def test_sheet_frequency(): mock.assert_called_once() us = fresh_user_state() - with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"sheet_frequency": ["daily", "hourly"]})]): + with patch.object( + UserState, + "user_groups", + new_callable=PropertyMock, + return_value=[ + models.Group( + id="group1", + permissions={"sheet_frequency": ["daily", "hourly"]}, + ) + ], + ): assert us.sheet_frequency == {"daily", "hourly"} us = fresh_user_state() - with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"sheet_frequency": []})]): + with patch.object( + UserState, + "user_groups", + new_callable=PropertyMock, + return_value=[ + models.Group(id="group1", permissions={"sheet_frequency": []}) + ], + ): assert us.sheet_frequency == set() def test_max_archive_lifespan_months(): us = fresh_user_state() - default = GroupPermissions.model_fields["max_archive_lifespan_months"].default - with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock: + default = GroupPermissions.model_fields[ + "max_archive_lifespan_months" + ].default + with patch.object( + UserState, + "user_groups", + new_callable=PropertyMock, + return_value=[models.Group(id="no-permissions", permissions={})], + ) as mock: assert not hasattr(us, "_max_archive_lifespan_months") assert us.max_archive_lifespan_months == default assert us._max_archive_lifespan_months == default @@ -187,18 +332,44 @@ def test_max_archive_lifespan_months(): mock.assert_called_once() us = fresh_user_state() - with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_archive_lifespan_months": 24})]): + with patch.object( + UserState, + "user_groups", + new_callable=PropertyMock, + return_value=[ + models.Group( + id="group1", permissions={"max_archive_lifespan_months": 24} + ) + ], + ): assert us.max_archive_lifespan_months == 24 us = fresh_user_state() - with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_archive_lifespan_months": 150}), models.Group(id="group2", permissions={"max_archive_lifespan_months": -1})]): + with patch.object( + UserState, + "user_groups", + new_callable=PropertyMock, + return_value=[ + models.Group( + id="group1", permissions={"max_archive_lifespan_months": 150} + ), + models.Group( + id="group2", permissions={"max_archive_lifespan_months": -1} + ), + ], + ): assert us.max_archive_lifespan_months == -1 def test_max_monthly_urls(): us = fresh_user_state() default = GroupPermissions.model_fields["max_monthly_urls"].default - with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock: + with patch.object( + UserState, + "user_groups", + new_callable=PropertyMock, + return_value=[models.Group(id="no-permissions", permissions={})], + ) as mock: assert not hasattr(us, "_max_monthly_urls") assert us.max_monthly_urls == default assert us._max_monthly_urls == default @@ -208,18 +379,38 @@ def test_max_monthly_urls(): mock.assert_called_once() us = fresh_user_state() - with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_monthly_urls": 100})]): + with patch.object( + UserState, + "user_groups", + new_callable=PropertyMock, + return_value=[ + models.Group(id="group1", permissions={"max_monthly_urls": 100}) + ], + ): assert us.max_monthly_urls == 100 us = fresh_user_state() - with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_monthly_urls": 150}), models.Group(id="group2", permissions={"max_monthly_urls": -1})]): + with patch.object( + UserState, + "user_groups", + new_callable=PropertyMock, + return_value=[ + models.Group(id="group1", permissions={"max_monthly_urls": 150}), + models.Group(id="group2", permissions={"max_monthly_urls": -1}), + ], + ): assert us.max_monthly_urls == -1 def test_max_monthly_mbs(): us = fresh_user_state() default = GroupPermissions.model_fields["max_monthly_mbs"].default - with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock: + with patch.object( + UserState, + "user_groups", + new_callable=PropertyMock, + return_value=[models.Group(id="no-permissions", permissions={})], + ) as mock: assert not hasattr(us, "_max_monthly_mbs") assert us.max_monthly_mbs == default assert us._max_monthly_mbs == default @@ -229,17 +420,37 @@ def test_max_monthly_mbs(): mock.assert_called_once() us = fresh_user_state() - with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_monthly_mbs": 1000})]): + with patch.object( + UserState, + "user_groups", + new_callable=PropertyMock, + return_value=[ + models.Group(id="group1", permissions={"max_monthly_mbs": 1000}) + ], + ): assert us.max_monthly_mbs == 1000 us = fresh_user_state() - with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_monthly_mbs": 1500}), models.Group(id="group2", permissions={"max_monthly_mbs": -1})]): + with patch.object( + UserState, + "user_groups", + new_callable=PropertyMock, + return_value=[ + models.Group(id="group1", permissions={"max_monthly_mbs": 1500}), + models.Group(id="group2", permissions={"max_monthly_mbs": -1}), + ], + ): assert us.max_monthly_mbs == -1 def test_priority(user_state): default = GroupPermissions.model_fields["priority"].default - with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock: + with patch.object( + UserState, + "user_groups", + new_callable=PropertyMock, + return_value=[models.Group(id="no-permissions", permissions={})], + ) as mock: assert not hasattr(user_state, "_priority") assert user_state.priority == default assert user_state._priority == default @@ -249,11 +460,26 @@ def test_priority(user_state): mock.assert_called_once() us = fresh_user_state() - with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"priority": "high"})]): + with patch.object( + UserState, + "user_groups", + new_callable=PropertyMock, + return_value=[ + models.Group(id="group1", permissions={"priority": "high"}) + ], + ): assert us.priority == "high" us = fresh_user_state() - with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"priority": "low"}), models.Group(id="group2", permissions={"priority": "medium"})]): + with patch.object( + UserState, + "user_groups", + new_callable=PropertyMock, + return_value=[ + models.Group(id="group1", permissions={"priority": "low"}), + models.Group(id="group2", permissions={"priority": "medium"}), + ], + ): assert us.priority == "low" @@ -263,21 +489,45 @@ def test_active(): (True, False, False, False, True), (False, True, False, False, True), (False, False, True, False, True), - (False, False, False, True, True) + (False, False, False, True, True), ]: us = fresh_user_state() - with patch.object(UserState, 'read', new_callable=PropertyMock, return_value=read), \ - patch.object(UserState, 'read_public', new_callable=PropertyMock, return_value=read_public), \ - patch.object(UserState, 'archive_url', new_callable=PropertyMock, return_value=archive_url), \ - patch.object(UserState, 'archive_sheet', new_callable=PropertyMock, return_value=archive_sheet): + with ( + patch.object( + UserState, "read", new_callable=PropertyMock, return_value=read + ), + patch.object( + UserState, + "read_public", + new_callable=PropertyMock, + return_value=read_public, + ), + patch.object( + UserState, + "archive_url", + new_callable=PropertyMock, + return_value=archive_url, + ), + patch.object( + UserState, + "archive_sheet", + new_callable=PropertyMock, + return_value=archive_sheet, + ), + ): assert us.active == is_active def test_in_group(user_state): - with patch.object(UserState, 'user_groups_names', new_callable=PropertyMock, return_value=["group1", "group2"]): - assert user_state.in_group("group1") == True - assert user_state.in_group("group2") == True - assert user_state.in_group("group3") == False + with patch.object( + UserState, + "user_groups_names", + new_callable=PropertyMock, + return_value=["group1", "group2"], + ): + assert user_state.in_group("group1") is True + assert user_state.in_group("group2") is True + assert user_state.in_group("group3") is False def test_usage(db_session): @@ -295,10 +545,34 @@ def test_usage(db_session): ] megabytes = int(sum(bytes) / 1024 / 1024) - with patch.object(db_session, 'query', side_effect=[ - MagicMock(filter=MagicMock(return_value=MagicMock(group_by=MagicMock(return_value=MagicMock(all=MagicMock(return_value=user_sheets)))))), - MagicMock(filter=MagicMock(return_value=MagicMock(group_by=MagicMock(return_value=MagicMock(all=MagicMock(return_value=urls_by_group)))))) - ]): + with patch.object( + db_session, + "query", + side_effect=[ + MagicMock( + filter=MagicMock( + return_value=MagicMock( + group_by=MagicMock( + return_value=MagicMock( + all=MagicMock(return_value=user_sheets) + ) + ) + ) + ) + ), + MagicMock( + filter=MagicMock( + return_value=MagicMock( + group_by=MagicMock( + return_value=MagicMock( + all=MagicMock(return_value=urls_by_group) + ) + ) + ) + ) + ), + ], + ): usage_response = user_state.usage() assert usage_response.monthly_urls == 155 @@ -306,11 +580,15 @@ def test_usage(db_session): assert usage_response.total_sheets == 115 assert usage_response.groups["group1"].monthly_urls == 50 - assert usage_response.groups["group1"].monthly_mbs == int(bytes[0] / 1024 / 1024) + assert usage_response.groups["group1"].monthly_mbs == int( + bytes[0] / 1024 / 1024 + ) assert usage_response.groups["group1"].total_sheets == 5 assert usage_response.groups["group2"].monthly_urls == 100 - assert usage_response.groups["group2"].monthly_mbs == int(bytes[1] / 1024 / 1024) + assert usage_response.groups["group2"].monthly_mbs == int( + bytes[1] / 1024 / 1024 + ) assert usage_response.groups["group2"].total_sheets == 10 assert usage_response.groups["group3"].monthly_urls == 0 @@ -318,7 +596,9 @@ def test_usage(db_session): assert usage_response.groups["group3"].total_sheets == 100 assert usage_response.groups["group4"].monthly_urls == 5 - assert usage_response.groups["group4"].monthly_mbs == int(bytes[2] / 1024 / 1024) + assert usage_response.groups["group4"].monthly_mbs == int( + bytes[2] / 1024 / 1024 + ) assert usage_response.groups["group4"].total_sheets == 0 @@ -334,8 +614,23 @@ def test_has_quota_monthly_sheets(db_session): ] for permissions, count, expected in test_cases: - with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions): - with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(count=MagicMock(return_value=count))))): + with patch.object( + UserState, + "permissions", + new_callable=PropertyMock, + return_value=permissions, + ): + with patch.object( + us.db, + "query", + return_value=MagicMock( + filter=MagicMock( + return_value=MagicMock( + count=MagicMock(return_value=count) + ) + ) + ), + ): assert us.has_quota_monthly_sheets("group1") == expected @@ -350,8 +645,23 @@ def test_has_quota_max_monthly_urls(db_session): ] for permissions, count, expected in test_cases: - with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions): - with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(count=MagicMock(return_value=count))))): + with patch.object( + UserState, + "permissions", + new_callable=PropertyMock, + return_value=permissions, + ): + with patch.object( + us.db, + "query", + return_value=MagicMock( + filter=MagicMock( + return_value=MagicMock( + count=MagicMock(return_value=count) + ) + ) + ), + ): assert us.has_quota_max_monthly_urls("group1") == expected test_cases = [ (-1, 1000, True), @@ -361,8 +671,23 @@ def test_has_quota_max_monthly_urls(db_session): ] for max_urls, count, expected in test_cases: - with patch.object(UserState, 'max_monthly_urls', new_callable=PropertyMock, return_value=max_urls): - with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(count=MagicMock(return_value=count))))): + with patch.object( + UserState, + "max_monthly_urls", + new_callable=PropertyMock, + return_value=max_urls, + ): + with patch.object( + us.db, + "query", + return_value=MagicMock( + filter=MagicMock( + return_value=MagicMock( + count=MagicMock(return_value=count) + ) + ) + ), + ): assert us.has_quota_max_monthly_urls("") == expected @@ -377,8 +702,29 @@ def test_has_quota_max_monthly_mbs(db_session): ] for permissions, mbs, expected in test_cases: - with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions): - with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(with_entities=MagicMock(return_value=MagicMock(scalar=MagicMock(return_value=mbs * 1024 * 1024))))))): + with patch.object( + UserState, + "permissions", + new_callable=PropertyMock, + return_value=permissions, + ): + with patch.object( + us.db, + "query", + return_value=MagicMock( + filter=MagicMock( + return_value=MagicMock( + with_entities=MagicMock( + return_value=MagicMock( + scalar=MagicMock( + return_value=mbs * 1024 * 1024 + ) + ) + ) + ) + ) + ), + ): assert us.has_quota_max_monthly_mbs("group1") == expected test_cases = [ @@ -389,8 +735,29 @@ def test_has_quota_max_monthly_mbs(db_session): ] for max_mbs, mbs, expected in test_cases: - with patch.object(UserState, 'max_monthly_mbs', new_callable=PropertyMock, return_value=max_mbs): - with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(with_entities=MagicMock(return_value=MagicMock(scalar=MagicMock(return_value=mbs * 1024 * 1024))))))): + with patch.object( + UserState, + "max_monthly_mbs", + new_callable=PropertyMock, + return_value=max_mbs, + ): + with patch.object( + us.db, + "query", + return_value=MagicMock( + filter=MagicMock( + return_value=MagicMock( + with_entities=MagicMock( + return_value=MagicMock( + scalar=MagicMock( + return_value=mbs * 1024 * 1024 + ) + ) + ) + ) + ) + ), + ): assert us.has_quota_max_monthly_mbs("") == expected @@ -400,10 +767,15 @@ def test_can_manually_trigger(user_state): "group2": GroupInfo(manually_trigger_sheet=False), } - with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions): - assert user_state.can_manually_trigger("group1") == True - assert user_state.can_manually_trigger("group2") == False - assert user_state.can_manually_trigger("group3") == False + with patch.object( + UserState, + "permissions", + new_callable=PropertyMock, + return_value=permissions, + ): + assert user_state.can_manually_trigger("group1") is True + assert user_state.can_manually_trigger("group2") is False + assert user_state.can_manually_trigger("group3") is False def test_is_sheet_frequency_allowed(user_state): @@ -412,23 +784,44 @@ def test_is_sheet_frequency_allowed(user_state): "group2": GroupInfo(sheet_frequency={"daily"}), } - with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions): - assert user_state.is_sheet_frequency_allowed("group1", "daily") == True - assert user_state.is_sheet_frequency_allowed("group1", "hourly") == True - assert user_state.is_sheet_frequency_allowed("group1", "weekly") == False - assert user_state.is_sheet_frequency_allowed("group2", "hourly") == False - assert user_state.is_sheet_frequency_allowed("group2", "daily") == True - assert user_state.is_sheet_frequency_allowed("group3", "daily") == False + with patch.object( + UserState, + "permissions", + new_callable=PropertyMock, + return_value=permissions, + ): + assert user_state.is_sheet_frequency_allowed("group1", "daily") is True + assert user_state.is_sheet_frequency_allowed("group1", "hourly") is True + assert ( + user_state.is_sheet_frequency_allowed("group1", "weekly") is False + ) + assert ( + user_state.is_sheet_frequency_allowed("group2", "hourly") is False + ) + assert user_state.is_sheet_frequency_allowed("group2", "daily") is True + assert user_state.is_sheet_frequency_allowed("group3", "daily") is False def test_priority_group(user_state): - from app.web.utils.misc import convert_priority_to_queue_dict - with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[ - models.Group(id="group1", permissions={"priority": "high"}), - models.Group(id="group2", permissions={"priority": "medium"}), - models.Group(id="group3", permissions={"priority": "low"}), - ]): - assert user_state.priority_group("group1") == convert_priority_to_queue_dict("high") - assert user_state.priority_group("group2") == convert_priority_to_queue_dict("medium") - assert user_state.priority_group("group3") == convert_priority_to_queue_dict("low") - assert user_state.priority_group("group4") == convert_priority_to_queue_dict("low") + with patch.object( + UserState, + "user_groups", + new_callable=PropertyMock, + return_value=[ + models.Group(id="group1", permissions={"priority": "high"}), + models.Group(id="group2", permissions={"priority": "medium"}), + models.Group(id="group3", permissions={"priority": "low"}), + ], + ): + assert user_state.priority_group( + "group1" + ) == convert_priority_to_queue_dict("high") + assert user_state.priority_group( + "group2" + ) == convert_priority_to_queue_dict("medium") + assert user_state.priority_group( + "group3" + ) == convert_priority_to_queue_dict("low") + assert user_state.priority_group( + "group4" + ) == convert_priority_to_queue_dict("low") diff --git a/app/tests/web/endpoints/test_default.py b/app/tests/web/endpoints/test_default.py index e4e34cc..970e705 100644 --- a/app/tests/web/endpoints/test_default.py +++ b/app/tests/web/endpoints/test_default.py @@ -1,17 +1,20 @@ +from http import HTTPStatus from unittest.mock import MagicMock import pytest from fastapi.testclient import TestClient +from loguru import logger from app.shared.schemas import Usage, UsageResponse from app.shared.user_groups import GroupInfo -from app.tests.web.db.test_crud import test_data from app.web.config import VERSION +from app.web.security import get_user_state +from app.web.utils.metrics import measure_regular_metrics def test_endpoint_home(client_with_auth): r = client_with_auth.get("/") - assert r.status_code == 200 + assert r.status_code == HTTPStatus.OK j = r.json() assert "version" in j and j["version"] == VERSION assert "breakingChanges" in j @@ -20,7 +23,7 @@ def test_endpoint_home(client_with_auth): def test_endpoint_health(client_with_auth): r = client_with_auth.get("/health") - assert r.status_code == 200 + assert r.status_code == HTTPStatus.OK assert r.json() == {"status": "ok"} @@ -31,32 +34,31 @@ def test_endpoint_active_no_auth(client, test_no_auth): def test_endpoint_active(app): m_user_state = MagicMock() - from app.web.security import get_user_state app.dependency_overrides[get_user_state] = lambda: m_user_state # inactive user m_user_state.active = False client = TestClient(app) r = client.get("/user/active") - assert r.status_code == 200 + assert r.status_code == HTTPStatus.OK assert r.json() == {"active": False} # active user m_user_state.active = True client = TestClient(app) r = client.get("/user/active") - assert r.status_code == 200 + assert r.status_code == HTTPStatus.OK assert r.json() == {"active": True} def test_no_serve_local_archive_by_default(client_with_auth): r = client_with_auth.get("/app/local_archive_test/temp.txt") - assert r.status_code == 404 + assert r.status_code == HTTPStatus.NOT_FOUND def test_favicon(client_with_auth): r = client_with_auth.get("/favicon.ico") - assert r.status_code == 200 + assert r.status_code == HTTPStatus.OK assert r.headers["content-type"] == "image/vnd.microsoft.icon" @@ -72,8 +74,10 @@ def test_endpoint_test_prometheus_no_user_auth(client_with_auth, test_no_auth): async def test_prometheus_metrics(test_data, client_with_token, get_settings): # before metrics calculation r = client_with_token.get("/metrics") - assert r.status_code == 200 - assert r.headers["content-type"] == "text/plain; version=0.0.4; charset=utf-8" + assert r.status_code == HTTPStatus.OK + assert ( + r.headers["content-type"] == "text/plain; version=0.0.4; charset=utf-8" + ) assert "disk_utilization" in r.text assert "database_metrics" in r.text assert "exceptions" in r.text @@ -81,8 +85,9 @@ async def test_prometheus_metrics(test_data, client_with_token, get_settings): assert 'disk_utilization{type="used"}' not in r.text # after metrics calculation - from app.web.utils.metrics import measure_regular_metrics - await measure_regular_metrics(get_settings.DATABASE_PATH, 60 * 60 * 24 * 31 * 12 * 100) + await measure_regular_metrics( + get_settings.DATABASE_PATH, 60 * 60 * 24 * 31 * 12 * 100 + ) r2 = client_with_token.get("/metrics") assert 'disk_utilization{type="used"}' in r2.text assert 'disk_utilization{type="free"}' in r2.text @@ -90,20 +95,37 @@ async def test_prometheus_metrics(test_data, client_with_token, get_settings): assert 'database_metrics{query="count_archives"} 100.0' in r2.text assert 'database_metrics{query="count_archive_urls"} 1000.0' in r2.text assert 'database_metrics{query="count_users"} 3.0' in r2.text - assert 'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0' in r2.text - assert 'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0' in r2.text - assert 'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0' in r2.text + assert ( + 'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0' + in r2.text + ) + assert ( + 'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0' + in r2.text + ) + assert ( + 'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0' + in r2.text + ) # 30s window, should not change the gauges nor the total in the counters - from app.web.utils.metrics import measure_regular_metrics await measure_regular_metrics(get_settings.DATABASE_PATH, 30) r3 = client_with_token.get("/metrics") assert 'database_metrics{query="count_archives"} 100.0' in r3.text assert 'database_metrics{query="count_archive_urls"} 1000.0' in r3.text assert 'database_metrics{query="count_users"} 3.0' in r3.text - assert 'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0' in r3.text - assert 'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0' in r3.text - assert 'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0' in r3.text + assert ( + 'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0' + in r3.text + ) + assert ( + 'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0' + in r3.text + ) + assert ( + 'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0' + in r3.text + ) def test_endpoint_get_user_permissions_no_user_auth(client, test_no_auth): @@ -111,14 +133,12 @@ def test_endpoint_get_user_permissions_no_user_auth(client, test_no_auth): def test_endpoint_get_user_permissions(app): - from app.web.security import get_user_state - m_user_state = MagicMock() rv = { "all": GroupInfo(read=True), "group1": GroupInfo(archive_url=True), } - from loguru import logger + logger.info(rv) m_user_state.permissions = rv @@ -126,13 +146,13 @@ def test_endpoint_get_user_permissions(app): client = TestClient(app) r = client.get("/user/permissions") - assert r.status_code == 200 + assert r.status_code == HTTPStatus.OK response = r.json() assert response.keys() == {"all", "group1"} assert response["all"]["read"] assert response["group1"]["read"] == [] assert response["group1"]["archive_url"] - assert response["all"]["archive_url"] == False + assert response["all"]["archive_url"] is False def test_endpoint_get_user_usage_no_user_auth(client, test_no_auth): @@ -140,8 +160,6 @@ def test_endpoint_get_user_usage_no_user_auth(client, test_no_auth): def test_endpoint_get_user_usage_inactive(app): - from app.web.security import get_user_state - m_user_state = MagicMock() m_user_state.active = False @@ -149,13 +167,11 @@ def test_endpoint_get_user_usage_inactive(app): client = TestClient(app) r = client.get("/user/usage") - assert r.status_code == 403 + assert r.status_code == HTTPStatus.FORBIDDEN assert r.json() == {"detail": "User is not active."} def test_endpoint_get_user_usage_active(app): - from app.web.security import get_user_state - m_user_state = MagicMock() m_user_state.active = True mock_usage = UsageResponse( @@ -164,8 +180,8 @@ def test_endpoint_get_user_usage_active(app): total_sheets=3, groups={ "group1": Usage(monthly_urls=4, monthly_mbs=5, total_sheets=6), - "group2": Usage(monthly_urls=7, monthly_mbs=8, total_sheets=9) - } + "group2": Usage(monthly_urls=7, monthly_mbs=8, total_sheets=9), + }, ) m_user_state.usage.return_value = mock_usage @@ -173,5 +189,5 @@ def test_endpoint_get_user_usage_active(app): client = TestClient(app) r = client.get("/user/usage") - assert r.status_code == 200 + assert r.status_code == HTTPStatus.OK assert UsageResponse(**r.json()) == mock_usage diff --git a/app/tests/web/endpoints/test_interoperability.py b/app/tests/web/endpoints/test_interoperability.py index 703f69a..d102116 100644 --- a/app/tests/web/endpoints/test_interoperability.py +++ b/app/tests/web/endpoints/test_interoperability.py @@ -1,10 +1,9 @@ import json from datetime import datetime +from http import HTTPStatus from unittest.mock import MagicMock, patch from app.shared.db import models -from app.web.config import ALLOW_ANY_EMAIL -from app.web.db import crud def test_submit_manual_archive_unauthenticated(client, test_no_auth): @@ -15,46 +14,134 @@ def test_submit_manual_archive_not_user_auth(client_with_auth, test_no_auth): test_no_auth(client_with_auth.post, "/interop/submit-archive") -@patch("app.web.endpoints.interoperability.business_logic", return_value=MagicMock(get_store_archive_until=MagicMock(return_value=datetime))) +@patch( + "app.web.endpoints.interoperability.business_logic", + return_value=MagicMock( + get_store_archive_until=MagicMock(return_value=datetime) + ), +) def test_submit_manual_archive(m1, client_with_token, db_session): # normal workflow - aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.s3.com"]}]}) - r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": True, "author_id": "jerry@gmail.com", "group_id": "spaceship", "tags": ["test"], "url": "http://example.com"}) - assert r.status_code == 201 + aa_metadata = json.dumps( + { + "status": "test: success", + "metadata": {"url": "http://example.com"}, + "media": [{"filename": "fn1", "urls": ["http://example.s3.com"]}], + } + ) + r = client_with_token.post( + "/interop/submit-archive", + json={ + "result": aa_metadata, + "public": True, + "author_id": "jerry@gmail.com", + "group_id": "spaceship", + "tags": ["test"], + "url": "http://example.com", + }, + ) + assert r.status_code == HTTPStatus.CREATED assert "id" in r.json() - inserted = db_session.query(models.Archive).filter(models.Archive.id == r.json()["id"]).first() + inserted = ( + db_session.query(models.Archive) + .filter(models.Archive.id == r.json()["id"]) + .first() + ) assert inserted.url == "http://example.com" assert inserted.group_id == "spaceship" assert inserted.author_id == "jerry@gmail.com" assert sorted([t.id for t in inserted.tags]) == sorted(["test", "manual"]) assert inserted.public - assert type(inserted.result) == dict + assert isinstance(inserted.result, dict) assert [u.url for u in inserted.urls] == ["http://example.s3.com"] - assert type(inserted.store_until) == datetime + assert isinstance(inserted.store_until, datetime) # cannot have the same URL twice - aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.com", "http://example.com"]}]}) - r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": False, "author_id": "jerry@gmail.com", "tags": ["test"], "url": "http://example.com"}) - assert r.status_code == 422 - assert r.json() == {"detail": "Cannot insert into DB due to integrity error, likely duplicate urls."} + aa_metadata = json.dumps( + { + "status": "test: success", + "metadata": {"url": "http://example.com"}, + "media": [ + { + "filename": "fn1", + "urls": ["http://example.com", "http://example.com"], + } + ], + } + ) + r = client_with_token.post( + "/interop/submit-archive", + json={ + "result": aa_metadata, + "public": False, + "author_id": "jerry@gmail.com", + "tags": ["test"], + "url": "http://example.com", + }, + ) + assert r.status_code == HTTPStatus.UNPROCESSABLE_ENTITY + assert r.json() == { + "detail": "Cannot insert into DB due to integrity error, likely duplicate urls." + } # test with invalid JSON def test_submit_manual_archive_invalid_json(client_with_token): - r = client_with_token.post("/interop/submit-archive", json={"result": "invalid json", "public": False, "author_id": "jer", "tags": ["test"], "url": "http://example.com"}) - assert r.status_code == 422 + r = client_with_token.post( + "/interop/submit-archive", + json={ + "result": "invalid json", + "public": False, + "author_id": "jer", + "tags": ["test"], + "url": "http://example.com", + }, + ) + assert r.status_code == HTTPStatus.UNPROCESSABLE_ENTITY assert r.json() == {"detail": "Invalid JSON in result field."} -@patch("app.web.endpoints.interoperability.business_logic.get_store_archive_until", side_effect=AssertionError("AssertionError")) -def test_submit_manual_archive_no_store_until(m_sau, client_with_token, db_session): - aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.s3.com"]}]}) - r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": True, "author_id": "jerry@gmail.com", "group_id": "spaceship", "tags": ["test"], "url": "http://example.com"}) - assert r.status_code == 201 +@patch( + "app.web.endpoints.interoperability.business_logic.get_store_archive_until", + side_effect=AssertionError("AssertionError"), +) +def test_submit_manual_archive_no_store_until( + m_sau, client_with_token, db_session +): + aa_metadata = json.dumps( + { + "status": "test: success", + "metadata": {"url": "http://example.com"}, + "media": [{"filename": "fn1", "urls": ["http://example.s3.com"]}], + } + ) + r = client_with_token.post( + "/interop/submit-archive", + json={ + "result": aa_metadata, + "public": True, + "author_id": "jerry@gmail.com", + "group_id": "spaceship", + "tags": ["test"], + "url": "http://example.com", + }, + ) + assert r.status_code == HTTPStatus.CREATED assert len(r.json()["id"]) == 36 - res = db_session.query(models.Archive).filter(models.Archive.id == r.json()["id"]).first() + res = ( + db_session.query(models.Archive) + .filter(models.Archive.id == r.json()["id"]) + .first() + ) assert res.store_until is None # testing that store_until = None is not comparable with datetime, and will always return False - res = db_session.query(models.Archive).filter(models.Archive.id == r.json()["id"], models.Archive.store_until < datetime.now()).first() + res = ( + db_session.query(models.Archive) + .filter( + models.Archive.id == r.json()["id"], + models.Archive.store_until < datetime.now(), + ) + .first() + ) assert res is None diff --git a/app/tests/web/endpoints/test_sheet.py b/app/tests/web/endpoints/test_sheet.py index 9b47228..c318496 100644 --- a/app/tests/web/endpoints/test_sheet.py +++ b/app/tests/web/endpoints/test_sheet.py @@ -1,10 +1,13 @@ -import json from datetime import datetime +from http import HTTPStatus from unittest.mock import MagicMock, patch from fastapi.testclient import TestClient +from app.shared.db import models from app.shared.schemas import TaskResult +from app.web.db.user_state import UserState +from app.web.security import get_user_state def test_endpoints_no_auth(client, test_no_auth): @@ -20,34 +23,38 @@ def test_create_sheet_endpoint(app_with_auth, db_session): "id": "123-sheet-id", "name": "Test Sheet", "group_id": "spaceship", - "frequency": "daily" + "frequency": "daily", } # with good data response = client_with_auth.post("/sheet/create", json=good_data) - assert response.status_code == 201 + assert response.status_code == HTTPStatus.CREATED j = response.json() assert datetime.fromisoformat(j.pop("created_at")) assert datetime.fromisoformat(j.pop("last_url_archived_at")) - assert j.pop("author_id") == 'morty@example.com' + assert j.pop("author_id") == "morty@example.com" assert j == good_data # already exists response = client_with_auth.post("/sheet/create", json=good_data) - assert response.status_code == 400 - assert response.json() == {"detail": "Sheet with this ID is already being archived."} + assert response.status_code == HTTPStatus.BAD_REQUEST + assert response.json() == { + "detail": "Sheet with this ID is already being archived." + } # bad group bad_data = good_data.copy() bad_data["group_id"] = "not a group" response = client_with_auth.post("/sheet/create", json=bad_data) - assert response.status_code == 403 - assert response.json() == {"detail": "User does not have access to this group."} + assert response.status_code == HTTPStatus.FORBIDDEN + assert response.json() == { + "detail": "User does not have access to this group." + } # switch to jerry who's got less quota/permissions - from app.web.db.user_state import UserState - from app.web.security import get_user_state - app_with_auth.dependency_overrides[get_user_state] = lambda: UserState(db_session, "jerry@example.com") + app_with_auth.dependency_overrides[get_user_state] = lambda: UserState( + db_session, "jerry@example.com" + ) client_jerry = TestClient(app_with_auth) # frequency not allowed @@ -56,39 +63,62 @@ def test_create_sheet_endpoint(app_with_auth, db_session): jerry_data["frequency"] = "hourly" jerry_data["id"] = "jerry-sheet-id" response = client_jerry.post("/sheet/create", json=jerry_data) - assert response.status_code == 422 - assert response.json() == {"detail": "Invalid frequency selected for this group."} + assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY + assert response.json() == { + "detail": "Invalid frequency selected for this group." + } jerry_data["frequency"] = "daily" # success for the first sheet, bad quota on second response = client_jerry.post("/sheet/create", json=jerry_data) - assert response.status_code == 201 + assert response.status_code == HTTPStatus.CREATED response = client_jerry.post("/sheet/create", json=jerry_data) - assert response.status_code == 429 - assert response.json() == {"detail": "User has reached their sheet quota for this group."} + assert response.status_code == HTTPStatus.TOO_MANY_REQUESTS + assert response.json() == { + "detail": "User has reached their sheet quota for this group." + } def test_get_user_sheets_endpoint(client_with_auth, db_session): # no data response = client_with_auth.get("/sheet/mine") - assert response.status_code == 200 + assert response.status_code == HTTPStatus.OK assert response.json() == [] # with data - from app.shared.db import models db_session.add( - models.Sheet(id="123", name="Test Sheet 1", author_id="morty@example.com", group_id="spaceship", frequency="hourly") + models.Sheet( + id="123", + name="Test Sheet 1", + author_id="morty@example.com", + group_id="spaceship", + frequency="hourly", + ) ) db_session.commit() - db_session.add_all([ - models.Sheet(id="456", name="Test Sheet 2", author_id="morty@example.com", group_id="interdimensional", frequency="daily"), - models.Sheet(id="789", name="Test Sheet 3", author_id="rick@example.com", group_id="interdimensional", frequency="hourly"), - ]) + db_session.add_all( + [ + models.Sheet( + id="456", + name="Test Sheet 2", + author_id="morty@example.com", + group_id="interdimensional", + frequency="daily", + ), + models.Sheet( + id="789", + name="Test Sheet 3", + author_id="rick@example.com", + group_id="interdimensional", + frequency="hourly", + ), + ] + ) db_session.commit() response = client_with_auth.get("/sheet/mine") - assert response.status_code == 200 + assert response.status_code == HTTPStatus.OK r = response.json() assert isinstance(r, list) assert len(r) == 2 @@ -97,65 +127,84 @@ def test_get_user_sheets_endpoint(client_with_auth, db_session): assert datetime.fromisoformat(r[1].pop("created_at")) assert datetime.fromisoformat(r[1].pop("last_url_archived_at")) assert r[0] == { - 'id': '123', - 'author_id': 'morty@example.com', - 'frequency': 'hourly', - 'group_id': 'spaceship', - 'name': 'Test Sheet 1', + "id": "123", + "author_id": "morty@example.com", + "frequency": "hourly", + "group_id": "spaceship", + "name": "Test Sheet 1", } assert r[1] == { - 'id': '456', - 'author_id': 'morty@example.com', - 'frequency': 'daily', - 'group_id': 'interdimensional', - 'name': 'Test Sheet 2', + "id": "456", + "author_id": "morty@example.com", + "frequency": "daily", + "group_id": "interdimensional", + "name": "Test Sheet 2", } def test_delete_sheet_endpoint(client_with_auth, db_session): # missing sheet response = client_with_auth.delete("/sheet/123-sheet-id") - assert response.status_code == 200 - assert response.json() == { - "id": "123-sheet-id", - "deleted": False - } + assert response.status_code == HTTPStatus.OK + assert response.json() == {"id": "123-sheet-id", "deleted": False} # add sheets for deletion - from app.shared.db import models - db_session.add_all([ - models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="interdimensional", frequency="daily"), - models.Sheet(id="456-sheet-id", name="Test Sheet 2", author_id="rick@example.com", group_id="spaceship", frequency="hourly"), - ]) + db_session.add_all( + [ + models.Sheet( + id="123-sheet-id", + name="Test Sheet 1", + author_id="morty@example.com", + group_id="interdimensional", + frequency="daily", + ), + models.Sheet( + id="456-sheet-id", + name="Test Sheet 2", + author_id="rick@example.com", + group_id="spaceship", + frequency="hourly", + ), + ] + ) db_session.commit() # morty can delete his response = client_with_auth.delete("/sheet/123-sheet-id") - assert response.status_code == 200 + assert response.status_code == HTTPStatus.OK assert response.json() == {"id": "123-sheet-id", "deleted": True} # but only once response = client_with_auth.delete("/sheet/123-sheet-id") - assert response.status_code == 200 + assert response.status_code == HTTPStatus.OK assert response.json() == {"id": "123-sheet-id", "deleted": False} - # and not rick's + # and not Rick's response = client_with_auth.delete("/sheet/456-sheet-id") - assert response.status_code == 200 + assert response.status_code == HTTPStatus.OK assert response.json() == {"id": "456-sheet-id", "deleted": False} class TestArchiveUserSheetEndpoint: @patch("app.web.endpoints.sheet.celery", return_value=MagicMock()) def test_normal_flow(self, m_celery, client_with_auth, db_session): - from app.shared.db import models - db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="spaceship", frequency="hourly")) + db_session.add( + models.Sheet( + id="123-sheet-id", + name="Test Sheet 1", + author_id="morty@example.com", + group_id="spaceship", + frequency="hourly", + ) + ) db_session.commit() m_signature = MagicMock() - m_signature.apply_async.return_value = TaskResult(id="123-taskid", status="PENDING", result="") + m_signature.apply_async.return_value = TaskResult( + id="123-taskid", status="PENDING", result="" + ) m_celery.signature.return_value = m_signature r = client_with_auth.post("/sheet/123-sheet-id/archive") - assert r.status_code == 201 + assert r.status_code == HTTPStatus.CREATED assert r.json() == {"id": "123-taskid"} m_celery.signature.assert_called_once() m_signature.apply_async.assert_called_once() @@ -165,29 +214,54 @@ class TestArchiveUserSheetEndpoint: def test_missing_data(self, client_with_auth): r = client_with_auth.post("/sheet/123-sheet-id/archive") - assert r.status_code == 403 + assert r.status_code == HTTPStatus.FORBIDDEN assert r.json() == {"detail": "No access to this sheet."} def test_no_access(self, client_with_auth, db_session): - from app.shared.db import models - db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="rick@example.com", group_id="spaceship", frequency="hourly")) + db_session.add( + models.Sheet( + id="123-sheet-id", + name="Test Sheet 1", + author_id="rick@example.com", + group_id="spaceship", + frequency="hourly", + ) + ) db_session.commit() r = client_with_auth.post("/sheet/123-sheet-id/archive") - assert r.status_code == 403 + assert r.status_code == HTTPStatus.FORBIDDEN assert r.json() == {"detail": "No access to this sheet."} def test_user_not_in_group(self, client_with_auth, db_session): - from app.shared.db import models - db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="interdimensional", frequency="hourly")) + db_session.add( + models.Sheet( + id="123-sheet-id", + name="Test Sheet 1", + author_id="morty@example.com", + group_id="interdimensional", + frequency="hourly", + ) + ) db_session.commit() r = client_with_auth.post("/sheet/123-sheet-id/archive") - assert r.status_code == 403 - assert r.json() == {"detail": "User does not have access to this group."} + assert r.status_code == HTTPStatus.FORBIDDEN + assert r.json() == { + "detail": "User does not have access to this group." + } def test_user_cannot_manually_trigger(self, client_with_auth, db_session): - from app.shared.db import models - db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="default", frequency="hourly")) + db_session.add( + models.Sheet( + id="123-sheet-id", + name="Test Sheet 1", + author_id="morty@example.com", + group_id="default", + frequency="hourly", + ) + ) db_session.commit() r = client_with_auth.post("/sheet/123-sheet-id/archive") - assert r.status_code == 429 - assert r.json() == {"detail": "User cannot manually trigger sheet archiving in this group."} + assert r.status_code == HTTPStatus.TOO_MANY_REQUESTS + assert r.json() == { + "detail": "User cannot manually trigger sheet archiving in this group." + } diff --git a/app/tests/web/endpoints/test_task.py b/app/tests/web/endpoints/test_task.py index 937ad46..038babf 100644 --- a/app/tests/web/endpoints/test_task.py +++ b/app/tests/web/endpoints/test_task.py @@ -1,3 +1,4 @@ +from http import HTTPStatus from unittest.mock import patch @@ -12,27 +13,26 @@ def test_get_status_success(mock_async_result, client_with_auth): response = client_with_auth.get("/task/test-task-id") - assert response.status_code == 200 + assert response.status_code == HTTPStatus.OK assert response.json() == { "id": "test-task-id", "status": "SUCCESS", - "result": {"data": "some result"} + "result": {"data": "some result"}, } @patch("app.web.endpoints.task.AsyncResult") def test_get_status_failure(mock_async_result, client_with_auth): - mock_async_result.return_value.status = "FAILURE" mock_async_result.return_value.result = Exception("Some error") response = client_with_auth.get("/task/test-task-id") - assert response.status_code == 200 + assert response.status_code == HTTPStatus.OK assert response.json() == { "id": "test-task-id", "status": "FAILURE", - "result": {"error": "Some error"} + "result": {"error": "Some error"}, } @@ -43,9 +43,9 @@ def test_get_status_pending(mock_async_result, client_with_auth): response = client_with_auth.get("/task/test-task-id") - assert response.status_code == 200 + assert response.status_code == HTTPStatus.OK assert response.json() == { "id": "test-task-id", "status": "PENDING", - "result": None + "result": None, } diff --git a/app/tests/web/endpoints/test_url.py b/app/tests/web/endpoints/test_url.py index 1b6ee85..cd64262 100644 --- a/app/tests/web/endpoints/test_url.py +++ b/app/tests/web/endpoints/test_url.py @@ -1,6 +1,9 @@ import json +from http import HTTPStatus from unittest.mock import MagicMock, patch +from app.shared import schemas +from app.shared.db import worker_crud from app.shared.schemas import ArchiveCreate, TaskResult from app.web.config import ALLOW_ANY_EMAIL @@ -13,7 +16,9 @@ def test_archive_url_unauthenticated(client, test_no_auth): @patch("app.web.endpoints.url.celery", return_value=MagicMock()) def test_archive_url(m_celery, m2, client_with_auth): m_signature = MagicMock() - m_signature.apply_async.return_value = TaskResult(id="123-456-789", status="PENDING", result="") + m_signature.apply_async.return_value = TaskResult( + id="123-456-789", status="PENDING", result="" + ) m_celery.signature.return_value = m_signature m_user_state = MagicMock() @@ -21,62 +26,98 @@ def test_archive_url(m_celery, m2, client_with_auth): # url is too short response = client_with_auth.post("/url/archive", json={"url": "bad"}) - assert response.status_code == 422 - assert response.json()["detail"][0]["msg"] == 'String should have at least 5 characters' + assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY + assert ( + response.json()["detail"][0]["msg"] + == "String should have at least 5 characters" + ) m_celery.signature.assert_not_called() # url is invalid - response = client_with_auth.post("/url/archive", json={"url": "example.com"}) - assert response.status_code == 400 + response = client_with_auth.post( + "/url/archive", json={"url": "example.com"} + ) + assert response.status_code == HTTPStatus.BAD_REQUEST assert response.json()["detail"] == "Invalid URL received." # valid request m_user_state.has_quota_max_monthly_urls.return_value = True m_user_state.has_quota_max_monthly_mbs.return_value = True - response = client_with_auth.post("/url/archive", json={"url": "https://example.com"}) - assert response.status_code == 201 - assert response.json() == {'id': '123-456-789'} + response = client_with_auth.post( + "/url/archive", json={"url": "https://example.com"} + ) + assert response.status_code == HTTPStatus.CREATED + assert response.json() == {"id": "123-456-789"} m_celery.signature.assert_called_once() m_signature.apply_async.assert_called_once() called_val = m_celery.signature.call_args assert called_val[0][0] == "create_archive_task" - assert json.loads(called_val[1]['args'][0]) == {"id": None, "url": "https://example.com", "result": None, "public": False, "author_id": "rick@example.com", "group_id": "default", "tags": None, "sheet_id": None, "store_until": None, "urls": None} + assert json.loads(called_val[1]["args"][0]) == { + "id": None, + "url": "https://example.com", + "result": None, + "public": False, + "author_id": "rick@example.com", + "group_id": "default", + "tags": None, + "sheet_id": None, + "store_until": None, + "urls": None, + } m_user_state.has_quota_max_monthly_urls.assert_called_once() m_user_state.has_quota_max_monthly_mbs.assert_called_once() m_user_state.in_group.assert_called_once_with("default") # user is not in group m_user_state.in_group.return_value = False - response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "new-group"}) - assert response.status_code == 403 - assert response.json()["detail"] == "User does not have access to this group." + response = client_with_auth.post( + "/url/archive", + json={"url": "https://example.com", "group_id": "new-group"}, + ) + assert response.status_code == HTTPStatus.FORBIDDEN + assert ( + response.json()["detail"] == "User does not have access to this group." + ) m_user_state.in_group.assert_called_with("new-group") # user is in group m_user_state.in_group.return_value = True - response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spaceship"}) - assert response.status_code == 201 - assert response.json() == {'id': '123-456-789'} + response = client_with_auth.post( + "/url/archive", + json={"url": "https://example.com", "group_id": "spaceship"}, + ) + assert response.status_code == HTTPStatus.CREATED + assert response.json() == {"id": "123-456-789"} assert m_celery.signature.call_count == 2 assert m_signature.apply_async.call_count == 2 called_val = m_celery.signature.call_args - assert json.loads(called_val[1]['args'][0])["group_id"] == "spaceship" + assert json.loads(called_val[1]["args"][0])["group_id"] == "spaceship" m_user_state.in_group.assert_called_with("spaceship") # user is over monthly URL quota m_user_state.has_quota_max_monthly_urls.return_value = False m_user_state.has_quota_max_monthly_mbs.return_value = True - response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spaceship"}) - assert response.status_code == 429 - assert response.json()["detail"] == "User has reached their monthly URL quota." + response = client_with_auth.post( + "/url/archive", + json={"url": "https://example.com", "group_id": "spaceship"}, + ) + assert response.status_code == HTTPStatus.TOO_MANY_REQUESTS + assert ( + response.json()["detail"] == "User has reached their monthly URL quota." + ) m_user_state.has_quota_max_monthly_urls.assert_called_with("spaceship") # user is over monthly MB quota m_user_state.has_quota_max_monthly_urls.return_value = True m_user_state.has_quota_max_monthly_mbs.return_value = False - response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spacesuit"}) - assert response.status_code == 429 - assert response.json()["detail"] == "User has reached their monthly MB quota." + response = client_with_auth.post( + "/url/archive", + json={"url": "https://example.com", "group_id": "spacesuit"}, + ) + assert response.status_code == HTTPStatus.TOO_MANY_REQUESTS + assert ( + response.json()["detail"] == "User has reached their monthly MB quota." + ) m_user_state.has_quota_max_monthly_mbs.assert_called_with("spacesuit") assert m_celery.signature.call_count == 2 assert m_signature.apply_async.call_count == 2 @@ -89,40 +130,77 @@ def test_archive_url_quotas(m1, client_with_auth): # misses on monthly URLs quota m_user_state.has_quota_max_monthly_urls.return_value = False - response = client_with_auth.post("/url/archive", json={"url": "https://example.com"}) - assert response.status_code == 429 - assert response.json()["detail"] == "User has reached their monthly URL quota." + response = client_with_auth.post( + "/url/archive", json={"url": "https://example.com"} + ) + assert response.status_code == HTTPStatus.TOO_MANY_REQUESTS + assert ( + response.json()["detail"] == "User has reached their monthly URL quota." + ) m_user_state.has_quota_max_monthly_urls.assert_called_once() # misses on monthly MBs quota m_user_state.has_quota_max_monthly_urls.return_value = True m_user_state.has_quota_max_monthly_mbs.return_value = False - response = client_with_auth.post("/url/archive", json={"url": "https://example.com"}) - assert response.status_code == 429 - assert response.json()["detail"] == "User has reached their monthly MB quota." + response = client_with_auth.post( + "/url/archive", json={"url": "https://example.com"} + ) + assert response.status_code == HTTPStatus.TOO_MANY_REQUESTS + assert ( + response.json()["detail"] == "User has reached their monthly MB quota." + ) m_user_state.has_quota_max_monthly_mbs.assert_called_once() @patch("app.web.endpoints.url.celery", return_value=MagicMock()) def test_archive_url_with_api_token(m_celery, client_with_token): m_signature = MagicMock() - m_signature.apply_async.return_value = TaskResult(id="123-456-789", status="PENDING", result="") + m_signature.apply_async.return_value = TaskResult( + id="123-456-789", status="PENDING", result="" + ) m_celery.signature.return_value = m_signature - response = client_with_token.post("/url/archive", json={"url": "https://example.com", "author_id": "someone@example.com"}) - assert response.status_code == 201 - assert response.json() == {'id': '123-456-789'} + response = client_with_token.post( + "/url/archive", + json={"url": "https://example.com", "author_id": "someone@example.com"}, + ) + assert response.status_code == HTTPStatus.CREATED + assert response.json() == {"id": "123-456-789"} m_celery.signature.assert_called_once() m_signature.apply_async.assert_called_once() called_val = m_celery.signature.call_args assert called_val[0][0] == "create_archive_task" - assert json.loads(called_val[1]['args'][0]) == {"id": None, "url": "https://example.com", "result": None, "public": False, "author_id": "someone@example.com", "group_id": "default", "tags": None, "sheet_id": None, "store_until": None, "urls": None} + assert json.loads(called_val[1]["args"][0]) == { + "id": None, + "url": "https://example.com", + "result": None, + "public": False, + "author_id": "someone@example.com", + "group_id": "default", + "tags": None, + "sheet_id": None, + "store_until": None, + "urls": None, + } # missing id should use ALLOW_ANY_EMAIL - response = client_with_token.post("/url/archive", json={"url": "https://example.com", "author_id": None}) - assert response.status_code == 201 + response = client_with_token.post( + "/url/archive", json={"url": "https://example.com", "author_id": None} + ) + assert response.status_code == HTTPStatus.CREATED called_val = m_celery.signature.call_args assert called_val[0][0] == "create_archive_task" - assert json.loads(called_val[1]['args'][0]) == {"id": None, "url": "https://example.com", "result": None, "public": False, "author_id": ALLOW_ANY_EMAIL, "group_id": "default", "tags": None, "sheet_id": None, "store_until": None, "urls": None} + assert json.loads(called_val[1]["args"][0]) == { + "id": None, + "url": "https://example.com", + "result": None, + "public": False, + "author_id": ALLOW_ANY_EMAIL, + "group_id": "default", + "tags": None, + "sheet_id": None, + "store_until": None, + "urls": None, + } def test_search_by_url_unauthenticated(client, test_no_auth): @@ -132,46 +210,67 @@ def test_search_by_url_unauthenticated(client, test_no_auth): def test_search_by_url(client_with_auth, client_with_token, db_session): # tests the search endpoint, including through some db data for the endpoint params response = client_with_auth.get("/url/search") - assert response.status_code == 422 + assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY assert response.json()["detail"][0]["msg"] == "Field required" response = client_with_auth.get("/url/search?url=https://example.com") - assert response.status_code == 200 + assert response.status_code == HTTPStatus.OK assert response.json() == [] - from app.shared import schemas - from app.shared.db import worker_crud for i in range(11): - worker_crud.create_archive(db_session, ArchiveCreate(id=f"url-456-{i}", url="https://example.com" if i < 10 else "https://something-else.com", result={}, public=True, author_id="rick@example.com"), [], []) + worker_crud.create_archive( + db_session, + ArchiveCreate( + id=f"url-456-{i}", + url="https://example.com" + if i < 10 + else "https://something-else.com", + result={}, + public=True, + author_id="rick@example.com", + ), + [], + [], + ) # NB: this insertion is too fast for the ordering to be correct as they are within the same second response = client_with_auth.get("/url/search?url=https://example.com") - assert response.status_code == 200 + assert response.status_code == HTTPStatus.OK assert len(j := response.json()) == 10 assert "url-456-0" in [i["id"] for i in j] assert "url-456-9" in [i["id"] for i in j] assert "url-456-10" not in [i["id"] for i in j] assert j[0].keys() == schemas.ArchiveResult.model_fields.keys() - response = client_with_auth.get("/url/search?url=https://example.com&limit=5") - assert response.status_code == 200 + response = client_with_auth.get( + "/url/search?url=https://example.com&limit=5" + ) + assert response.status_code == HTTPStatus.OK assert len(response.json()) == 5 - response = client_with_auth.get("/url/search?url=https://example.com&skip=5&limit=2") - assert response.status_code == 200 + response = client_with_auth.get( + "/url/search?url=https://example.com&skip=5&limit=2" + ) + assert response.status_code == HTTPStatus.OK assert len(response.json()) == 2 - response = client_with_auth.get("/url/search?url=https://example.com&archived_before=2010-01-01") - assert response.status_code == 200 + response = client_with_auth.get( + "/url/search?url=https://example.com&archived_before=2010-01-01" + ) + assert response.status_code == HTTPStatus.OK assert len(response.json()) == 0 - response = client_with_auth.get("/url/search?url=https://example.com&archived_after=2010-01-01") - assert response.status_code == 200 + response = client_with_auth.get( + "/url/search?url=https://example.com&archived_after=2010-01-01" + ) + assert response.status_code == HTTPStatus.OK assert len(response.json()) == 10 # API token will also work - response = client_with_token.get("/url/search?url=https://example.com&archived_after=2010-01-01") - assert response.status_code == 200 + response = client_with_token.get( + "/url/search?url=https://example.com&archived_after=2010-01-01" + ) + assert response.status_code == HTTPStatus.OK assert len(response.json()) == 10 @@ -181,7 +280,7 @@ def test_search_no_read_access(mock_user_state, client_with_auth): mock_user_state.return_value.read_public = False response = client_with_auth.get("/url/search?url=https://example.com") - assert response.status_code == 403 + assert response.status_code == HTTPStatus.FORBIDDEN assert response.json() == {"detail": "User does not have read access."} @@ -191,12 +290,22 @@ def test_delete_task_unauthenticated(client, test_no_auth): def test_delete_task(client_with_auth, db_session): response = client_with_auth.delete("/url/delete-123-456-789") - assert response.status_code == 200 + assert response.status_code == HTTPStatus.OK assert response.json() == {"id": "delete-123-456-789", "deleted": False} - from app.shared.db import worker_crud - worker_crud.create_archive(db_session, ArchiveCreate(id="delete-123-456-789", url="https://example.com", result={}, public=True, author_id="morty@example.com"), [], []) + worker_crud.create_archive( + db_session, + ArchiveCreate( + id="delete-123-456-789", + url="https://example.com", + result={}, + public=True, + author_id="morty@example.com", + ), + [], + [], + ) response = client_with_auth.delete("/url/delete-123-456-789") - assert response.status_code == 200 + assert response.status_code == HTTPStatus.OK assert response.json() == {"id": "delete-123-456-789", "deleted": True} diff --git a/app/tests/web/test_main.py b/app/tests/web/test_main.py index a4ddf1e..0dbd5fa 100644 --- a/app/tests/web/test_main.py +++ b/app/tests/web/test_main.py @@ -1,25 +1,32 @@ import os import shutil +from http import HTTPStatus from unittest.mock import patch +import alembic.config import pytest from fastapi.testclient import TestClient +from app.web.utils.metrics import EXCEPTION_COUNTER + def test_lifespan(app): with TestClient(app) as client: r = client.get("/health") - assert r.status_code == 200 + assert r.status_code == HTTPStatus.OK assert r.json() == {"status": "ok"} -def test_alembic(db_session): - import alembic.config - alembic.config.main(argv=['--raiseerr', 'upgrade', 'head']) - alembic.config.main(argv=['--raiseerr', 'downgrade', 'base']) -@patch("app.web.endpoints.url.crud.soft_delete_archive", side_effect=Exception('mocked error')) +def test_alembic(db_session): + alembic.config.main(argv=["--raiseerr", "upgrade", "head"]) + alembic.config.main(argv=["--raiseerr", "downgrade", "base"]) + + +@patch( + "app.web.endpoints.url.crud.soft_delete_archive", + side_effect=Exception("mocked error"), +) def test_logging_middleware(m1, client_with_auth): - from app.web.utils.metrics import EXCEPTION_COUNTER assert len(EXCEPTION_COUNTER.collect()[0].samples) == 0 with pytest.raises(Exception, match="mocked error"): client_with_auth.delete("/url/123") @@ -37,12 +44,13 @@ def test_serve_local_archive_logic(get_settings): # modify the settings get_settings.SERVE_LOCAL_ARCHIVE = "/app/local_archive_test" from app.web.main import app_factory + app = app_factory(get_settings) # test client = TestClient(app) r = client.get("/app/local_archive_test/temp.txt") - assert r.status_code == 200 + assert r.status_code == HTTPStatus.OK assert r.text == "test" finally: # cleanup diff --git a/app/tests/web/test_security.py b/app/tests/web/test_security.py index 55a434b..07943e5 100644 --- a/app/tests/web/test_security.py +++ b/app/tests/web/test_security.py @@ -1,3 +1,4 @@ +from http import HTTPStatus from unittest.mock import Mock, patch import pytest @@ -5,112 +6,177 @@ from fastapi import HTTPException from fastapi.security import HTTPAuthorizationCredentials from app.web.config import ALLOW_ANY_EMAIL +from app.web.db.user_state import UserState +from app.web.security import ( + authenticate_user, + get_token_or_user_auth, + get_user_auth, + get_user_state, + secure_compare, + token_api_key_auth, +) def test_secure_compare(): - from app.web.security import secure_compare - assert secure_compare("test", "test") assert not secure_compare("test", "test2") @pytest.mark.asyncio async def test_get_token_or_user_auth_with_api(): - from app.web.security import get_token_or_user_auth - mock_api = HTTPAuthorizationCredentials(scheme="lorem", credentials="this_is_the_test_api_token") + mock_api = HTTPAuthorizationCredentials( + scheme="lorem", credentials="this_is_the_test_api_token" + ) assert await get_token_or_user_auth(mock_api) == ALLOW_ANY_EMAIL @pytest.mark.asyncio async def test_get_token_or_user_auth_with_user(): - from app.web.security import get_token_or_user_auth - bad_user = HTTPAuthorizationCredentials(scheme="ipsum", credentials="invalid") - e: pytest.ExceptionInfo = None + bad_user = HTTPAuthorizationCredentials( + scheme="ipsum", credentials="invalid" + ) with pytest.raises(HTTPException) as e: await get_token_or_user_auth(bad_user) - assert e.value.status_code == 401 + assert e.value.status_code == HTTPStatus.UNAUTHORIZED assert e.value.detail == "invalid access_token" -@patch("app.web.security.authenticate_user", return_value=(True, "summer@example.com")) +@patch( + "app.web.security.authenticate_user", + return_value=(True, "summer@example.com"), +) @pytest.mark.asyncio async def test_get_user_auth(m1): - from app.web.security import get_user_auth - good_user = HTTPAuthorizationCredentials(scheme="ipsum", credentials="valid-and-good") + good_user = HTTPAuthorizationCredentials( + scheme="ipsum", credentials="valid-and-good" + ) assert await get_user_auth(good_user) == "summer@example.com" @patch("app.web.security.secure_compare", return_value=False) @pytest.mark.asyncio async def test_token_api_key_auth_exception(m1): - from app.web.security import token_api_key_auth - - e: pytest.ExceptionInfo = None with pytest.raises(HTTPException) as e: - await token_api_key_auth(HTTPAuthorizationCredentials(scheme="ipsum", credentials="does-not-matter"), auto_error=True) - assert e.value.status_code == 401 + await token_api_key_auth( + HTTPAuthorizationCredentials( + scheme="ipsum", credentials="does-not-matter" + ), + auto_error=True, + ) + assert e.value.status_code == HTTPStatus.UNAUTHORIZED assert e.value.detail == "Wrong auth credentials" @pytest.mark.asyncio async def test_authenticate_user(): - from app.web.security import authenticate_user - assert authenticate_user("test") == (False, "invalid access_token") assert authenticate_user(123) == (False, "invalid access_token") with patch("app.web.security.requests.get") as mock_get: # bad response from oauth2 - mock_get.return_value.status_code = 403 - assert authenticate_user("this-will-call-requests") == (False, "invalid token") + mock_get.return_value.status_code = HTTPStatus.FORBIDDEN + assert authenticate_user("this-will-call-requests") == ( + False, + "invalid token", + ) assert mock_get.call_count == 1 # 200 but invalid json - mock_get.return_value.status_code = 200 - assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID") + mock_get.return_value.status_code = HTTPStatus.OK + assert authenticate_user("this-will-call-requests") == ( + False, + "token does not belong to valid APP_ID", + ) assert mock_get.call_count == 2 # 200 but invalid azp and aud - mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "not_an_app"} - assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID") + mock_get.return_value.json.return_value = { + "email": "summer@example.com", + "azp": "not_an_app", + } + assert authenticate_user("this-will-call-requests") == ( + False, + "token does not belong to valid APP_ID", + ) - mock_get.return_value.json.return_value = {"email": "summer@example.com", "aud": "not_an_app"} - assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID") + mock_get.return_value.json.return_value = { + "email": "summer@example.com", + "aud": "not_an_app", + } + assert authenticate_user("this-will-call-requests") == ( + False, + "token does not belong to valid APP_ID", + ) - mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "not_an_app", "aud": "not_an_app"} - assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID") + mock_get.return_value.json.return_value = { + "email": "summer@example.com", + "azp": "not_an_app", + "aud": "not_an_app", + } + assert authenticate_user("this-will-call-requests") == ( + False, + "token does not belong to valid APP_ID", + ) # blocked email - mock_get.return_value.json.return_value = {"email": "blocked@example.com", "azp": "test_app_id_1", "aud": "not_an_app"} - assert authenticate_user("this-will-call-requests") == (False, "email 'blocked@example.com' not allowed") + mock_get.return_value.json.return_value = { + "email": "blocked@example.com", + "azp": "test_app_id_1", + "aud": "not_an_app", + } + assert authenticate_user("this-will-call-requests") == ( + False, + "email 'blocked@example.com' not allowed", + ) # not verified - mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "not_an_app", "aud": "test_app_id_1"} - assert authenticate_user("this-will-call-requests") == (False, "email 'summer@example.com' not verified") + mock_get.return_value.json.return_value = { + "email": "summer@example.com", + "azp": "not_an_app", + "aud": "test_app_id_1", + } + assert authenticate_user("this-will-call-requests") == ( + False, + "email 'summer@example.com' not verified", + ) # token expired - mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "test_app_id_2", "email_verified": "true"} - assert authenticate_user("this-will-call-requests") == (False, "Token expired") + mock_get.return_value.json.return_value = { + "email": "summer@example.com", + "azp": "test_app_id_2", + "email_verified": "true", + } + assert authenticate_user("this-will-call-requests") == ( + False, + "Token expired", + ) # 200 and valid azp and aup and verified - mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "test_app_id_2", "email_verified": "true", "expires_in": 100} - assert authenticate_user("this-will-call-requests") == (True, "summer@example.com") + mock_get.return_value.json.return_value = { + "email": "summer@example.com", + "azp": "test_app_id_2", + "email_verified": "true", + "expires_in": 100, + } + assert authenticate_user("this-will-call-requests") == ( + True, + "summer@example.com", + ) assert mock_get.call_count == 9 @pytest.mark.asyncio async def test_authenticate_user_exception(): - from app.web.security import authenticate_user with patch("app.web.security.requests.get") as mock_get: - mock_get.return_value.status_code = 200 + mock_get.return_value.status_code = HTTPStatus.OK mock_get.return_value.json.side_effect = Exception("mocked error") - assert authenticate_user("this-will-call-requests") == (False, "exception occurred") + assert authenticate_user("this-will-call-requests") == ( + False, + "exception occurred", + ) def test_get_user_state(): - from app.web.db.user_state import UserState - from app.web.security import get_user_state - mock_session = Mock() test_email = "test@example.com" diff --git a/app/tests/worker/test_worker_main.py b/app/tests/worker/test_worker_main.py index 9a77528..e67aa88 100644 --- a/app/tests/worker/test_worker_main.py +++ b/app/tests/worker/test_worker_main.py @@ -6,23 +6,41 @@ from auto_archiver.core import Media, Metadata from app.shared import schemas from app.shared.db import models +from app.worker.main import create_archive_task, create_sheet_task, get_all_urls -class Test_create_archive_task(): +class TestCreateArchiveTask: URL = "https://example-live.com" - archive = schemas.ArchiveCreate(url=URL, tags=["tag-celery"], public=True, author_id="rick@example.com", group_id="interstellar") + archive = schemas.ArchiveCreate( + url=URL, + tags=["tag-celery"], + public=True, + author_id="rick@example.com", + group_id="interstellar", + ) @patch("app.worker.main.ArchivingOrchestrator") @patch("app.worker.main.get_all_urls", return_value=[]) @patch("app.worker.main.insert_result_into_db") @patch("app.worker.main.get_store_until", return_value=datetime.now()) - @patch("app.worker.main.get_orchestrator_args", return_value=["arg1", "arg2"]) + @patch( + "app.worker.main.get_orchestrator_args", return_value=["arg1", "arg2"] + ) @patch("celery.app.task.Task.request") - def test_success(self, m_req, m_args, m_store, m_insert, m_urls, m_orchestrator, db_session): - from app.worker.main import create_archive_task - + def test_success( + self, + m_req, + m_args, + m_store, + m_insert, + m_urls, + m_orchestrator, + db_session, + ): m_req.id = "this-just-in" - m_orchestrator.return_value.feed.return_value = iter([Metadata().set_url(self.URL).success()]) + m_orchestrator.return_value.feed.return_value = iter( + [Metadata().set_url(self.URL).success()] + ) task = create_archive_task(self.archive.model_dump_json()) @@ -38,15 +56,15 @@ class Test_create_archive_task(): assert len(task["media"]) == 0 def test_raise_invalid(self): - from app.worker.main import create_archive_task - with pytest.raises(Exception): + with pytest.raises(Exception) as _: create_archive_task(self.archive.model_dump_json()) @patch("app.worker.main.ArchivingOrchestrator") @patch("app.worker.main.get_orchestrator_args") def test_raise_db_error(self, m_args, m_orchestrator): - from app.worker.main import create_archive_task - m_orchestrator.return_value.feed.side_effect = Exception("Orchestrator failed") + m_orchestrator.return_value.feed.side_effect = Exception( + "Orchestrator failed" + ) with pytest.raises(Exception) as e: create_archive_task(self.archive.model_dump_json()) @@ -58,7 +76,6 @@ class Test_create_archive_task(): @patch("app.worker.main.insert_result_into_db", return_value=None) @patch("app.worker.main.get_orchestrator_args") def test_raise_empty_result(self, m_args, m_insert, m_orchestrator): - from app.worker.main import create_archive_task m_orchestrator.return_value.feed.return_value = iter([None]) with pytest.raises(Exception) as e: @@ -67,61 +84,83 @@ class Test_create_archive_task(): m_orchestrator.return_value.feed.assert_called_once() -class Test_create_sheet_task(): +class TestCreateSheetTask: URL = "https://example-live.com" - sheet = schemas.SubmitSheet(sheet_id="123", author_id="rick@example.com", group_id="interstellar", tags=["spaceship"]) + sheet = schemas.SubmitSheet( + sheet_id="123", + author_id="rick@example.com", + group_id="interstellar", + tags=["spaceship"], + ) @patch("app.worker.main.get_all_urls", return_value=[]) @patch("app.worker.main.ArchivingOrchestrator") @patch("app.worker.main.models.generate_uuid", return_value="constant-uuid") @patch("app.worker.main.get_store_until", return_value=datetime.now()) @patch("app.worker.main.get_orchestrator_args") - def test_success(self, m_args, m_store, m_uuid, m_orchestrator, m_urls, db_session): - from app.worker.main import create_sheet_task - - assert db_session.query(models.Archive).filter(models.Archive.url == self.URL).count() == 0 + def test_success( + self, m_args, m_store, m_uuid, m_orchestrator, m_urls, db_session + ): + assert ( + db_session.query(models.Archive) + .filter(models.Archive.url == self.URL) + .count() + == 0 + ) mock_metadata = Metadata().set_url(self.URL).success() mock_metadata.add_media(Media("fn1.txt", urls=["outcome1.com"])) - m_orchestrator.return_value.feed.return_value = iter([False, mock_metadata, mock_metadata]) + m_orchestrator.return_value.feed.return_value = iter( + [False, mock_metadata, mock_metadata] + ) res = create_sheet_task(self.sheet.model_dump_json()) - m_args.assert_called_once_with("interstellar", True, ["--gsheet_feeder.sheet_id", "123"]) + m_args.assert_called_once_with( + "interstellar", True, ["--gsheet_feeder.sheet_id", "123"] + ) m_orchestrator.return_value.setup.assert_called_once() m_orchestrator.return_value.feed.assert_called_once() m_store.assert_called_with("interstellar") - m_store.call_count == 2 - m_uuid.call_count == 2 - assert type(res) == dict + assert m_store.call_count == 2 + assert m_uuid.call_count == 2 + assert isinstance(res, dict) assert res["stats"]["archived"] == 1 assert res["stats"]["failed"] == 1 assert len(res["stats"]["errors"]) == 1 assert res["sheet_id"] == "123" assert res["success"] - assert type(res["time"]) == datetime + assert isinstance(res["time"], datetime) # query created archive entry - inserted = db_session.query(models.Archive).filter(models.Archive.url == self.URL).one() + inserted = ( + db_session.query(models.Archive) + .filter(models.Archive.url == self.URL) + .one() + ) assert inserted is not None assert inserted.url == self.URL assert len(inserted.tags) == 1 assert inserted.tags[0].id == "spaceship" assert inserted.group_id == "interstellar" assert inserted.author_id == "rick@example.com" - assert inserted.public == False + assert inserted.public is False def test_get_all_urls(db_session): - from app.worker.main import get_all_urls - meta = Metadata().set_url("https://example.com") m1 = meta.add_media(Media("fn1.txt", urls=["outcome1.com"])) m2 = meta.add_media(Media("fn2.txt", urls=["outcome2.com"])) m3 = meta.add_media(Media("fn3.txt", urls=["outcome3.com"])) m1.set("screenshot", Media("screenshot.png", urls=["screenshot.com"])) - m2.set("thumbnails", [Media("thumb1.png", urls=["thumb1.com"]), Media("thumb2.png", urls=["thumb2.com"])]) + m2.set( + "thumbnails", + [ + Media("thumb1.png", urls=["thumb1.com"]), + Media("thumb2.png", urls=["thumb2.com"]), + ], + ) m3.set("ssl_data", Media("ssl_data.txt", urls=["ssl_data.com"]).to_dict()) m3.set("bad_data", {"bad": "dict is ignored"})