Format and lint the tests directory (#58)

This commit is contained in:
Michael Plunkett
2025-02-27 12:35:23 -06:00
committed by GitHub
parent 229db7dd5c
commit d575b6f9af
15 changed files with 1894 additions and 585 deletions

View File

@@ -1,4 +1,6 @@
import os
from datetime import datetime
from http import HTTPStatus
from typing import AsyncGenerator
from unittest.mock import patch
@@ -7,15 +9,31 @@ import pytest_asyncio
from fastapi.testclient import TestClient
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from app.shared.db import models
from app.shared.db.database import (
make_async_engine,
make_async_session_local,
make_engine,
make_session_local,
)
from app.shared.settings import Settings
from app.web.config import ALLOW_ANY_EMAIL
from app.web.db import crud
from app.web.db.crud import get_user_group_names
from app.web.db.user_state import UserState
from app.web.main import app_factory
from app.web.security import (
get_token_or_user_auth,
get_user_auth,
get_user_state,
token_api_key_auth,
)
@pytest.fixture(autouse=True)
def mock_logger_add():
"""Fixture to mock loguru.logger.add for all tests."""
with patch('loguru.logger.add') as mock_add:
with patch("loguru.logger.add") as mock_add:
yield mock_add # This makes the mock available to tests
@@ -26,23 +44,22 @@ def get_settings():
@pytest.fixture(autouse=True)
def mock_settings():
with patch('app.shared.settings.Settings', return_value=Settings(_env_file=".env.test")) as mock_settings:
with patch(
"app.shared.settings.Settings",
return_value=Settings(_env_file=".env.test"),
) as mock_settings:
yield mock_settings
@pytest.fixture()
def test_db(get_settings: Settings):
from app.shared.db import models
from app.shared.db.database import make_engine
from app.web.db.crud import get_user_group_names
get_user_group_names.cache_clear()
make_engine.cache_clear()
engine = make_engine(get_settings.DATABASE_PATH)
fs = get_settings.DATABASE_PATH.replace("sqlite:///", "")
if not os.path.exists(fs):
open(fs, 'w').close()
open(fs, "w").close()
models.Base.metadata.create_all(engine)
@@ -59,7 +76,6 @@ def test_db(get_settings: Settings):
@pytest.fixture()
def db_session(test_db):
from app.shared.db.database import make_session_local
session_local = make_session_local(test_db)
with session_local() as session:
yield session
@@ -67,18 +83,12 @@ def db_session(test_db):
@pytest_asyncio.fixture()
async def async_test_db(get_settings: Settings):
import asyncio
from app.shared.db import models
from app.shared.db.database import make_async_engine
from app.web.db.crud import get_user_group_names
get_user_group_names.cache_clear()
engine = await make_async_engine(get_settings.ASYNC_DATABASE_PATH)
fs = get_settings.ASYNC_DATABASE_PATH.replace("sqlite+aiosqlite:///", "")
if not os.path.exists(fs):
open(fs, 'w').close()
open(fs, "w").close()
async def create_all():
async with engine.begin() as conn:
@@ -102,8 +112,9 @@ async def async_test_db(get_settings: Settings):
@pytest_asyncio.fixture()
async def async_db_session(async_test_db: AsyncEngine) -> AsyncGenerator[AsyncSession, None]:
from app.shared.db.database import make_async_session_local
async def async_db_session(
async_test_db: AsyncEngine,
) -> AsyncGenerator[AsyncSession, None]:
session_local = await make_async_session_local(async_test_db)
async with session_local() as session:
yield session
@@ -111,8 +122,6 @@ async def async_db_session(async_test_db: AsyncEngine) -> AsyncGenerator[AsyncSe
@pytest.fixture()
def app(db_session):
from app.web.db import crud
from app.web.main import app_factory
app = app_factory()
crud.upsert_user_groups(db_session)
return app
@@ -126,14 +135,13 @@ def client(app):
@pytest.fixture()
def app_with_auth(app, db_session):
from app.web.security import (
get_token_or_user_auth,
get_user_auth,
get_user_state,
app.dependency_overrides[get_token_or_user_auth] = (
lambda: "rick@example.com"
)
app.dependency_overrides[get_token_or_user_auth] = lambda: "rick@example.com"
app.dependency_overrides[get_user_auth] = lambda: "morty@example.com"
app.dependency_overrides[get_user_state] = lambda: UserState(db_session, "MORTY@example.com")
app.dependency_overrides[get_user_state] = lambda: UserState(
db_session, "MORTY@example.com"
)
return app
@@ -145,7 +153,6 @@ def client_with_auth(app_with_auth):
@pytest.fixture()
def app_with_token(app):
from app.web.security import get_token_or_user_auth, token_api_key_auth
app.dependency_overrides[token_api_key_auth] = lambda: ALLOW_ANY_EMAIL
app.dependency_overrides[get_token_or_user_auth] = lambda: ALLOW_ANY_EMAIL
return app
@@ -162,6 +169,93 @@ def test_no_auth():
# reusable code to ensure a method/endpoint combination is unauthorized
def no_auth(http_method, endpoint):
response = http_method(endpoint)
assert response.status_code == 403
assert response.status_code == HTTPStatus.FORBIDDEN
assert response.json() == {"detail": "Not authenticated"}
return no_auth
@pytest.fixture()
def test_data(db_session):
author_emails = [
"rick@example.com",
"morty@example.com",
"jerry@example.com",
]
# creates 3 users
for email in author_emails:
db_session.add(models.User(email=email))
db_session.commit()
assert db_session.query(models.User).count() == 3
# creates 100 archives for 3 users over 2 months with repeating URLs
for i in range(100):
author = author_emails[i % 3]
archive = models.Archive(
id=f"archive-id-456-{i}",
url=f"https://example-{i % 3}.com",
result={},
public=author == "jerry@example.com",
author_id=author,
group_id="spaceship"
if author == "morty@example.com" and i % 2 == 0
else None,
created_at=datetime(2021, (i % 2) + 1, (i % 25) + 1),
)
if i % 5 == 0:
archive.tags.append(models.Tag(id=f"tag-{i}"))
if i % 10 == 0:
archive.tags.append(models.Tag(id=f"tag-second-{i}"))
if i % 4 == 0:
archive.tags.append(models.Tag(id=f"tag-third-{i}"))
for j in range(10):
archive.urls.append(
models.ArchiveUrl(
url=f"https://example-{i}.com/{j}", key=f"media_{j}"
)
)
db_session.add(archive)
# creates a sheet for each user
for i, email in enumerate(
["rick@example.com", "morty@example.com", "jerry@example.com"]
):
db_session.add(
models.Sheet(
id=f"sheet-{i}",
name=f"sheet-{i}",
author_id=email,
group_id=None,
frequency="daily",
)
)
if email == "rick@example.com":
db_session.add(
models.Sheet(
id=f"sheet-{i}-2",
name=f"sheet-{i}-2",
author_id=email,
group_id="spaceship",
frequency="hourly",
)
)
db_session.commit()
assert db_session.query(models.Archive).count() == 100
assert db_session.query(models.Tag).count() == 20 + 10 + 25
assert db_session.query(models.ArchiveUrl).count() == 1000
assert (
db_session.query(models.ArchiveUrl)
.filter(models.ArchiveUrl.archive_id == "archive-id-456-0")
.count()
== 10
)
# setup groups
assert db_session.query(models.Group).count() == 0
crud.upsert_user_groups(db_session)
assert db_session.query(models.Group).count() == 4
assert db_session.query(models.User).count() == 3