Files
auto-archiver-api/app/tests/conftest.py
2025-03-03 13:20:50 -06:00

262 lines
7.1 KiB
Python

import os
from datetime import datetime
from http import HTTPStatus
from typing import AsyncGenerator
from unittest.mock import patch
import pytest
import pytest_asyncio
from fastapi.testclient import TestClient
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from app.shared.db import models
from app.shared.db.database import (
make_async_engine,
make_async_session_local,
make_engine,
make_session_local,
)
from app.shared.settings import Settings
from app.web.config import ALLOW_ANY_EMAIL
from app.web.db import crud
from app.web.db.crud import get_user_group_names
from app.web.db.user_state import UserState
from app.web.main import app_factory
from app.web.security import (
get_token_or_user_auth,
get_user_auth,
get_user_state,
token_api_key_auth,
)
@pytest.fixture(autouse=True)
def mock_logger_add():
"""Fixture to mock loguru.logger.add for all tests."""
with patch("loguru.logger.add") as mock_add:
yield mock_add # This makes the mock available to tests
@pytest.fixture()
def get_settings():
return Settings(_env_file=".env.test")
@pytest.fixture(autouse=True)
def mock_settings():
with patch(
"app.shared.settings.Settings",
return_value=Settings(_env_file=".env.test"),
) as mock_settings:
yield mock_settings
@pytest.fixture()
def test_db(get_settings: Settings):
get_user_group_names.cache_clear()
make_engine.cache_clear()
engine = make_engine(get_settings.DATABASE_PATH)
fs = get_settings.DATABASE_PATH.replace("sqlite:///", "")
if not os.path.exists(fs):
open(fs, "w").close()
models.Base.metadata.create_all(engine)
connection = engine.connect()
yield connection
connection.close()
models.Base.metadata.drop_all(bind=engine)
for suffix in ["", "-wal", "-shm"]:
new_fs = fs + suffix
if os.path.exists(new_fs):
os.remove(new_fs)
@pytest.fixture()
def db_session(test_db):
session_local = make_session_local(test_db)
with session_local() as session:
yield session
@pytest_asyncio.fixture()
async def async_test_db(get_settings: Settings):
get_user_group_names.cache_clear()
engine = await make_async_engine(get_settings.async_database_path)
fs = get_settings.async_database_path.replace("sqlite+aiosqlite:///", "")
if not os.path.exists(fs):
open(fs, "w").close()
async def create_all():
async with engine.begin() as conn:
await conn.run_sync(models.Base.metadata.create_all)
await create_all()
yield engine
async def drop_all():
async with engine.begin() as conn:
await conn.run_sync(models.Base.metadata.drop_all)
await drop_all()
engine.dispose()
for suffix in ["", "-wal", "-shm"]:
new_fs = fs + suffix
if os.path.exists(new_fs):
os.remove(new_fs)
@pytest_asyncio.fixture()
async def async_db_session(
async_test_db: AsyncEngine,
) -> AsyncGenerator[AsyncSession, None]:
session_local = await make_async_session_local(async_test_db)
async with session_local() as session:
yield session
@pytest.fixture()
def app(db_session):
app = app_factory()
crud.upsert_user_groups(db_session)
return app
@pytest.fixture()
def client(app):
client = TestClient(app)
return client
@pytest.fixture()
def app_with_auth(app, db_session):
app.dependency_overrides[get_token_or_user_auth] = (
lambda: "rick@example.com"
)
app.dependency_overrides[get_user_auth] = lambda: "morty@example.com"
app.dependency_overrides[get_user_state] = lambda: UserState(
db_session, "MORTY@example.com"
)
return app
@pytest.fixture()
def client_with_auth(app_with_auth):
client = TestClient(app_with_auth)
return client
@pytest.fixture()
def app_with_token(app):
app.dependency_overrides[token_api_key_auth] = lambda: ALLOW_ANY_EMAIL
app.dependency_overrides[get_token_or_user_auth] = lambda: ALLOW_ANY_EMAIL
return app
@pytest.fixture()
def client_with_token(app_with_token):
client = TestClient(app_with_token)
return client
@pytest.fixture()
def test_no_auth():
# reusable code to ensure a method/endpoint combination is unauthorized
def no_auth(http_method, endpoint):
response = http_method(endpoint)
assert response.status_code == HTTPStatus.FORBIDDEN
assert response.json() == {"detail": "Not authenticated"}
return no_auth
@pytest.fixture()
def test_data(db_session):
author_emails = [
"rick@example.com",
"morty@example.com",
"jerry@example.com",
]
# creates 3 users
for email in author_emails:
db_session.add(models.User(email=email))
db_session.commit()
assert db_session.query(models.User).count() == 3
# creates 100 archives for 3 users over 2 months with repeating URLs
for i in range(100):
author = author_emails[i % 3]
archive = models.Archive(
id=f"archive-id-456-{i}",
url=f"https://example-{i % 3}.com",
result={},
public=author == "jerry@example.com",
author_id=author,
group_id="spaceship"
if author == "morty@example.com" and i % 2 == 0
else None,
created_at=datetime(2021, (i % 2) + 1, (i % 25) + 1),
)
if i % 5 == 0:
archive.tags.append(models.Tag(id=f"tag-{i}"))
if i % 10 == 0:
archive.tags.append(models.Tag(id=f"tag-second-{i}"))
if i % 4 == 0:
archive.tags.append(models.Tag(id=f"tag-third-{i}"))
for j in range(10):
archive.urls.append(
models.ArchiveUrl(
url=f"https://example-{i}.com/{j}", key=f"media_{j}"
)
)
db_session.add(archive)
# creates a sheet for each user
for i, email in enumerate(
["rick@example.com", "morty@example.com", "jerry@example.com"]
):
db_session.add(
models.Sheet(
id=f"sheet-{i}",
name=f"sheet-{i}",
author_id=email,
group_id=None,
frequency="daily",
)
)
if email == "rick@example.com":
db_session.add(
models.Sheet(
id=f"sheet-{i}-2",
name=f"sheet-{i}-2",
author_id=email,
group_id="spaceship",
frequency="hourly",
)
)
db_session.commit()
assert db_session.query(models.Archive).count() == 100
assert db_session.query(models.Tag).count() == 20 + 10 + 25
assert db_session.query(models.ArchiveUrl).count() == 1000
assert (
db_session.query(models.ArchiveUrl)
.filter(models.ArchiveUrl.archive_id == "archive-id-456-0")
.count()
== 10
)
# setup groups
assert db_session.query(models.Group).count() == 0
crud.upsert_user_groups(db_session)
assert db_session.query(models.Group).count() == 4
assert db_session.query(models.User).count() == 3