mirror of
https://github.com/bellingcat/auto-archiver-api.git
synced 2026-06-08 03:28:35 +03:00
adds missing tests
This commit is contained in:
@@ -1,7 +1,10 @@
|
||||
import os
|
||||
from typing import AsyncGenerator
|
||||
from fastapi.testclient import TestClient
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
import pytest_asyncio
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
from app.shared.settings import Settings
|
||||
from app.web.db.user_state import UserState
|
||||
@@ -60,6 +63,49 @@ def db_session(test_db):
|
||||
yield session
|
||||
|
||||
|
||||
@pytest_asyncio.fixture()
|
||||
async def async_test_db(get_settings: Settings):
|
||||
from app.shared.db import models
|
||||
from app.shared.db.database import make_async_engine
|
||||
from app.web.db.crud import get_user_group_names
|
||||
import asyncio
|
||||
|
||||
get_user_group_names.cache_clear()
|
||||
engine = await make_async_engine(get_settings.ASYNC_DATABASE_PATH)
|
||||
|
||||
fs = get_settings.ASYNC_DATABASE_PATH.replace("sqlite+aiosqlite:///", "")
|
||||
if not os.path.exists(fs):
|
||||
open(fs, 'w').close()
|
||||
|
||||
async def create_all():
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(models.Base.metadata.create_all)
|
||||
|
||||
await create_all()
|
||||
|
||||
yield engine
|
||||
|
||||
async def drop_all():
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(models.Base.metadata.drop_all)
|
||||
|
||||
await drop_all()
|
||||
|
||||
engine.dispose()
|
||||
for suffix in ["", "-wal", "-shm"]:
|
||||
new_fs = fs + suffix
|
||||
if os.path.exists(new_fs):
|
||||
os.remove(new_fs)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture()
|
||||
async def async_db_session(async_test_db: AsyncEngine) -> AsyncGenerator[AsyncSession, None]:
|
||||
from app.shared.db.database import make_async_session_local
|
||||
session_local = await make_async_session_local(async_test_db)
|
||||
async with session_local() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def app(db_session):
|
||||
from app.web.main import app_factory
|
||||
|
||||
@@ -1,8 +1,27 @@
|
||||
from app.shared.db import models
|
||||
from app.shared.db import worker_crud, models
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
from app.tests.web.db.test_crud import test_data
|
||||
|
||||
def test_update_sheet_last_url_archived_at(db_session):
|
||||
|
||||
# Create test sheet
|
||||
test_sheet = models.Sheet(id="sheet-123")
|
||||
db_session.add(test_sheet)
|
||||
db_session.commit()
|
||||
|
||||
# Test updating existing sheet
|
||||
assert isinstance(test_sheet.last_url_archived_at, datetime)
|
||||
before = test_sheet.last_url_archived_at
|
||||
assert worker_crud.update_sheet_last_url_archived_at(db_session, "sheet-123") is True
|
||||
db_session.refresh(test_sheet)
|
||||
assert isinstance(test_sheet.last_url_archived_at, datetime)
|
||||
assert test_sheet.last_url_archived_at > before
|
||||
|
||||
# Test non-existent sheet
|
||||
assert worker_crud.update_sheet_last_url_archived_at(db_session, "non-existent-sheet") is False
|
||||
|
||||
def test_get_group(test_data, db_session):
|
||||
from app.shared.db import worker_crud
|
||||
@@ -95,4 +114,4 @@ def test_create_task(db_session):
|
||||
assert nt.group_id == "spaceship"
|
||||
assert len(nt.tags) == 0
|
||||
assert len(nt.urls) == 0
|
||||
assert nt.created_at is not None
|
||||
assert nt.created_at is not None
|
||||
36
app/tests/shared/test_business_logic.py
Normal file
36
app/tests/shared/test_business_logic.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from app.shared.business_logic import get_store_archive_until
|
||||
|
||||
class Test_get_store_archive_until:
|
||||
GROUP_ID = "test-group"
|
||||
|
||||
def test_group_not_found(self, db_session):
|
||||
with pytest.raises(AssertionError) as exc:
|
||||
get_store_archive_until(db_session, self.GROUP_ID)
|
||||
assert str(exc.value) == f"Group {self.GROUP_ID} not found."
|
||||
|
||||
@patch("app.shared.db.worker_crud.get_group")
|
||||
def test_no_max_lifespan(self, mock_get_group, db_session):
|
||||
group = MagicMock()
|
||||
group.permissions = {"max_archive_lifespan_months": -1}
|
||||
mock_get_group.return_value = group
|
||||
|
||||
result = get_store_archive_until(db_session, self.GROUP_ID)
|
||||
assert result is None
|
||||
mock_get_group.assert_called_once_with(db_session, self.GROUP_ID)
|
||||
|
||||
@patch("app.shared.db.worker_crud.get_group")
|
||||
def test_with_max_lifespan(self, mock_get_group, db_session):
|
||||
group = MagicMock()
|
||||
group.permissions = {"max_archive_lifespan_months": 6}
|
||||
mock_get_group.return_value = group
|
||||
|
||||
result = get_store_archive_until(db_session, self.GROUP_ID)
|
||||
expected = datetime.now() + timedelta(days=180) # 6 months
|
||||
|
||||
assert isinstance(result, datetime)
|
||||
# Allow 1 second difference due to execution time
|
||||
assert abs(result - expected) < timedelta(seconds=1)
|
||||
mock_get_group.assert_called_once_with(db_session, self.GROUP_ID)
|
||||
31
app/tests/shared/utils/test_misc.py
Normal file
31
app/tests/shared/utils/test_misc.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from app.shared.utils.misc import fnv1a_hash_mod
|
||||
|
||||
|
||||
def test_fnv1a_hash_mod():
|
||||
# Test basic string hashing
|
||||
assert fnv1a_hash_mod("test", 10) == fnv1a_hash_mod("test", 10)
|
||||
assert 0 <= fnv1a_hash_mod("test", 10) < 10
|
||||
|
||||
# Test different strings give different hashes
|
||||
assert fnv1a_hash_mod("test1", 100) != fnv1a_hash_mod("test2", 100)
|
||||
|
||||
# Test different modulos
|
||||
hash1 = fnv1a_hash_mod("test", 5)
|
||||
hash2 = fnv1a_hash_mod("test", 10)
|
||||
assert 0 <= hash1 < 5
|
||||
assert 0 <= hash2 < 10
|
||||
|
||||
# Test empty string
|
||||
assert isinstance(fnv1a_hash_mod("", 10), int)
|
||||
assert 0 <= fnv1a_hash_mod("", 10) < 10
|
||||
|
||||
# Test long string
|
||||
long_str = "a" * 1000
|
||||
assert 0 <= fnv1a_hash_mod(long_str, 20) < 20
|
||||
|
||||
# Test unicode string
|
||||
assert isinstance(fnv1a_hash_mod("测试", 10), int)
|
||||
assert 0 <= fnv1a_hash_mod("测试", 10) < 10
|
||||
|
||||
# Test modulo = 1 edge case
|
||||
assert fnv1a_hash_mod("test", 1) == 0
|
||||
@@ -1,4 +1,4 @@
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -6,6 +6,7 @@ import yaml
|
||||
from app.shared.db import models
|
||||
from app.shared.settings import Settings
|
||||
|
||||
from app.web.db import crud
|
||||
authors = ["rick@example.com", "morty@example.com", "jerry@example.com"]
|
||||
|
||||
|
||||
@@ -62,7 +63,6 @@ def test_data(db_session):
|
||||
|
||||
|
||||
def test_get_archive(test_data, db_session):
|
||||
from app.web.db import crud
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
|
||||
# each author's archives work
|
||||
@@ -91,7 +91,6 @@ def test_get_archive(test_data, db_session):
|
||||
|
||||
|
||||
def test_search_archives_by_url(test_data, db_session):
|
||||
from app.web.db import crud
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
|
||||
# rick's archives are private
|
||||
@@ -139,7 +138,6 @@ def test_search_archives_by_url(test_data, db_session):
|
||||
|
||||
def test_search_archives_by_email(test_data, db_session):
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
from app.web.db import crud
|
||||
|
||||
# lower/upper case
|
||||
assert len(crud.search_archives_by_email(db_session, "rick@example.com")) == 34
|
||||
@@ -160,7 +158,6 @@ def test_search_archives_by_email(test_data, db_session):
|
||||
|
||||
@patch("app.web.db.crud.DATABASE_QUERY_LIMIT", new=25)
|
||||
def test_max_query_limit(test_data, db_session):
|
||||
from app.web.db import crud
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL)) == 25
|
||||
@@ -171,8 +168,6 @@ def test_max_query_limit(test_data, db_session):
|
||||
|
||||
|
||||
def test_soft_delete(test_data, db_session):
|
||||
from app.web.db import crud
|
||||
|
||||
# none deleted yet
|
||||
assert crud.get_archive(db_session, "archive-id-456-0", "rick@example.com") is not None
|
||||
assert db_session.query(models.Archive).filter(models.Archive.deleted == True).count() == 0
|
||||
@@ -189,8 +184,6 @@ def test_soft_delete(test_data, db_session):
|
||||
|
||||
|
||||
def test_count_archives(test_data, db_session):
|
||||
from app.web.db import crud
|
||||
|
||||
assert crud.count_archives(db_session) == 100
|
||||
db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").delete()
|
||||
db_session.commit()
|
||||
@@ -198,8 +191,6 @@ def test_count_archives(test_data, db_session):
|
||||
|
||||
|
||||
def test_count_archive_urls(test_data, db_session):
|
||||
from app.web.db import crud
|
||||
|
||||
assert crud.count_archive_urls(db_session) == 1000
|
||||
db_session.query(models.ArchiveUrl).filter(models.ArchiveUrl.url == "https://example-0.com/0").delete()
|
||||
db_session.commit()
|
||||
@@ -213,8 +204,6 @@ def test_count_archive_urls(test_data, db_session):
|
||||
|
||||
|
||||
def test_count_users(test_data, db_session):
|
||||
from app.web.db import crud
|
||||
|
||||
assert crud.count_users(db_session) == 3
|
||||
db_session.query(models.User).filter(models.User.email == "rick@example.com").delete()
|
||||
db_session.commit()
|
||||
@@ -232,8 +221,6 @@ def test_count_by_users_since(test_data, db_session):
|
||||
|
||||
|
||||
def test_upsert_group(test_data, db_session):
|
||||
from app.web.db import crud
|
||||
|
||||
assert db_session.query(models.Group).count() == 4
|
||||
|
||||
repeatable_params = ["desc 1", "orch.yaml", "sheet.yaml", "service_account_email@example.com", {"read": ["all"]}, ["example.com"]]
|
||||
@@ -262,8 +249,6 @@ def test_upsert_group(test_data, db_session):
|
||||
|
||||
|
||||
def test_upsert_user_groups(db_session):
|
||||
from app.web.db import crud
|
||||
|
||||
@patch('app.web.db.crud.get_settings', new=lambda: bad_setings)
|
||||
def test_missing_yaml(db_session):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
@@ -284,8 +269,6 @@ def test_upsert_user_groups(db_session):
|
||||
|
||||
|
||||
def test_create_sheet(db_session):
|
||||
from app.web.db import crud
|
||||
|
||||
assert db_session.query(models.Sheet).count() == 0
|
||||
|
||||
s = crud.create_sheet(db_session, "sheet-id-123", "sheet name", "email@example.com", "group-id", "hourly")
|
||||
@@ -305,8 +288,6 @@ def test_create_sheet(db_session):
|
||||
|
||||
|
||||
def test_get_user_sheet(test_data, db_session):
|
||||
from app.web.db import crud
|
||||
|
||||
assert crud.get_user_sheet(db_session, "", "sheet-0") is None
|
||||
assert crud.get_user_sheet(db_session, "morty@example.com", "sheet-0") is None
|
||||
|
||||
@@ -316,8 +297,6 @@ def test_get_user_sheet(test_data, db_session):
|
||||
|
||||
|
||||
def test_get_user_sheets(test_data, db_session):
|
||||
from app.web.db import crud
|
||||
|
||||
assert len(crud.get_user_sheets(db_session, "")) == 0
|
||||
rick_sheets = crud.get_user_sheets(db_session, "rick@example.com")
|
||||
assert len(rick_sheets) == 2
|
||||
@@ -326,8 +305,156 @@ def test_get_user_sheets(test_data, db_session):
|
||||
|
||||
|
||||
def test_delete_sheet(test_data, db_session):
|
||||
from app.web.db import crud
|
||||
|
||||
assert crud.delete_sheet(db_session, "sheet-0", "") == False
|
||||
assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") == True
|
||||
assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") == False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_by_store_until(async_db_session):
|
||||
# Add archives with different store_until dates
|
||||
now = datetime.now()
|
||||
archive1 = models.Archive(
|
||||
id="archive-expired-1",
|
||||
url="https://example-expired-1.com",
|
||||
result={},
|
||||
author_id="rick@example.com",
|
||||
store_until=now - timedelta(days=1)
|
||||
)
|
||||
archive2 = models.Archive(
|
||||
id="archive-expired-2",
|
||||
url="https://example-expired-2.com",
|
||||
result={},
|
||||
author_id="rick@example.com",
|
||||
store_until=now - timedelta(hours=1)
|
||||
)
|
||||
archive3 = models.Archive(
|
||||
id="archive-active",
|
||||
url="https://example-active.com",
|
||||
result={},
|
||||
author_id="rick@example.com",
|
||||
store_until=now + timedelta(days=1)
|
||||
)
|
||||
async_db_session.add_all([archive1, archive2, archive3])
|
||||
await async_db_session.commit()
|
||||
|
||||
# Should find 2 expired archives
|
||||
expired = await crud.find_by_store_until(async_db_session, now)
|
||||
assert len(list(expired)) == 2
|
||||
|
||||
# Should find 1 archive expired before 2 hours ago
|
||||
expired = await crud.find_by_store_until(async_db_session, now - timedelta(hours=2))
|
||||
assert len(list(expired)) == 1
|
||||
|
||||
# Should find no archives expired before 2 days ago
|
||||
expired = await crud.find_by_store_until(async_db_session, now - timedelta(days=2))
|
||||
assert len(list(expired)) == 0
|
||||
|
||||
# Should not find deleted archives
|
||||
archive1.deleted = True
|
||||
await async_db_session.commit()
|
||||
expired = await crud.find_by_store_until(async_db_session, now)
|
||||
assert len(list(expired)) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_sheets_by_id_hash(async_db_session):
|
||||
# Add test data
|
||||
authors = ["rick@example.com", "morty@example.com", "jerry@example.com"]
|
||||
sheets = [
|
||||
models.Sheet(id="sheet-0", name="sheet-0", author_id=authors[0], group_id=None, frequency="daily"),
|
||||
models.Sheet(id="sheet-0-2", name="sheet-0-2", author_id=authors[0], group_id="spaceship", frequency="hourly"),
|
||||
models.Sheet(id="sheet-1", name="sheet-1", author_id=authors[1], group_id=None, frequency="daily"),
|
||||
models.Sheet(id="sheet-2", name="sheet-2", author_id=authors[2], group_id=None, frequency="daily")
|
||||
]
|
||||
async_db_session.add_all(sheets)
|
||||
await async_db_session.commit()
|
||||
|
||||
with patch("app.web.db.crud.fnv1a_hash_mod", return_value=1):
|
||||
# Test retrieving hourly sheets
|
||||
hourly_sheets = await crud.get_sheets_by_id_hash(async_db_session, "hourly", 4, 1)
|
||||
assert len(hourly_sheets) == 1
|
||||
assert hourly_sheets[0].id == "sheet-0-2"
|
||||
assert hourly_sheets[0].frequency == "hourly"
|
||||
|
||||
# Test retrieving daily sheets
|
||||
daily_sheets = await crud.get_sheets_by_id_hash(async_db_session, "daily", 4, 1)
|
||||
assert len(daily_sheets) == 3
|
||||
assert all(sheet.frequency == "daily" for sheet in daily_sheets)
|
||||
assert {sheet.id for sheet in daily_sheets} == {"sheet-0", "sheet-1", "sheet-2"}
|
||||
|
||||
# Test with non-matching hash
|
||||
no_sheets = await crud.get_sheets_by_id_hash(async_db_session, "daily", 4, 3)
|
||||
assert len(no_sheets) == 0
|
||||
|
||||
# Test with non-existent frequency
|
||||
weekly_sheets = await crud.get_sheets_by_id_hash(async_db_session, "weekly", 4, 1)
|
||||
assert len(weekly_sheets) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_stale_sheets(async_db_session):
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.sql import select
|
||||
|
||||
now = datetime.now()
|
||||
active_date = now - timedelta(days=5)
|
||||
stale_date = now - timedelta(days=15)
|
||||
|
||||
# Create test sheets with different last_url_archived_at dates
|
||||
sheets = [
|
||||
models.Sheet(
|
||||
id="sheet-active-1",
|
||||
name="Active Sheet 1",
|
||||
author_id="rick@example.com",
|
||||
frequency="daily",
|
||||
last_url_archived_at=active_date
|
||||
),
|
||||
models.Sheet(
|
||||
id="sheet-active-2",
|
||||
name="Active Sheet 2",
|
||||
author_id="morty@example.com",
|
||||
frequency="hourly",
|
||||
last_url_archived_at=active_date
|
||||
),
|
||||
models.Sheet(
|
||||
id="sheet-stale-1",
|
||||
name="Stale Sheet 1",
|
||||
author_id="rick@example.com",
|
||||
frequency="daily",
|
||||
last_url_archived_at=stale_date
|
||||
),
|
||||
models.Sheet(
|
||||
id="sheet-stale-2",
|
||||
name="Stale Sheet 2",
|
||||
author_id="morty@example.com",
|
||||
frequency="daily",
|
||||
last_url_archived_at=stale_date
|
||||
)
|
||||
]
|
||||
async_db_session.add_all(sheets)
|
||||
await async_db_session.commit()
|
||||
|
||||
# Should not delete sheets with 20 days inactivity threshold
|
||||
deleted = await crud.delete_stale_sheets(async_db_session, 20)
|
||||
assert len(deleted) == 0 # No sheets should be deleted
|
||||
result = await async_db_session.execute(select(models.Sheet))
|
||||
assert len(list(result.scalars())) == 4 # All sheets should remain
|
||||
|
||||
# Should delete sheets with 7 days inactivity threshold
|
||||
deleted = await crud.delete_stale_sheets(async_db_session, 7)
|
||||
assert len(deleted) == 2 # Two authors affected
|
||||
assert len(deleted["rick@example.com"]) == 1 # One sheet deleted for Rick
|
||||
assert len(deleted["morty@example.com"]) == 1 # One sheet deleted for Morty
|
||||
assert deleted["rick@example.com"][0].id == "sheet-stale-1"
|
||||
assert deleted["morty@example.com"][0].id == "sheet-stale-2"
|
||||
|
||||
# Verify only active sheets remain
|
||||
result = await async_db_session.execute(select(models.Sheet))
|
||||
remaining = list(result.scalars())
|
||||
assert len(remaining) == 2
|
||||
assert {s.id for s in remaining} == {"sheet-active-1", "sheet-active-2"}
|
||||
|
||||
# Running again should not delete anything
|
||||
deleted = await crud.delete_stale_sheets(async_db_session, 7)
|
||||
assert len(deleted) == 0
|
||||
@@ -1,6 +1,8 @@
|
||||
from unittest.mock import MagicMock
|
||||
from fastapi.testclient import TestClient
|
||||
import pytest
|
||||
from app.shared.schemas import Usage, UsageResponse
|
||||
from app.shared.user_groups import GroupInfo
|
||||
from app.web.config import VERSION
|
||||
from app.tests.web.db.test_crud import test_data
|
||||
|
||||
@@ -13,6 +15,7 @@ def test_endpoint_home(client_with_auth):
|
||||
assert "breakingChanges" in j
|
||||
assert "groups" not in j
|
||||
|
||||
|
||||
def test_endpoint_health(client_with_auth):
|
||||
r = client_with_auth.get("/health")
|
||||
assert r.status_code == 200
|
||||
@@ -28,7 +31,7 @@ def test_endpoint_active(app):
|
||||
|
||||
from app.web.security import get_user_state
|
||||
app.dependency_overrides[get_user_state] = lambda: m_user_state
|
||||
|
||||
|
||||
# inactive user
|
||||
m_user_state.active = False
|
||||
client = TestClient(app)
|
||||
@@ -42,7 +45,6 @@ def test_endpoint_active(app):
|
||||
r = client.get("/user/active")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"active": True}
|
||||
|
||||
|
||||
|
||||
def test_no_serve_local_archive_by_default(client_with_auth):
|
||||
@@ -100,3 +102,74 @@ async def test_prometheus_metrics(test_data, client_with_token, get_settings):
|
||||
assert 'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0' in r3.text
|
||||
assert 'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0' in r3.text
|
||||
assert 'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0' in r3.text
|
||||
|
||||
|
||||
def test_endpoint_get_user_permissions_no_user_auth(client, test_no_auth):
|
||||
test_no_auth(client.get, "/user/permissions")
|
||||
|
||||
|
||||
def test_endpoint_get_user_permissions(app):
|
||||
from app.web.security import get_user_state
|
||||
|
||||
m_user_state = MagicMock()
|
||||
rv = {
|
||||
"all": GroupInfo(read=True),
|
||||
"group1": GroupInfo(archive_url=True),
|
||||
}
|
||||
from loguru import logger
|
||||
logger.info(rv)
|
||||
m_user_state.permissions = rv
|
||||
|
||||
app.dependency_overrides[get_user_state] = lambda: m_user_state
|
||||
|
||||
client = TestClient(app)
|
||||
r = client.get("/user/permissions")
|
||||
assert r.status_code == 200
|
||||
response = r.json()
|
||||
assert response.keys() == {"all", "group1"}
|
||||
assert response["all"]["read"]
|
||||
assert response["group1"]["read"] == []
|
||||
assert response["group1"]["archive_url"]
|
||||
assert response["all"]["archive_url"] == False
|
||||
|
||||
|
||||
def test_endpoint_get_user_usage_no_user_auth(client, test_no_auth):
|
||||
test_no_auth(client.get, "/user/usage")
|
||||
|
||||
|
||||
def test_endpoint_get_user_usage_inactive(app):
|
||||
from app.web.security import get_user_state
|
||||
|
||||
m_user_state = MagicMock()
|
||||
m_user_state.active = False
|
||||
|
||||
app.dependency_overrides[get_user_state] = lambda: m_user_state
|
||||
|
||||
client = TestClient(app)
|
||||
r = client.get("/user/usage")
|
||||
assert r.status_code == 403
|
||||
assert r.json() == {"detail": "User is not active."}
|
||||
|
||||
|
||||
def test_endpoint_get_user_usage_active(app):
|
||||
from app.web.security import get_user_state
|
||||
|
||||
m_user_state = MagicMock()
|
||||
m_user_state.active = True
|
||||
mock_usage = UsageResponse(
|
||||
monthly_urls=1,
|
||||
monthly_mbs=2,
|
||||
total_sheets=3,
|
||||
groups={
|
||||
"group1": Usage(monthly_urls=4, monthly_mbs=5, total_sheets=6),
|
||||
"group2": Usage(monthly_urls=7, monthly_mbs=8, total_sheets=9)
|
||||
}
|
||||
)
|
||||
m_user_state.usage.return_value = mock_usage
|
||||
|
||||
app.dependency_overrides[get_user_state] = lambda: m_user_state
|
||||
|
||||
client = TestClient(app)
|
||||
r = client.get("/user/usage")
|
||||
assert r.status_code == 200
|
||||
assert UsageResponse(**r.json()) == mock_usage
|
||||
|
||||
@@ -32,9 +32,24 @@ def test_submit_manual_archive(m1, client_with_token, db_session):
|
||||
assert [u.url for u in inserted.urls] == ["http://example.s3.com"]
|
||||
assert type(inserted.store_until) == datetime
|
||||
|
||||
|
||||
# cannot have the same URL twice
|
||||
# cannot have the same URL twice
|
||||
aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.com", "http://example.com"]}]})
|
||||
r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": False, "author_id": "jerry@gmail.com", "tags": ["test"], "url": "http://example.com"})
|
||||
assert r.status_code == 422
|
||||
assert r.json() == {"detail": "Cannot insert into DB due to integrity error, likely duplicate urls."}
|
||||
|
||||
|
||||
# test with invalid JSON
|
||||
def test_submit_manual_archive_invalid_json(client_with_token):
|
||||
r = client_with_token.post("/interop/submit-archive", json={"result": "invalid json", "public": False, "author_id": "jer", "tags": ["test"], "url": "http://example.com"})
|
||||
assert r.status_code == 422
|
||||
assert r.json() == {"detail": "Invalid JSON in result field."}
|
||||
|
||||
|
||||
@patch("app.web.endpoints.interoperability.business_logic")
|
||||
def test_submit_manual_archive_no_store_until(m_b, client_with_token, db_session):
|
||||
m_b.get_store_archive_until.side_effect = AssertionError("AssertionError")
|
||||
aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.s3.com"]}]})
|
||||
r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": True, "author_id": "jerry@gmail.com", "group_id": "spaceship", "tags": ["test"], "url": "http://example.com"})
|
||||
assert r.status_code == 422
|
||||
assert r.json() == {"detail": "AssertionError"}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
@@ -101,8 +101,21 @@ async def test_authenticate_user():
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_user_exception():
|
||||
from app.web.security import authenticate_user
|
||||
|
||||
with patch("app.web.security.requests.get") as mock_get:
|
||||
mock_get.return_value.status_code = 200
|
||||
mock_get.return_value.json.side_effect = Exception("mocked error")
|
||||
assert authenticate_user("this-will-call-requests") == (False, "exception occurred")
|
||||
|
||||
|
||||
def test_get_user_state():
|
||||
from app.web.security import get_user_state
|
||||
from app.web.db.user_state import UserState
|
||||
|
||||
mock_session = Mock()
|
||||
test_email = "test@example.com"
|
||||
|
||||
state = get_user_state(test_email, mock_session)
|
||||
|
||||
assert isinstance(state, UserState)
|
||||
assert state.email == test_email
|
||||
assert state.db == mock_session
|
||||
|
||||
@@ -10,7 +10,6 @@ from app.shared import schemas
|
||||
from auto_archiver.core import Media, Metadata
|
||||
|
||||
|
||||
|
||||
class Test_create_archive_task():
|
||||
URL = "https://example-live.com"
|
||||
archive = schemas.ArchiveCreate(url=URL, tags=["tag-celery"], public=True, author_id="rick@example.com", group_id="interstellar")
|
||||
@@ -19,20 +18,21 @@ class Test_create_archive_task():
|
||||
@patch("app.worker.main.get_all_urls", return_value=[])
|
||||
@patch("app.worker.main.insert_result_into_db")
|
||||
@patch("app.worker.main.get_store_until", return_value=datetime.now())
|
||||
@patch("app.worker.main.get_orchestrator_args", return_value=["arg1", "arg2"])
|
||||
@patch("app.worker.main.get_orchestrator_args", return_value=["arg1", "arg2"])
|
||||
@patch("celery.app.task.Task.request")
|
||||
def test_success(self, m_req, m_args, m_store, m_insert, m_urls, m_orchestrator, db_session):
|
||||
from app.worker.main import create_archive_task
|
||||
|
||||
m_req.id = "this-just-in"
|
||||
m_orchestrator.run.return_value = Metadata().set_url(self.URL).success()
|
||||
m_orchestrator.return_value.run.return_value = iter([Metadata().set_url(self.URL).success()])
|
||||
|
||||
task = create_archive_task(self.archive.model_dump_json())
|
||||
|
||||
m_args.assert_called_once()
|
||||
m_store.assert_called_once_with("interstellar")
|
||||
m_insert.assert_called_once()
|
||||
m_orchestrator.run.assert_called_once()
|
||||
m_urls.assert_called_once()
|
||||
m_orchestrator.return_value.run.assert_called_once()
|
||||
|
||||
assert task["status"] == "success"
|
||||
assert task["metadata"]["url"] == self.URL
|
||||
@@ -43,56 +43,54 @@ class Test_create_archive_task():
|
||||
with pytest.raises(Exception):
|
||||
create_archive_task(self.archive.model_dump_json())
|
||||
|
||||
@patch("app.worker.main.insert_result_into_db", side_effect=Exception)
|
||||
@patch("app.worker.main.ArchivingOrchestrator")
|
||||
@patch("app.worker.main.get_orchestrator_args")
|
||||
def test_raise_db_error(self, m_args, m_insert):
|
||||
def test_raise_db_error(self, m_args, m_orchestrator):
|
||||
from app.worker.main import create_archive_task
|
||||
mock_orchestrator = self.mock_orchestrator_choice(m_args)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
create_archive_task(self.archive.model_dump_json())
|
||||
mock_orchestrator.feed_item.assert_called_once()
|
||||
|
||||
|
||||
@patch("app.worker.main.insert_result_into_db", return_value=None)
|
||||
@patch("app.worker.main.get_orchestrator_args")
|
||||
def test_raise_empty_result(self, m_args, m_insert):
|
||||
from app.worker.main import create_archive_task
|
||||
mock_orchestrator = self.mock_orchestrator_choice(m_args)
|
||||
m_orchestrator.return_value.run.side_effect = Exception("Orchestrator failed")
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
create_archive_task(self.archive.model_dump_json())
|
||||
mock_orchestrator.feed_item.assert_called_once()
|
||||
assert str(e.value) == "Orchestrator failed"
|
||||
m_args.assert_called_once()
|
||||
m_orchestrator.return_value.run.assert_called_once()
|
||||
|
||||
def mock_orchestrator_choice(self, m_load):
|
||||
mock_orchestrator = mock.MagicMock()
|
||||
mock_orchestrator.configure_mock(feed_item=mock.MagicMock(return_value=Metadata().set_url(self.URL).success()))
|
||||
m_load.return_value = mock_orchestrator
|
||||
return mock_orchestrator
|
||||
@patch("app.worker.main.ArchivingOrchestrator")
|
||||
@patch("app.worker.main.insert_result_into_db", return_value=None)
|
||||
@patch("app.worker.main.get_orchestrator_args")
|
||||
def test_raise_empty_result(self, m_args, m_insert, m_orchestrator):
|
||||
from app.worker.main import create_archive_task
|
||||
m_orchestrator.return_value.run.return_value = iter([None])
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
create_archive_task(self.archive.model_dump_json())
|
||||
assert str(e.value) == "UNABLE TO archive: https://example-live.com"
|
||||
m_orchestrator.return_value.run.assert_called_once()
|
||||
|
||||
|
||||
class Test_create_sheet_task():
|
||||
URL = "https://example-live.com"
|
||||
sheet = schemas.SubmitSheet(sheet_id="123", author_id="rick@example.com", group_id="interstellar", tags=["spaceship"])
|
||||
|
||||
@patch("app.worker.main.get_all_urls", return_value=[])
|
||||
@patch("app.worker.main.ArchivingOrchestrator")
|
||||
@patch("app.worker.main.models.generate_uuid", return_value="constant-uuid")
|
||||
@patch("app.worker.main.get_store_until", return_value=datetime.now())
|
||||
@patch("app.worker.main.get_orchestrator_args")
|
||||
def test_success(self, m_args, m_store, m_uuid, db_session):
|
||||
def test_success(self, m_args, m_store, m_uuid, m_orchestrator, m_urls, db_session):
|
||||
from app.worker.main import create_sheet_task
|
||||
|
||||
assert db_session.query(models.Archive).filter(models.Archive.url == self.URL).count() == 0
|
||||
|
||||
mock_metadata = Metadata().set_url(self.URL).success()
|
||||
mock_metadata.add_media(Media("fn1.txt", urls=["outcome1.com"]))
|
||||
m_orch = MagicMock()
|
||||
m_orch.feed.return_value = iter([False, mock_metadata, mock_metadata])
|
||||
m_args.return_value = m_orch
|
||||
|
||||
m_orchestrator.return_value.run.return_value = iter([False, mock_metadata, mock_metadata])
|
||||
|
||||
res = create_sheet_task(self.sheet.model_dump_json())
|
||||
|
||||
m_args.assert_called_once_with("interstellar", True, {'configurations': {'gsheet_feeder': {'sheet_id': '123'}}})
|
||||
m_orch.feed.assert_called_once()
|
||||
m_args.assert_called_once_with("interstellar", True, ["--gsheet_feeder.sheet_id", "123"])
|
||||
m_orchestrator.return_value.run.assert_called_once()
|
||||
m_store.assert_called_with("interstellar")
|
||||
m_store.call_count == 2
|
||||
m_uuid.call_count == 2
|
||||
|
||||
@@ -242,7 +242,7 @@ def get_user_sheets(db: Session, email: str) -> list[models.Sheet]:
|
||||
return db.query(models.Sheet).filter(models.Sheet.author_id == email).order_by(models.Sheet.last_url_archived_at.desc()).all()
|
||||
|
||||
|
||||
async def get_sheets_by_id_hash(db: AsyncSession, frequency: str, modulo: str, id_hash: str) -> list[models.Sheet]:
|
||||
async def get_sheets_by_id_hash(db: AsyncSession, frequency: str, modulo: str, id_hash: int) -> list[models.Sheet]:
|
||||
result = await db.execute(
|
||||
select(models.Sheet).filter(models.Sheet.frequency == frequency)
|
||||
)
|
||||
|
||||
@@ -4,7 +4,6 @@ import traceback, datetime
|
||||
from celery.signals import task_failure
|
||||
from loguru import logger
|
||||
from sqlalchemy import exc
|
||||
import auto_archiver
|
||||
from auto_archiver.core.orchestrator import ArchivingOrchestrator
|
||||
|
||||
from app.shared.db import models
|
||||
|
||||
Reference in New Issue
Block a user