refactor: use Depends for settings and session_local (#59)

This commit is contained in:
Felix Spöttel
2024-01-01 14:28:19 +01:00
committed by GitHub
parent 557de5a442
commit 3559aa5936
20 changed files with 207 additions and 122 deletions

View File

@@ -1,13 +1,9 @@
from celery import Celery from celery import Celery
from app.shared.settings import settings
def get_celery_binding(broker_url: str) -> Celery:
def get_celery_binding() -> Celery: return Celery(
celery = Celery( broker_url=broker_url,
broker_url=settings.BROKER_URL,
broker_connection_retry=False, broker_connection_retry=False,
broker_connection_retry_on_startup=False, broker_connection_retry_on_startup=False,
) )
return celery

View File

@@ -4,7 +4,9 @@ from alembic import context
from sqlalchemy import engine_from_config, pool from sqlalchemy import engine_from_config, pool
from app.shared.db.models import Base from app.shared.db.models import Base
from app.shared.settings import settings from app.shared.settings import Settings
settings = Settings() # type: ignore
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides
# access to the values within the .ini file in use. # access to the values within the .ini file in use.

View File

@@ -1,26 +1,21 @@
from typing import Any, Generator from typing import Any
from sqlalchemy import create_engine, event from sqlalchemy import Engine, create_engine, event
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import sessionmaker
from app.shared.settings import settings
engine = create_engine(settings.DATABASE_URI, connect_args={"check_same_thread": False})
@event.listens_for(engine, "connect") def make_engine(database_url: str):
def set_sqlite_pragma(conn: Any, _: Any) -> None: engine = create_engine(database_url, connect_args={"check_same_thread": False})
cursor = conn.cursor()
cursor.execute("PRAGMA journal_mode=WAL") @event.listens_for(engine, "connect")
cursor.close() def set_sqlite_pragma(conn: Any, _: Any) -> None:
cursor = conn.cursor()
cursor.execute("PRAGMA journal_mode=WAL")
cursor.close()
return engine
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) def make_session_local(engine: Engine):
session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
return session_local
def get_session() -> Generator[Session, None, None]:
session: Session = SessionLocal()
try:
yield session
finally:
session.close()

7
app/shared/logger.py Normal file
View File

@@ -0,0 +1,7 @@
import logging
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

View File

@@ -1,5 +1,3 @@
import sys
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
@@ -13,9 +11,3 @@ class Settings(BaseSettings):
TASK_HARD_TIME_LIMIT: int = 4 * 60 * 60 TASK_HARD_TIME_LIMIT: int = 4 * 60 * 60
ENABLE_SHARING: bool = False ENABLE_SHARING: bool = False
if "pytest" in sys.modules:
settings = Settings(_env_file=".env.test") # type: ignore
else:
settings = Settings() # type: ignore

View File

@@ -3,44 +3,57 @@ from fastapi.testclient import TestClient
from sqlalchemy_utils import create_database, database_exists, drop_database from sqlalchemy_utils import create_database, database_exists, drop_database
import app.shared.db.models as models import app.shared.db.models as models
from app.shared.db.base import SessionLocal, engine from app.shared.db.base import make_engine, make_session_local
from app.shared.settings import settings from app.shared.settings import Settings
from app.web.injections.db import get_session
from app.web.injections.settings import get_settings
from app.web.main import app_factory from app.web.main import app_factory
def pytest_configure() -> None: @pytest.fixture()
if not database_exists(engine.url): def settings():
create_database(engine.url) return Settings(_env_file=".env.test") # type: ignore
def pytest_unconfigure() -> None:
if database_exists(engine.url):
drop_database(engine.url)
@pytest.fixture() @pytest.fixture()
def auth_headers() -> dict[str, str]: def auth_headers(settings) -> dict[str, str]:
return {"Authorization": f"Bearer {settings.API_SECRET}"} return {"Authorization": f"Bearer {settings.API_SECRET}"}
@pytest.fixture() @pytest.fixture()
def test_db(): def test_db(settings):
engine = make_engine(settings.DATABASE_URI)
if not database_exists(engine.url):
create_database(engine.url)
models.Base.metadata.create_all(engine) models.Base.metadata.create_all(engine)
connection = engine.connect() connection = engine.connect()
yield connection yield connection
connection.close() connection.close()
models.Base.metadata.drop_all(bind=engine) models.Base.metadata.drop_all(bind=engine)
drop_database(engine.url)
@pytest.fixture() @pytest.fixture()
def db_session(test_db): def db_session(test_db):
with SessionLocal(bind=test_db) as session: session_local = make_session_local(test_db)
with session_local() as session:
yield session yield session
@pytest.fixture() @pytest.fixture()
def client(db_session): def app(db_session, settings):
app = app_factory(lambda: db_session) app = app_factory()
app.dependency_overrides[get_settings] = lambda: settings
app.dependency_overrides[get_session] = lambda: db_session
return app
@pytest.fixture()
def client(app):
client = TestClient(app) client = TestClient(app)
return client return client
@@ -66,10 +79,3 @@ def mock_artifact(db_session, mock_job):
db_session.add(artifact) db_session.add(artifact)
db_session.commit() db_session.commit()
return artifact return artifact
@pytest.fixture()
def sharing_enabled():
settings.ENABLE_SHARING = True
yield
settings.ENABLE_SHARING = False

View File

@@ -1,7 +1,6 @@
from fastapi.testclient import TestClient
import app.shared.db.models as models import app.shared.db.models as models
from app.web.main import app_factory from app.shared.settings import Settings
from app.web.injections.settings import get_settings
# POST /api/v1/jobs # POST /api/v1/jobs
@@ -69,9 +68,10 @@ def test_get_job_sharing_disabled(client, mock_job):
assert res.status_code == 401 assert res.status_code == 401
def test_get_job_sharing_enabled(db_session, mock_job, sharing_enabled): def test_get_job_sharing_enabled(client, app, mock_job):
# HACK: delay construction until settings are patched. app.dependency_overrides[get_settings] = lambda: Settings(
client = TestClient(app_factory(lambda: db_session)) _env_file=".env.test", ENABLE_SHARING=True # type: ignore
)
res = client.get( res = client.get(
f"/api/v1/jobs/{mock_job.id}", f"/api/v1/jobs/{mock_job.id}",

View File

@@ -1,4 +1,3 @@
from app.shared.db.base import get_session
from app.web.main import app_factory from app.web.main import app_factory
app = app_factory(get_session) app = app_factory

View File

26
app/web/injections/db.py Normal file
View File

@@ -0,0 +1,26 @@
from functools import lru_cache
from typing import Generator
from fastapi import Depends
from sqlalchemy.orm import Session
from app.shared.db.base import make_engine, make_session_local
from app.shared.settings import Settings
from app.web.injections.settings import get_settings
@lru_cache
def session_local(database_url: str):
engine = make_engine(database_url)
return make_session_local(engine)
def get_session_local(settings: Settings = Depends(get_settings)):
return session_local(settings.DATABASE_URI)
def get_session(
session_local=Depends(get_session_local),
) -> Generator[Session, None, None]:
with session_local() as session:
yield session

View File

@@ -0,0 +1,39 @@
from hmac import compare_digest
from typing import Annotated
from fastapi import Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from app.shared.settings import Settings
from app.web.injections.settings import get_settings
def api_key_auth(
credentials: Annotated[
HTTPAuthorizationCredentials, Depends(HTTPBearer(auto_error=False))
],
settings: Annotated[Settings, Depends(get_settings)],
):
validate_credentials(credentials, settings.API_SECRET)
def sharing_auth(
credentials: Annotated[
HTTPAuthorizationCredentials, Depends(HTTPBearer(auto_error=False))
],
settings: Annotated[Settings, Depends(get_settings)],
):
if settings.ENABLE_SHARING:
pass
else:
validate_credentials(credentials, settings.API_SECRET)
def validate_credentials(credentials: HTTPAuthorizationCredentials, secret: str):
# use compare_digest to counter timing attacks.
if (
not credentials
or not secret
or not compare_digest(secret, credentials.credentials)
):
raise HTTPException(status_code=401)

View File

@@ -0,0 +1,8 @@
from functools import lru_cache
from app.shared.settings import Settings
@lru_cache
def get_settings():
return Settings() # type: ignore

View File

@@ -0,0 +1,16 @@
from functools import lru_cache
from fastapi import Depends
from app.shared.settings import Settings
from app.web.injections.settings import get_settings
from app.web.task_queue import TaskQueue
@lru_cache
def task_queue(broker_url: str):
return TaskQueue(broker_url)
def get_task_queue(settings: Settings = Depends(get_settings)):
return task_queue(settings.BROKER_URL)

View File

@@ -1,4 +1,4 @@
from typing import Annotated, Callable, Generator from typing import Annotated
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path
@@ -7,18 +7,15 @@ from sqlalchemy.orm import Session
import app.shared.db.models as models import app.shared.db.models as models
import app.web.dtos as dtos import app.web.dtos as dtos
from app.shared.settings import settings from app.web.injections.db import get_session
from app.web.security import authenticate_api_key from app.web.injections.security import api_key_auth, sharing_auth
from app.web.injections.task_queue import get_task_queue
from app.web.task_queue import TaskQueue from app.web.task_queue import TaskQueue
DatabaseSession = Annotated[Session, Depends(get_session)]
def app_factory(
session_getter: Callable[[], Generator[Session, None, None]]
) -> FastAPI:
DatabaseSession = Annotated[Session, Depends(session_getter)]
task_queue = TaskQueue()
def app_factory():
app = FastAPI( app = FastAPI(
description=( description=(
"whisperbox-transcribe is an async HTTP wrapper for openai/whisper." "whisperbox-transcribe is an async HTTP wrapper for openai/whisper."
@@ -28,13 +25,13 @@ def app_factory(
api_router = APIRouter(prefix="/api/v1") api_router = APIRouter(prefix="/api/v1")
@api_router.get("/", response_model=None, status_code=204) @api_router.get("/", status_code=204)
def api_root() -> None: def api_root():
return None return None
@api_router.get( @api_router.get(
"/jobs", "/jobs",
dependencies=[Depends(authenticate_api_key)], dependencies=[Depends(api_key_auth)],
response_model=list[dtos.Job], response_model=list[dtos.Job],
summary="Get metadata for all jobs", summary="Get metadata for all jobs",
) )
@@ -52,7 +49,7 @@ def app_factory(
@api_router.get( @api_router.get(
"/jobs/{id}", "/jobs/{id}",
dependencies=[] if settings.ENABLE_SHARING else [Depends(authenticate_api_key)], dependencies=[Depends(sharing_auth)],
response_model=dtos.Job, response_model=dtos.Job,
summary="Get metadata for one job", summary="Get metadata for one job",
) )
@@ -72,7 +69,7 @@ def app_factory(
@api_router.get( @api_router.get(
"/jobs/{id}/artifacts", "/jobs/{id}/artifacts",
dependencies=[] if settings.ENABLE_SHARING else [Depends(authenticate_api_key)], dependencies=[Depends(api_key_auth)],
response_model=list[dtos.Artifact], response_model=list[dtos.Artifact],
summary="Get all artifacts for one job", summary="Get all artifacts for one job",
) )
@@ -93,7 +90,7 @@ def app_factory(
@api_router.delete( @api_router.delete(
"/jobs/{id}", "/jobs/{id}",
dependencies=[Depends(authenticate_api_key)], dependencies=[Depends(sharing_auth)],
status_code=204, status_code=204,
summary="Delete a job with all artifacts", summary="Delete a job with all artifacts",
) )
@@ -130,7 +127,7 @@ def app_factory(
@api_router.post( @api_router.post(
"/jobs", "/jobs",
dependencies=[Depends(authenticate_api_key)], dependencies=[Depends(api_key_auth)],
response_model=dtos.Job, response_model=dtos.Job,
status_code=201, status_code=201,
summary="Enqueue a new job", summary="Enqueue a new job",
@@ -138,6 +135,7 @@ def app_factory(
def create_job( def create_job(
payload: PostJobPayload, payload: PostJobPayload,
session: DatabaseSession, session: DatabaseSession,
task_queue: Annotated[TaskQueue, Depends(get_task_queue)],
) -> models.Job: ) -> models.Job:
""" """
Enqueue a new whisper job for processing. Enqueue a new whisper job for processing.

View File

@@ -1,16 +0,0 @@
from hmac import compare_digest
from fastapi import Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from app.shared.settings import settings
def authenticate_api_key(
credentials: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
) -> None:
# use compare_digest to counter timing attacks.
if not credentials or not compare_digest(
settings.API_SECRET, credentials.credentials
):
raise HTTPException(status_code=401)

View File

@@ -7,8 +7,8 @@ from app.shared.celery import get_celery_binding
class TaskQueue: class TaskQueue:
celery: Celery celery: Celery
def __init__(self) -> None: def __init__(self, broker_url: str) -> None:
self.celery = get_celery_binding() self.celery = get_celery_binding(broker_url=broker_url)
def queue_task(self, job: models.Job): def queue_task(self, job: models.Job):
""" """

View File

@@ -1,4 +1,3 @@
from asyncio.log import logger
from typing import Any from typing import Any
from uuid import UUID from uuid import UUID
@@ -7,11 +6,16 @@ from sqlalchemy.orm import Session
import app.shared.db.models as models import app.shared.db.models as models
from app.shared.celery import get_celery_binding from app.shared.celery import get_celery_binding
from app.shared.db.base import SessionLocal from app.shared.db.base import make_engine, make_session_local
from app.shared.settings import settings from app.shared.logger import logger
from app.shared.settings import Settings
from app.worker.strategies.local import LocalStrategy from app.worker.strategies.local import LocalStrategy
celery = get_celery_binding() # TODO: refactor to be part of a Task instance.
settings = Settings() # type: ignore
celery = get_celery_binding(settings.BROKER_URL)
engine = make_engine(settings.DATABASE_URI)
SessionLocal = make_session_local(engine)
class TranscribeTask(Task): class TranscribeTask(Task):
@@ -43,14 +47,24 @@ class TranscribeTask(Task):
task_acks_on_failure_or_timeout=True, task_acks_on_failure_or_timeout=True,
task_reject_on_worker_lost=True, task_reject_on_worker_lost=True,
) )
def transcribe(self: Task, job_id: UUID) -> None: def transcribe(self: TranscribeTask, job_id: UUID) -> None:
session: Session | None = None
job: models.Job | None = None
try: try:
if not self.strategy:
raise Exception("expected a transcription strategy to be defined.")
# runs in a separate thread => requires sqlite's WAL mode to be enabled. # runs in a separate thread => requires sqlite's WAL mode to be enabled.
db: Session = SessionLocal() session = SessionLocal()
# work around mypy not inferring the sum type correctly.
if not session:
raise Exception("failed to acquire a session.")
# check if passed job should be processed. # check if passed job should be processed.
job = db.query(models.Job).filter(models.Job.id == job_id).one_or_none() job = session.query(models.Job).filter(models.Job.id == job_id).one_or_none()
if job is None: if job is None:
logger.warn("[{job.id}]: Received unknown job, abort.") logger.warn("[{job.id}]: Received unknown job, abort.")
@@ -62,7 +76,7 @@ def transcribe(self: Task, job_id: UUID) -> None:
logger.debug(f"[{job.id}]: start processing {job.type} job.") logger.debug(f"[{job.id}]: start processing {job.type} job.")
if job.meta: if job.meta is not None:
attempts = 1 + (job.meta.get("attempts") or 0) attempts = 1 + (job.meta.get("attempts") or 0)
else: else:
attempts = 1 attempts = 1
@@ -77,7 +91,7 @@ def transcribe(self: Task, job_id: UUID) -> None:
job.meta = {"task_id": self.request.id, "attempts": attempts} job.meta = {"task_id": self.request.id, "attempts": attempts}
job.status = models.JobStatus.processing job.status = models.JobStatus.processing
db.commit() session.commit()
logger.debug(f"[{job.id}]: finished setting task to {job.status}.") logger.debug(f"[{job.id}]: finished setting task to {job.status}.")
@@ -86,25 +100,27 @@ def transcribe(self: Task, job_id: UUID) -> None:
logger.debug(f"[{job.id}]: successfully processed audio.") logger.debug(f"[{job.id}]: successfully processed audio.")
artifact = models.Artifact(job_id=str(job.id), data=result, type=result_type) artifact = models.Artifact(job_id=str(job.id), data=result, type=result_type)
db.add(artifact) session.add(artifact)
job.status = models.JobStatus.success job.status = models.JobStatus.success
db.commit() session.commit()
logger.debug(f"[{job.id}]: successfully stored artifact.") logger.debug(f"[{job.id}]: successfully stored artifact.")
except Exception as e: except Exception as e:
if job and db: if job and session:
if db.in_transaction(): if session.in_transaction():
db.rollback() session.rollback()
if job.meta: if job.meta is not None:
job.meta = {**job.meta, "error": str(e)} # type: ignore job.meta = {**job.meta, "error": str(e)}
else: else:
job.meta = {"error": str(e)} job.meta = {"error": str(e)}
job.status = models.JobStatus.error job.status = models.JobStatus.error
db.commit() session.commit()
raise raise
finally: finally:
self.strategy.cleanup(job_id) if self.strategy:
db.close() self.strategy.cleanup(job_id)
if session:
session.close()

View File

@@ -12,7 +12,7 @@ services:
- "--entrypoints.web.address=:80" - "--entrypoints.web.address=:80"
web: web:
command: bash -c "alembic upgrade head && uvicorn app.web:app --reload --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info" command: bash -c "alembic upgrade head && uvicorn app.web:app --reload --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info --factory"
# NOTE: the docker on mac mount adapter (virtioFS) does not support flock. # NOTE: the docker on mac mount adapter (virtioFS) does not support flock.
# this can cause the sqlite database to corrupt when written from worker <> api simultaneously. # this can cause the sqlite database to corrupt when written from worker <> api simultaneously.
volumes: volumes:

View File

@@ -2,3 +2,4 @@
plugins = sqlalchemy.ext.mypy.plugin plugins = sqlalchemy.ext.mypy.plugin
ignore_missing_imports = True ignore_missing_imports = True
disallow_untyped_defs = False disallow_untyped_defs = False
check_untyped_defs = True

View File

@@ -20,4 +20,4 @@ COPY alembic.ini .
ENV VIRTUAL_ENV /opt/venv ENV VIRTUAL_ENV /opt/venv
ENV PATH /opt/venv/bin:$PATH ENV PATH /opt/venv/bin:$PATH
CMD alembic upgrade head && uvicorn app.web:app --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info --workers 4 --proxy-headers CMD alembic upgrade head && uvicorn app.web:app --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info --workers 4 --proxy-headers --factory