diff --git a/app/shared/celery.py b/app/shared/celery.py index 71c1342..7c216bd 100644 --- a/app/shared/celery.py +++ b/app/shared/celery.py @@ -1,13 +1,9 @@ from celery import Celery -from app.shared.settings import settings - -def get_celery_binding() -> Celery: - celery = Celery( - broker_url=settings.BROKER_URL, +def get_celery_binding(broker_url: str) -> Celery: + return Celery( + broker_url=broker_url, broker_connection_retry=False, broker_connection_retry_on_startup=False, ) - - return celery diff --git a/app/shared/db/alembic/env.py b/app/shared/db/alembic/env.py index 53ffd1f..97bce90 100644 --- a/app/shared/db/alembic/env.py +++ b/app/shared/db/alembic/env.py @@ -4,7 +4,9 @@ from alembic import context from sqlalchemy import engine_from_config, pool from app.shared.db.models import Base -from app.shared.settings import settings +from app.shared.settings import Settings + +settings = Settings() # type: ignore # this is the Alembic Config object, which provides # access to the values within the .ini file in use. diff --git a/app/shared/db/base.py b/app/shared/db/base.py index a717f39..00eaeba 100644 --- a/app/shared/db/base.py +++ b/app/shared/db/base.py @@ -1,26 +1,21 @@ -from typing import Any, Generator +from typing import Any -from sqlalchemy import create_engine, event -from sqlalchemy.orm import Session, sessionmaker - -from app.shared.settings import settings - -engine = create_engine(settings.DATABASE_URI, connect_args={"check_same_thread": False}) +from sqlalchemy import Engine, create_engine, event +from sqlalchemy.orm import sessionmaker -@event.listens_for(engine, "connect") -def set_sqlite_pragma(conn: Any, _: Any) -> None: - cursor = conn.cursor() - cursor.execute("PRAGMA journal_mode=WAL") - cursor.close() +def make_engine(database_url: str): + engine = create_engine(database_url, connect_args={"check_same_thread": False}) + + @event.listens_for(engine, "connect") + def set_sqlite_pragma(conn: Any, _: Any) -> None: + cursor = conn.cursor() + cursor.execute("PRAGMA journal_mode=WAL") + cursor.close() + + return engine -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - - -def get_session() -> Generator[Session, None, None]: - session: Session = SessionLocal() - try: - yield session - finally: - session.close() +def make_session_local(engine: Engine): + session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine) + return session_local diff --git a/app/shared/logger.py b/app/shared/logger.py new file mode 100644 index 0000000..b89b9a7 --- /dev/null +++ b/app/shared/logger.py @@ -0,0 +1,7 @@ +import logging + +logging.basicConfig() + +logger = logging.getLogger(__name__) + +logger.setLevel(logging.INFO) diff --git a/app/shared/settings.py b/app/shared/settings.py index 2d1f283..47b2545 100644 --- a/app/shared/settings.py +++ b/app/shared/settings.py @@ -1,5 +1,3 @@ -import sys - from pydantic_settings import BaseSettings @@ -13,9 +11,3 @@ class Settings(BaseSettings): TASK_HARD_TIME_LIMIT: int = 4 * 60 * 60 ENABLE_SHARING: bool = False - - -if "pytest" in sys.modules: - settings = Settings(_env_file=".env.test") # type: ignore -else: - settings = Settings() # type: ignore diff --git a/app/tests/conftest.py b/app/tests/conftest.py index b2b08d9..3b50d2e 100644 --- a/app/tests/conftest.py +++ b/app/tests/conftest.py @@ -3,44 +3,57 @@ from fastapi.testclient import TestClient from sqlalchemy_utils import create_database, database_exists, drop_database import app.shared.db.models as models -from app.shared.db.base import SessionLocal, engine -from app.shared.settings import settings +from app.shared.db.base import make_engine, make_session_local +from app.shared.settings import Settings +from app.web.injections.db import get_session +from app.web.injections.settings import get_settings from app.web.main import app_factory -def pytest_configure() -> None: - if not database_exists(engine.url): - create_database(engine.url) - - -def pytest_unconfigure() -> None: - if database_exists(engine.url): - drop_database(engine.url) +@pytest.fixture() +def settings(): + return Settings(_env_file=".env.test") # type: ignore @pytest.fixture() -def auth_headers() -> dict[str, str]: +def auth_headers(settings) -> dict[str, str]: return {"Authorization": f"Bearer {settings.API_SECRET}"} @pytest.fixture() -def test_db(): +def test_db(settings): + engine = make_engine(settings.DATABASE_URI) + + if not database_exists(engine.url): + create_database(engine.url) + models.Base.metadata.create_all(engine) + connection = engine.connect() yield connection connection.close() + models.Base.metadata.drop_all(bind=engine) + drop_database(engine.url) @pytest.fixture() def db_session(test_db): - with SessionLocal(bind=test_db) as session: + session_local = make_session_local(test_db) + with session_local() as session: yield session @pytest.fixture() -def client(db_session): - app = app_factory(lambda: db_session) +def app(db_session, settings): + app = app_factory() + app.dependency_overrides[get_settings] = lambda: settings + app.dependency_overrides[get_session] = lambda: db_session + return app + + +@pytest.fixture() +def client(app): client = TestClient(app) return client @@ -66,10 +79,3 @@ def mock_artifact(db_session, mock_job): db_session.add(artifact) db_session.commit() return artifact - - -@pytest.fixture() -def sharing_enabled(): - settings.ENABLE_SHARING = True - yield - settings.ENABLE_SHARING = False diff --git a/app/tests/test_api.py b/app/tests/test_api.py index 8d86677..86bd515 100644 --- a/app/tests/test_api.py +++ b/app/tests/test_api.py @@ -1,7 +1,6 @@ -from fastapi.testclient import TestClient - import app.shared.db.models as models -from app.web.main import app_factory +from app.shared.settings import Settings +from app.web.injections.settings import get_settings # POST /api/v1/jobs @@ -69,9 +68,10 @@ def test_get_job_sharing_disabled(client, mock_job): assert res.status_code == 401 -def test_get_job_sharing_enabled(db_session, mock_job, sharing_enabled): - # HACK: delay construction until settings are patched. - client = TestClient(app_factory(lambda: db_session)) +def test_get_job_sharing_enabled(client, app, mock_job): + app.dependency_overrides[get_settings] = lambda: Settings( + _env_file=".env.test", ENABLE_SHARING=True # type: ignore + ) res = client.get( f"/api/v1/jobs/{mock_job.id}", diff --git a/app/web/__init__.py b/app/web/__init__.py index 61dd17a..04518a2 100644 --- a/app/web/__init__.py +++ b/app/web/__init__.py @@ -1,4 +1,3 @@ -from app.shared.db.base import get_session from app.web.main import app_factory -app = app_factory(get_session) +app = app_factory diff --git a/app/web/injections/__init__.py b/app/web/injections/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/web/injections/db.py b/app/web/injections/db.py new file mode 100644 index 0000000..6662933 --- /dev/null +++ b/app/web/injections/db.py @@ -0,0 +1,26 @@ +from functools import lru_cache +from typing import Generator + +from fastapi import Depends +from sqlalchemy.orm import Session + +from app.shared.db.base import make_engine, make_session_local +from app.shared.settings import Settings +from app.web.injections.settings import get_settings + + +@lru_cache +def session_local(database_url: str): + engine = make_engine(database_url) + return make_session_local(engine) + + +def get_session_local(settings: Settings = Depends(get_settings)): + return session_local(settings.DATABASE_URI) + + +def get_session( + session_local=Depends(get_session_local), +) -> Generator[Session, None, None]: + with session_local() as session: + yield session diff --git a/app/web/injections/security.py b/app/web/injections/security.py new file mode 100644 index 0000000..af8e538 --- /dev/null +++ b/app/web/injections/security.py @@ -0,0 +1,39 @@ +from hmac import compare_digest +from typing import Annotated + +from fastapi import Depends, HTTPException +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer + +from app.shared.settings import Settings +from app.web.injections.settings import get_settings + + +def api_key_auth( + credentials: Annotated[ + HTTPAuthorizationCredentials, Depends(HTTPBearer(auto_error=False)) + ], + settings: Annotated[Settings, Depends(get_settings)], +): + validate_credentials(credentials, settings.API_SECRET) + + +def sharing_auth( + credentials: Annotated[ + HTTPAuthorizationCredentials, Depends(HTTPBearer(auto_error=False)) + ], + settings: Annotated[Settings, Depends(get_settings)], +): + if settings.ENABLE_SHARING: + pass + else: + validate_credentials(credentials, settings.API_SECRET) + + +def validate_credentials(credentials: HTTPAuthorizationCredentials, secret: str): + # use compare_digest to counter timing attacks. + if ( + not credentials + or not secret + or not compare_digest(secret, credentials.credentials) + ): + raise HTTPException(status_code=401) diff --git a/app/web/injections/settings.py b/app/web/injections/settings.py new file mode 100644 index 0000000..30b64a0 --- /dev/null +++ b/app/web/injections/settings.py @@ -0,0 +1,8 @@ +from functools import lru_cache + +from app.shared.settings import Settings + + +@lru_cache +def get_settings(): + return Settings() # type: ignore diff --git a/app/web/injections/task_queue.py b/app/web/injections/task_queue.py new file mode 100644 index 0000000..f8d9adc --- /dev/null +++ b/app/web/injections/task_queue.py @@ -0,0 +1,16 @@ +from functools import lru_cache + +from fastapi import Depends + +from app.shared.settings import Settings +from app.web.injections.settings import get_settings +from app.web.task_queue import TaskQueue + + +@lru_cache +def task_queue(broker_url: str): + return TaskQueue(broker_url) + + +def get_task_queue(settings: Settings = Depends(get_settings)): + return task_queue(settings.BROKER_URL) diff --git a/app/web/main.py b/app/web/main.py index 5738b9b..a5b0262 100644 --- a/app/web/main.py +++ b/app/web/main.py @@ -1,4 +1,4 @@ -from typing import Annotated, Callable, Generator +from typing import Annotated from uuid import UUID from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path @@ -7,18 +7,15 @@ from sqlalchemy.orm import Session import app.shared.db.models as models import app.web.dtos as dtos -from app.shared.settings import settings -from app.web.security import authenticate_api_key +from app.web.injections.db import get_session +from app.web.injections.security import api_key_auth, sharing_auth +from app.web.injections.task_queue import get_task_queue from app.web.task_queue import TaskQueue +DatabaseSession = Annotated[Session, Depends(get_session)] -def app_factory( - session_getter: Callable[[], Generator[Session, None, None]] -) -> FastAPI: - DatabaseSession = Annotated[Session, Depends(session_getter)] - - task_queue = TaskQueue() +def app_factory(): app = FastAPI( description=( "whisperbox-transcribe is an async HTTP wrapper for openai/whisper." @@ -28,13 +25,13 @@ def app_factory( api_router = APIRouter(prefix="/api/v1") - @api_router.get("/", response_model=None, status_code=204) - def api_root() -> None: + @api_router.get("/", status_code=204) + def api_root(): return None @api_router.get( "/jobs", - dependencies=[Depends(authenticate_api_key)], + dependencies=[Depends(api_key_auth)], response_model=list[dtos.Job], summary="Get metadata for all jobs", ) @@ -52,7 +49,7 @@ def app_factory( @api_router.get( "/jobs/{id}", - dependencies=[] if settings.ENABLE_SHARING else [Depends(authenticate_api_key)], + dependencies=[Depends(sharing_auth)], response_model=dtos.Job, summary="Get metadata for one job", ) @@ -72,7 +69,7 @@ def app_factory( @api_router.get( "/jobs/{id}/artifacts", - dependencies=[] if settings.ENABLE_SHARING else [Depends(authenticate_api_key)], + dependencies=[Depends(api_key_auth)], response_model=list[dtos.Artifact], summary="Get all artifacts for one job", ) @@ -93,7 +90,7 @@ def app_factory( @api_router.delete( "/jobs/{id}", - dependencies=[Depends(authenticate_api_key)], + dependencies=[Depends(sharing_auth)], status_code=204, summary="Delete a job with all artifacts", ) @@ -130,7 +127,7 @@ def app_factory( @api_router.post( "/jobs", - dependencies=[Depends(authenticate_api_key)], + dependencies=[Depends(api_key_auth)], response_model=dtos.Job, status_code=201, summary="Enqueue a new job", @@ -138,6 +135,7 @@ def app_factory( def create_job( payload: PostJobPayload, session: DatabaseSession, + task_queue: Annotated[TaskQueue, Depends(get_task_queue)], ) -> models.Job: """ Enqueue a new whisper job for processing. diff --git a/app/web/security.py b/app/web/security.py deleted file mode 100644 index 6d66ef1..0000000 --- a/app/web/security.py +++ /dev/null @@ -1,16 +0,0 @@ -from hmac import compare_digest - -from fastapi import Depends, HTTPException -from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer - -from app.shared.settings import settings - - -def authenticate_api_key( - credentials: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), -) -> None: - # use compare_digest to counter timing attacks. - if not credentials or not compare_digest( - settings.API_SECRET, credentials.credentials - ): - raise HTTPException(status_code=401) diff --git a/app/web/task_queue.py b/app/web/task_queue.py index 1d630ab..77d1353 100644 --- a/app/web/task_queue.py +++ b/app/web/task_queue.py @@ -7,8 +7,8 @@ from app.shared.celery import get_celery_binding class TaskQueue: celery: Celery - def __init__(self) -> None: - self.celery = get_celery_binding() + def __init__(self, broker_url: str) -> None: + self.celery = get_celery_binding(broker_url=broker_url) def queue_task(self, job: models.Job): """ diff --git a/app/worker/main.py b/app/worker/main.py index 0e98169..edff4f4 100644 --- a/app/worker/main.py +++ b/app/worker/main.py @@ -1,4 +1,3 @@ -from asyncio.log import logger from typing import Any from uuid import UUID @@ -7,11 +6,16 @@ from sqlalchemy.orm import Session import app.shared.db.models as models from app.shared.celery import get_celery_binding -from app.shared.db.base import SessionLocal -from app.shared.settings import settings +from app.shared.db.base import make_engine, make_session_local +from app.shared.logger import logger +from app.shared.settings import Settings from app.worker.strategies.local import LocalStrategy -celery = get_celery_binding() +# TODO: refactor to be part of a Task instance. +settings = Settings() # type: ignore +celery = get_celery_binding(settings.BROKER_URL) +engine = make_engine(settings.DATABASE_URI) +SessionLocal = make_session_local(engine) class TranscribeTask(Task): @@ -43,14 +47,24 @@ class TranscribeTask(Task): task_acks_on_failure_or_timeout=True, task_reject_on_worker_lost=True, ) -def transcribe(self: Task, job_id: UUID) -> None: +def transcribe(self: TranscribeTask, job_id: UUID) -> None: + session: Session | None = None + job: models.Job | None = None + try: + if not self.strategy: + raise Exception("expected a transcription strategy to be defined.") + # runs in a separate thread => requires sqlite's WAL mode to be enabled. - db: Session = SessionLocal() + session = SessionLocal() + + # work around mypy not inferring the sum type correctly. + if not session: + raise Exception("failed to acquire a session.") # check if passed job should be processed. - job = db.query(models.Job).filter(models.Job.id == job_id).one_or_none() + job = session.query(models.Job).filter(models.Job.id == job_id).one_or_none() if job is None: logger.warn("[{job.id}]: Received unknown job, abort.") @@ -62,7 +76,7 @@ def transcribe(self: Task, job_id: UUID) -> None: logger.debug(f"[{job.id}]: start processing {job.type} job.") - if job.meta: + if job.meta is not None: attempts = 1 + (job.meta.get("attempts") or 0) else: attempts = 1 @@ -77,7 +91,7 @@ def transcribe(self: Task, job_id: UUID) -> None: job.meta = {"task_id": self.request.id, "attempts": attempts} job.status = models.JobStatus.processing - db.commit() + session.commit() logger.debug(f"[{job.id}]: finished setting task to {job.status}.") @@ -86,25 +100,27 @@ def transcribe(self: Task, job_id: UUID) -> None: logger.debug(f"[{job.id}]: successfully processed audio.") artifact = models.Artifact(job_id=str(job.id), data=result, type=result_type) - db.add(artifact) + session.add(artifact) job.status = models.JobStatus.success - db.commit() + session.commit() logger.debug(f"[{job.id}]: successfully stored artifact.") except Exception as e: - if job and db: - if db.in_transaction(): - db.rollback() - if job.meta: - job.meta = {**job.meta, "error": str(e)} # type: ignore + if job and session: + if session.in_transaction(): + session.rollback() + if job.meta is not None: + job.meta = {**job.meta, "error": str(e)} else: job.meta = {"error": str(e)} job.status = models.JobStatus.error - db.commit() + session.commit() raise finally: - self.strategy.cleanup(job_id) - db.close() + if self.strategy: + self.strategy.cleanup(job_id) + if session: + session.close() diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index 41d6c28..04fe4c5 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -12,7 +12,7 @@ services: - "--entrypoints.web.address=:80" web: - command: bash -c "alembic upgrade head && uvicorn app.web:app --reload --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info" + command: bash -c "alembic upgrade head && uvicorn app.web:app --reload --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info --factory" # NOTE: the docker on mac mount adapter (virtioFS) does not support flock. # this can cause the sqlite database to corrupt when written from worker <> api simultaneously. volumes: diff --git a/mypy.ini b/mypy.ini index fa12c48..aa36381 100644 --- a/mypy.ini +++ b/mypy.ini @@ -2,3 +2,4 @@ plugins = sqlalchemy.ext.mypy.plugin ignore_missing_imports = True disallow_untyped_defs = False +check_untyped_defs = True diff --git a/web.Dockerfile b/web.Dockerfile index e96f386..0ca9e6f 100644 --- a/web.Dockerfile +++ b/web.Dockerfile @@ -20,4 +20,4 @@ COPY alembic.ini . ENV VIRTUAL_ENV /opt/venv ENV PATH /opt/venv/bin:$PATH -CMD alembic upgrade head && uvicorn app.web:app --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info --workers 4 --proxy-headers +CMD alembic upgrade head && uvicorn app.web:app --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info --workers 4 --proxy-headers --factory