From 834e3d86da4b6f56a5059241f158f0ccd480cc16 Mon Sep 17 00:00:00 2001 From: msramalho <19508417+msramalho@users.noreply.github.com> Date: Mon, 17 Feb 2025 15:42:57 +0000 Subject: [PATCH] adds missing tests --- app/tests/conftest.py | 46 +++++ app/tests/shared/db/test_worker_crud.py | 21 ++- app/tests/shared/test_business_logic.py | 36 ++++ app/tests/shared/utils/test_misc.py | 31 +++ app/tests/web/db/test_crud.py | 177 +++++++++++++++--- app/tests/web/endpoints/test_default.py | 77 +++++++- .../web/endpoints/test_interoperability.py | 19 +- app/tests/web/test_security.py | 17 +- app/tests/worker/test_worker_main.py | 58 +++--- app/web/db/crud.py | 2 +- app/worker/main.py | 1 - 11 files changed, 421 insertions(+), 64 deletions(-) create mode 100644 app/tests/shared/test_business_logic.py create mode 100644 app/tests/shared/utils/test_misc.py diff --git a/app/tests/conftest.py b/app/tests/conftest.py index 37acbf2..afa76f9 100644 --- a/app/tests/conftest.py +++ b/app/tests/conftest.py @@ -1,7 +1,10 @@ import os +from typing import AsyncGenerator from fastapi.testclient import TestClient import pytest from unittest.mock import patch +import pytest_asyncio +from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine from app.web.config import ALLOW_ANY_EMAIL from app.shared.settings import Settings from app.web.db.user_state import UserState @@ -60,6 +63,49 @@ def db_session(test_db): yield session +@pytest_asyncio.fixture() +async def async_test_db(get_settings: Settings): + from app.shared.db import models + from app.shared.db.database import make_async_engine + from app.web.db.crud import get_user_group_names + import asyncio + + get_user_group_names.cache_clear() + engine = await make_async_engine(get_settings.ASYNC_DATABASE_PATH) + + fs = get_settings.ASYNC_DATABASE_PATH.replace("sqlite+aiosqlite:///", "") + if not os.path.exists(fs): + open(fs, 'w').close() + + async def create_all(): + async with engine.begin() as conn: + await conn.run_sync(models.Base.metadata.create_all) + + await create_all() + + yield engine + + async def drop_all(): + async with engine.begin() as conn: + await conn.run_sync(models.Base.metadata.drop_all) + + await drop_all() + + engine.dispose() + for suffix in ["", "-wal", "-shm"]: + new_fs = fs + suffix + if os.path.exists(new_fs): + os.remove(new_fs) + + +@pytest_asyncio.fixture() +async def async_db_session(async_test_db: AsyncEngine) -> AsyncGenerator[AsyncSession, None]: + from app.shared.db.database import make_async_session_local + session_local = await make_async_session_local(async_test_db) + async with session_local() as session: + yield session + + @pytest.fixture() def app(db_session): from app.web.main import app_factory diff --git a/app/tests/shared/db/test_worker_crud.py b/app/tests/shared/db/test_worker_crud.py index 70f0fda..c4e6247 100644 --- a/app/tests/shared/db/test_worker_crud.py +++ b/app/tests/shared/db/test_worker_crud.py @@ -1,8 +1,27 @@ from app.shared.db import models +from app.shared.db import worker_crud, models +from datetime import datetime from app.tests.web.db.test_crud import test_data +def test_update_sheet_last_url_archived_at(db_session): + + # Create test sheet + test_sheet = models.Sheet(id="sheet-123") + db_session.add(test_sheet) + db_session.commit() + + # Test updating existing sheet + assert isinstance(test_sheet.last_url_archived_at, datetime) + before = test_sheet.last_url_archived_at + assert worker_crud.update_sheet_last_url_archived_at(db_session, "sheet-123") is True + db_session.refresh(test_sheet) + assert isinstance(test_sheet.last_url_archived_at, datetime) + assert test_sheet.last_url_archived_at > before + + # Test non-existent sheet + assert worker_crud.update_sheet_last_url_archived_at(db_session, "non-existent-sheet") is False def test_get_group(test_data, db_session): from app.shared.db import worker_crud @@ -95,4 +114,4 @@ def test_create_task(db_session): assert nt.group_id == "spaceship" assert len(nt.tags) == 0 assert len(nt.urls) == 0 - assert nt.created_at is not None + assert nt.created_at is not None \ No newline at end of file diff --git a/app/tests/shared/test_business_logic.py b/app/tests/shared/test_business_logic.py new file mode 100644 index 0000000..80eecb3 --- /dev/null +++ b/app/tests/shared/test_business_logic.py @@ -0,0 +1,36 @@ +from datetime import datetime, timedelta +from unittest.mock import MagicMock, patch +import pytest +from app.shared.business_logic import get_store_archive_until + +class Test_get_store_archive_until: + GROUP_ID = "test-group" + + def test_group_not_found(self, db_session): + with pytest.raises(AssertionError) as exc: + get_store_archive_until(db_session, self.GROUP_ID) + assert str(exc.value) == f"Group {self.GROUP_ID} not found." + + @patch("app.shared.db.worker_crud.get_group") + def test_no_max_lifespan(self, mock_get_group, db_session): + group = MagicMock() + group.permissions = {"max_archive_lifespan_months": -1} + mock_get_group.return_value = group + + result = get_store_archive_until(db_session, self.GROUP_ID) + assert result is None + mock_get_group.assert_called_once_with(db_session, self.GROUP_ID) + + @patch("app.shared.db.worker_crud.get_group") + def test_with_max_lifespan(self, mock_get_group, db_session): + group = MagicMock() + group.permissions = {"max_archive_lifespan_months": 6} + mock_get_group.return_value = group + + result = get_store_archive_until(db_session, self.GROUP_ID) + expected = datetime.now() + timedelta(days=180) # 6 months + + assert isinstance(result, datetime) + # Allow 1 second difference due to execution time + assert abs(result - expected) < timedelta(seconds=1) + mock_get_group.assert_called_once_with(db_session, self.GROUP_ID) \ No newline at end of file diff --git a/app/tests/shared/utils/test_misc.py b/app/tests/shared/utils/test_misc.py new file mode 100644 index 0000000..d7595c8 --- /dev/null +++ b/app/tests/shared/utils/test_misc.py @@ -0,0 +1,31 @@ +from app.shared.utils.misc import fnv1a_hash_mod + + +def test_fnv1a_hash_mod(): + # Test basic string hashing + assert fnv1a_hash_mod("test", 10) == fnv1a_hash_mod("test", 10) + assert 0 <= fnv1a_hash_mod("test", 10) < 10 + + # Test different strings give different hashes + assert fnv1a_hash_mod("test1", 100) != fnv1a_hash_mod("test2", 100) + + # Test different modulos + hash1 = fnv1a_hash_mod("test", 5) + hash2 = fnv1a_hash_mod("test", 10) + assert 0 <= hash1 < 5 + assert 0 <= hash2 < 10 + + # Test empty string + assert isinstance(fnv1a_hash_mod("", 10), int) + assert 0 <= fnv1a_hash_mod("", 10) < 10 + + # Test long string + long_str = "a" * 1000 + assert 0 <= fnv1a_hash_mod(long_str, 20) < 20 + + # Test unicode string + assert isinstance(fnv1a_hash_mod("测试", 10), int) + assert 0 <= fnv1a_hash_mod("测试", 10) < 10 + + # Test modulo = 1 edge case + assert fnv1a_hash_mod("test", 1) == 0 \ No newline at end of file diff --git a/app/tests/web/db/test_crud.py b/app/tests/web/db/test_crud.py index c96c74d..c29cbfa 100644 --- a/app/tests/web/db/test_crud.py +++ b/app/tests/web/db/test_crud.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import datetime, timedelta from unittest.mock import patch import pytest @@ -6,6 +6,7 @@ import yaml from app.shared.db import models from app.shared.settings import Settings +from app.web.db import crud authors = ["rick@example.com", "morty@example.com", "jerry@example.com"] @@ -62,7 +63,6 @@ def test_data(db_session): def test_get_archive(test_data, db_session): - from app.web.db import crud from app.web.config import ALLOW_ANY_EMAIL # each author's archives work @@ -91,7 +91,6 @@ def test_get_archive(test_data, db_session): def test_search_archives_by_url(test_data, db_session): - from app.web.db import crud from app.web.config import ALLOW_ANY_EMAIL # rick's archives are private @@ -139,7 +138,6 @@ def test_search_archives_by_url(test_data, db_session): def test_search_archives_by_email(test_data, db_session): from app.web.config import ALLOW_ANY_EMAIL - from app.web.db import crud # lower/upper case assert len(crud.search_archives_by_email(db_session, "rick@example.com")) == 34 @@ -160,7 +158,6 @@ def test_search_archives_by_email(test_data, db_session): @patch("app.web.db.crud.DATABASE_QUERY_LIMIT", new=25) def test_max_query_limit(test_data, db_session): - from app.web.db import crud from app.web.config import ALLOW_ANY_EMAIL assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL)) == 25 @@ -171,8 +168,6 @@ def test_max_query_limit(test_data, db_session): def test_soft_delete(test_data, db_session): - from app.web.db import crud - # none deleted yet assert crud.get_archive(db_session, "archive-id-456-0", "rick@example.com") is not None assert db_session.query(models.Archive).filter(models.Archive.deleted == True).count() == 0 @@ -189,8 +184,6 @@ def test_soft_delete(test_data, db_session): def test_count_archives(test_data, db_session): - from app.web.db import crud - assert crud.count_archives(db_session) == 100 db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").delete() db_session.commit() @@ -198,8 +191,6 @@ def test_count_archives(test_data, db_session): def test_count_archive_urls(test_data, db_session): - from app.web.db import crud - assert crud.count_archive_urls(db_session) == 1000 db_session.query(models.ArchiveUrl).filter(models.ArchiveUrl.url == "https://example-0.com/0").delete() db_session.commit() @@ -213,8 +204,6 @@ def test_count_archive_urls(test_data, db_session): def test_count_users(test_data, db_session): - from app.web.db import crud - assert crud.count_users(db_session) == 3 db_session.query(models.User).filter(models.User.email == "rick@example.com").delete() db_session.commit() @@ -232,8 +221,6 @@ def test_count_by_users_since(test_data, db_session): def test_upsert_group(test_data, db_session): - from app.web.db import crud - assert db_session.query(models.Group).count() == 4 repeatable_params = ["desc 1", "orch.yaml", "sheet.yaml", "service_account_email@example.com", {"read": ["all"]}, ["example.com"]] @@ -262,8 +249,6 @@ def test_upsert_group(test_data, db_session): def test_upsert_user_groups(db_session): - from app.web.db import crud - @patch('app.web.db.crud.get_settings', new=lambda: bad_setings) def test_missing_yaml(db_session): with pytest.raises(FileNotFoundError): @@ -284,8 +269,6 @@ def test_upsert_user_groups(db_session): def test_create_sheet(db_session): - from app.web.db import crud - assert db_session.query(models.Sheet).count() == 0 s = crud.create_sheet(db_session, "sheet-id-123", "sheet name", "email@example.com", "group-id", "hourly") @@ -305,8 +288,6 @@ def test_create_sheet(db_session): def test_get_user_sheet(test_data, db_session): - from app.web.db import crud - assert crud.get_user_sheet(db_session, "", "sheet-0") is None assert crud.get_user_sheet(db_session, "morty@example.com", "sheet-0") is None @@ -316,8 +297,6 @@ def test_get_user_sheet(test_data, db_session): def test_get_user_sheets(test_data, db_session): - from app.web.db import crud - assert len(crud.get_user_sheets(db_session, "")) == 0 rick_sheets = crud.get_user_sheets(db_session, "rick@example.com") assert len(rick_sheets) == 2 @@ -326,8 +305,156 @@ def test_get_user_sheets(test_data, db_session): def test_delete_sheet(test_data, db_session): - from app.web.db import crud - assert crud.delete_sheet(db_session, "sheet-0", "") == False assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") == True assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") == False + + +@pytest.mark.asyncio +async def test_find_by_store_until(async_db_session): + # Add archives with different store_until dates + now = datetime.now() + archive1 = models.Archive( + id="archive-expired-1", + url="https://example-expired-1.com", + result={}, + author_id="rick@example.com", + store_until=now - timedelta(days=1) + ) + archive2 = models.Archive( + id="archive-expired-2", + url="https://example-expired-2.com", + result={}, + author_id="rick@example.com", + store_until=now - timedelta(hours=1) + ) + archive3 = models.Archive( + id="archive-active", + url="https://example-active.com", + result={}, + author_id="rick@example.com", + store_until=now + timedelta(days=1) + ) + async_db_session.add_all([archive1, archive2, archive3]) + await async_db_session.commit() + + # Should find 2 expired archives + expired = await crud.find_by_store_until(async_db_session, now) + assert len(list(expired)) == 2 + + # Should find 1 archive expired before 2 hours ago + expired = await crud.find_by_store_until(async_db_session, now - timedelta(hours=2)) + assert len(list(expired)) == 1 + + # Should find no archives expired before 2 days ago + expired = await crud.find_by_store_until(async_db_session, now - timedelta(days=2)) + assert len(list(expired)) == 0 + + # Should not find deleted archives + archive1.deleted = True + await async_db_session.commit() + expired = await crud.find_by_store_until(async_db_session, now) + assert len(list(expired)) == 1 + + +@pytest.mark.asyncio +async def test_get_sheets_by_id_hash(async_db_session): + # Add test data + authors = ["rick@example.com", "morty@example.com", "jerry@example.com"] + sheets = [ + models.Sheet(id="sheet-0", name="sheet-0", author_id=authors[0], group_id=None, frequency="daily"), + models.Sheet(id="sheet-0-2", name="sheet-0-2", author_id=authors[0], group_id="spaceship", frequency="hourly"), + models.Sheet(id="sheet-1", name="sheet-1", author_id=authors[1], group_id=None, frequency="daily"), + models.Sheet(id="sheet-2", name="sheet-2", author_id=authors[2], group_id=None, frequency="daily") + ] + async_db_session.add_all(sheets) + await async_db_session.commit() + + with patch("app.web.db.crud.fnv1a_hash_mod", return_value=1): + # Test retrieving hourly sheets + hourly_sheets = await crud.get_sheets_by_id_hash(async_db_session, "hourly", 4, 1) + assert len(hourly_sheets) == 1 + assert hourly_sheets[0].id == "sheet-0-2" + assert hourly_sheets[0].frequency == "hourly" + + # Test retrieving daily sheets + daily_sheets = await crud.get_sheets_by_id_hash(async_db_session, "daily", 4, 1) + assert len(daily_sheets) == 3 + assert all(sheet.frequency == "daily" for sheet in daily_sheets) + assert {sheet.id for sheet in daily_sheets} == {"sheet-0", "sheet-1", "sheet-2"} + + # Test with non-matching hash + no_sheets = await crud.get_sheets_by_id_hash(async_db_session, "daily", 4, 3) + assert len(no_sheets) == 0 + + # Test with non-existent frequency + weekly_sheets = await crud.get_sheets_by_id_hash(async_db_session, "weekly", 4, 1) + assert len(weekly_sheets) == 0 + + +@pytest.mark.asyncio +async def test_delete_stale_sheets(async_db_session): + from datetime import datetime, timedelta + from sqlalchemy.sql import select + + now = datetime.now() + active_date = now - timedelta(days=5) + stale_date = now - timedelta(days=15) + + # Create test sheets with different last_url_archived_at dates + sheets = [ + models.Sheet( + id="sheet-active-1", + name="Active Sheet 1", + author_id="rick@example.com", + frequency="daily", + last_url_archived_at=active_date + ), + models.Sheet( + id="sheet-active-2", + name="Active Sheet 2", + author_id="morty@example.com", + frequency="hourly", + last_url_archived_at=active_date + ), + models.Sheet( + id="sheet-stale-1", + name="Stale Sheet 1", + author_id="rick@example.com", + frequency="daily", + last_url_archived_at=stale_date + ), + models.Sheet( + id="sheet-stale-2", + name="Stale Sheet 2", + author_id="morty@example.com", + frequency="daily", + last_url_archived_at=stale_date + ) + ] + async_db_session.add_all(sheets) + await async_db_session.commit() + + # Should not delete sheets with 20 days inactivity threshold + deleted = await crud.delete_stale_sheets(async_db_session, 20) + assert len(deleted) == 0 # No sheets should be deleted + result = await async_db_session.execute(select(models.Sheet)) + assert len(list(result.scalars())) == 4 # All sheets should remain + + # Should delete sheets with 7 days inactivity threshold + deleted = await crud.delete_stale_sheets(async_db_session, 7) + assert len(deleted) == 2 # Two authors affected + assert len(deleted["rick@example.com"]) == 1 # One sheet deleted for Rick + assert len(deleted["morty@example.com"]) == 1 # One sheet deleted for Morty + assert deleted["rick@example.com"][0].id == "sheet-stale-1" + assert deleted["morty@example.com"][0].id == "sheet-stale-2" + + # Verify only active sheets remain + result = await async_db_session.execute(select(models.Sheet)) + remaining = list(result.scalars()) + assert len(remaining) == 2 + assert {s.id for s in remaining} == {"sheet-active-1", "sheet-active-2"} + + # Running again should not delete anything + deleted = await crud.delete_stale_sheets(async_db_session, 7) + assert len(deleted) == 0 \ No newline at end of file diff --git a/app/tests/web/endpoints/test_default.py b/app/tests/web/endpoints/test_default.py index b4ed7a5..401a164 100644 --- a/app/tests/web/endpoints/test_default.py +++ b/app/tests/web/endpoints/test_default.py @@ -1,6 +1,8 @@ from unittest.mock import MagicMock from fastapi.testclient import TestClient import pytest +from app.shared.schemas import Usage, UsageResponse +from app.shared.user_groups import GroupInfo from app.web.config import VERSION from app.tests.web.db.test_crud import test_data @@ -13,6 +15,7 @@ def test_endpoint_home(client_with_auth): assert "breakingChanges" in j assert "groups" not in j + def test_endpoint_health(client_with_auth): r = client_with_auth.get("/health") assert r.status_code == 200 @@ -28,7 +31,7 @@ def test_endpoint_active(app): from app.web.security import get_user_state app.dependency_overrides[get_user_state] = lambda: m_user_state - + # inactive user m_user_state.active = False client = TestClient(app) @@ -42,7 +45,6 @@ def test_endpoint_active(app): r = client.get("/user/active") assert r.status_code == 200 assert r.json() == {"active": True} - def test_no_serve_local_archive_by_default(client_with_auth): @@ -100,3 +102,74 @@ async def test_prometheus_metrics(test_data, client_with_token, get_settings): assert 'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0' in r3.text assert 'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0' in r3.text assert 'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0' in r3.text + + +def test_endpoint_get_user_permissions_no_user_auth(client, test_no_auth): + test_no_auth(client.get, "/user/permissions") + + +def test_endpoint_get_user_permissions(app): + from app.web.security import get_user_state + + m_user_state = MagicMock() + rv = { + "all": GroupInfo(read=True), + "group1": GroupInfo(archive_url=True), + } + from loguru import logger + logger.info(rv) + m_user_state.permissions = rv + + app.dependency_overrides[get_user_state] = lambda: m_user_state + + client = TestClient(app) + r = client.get("/user/permissions") + assert r.status_code == 200 + response = r.json() + assert response.keys() == {"all", "group1"} + assert response["all"]["read"] + assert response["group1"]["read"] == [] + assert response["group1"]["archive_url"] + assert response["all"]["archive_url"] == False + + +def test_endpoint_get_user_usage_no_user_auth(client, test_no_auth): + test_no_auth(client.get, "/user/usage") + + +def test_endpoint_get_user_usage_inactive(app): + from app.web.security import get_user_state + + m_user_state = MagicMock() + m_user_state.active = False + + app.dependency_overrides[get_user_state] = lambda: m_user_state + + client = TestClient(app) + r = client.get("/user/usage") + assert r.status_code == 403 + assert r.json() == {"detail": "User is not active."} + + +def test_endpoint_get_user_usage_active(app): + from app.web.security import get_user_state + + m_user_state = MagicMock() + m_user_state.active = True + mock_usage = UsageResponse( + monthly_urls=1, + monthly_mbs=2, + total_sheets=3, + groups={ + "group1": Usage(monthly_urls=4, monthly_mbs=5, total_sheets=6), + "group2": Usage(monthly_urls=7, monthly_mbs=8, total_sheets=9) + } + ) + m_user_state.usage.return_value = mock_usage + + app.dependency_overrides[get_user_state] = lambda: m_user_state + + client = TestClient(app) + r = client.get("/user/usage") + assert r.status_code == 200 + assert UsageResponse(**r.json()) == mock_usage diff --git a/app/tests/web/endpoints/test_interoperability.py b/app/tests/web/endpoints/test_interoperability.py index c3f8cb5..edf8c0b 100644 --- a/app/tests/web/endpoints/test_interoperability.py +++ b/app/tests/web/endpoints/test_interoperability.py @@ -32,9 +32,24 @@ def test_submit_manual_archive(m1, client_with_token, db_session): assert [u.url for u in inserted.urls] == ["http://example.s3.com"] assert type(inserted.store_until) == datetime - - # cannot have the same URL twice + # cannot have the same URL twice aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.com", "http://example.com"]}]}) r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": False, "author_id": "jerry@gmail.com", "tags": ["test"], "url": "http://example.com"}) assert r.status_code == 422 assert r.json() == {"detail": "Cannot insert into DB due to integrity error, likely duplicate urls."} + + +# test with invalid JSON +def test_submit_manual_archive_invalid_json(client_with_token): + r = client_with_token.post("/interop/submit-archive", json={"result": "invalid json", "public": False, "author_id": "jer", "tags": ["test"], "url": "http://example.com"}) + assert r.status_code == 422 + assert r.json() == {"detail": "Invalid JSON in result field."} + + +@patch("app.web.endpoints.interoperability.business_logic") +def test_submit_manual_archive_no_store_until(m_b, client_with_token, db_session): + m_b.get_store_archive_until.side_effect = AssertionError("AssertionError") + aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.s3.com"]}]}) + r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": True, "author_id": "jerry@gmail.com", "group_id": "spaceship", "tags": ["test"], "url": "http://example.com"}) + assert r.status_code == 422 + assert r.json() == {"detail": "AssertionError"} diff --git a/app/tests/web/test_security.py b/app/tests/web/test_security.py index 4a46823..1a6c00b 100644 --- a/app/tests/web/test_security.py +++ b/app/tests/web/test_security.py @@ -1,4 +1,4 @@ -from unittest.mock import patch +from unittest.mock import Mock, patch from fastapi import HTTPException from fastapi.security import HTTPAuthorizationCredentials @@ -101,8 +101,21 @@ async def test_authenticate_user(): @pytest.mark.asyncio async def test_authenticate_user_exception(): from app.web.security import authenticate_user - with patch("app.web.security.requests.get") as mock_get: mock_get.return_value.status_code = 200 mock_get.return_value.json.side_effect = Exception("mocked error") assert authenticate_user("this-will-call-requests") == (False, "exception occurred") + + +def test_get_user_state(): + from app.web.security import get_user_state + from app.web.db.user_state import UserState + + mock_session = Mock() + test_email = "test@example.com" + + state = get_user_state(test_email, mock_session) + + assert isinstance(state, UserState) + assert state.email == test_email + assert state.db == mock_session diff --git a/app/tests/worker/test_worker_main.py b/app/tests/worker/test_worker_main.py index e4dd549..b6ca8f3 100644 --- a/app/tests/worker/test_worker_main.py +++ b/app/tests/worker/test_worker_main.py @@ -10,7 +10,6 @@ from app.shared import schemas from auto_archiver.core import Media, Metadata - class Test_create_archive_task(): URL = "https://example-live.com" archive = schemas.ArchiveCreate(url=URL, tags=["tag-celery"], public=True, author_id="rick@example.com", group_id="interstellar") @@ -19,20 +18,21 @@ class Test_create_archive_task(): @patch("app.worker.main.get_all_urls", return_value=[]) @patch("app.worker.main.insert_result_into_db") @patch("app.worker.main.get_store_until", return_value=datetime.now()) - @patch("app.worker.main.get_orchestrator_args", return_value=["arg1", "arg2"]) + @patch("app.worker.main.get_orchestrator_args", return_value=["arg1", "arg2"]) @patch("celery.app.task.Task.request") def test_success(self, m_req, m_args, m_store, m_insert, m_urls, m_orchestrator, db_session): from app.worker.main import create_archive_task m_req.id = "this-just-in" - m_orchestrator.run.return_value = Metadata().set_url(self.URL).success() + m_orchestrator.return_value.run.return_value = iter([Metadata().set_url(self.URL).success()]) task = create_archive_task(self.archive.model_dump_json()) m_args.assert_called_once() m_store.assert_called_once_with("interstellar") m_insert.assert_called_once() - m_orchestrator.run.assert_called_once() + m_urls.assert_called_once() + m_orchestrator.return_value.run.assert_called_once() assert task["status"] == "success" assert task["metadata"]["url"] == self.URL @@ -43,56 +43,54 @@ class Test_create_archive_task(): with pytest.raises(Exception): create_archive_task(self.archive.model_dump_json()) - @patch("app.worker.main.insert_result_into_db", side_effect=Exception) + @patch("app.worker.main.ArchivingOrchestrator") @patch("app.worker.main.get_orchestrator_args") - def test_raise_db_error(self, m_args, m_insert): + def test_raise_db_error(self, m_args, m_orchestrator): from app.worker.main import create_archive_task - mock_orchestrator = self.mock_orchestrator_choice(m_args) - - with pytest.raises(Exception): - create_archive_task(self.archive.model_dump_json()) - mock_orchestrator.feed_item.assert_called_once() - - - @patch("app.worker.main.insert_result_into_db", return_value=None) - @patch("app.worker.main.get_orchestrator_args") - def test_raise_empty_result(self, m_args, m_insert): - from app.worker.main import create_archive_task - mock_orchestrator = self.mock_orchestrator_choice(m_args) + m_orchestrator.return_value.run.side_effect = Exception("Orchestrator failed") with pytest.raises(Exception) as e: create_archive_task(self.archive.model_dump_json()) - mock_orchestrator.feed_item.assert_called_once() + assert str(e.value) == "Orchestrator failed" + m_args.assert_called_once() + m_orchestrator.return_value.run.assert_called_once() - def mock_orchestrator_choice(self, m_load): - mock_orchestrator = mock.MagicMock() - mock_orchestrator.configure_mock(feed_item=mock.MagicMock(return_value=Metadata().set_url(self.URL).success())) - m_load.return_value = mock_orchestrator - return mock_orchestrator + @patch("app.worker.main.ArchivingOrchestrator") + @patch("app.worker.main.insert_result_into_db", return_value=None) + @patch("app.worker.main.get_orchestrator_args") + def test_raise_empty_result(self, m_args, m_insert, m_orchestrator): + from app.worker.main import create_archive_task + m_orchestrator.return_value.run.return_value = iter([None]) + + with pytest.raises(Exception) as e: + create_archive_task(self.archive.model_dump_json()) + assert str(e.value) == "UNABLE TO archive: https://example-live.com" + m_orchestrator.return_value.run.assert_called_once() class Test_create_sheet_task(): URL = "https://example-live.com" sheet = schemas.SubmitSheet(sheet_id="123", author_id="rick@example.com", group_id="interstellar", tags=["spaceship"]) + @patch("app.worker.main.get_all_urls", return_value=[]) + @patch("app.worker.main.ArchivingOrchestrator") @patch("app.worker.main.models.generate_uuid", return_value="constant-uuid") @patch("app.worker.main.get_store_until", return_value=datetime.now()) @patch("app.worker.main.get_orchestrator_args") - def test_success(self, m_args, m_store, m_uuid, db_session): + def test_success(self, m_args, m_store, m_uuid, m_orchestrator, m_urls, db_session): from app.worker.main import create_sheet_task assert db_session.query(models.Archive).filter(models.Archive.url == self.URL).count() == 0 mock_metadata = Metadata().set_url(self.URL).success() mock_metadata.add_media(Media("fn1.txt", urls=["outcome1.com"])) - m_orch = MagicMock() - m_orch.feed.return_value = iter([False, mock_metadata, mock_metadata]) - m_args.return_value = m_orch + + m_orchestrator.return_value.run.return_value = iter([False, mock_metadata, mock_metadata]) res = create_sheet_task(self.sheet.model_dump_json()) - m_args.assert_called_once_with("interstellar", True, {'configurations': {'gsheet_feeder': {'sheet_id': '123'}}}) - m_orch.feed.assert_called_once() + m_args.assert_called_once_with("interstellar", True, ["--gsheet_feeder.sheet_id", "123"]) + m_orchestrator.return_value.run.assert_called_once() m_store.assert_called_with("interstellar") m_store.call_count == 2 m_uuid.call_count == 2 diff --git a/app/web/db/crud.py b/app/web/db/crud.py index d3be57c..e70b01c 100644 --- a/app/web/db/crud.py +++ b/app/web/db/crud.py @@ -242,7 +242,7 @@ def get_user_sheets(db: Session, email: str) -> list[models.Sheet]: return db.query(models.Sheet).filter(models.Sheet.author_id == email).order_by(models.Sheet.last_url_archived_at.desc()).all() -async def get_sheets_by_id_hash(db: AsyncSession, frequency: str, modulo: str, id_hash: str) -> list[models.Sheet]: +async def get_sheets_by_id_hash(db: AsyncSession, frequency: str, modulo: str, id_hash: int) -> list[models.Sheet]: result = await db.execute( select(models.Sheet).filter(models.Sheet.frequency == frequency) ) diff --git a/app/worker/main.py b/app/worker/main.py index c8fb585..b5e74a2 100644 --- a/app/worker/main.py +++ b/app/worker/main.py @@ -4,7 +4,6 @@ import traceback, datetime from celery.signals import task_failure from loguru import logger from sqlalchemy import exc -import auto_archiver from auto_archiver.core.orchestrator import ArchivingOrchestrator from app.shared.db import models