mirror of
https://github.com/bellingcat/whisperbox-transcribe.git
synced 2026-06-07 19:18:35 +03:00
refactor: use Depends for settings and session_local (#59)
This commit is contained in:
@@ -1,13 +1,9 @@
|
||||
from celery import Celery
|
||||
|
||||
from app.shared.settings import settings
|
||||
|
||||
|
||||
def get_celery_binding() -> Celery:
|
||||
celery = Celery(
|
||||
broker_url=settings.BROKER_URL,
|
||||
def get_celery_binding(broker_url: str) -> Celery:
|
||||
return Celery(
|
||||
broker_url=broker_url,
|
||||
broker_connection_retry=False,
|
||||
broker_connection_retry_on_startup=False,
|
||||
)
|
||||
|
||||
return celery
|
||||
|
||||
@@ -4,7 +4,9 @@ from alembic import context
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
|
||||
from app.shared.db.models import Base
|
||||
from app.shared.settings import settings
|
||||
from app.shared.settings import Settings
|
||||
|
||||
settings = Settings() # type: ignore
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
from typing import Any, Generator
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy import Engine, create_engine, event
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from app.shared.settings import settings
|
||||
|
||||
engine = create_engine(settings.DATABASE_URI, connect_args={"check_same_thread": False})
|
||||
|
||||
def make_engine(database_url: str):
|
||||
engine = create_engine(database_url, connect_args={"check_same_thread": False})
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def set_sqlite_pragma(conn: Any, _: Any) -> None:
|
||||
@@ -14,13 +13,9 @@ def set_sqlite_pragma(conn: Any, _: Any) -> None:
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.close()
|
||||
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
return engine
|
||||
|
||||
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
session: Session = SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
def make_session_local(engine: Engine):
|
||||
session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
return session_local
|
||||
|
||||
7
app/shared/logger.py
Normal file
7
app/shared/logger.py
Normal file
@@ -0,0 +1,7 @@
|
||||
import logging
|
||||
|
||||
logging.basicConfig()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger.setLevel(logging.INFO)
|
||||
@@ -1,5 +1,3 @@
|
||||
import sys
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
@@ -13,9 +11,3 @@ class Settings(BaseSettings):
|
||||
TASK_HARD_TIME_LIMIT: int = 4 * 60 * 60
|
||||
|
||||
ENABLE_SHARING: bool = False
|
||||
|
||||
|
||||
if "pytest" in sys.modules:
|
||||
settings = Settings(_env_file=".env.test") # type: ignore
|
||||
else:
|
||||
settings = Settings() # type: ignore
|
||||
|
||||
@@ -3,44 +3,57 @@ from fastapi.testclient import TestClient
|
||||
from sqlalchemy_utils import create_database, database_exists, drop_database
|
||||
|
||||
import app.shared.db.models as models
|
||||
from app.shared.db.base import SessionLocal, engine
|
||||
from app.shared.settings import settings
|
||||
from app.shared.db.base import make_engine, make_session_local
|
||||
from app.shared.settings import Settings
|
||||
from app.web.injections.db import get_session
|
||||
from app.web.injections.settings import get_settings
|
||||
from app.web.main import app_factory
|
||||
|
||||
|
||||
def pytest_configure() -> None:
|
||||
if not database_exists(engine.url):
|
||||
create_database(engine.url)
|
||||
|
||||
|
||||
def pytest_unconfigure() -> None:
|
||||
if database_exists(engine.url):
|
||||
drop_database(engine.url)
|
||||
@pytest.fixture()
|
||||
def settings():
|
||||
return Settings(_env_file=".env.test") # type: ignore
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def auth_headers() -> dict[str, str]:
|
||||
def auth_headers(settings) -> dict[str, str]:
|
||||
return {"Authorization": f"Bearer {settings.API_SECRET}"}
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_db():
|
||||
def test_db(settings):
|
||||
engine = make_engine(settings.DATABASE_URI)
|
||||
|
||||
if not database_exists(engine.url):
|
||||
create_database(engine.url)
|
||||
|
||||
models.Base.metadata.create_all(engine)
|
||||
|
||||
connection = engine.connect()
|
||||
yield connection
|
||||
connection.close()
|
||||
|
||||
models.Base.metadata.drop_all(bind=engine)
|
||||
drop_database(engine.url)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def db_session(test_db):
|
||||
with SessionLocal(bind=test_db) as session:
|
||||
session_local = make_session_local(test_db)
|
||||
with session_local() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client(db_session):
|
||||
app = app_factory(lambda: db_session)
|
||||
def app(db_session, settings):
|
||||
app = app_factory()
|
||||
app.dependency_overrides[get_settings] = lambda: settings
|
||||
app.dependency_overrides[get_session] = lambda: db_session
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client(app):
|
||||
client = TestClient(app)
|
||||
return client
|
||||
|
||||
@@ -66,10 +79,3 @@ def mock_artifact(db_session, mock_job):
|
||||
db_session.add(artifact)
|
||||
db_session.commit()
|
||||
return artifact
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def sharing_enabled():
|
||||
settings.ENABLE_SHARING = True
|
||||
yield
|
||||
settings.ENABLE_SHARING = False
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import app.shared.db.models as models
|
||||
from app.web.main import app_factory
|
||||
from app.shared.settings import Settings
|
||||
from app.web.injections.settings import get_settings
|
||||
|
||||
|
||||
# POST /api/v1/jobs
|
||||
@@ -69,9 +68,10 @@ def test_get_job_sharing_disabled(client, mock_job):
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_get_job_sharing_enabled(db_session, mock_job, sharing_enabled):
|
||||
# HACK: delay construction until settings are patched.
|
||||
client = TestClient(app_factory(lambda: db_session))
|
||||
def test_get_job_sharing_enabled(client, app, mock_job):
|
||||
app.dependency_overrides[get_settings] = lambda: Settings(
|
||||
_env_file=".env.test", ENABLE_SHARING=True # type: ignore
|
||||
)
|
||||
|
||||
res = client.get(
|
||||
f"/api/v1/jobs/{mock_job.id}",
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from app.shared.db.base import get_session
|
||||
from app.web.main import app_factory
|
||||
|
||||
app = app_factory(get_session)
|
||||
app = app_factory
|
||||
|
||||
0
app/web/injections/__init__.py
Normal file
0
app/web/injections/__init__.py
Normal file
26
app/web/injections/db.py
Normal file
26
app/web/injections/db.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from functools import lru_cache
|
||||
from typing import Generator
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.shared.db.base import make_engine, make_session_local
|
||||
from app.shared.settings import Settings
|
||||
from app.web.injections.settings import get_settings
|
||||
|
||||
|
||||
@lru_cache
|
||||
def session_local(database_url: str):
|
||||
engine = make_engine(database_url)
|
||||
return make_session_local(engine)
|
||||
|
||||
|
||||
def get_session_local(settings: Settings = Depends(get_settings)):
|
||||
return session_local(settings.DATABASE_URI)
|
||||
|
||||
|
||||
def get_session(
|
||||
session_local=Depends(get_session_local),
|
||||
) -> Generator[Session, None, None]:
|
||||
with session_local() as session:
|
||||
yield session
|
||||
39
app/web/injections/security.py
Normal file
39
app/web/injections/security.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from hmac import compare_digest
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from app.shared.settings import Settings
|
||||
from app.web.injections.settings import get_settings
|
||||
|
||||
|
||||
def api_key_auth(
|
||||
credentials: Annotated[
|
||||
HTTPAuthorizationCredentials, Depends(HTTPBearer(auto_error=False))
|
||||
],
|
||||
settings: Annotated[Settings, Depends(get_settings)],
|
||||
):
|
||||
validate_credentials(credentials, settings.API_SECRET)
|
||||
|
||||
|
||||
def sharing_auth(
|
||||
credentials: Annotated[
|
||||
HTTPAuthorizationCredentials, Depends(HTTPBearer(auto_error=False))
|
||||
],
|
||||
settings: Annotated[Settings, Depends(get_settings)],
|
||||
):
|
||||
if settings.ENABLE_SHARING:
|
||||
pass
|
||||
else:
|
||||
validate_credentials(credentials, settings.API_SECRET)
|
||||
|
||||
|
||||
def validate_credentials(credentials: HTTPAuthorizationCredentials, secret: str):
|
||||
# use compare_digest to counter timing attacks.
|
||||
if (
|
||||
not credentials
|
||||
or not secret
|
||||
or not compare_digest(secret, credentials.credentials)
|
||||
):
|
||||
raise HTTPException(status_code=401)
|
||||
8
app/web/injections/settings.py
Normal file
8
app/web/injections/settings.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from functools import lru_cache
|
||||
|
||||
from app.shared.settings import Settings
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings():
|
||||
return Settings() # type: ignore
|
||||
16
app/web/injections/task_queue.py
Normal file
16
app/web/injections/task_queue.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from functools import lru_cache
|
||||
|
||||
from fastapi import Depends
|
||||
|
||||
from app.shared.settings import Settings
|
||||
from app.web.injections.settings import get_settings
|
||||
from app.web.task_queue import TaskQueue
|
||||
|
||||
|
||||
@lru_cache
|
||||
def task_queue(broker_url: str):
|
||||
return TaskQueue(broker_url)
|
||||
|
||||
|
||||
def get_task_queue(settings: Settings = Depends(get_settings)):
|
||||
return task_queue(settings.BROKER_URL)
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Annotated, Callable, Generator
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path
|
||||
@@ -7,18 +7,15 @@ from sqlalchemy.orm import Session
|
||||
|
||||
import app.shared.db.models as models
|
||||
import app.web.dtos as dtos
|
||||
from app.shared.settings import settings
|
||||
from app.web.security import authenticate_api_key
|
||||
from app.web.injections.db import get_session
|
||||
from app.web.injections.security import api_key_auth, sharing_auth
|
||||
from app.web.injections.task_queue import get_task_queue
|
||||
from app.web.task_queue import TaskQueue
|
||||
|
||||
DatabaseSession = Annotated[Session, Depends(get_session)]
|
||||
|
||||
def app_factory(
|
||||
session_getter: Callable[[], Generator[Session, None, None]]
|
||||
) -> FastAPI:
|
||||
DatabaseSession = Annotated[Session, Depends(session_getter)]
|
||||
|
||||
task_queue = TaskQueue()
|
||||
|
||||
def app_factory():
|
||||
app = FastAPI(
|
||||
description=(
|
||||
"whisperbox-transcribe is an async HTTP wrapper for openai/whisper."
|
||||
@@ -28,13 +25,13 @@ def app_factory(
|
||||
|
||||
api_router = APIRouter(prefix="/api/v1")
|
||||
|
||||
@api_router.get("/", response_model=None, status_code=204)
|
||||
def api_root() -> None:
|
||||
@api_router.get("/", status_code=204)
|
||||
def api_root():
|
||||
return None
|
||||
|
||||
@api_router.get(
|
||||
"/jobs",
|
||||
dependencies=[Depends(authenticate_api_key)],
|
||||
dependencies=[Depends(api_key_auth)],
|
||||
response_model=list[dtos.Job],
|
||||
summary="Get metadata for all jobs",
|
||||
)
|
||||
@@ -52,7 +49,7 @@ def app_factory(
|
||||
|
||||
@api_router.get(
|
||||
"/jobs/{id}",
|
||||
dependencies=[] if settings.ENABLE_SHARING else [Depends(authenticate_api_key)],
|
||||
dependencies=[Depends(sharing_auth)],
|
||||
response_model=dtos.Job,
|
||||
summary="Get metadata for one job",
|
||||
)
|
||||
@@ -72,7 +69,7 @@ def app_factory(
|
||||
|
||||
@api_router.get(
|
||||
"/jobs/{id}/artifacts",
|
||||
dependencies=[] if settings.ENABLE_SHARING else [Depends(authenticate_api_key)],
|
||||
dependencies=[Depends(api_key_auth)],
|
||||
response_model=list[dtos.Artifact],
|
||||
summary="Get all artifacts for one job",
|
||||
)
|
||||
@@ -93,7 +90,7 @@ def app_factory(
|
||||
|
||||
@api_router.delete(
|
||||
"/jobs/{id}",
|
||||
dependencies=[Depends(authenticate_api_key)],
|
||||
dependencies=[Depends(sharing_auth)],
|
||||
status_code=204,
|
||||
summary="Delete a job with all artifacts",
|
||||
)
|
||||
@@ -130,7 +127,7 @@ def app_factory(
|
||||
|
||||
@api_router.post(
|
||||
"/jobs",
|
||||
dependencies=[Depends(authenticate_api_key)],
|
||||
dependencies=[Depends(api_key_auth)],
|
||||
response_model=dtos.Job,
|
||||
status_code=201,
|
||||
summary="Enqueue a new job",
|
||||
@@ -138,6 +135,7 @@ def app_factory(
|
||||
def create_job(
|
||||
payload: PostJobPayload,
|
||||
session: DatabaseSession,
|
||||
task_queue: Annotated[TaskQueue, Depends(get_task_queue)],
|
||||
) -> models.Job:
|
||||
"""
|
||||
Enqueue a new whisper job for processing.
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
from hmac import compare_digest
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from app.shared.settings import settings
|
||||
|
||||
|
||||
def authenticate_api_key(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
|
||||
) -> None:
|
||||
# use compare_digest to counter timing attacks.
|
||||
if not credentials or not compare_digest(
|
||||
settings.API_SECRET, credentials.credentials
|
||||
):
|
||||
raise HTTPException(status_code=401)
|
||||
@@ -7,8 +7,8 @@ from app.shared.celery import get_celery_binding
|
||||
class TaskQueue:
|
||||
celery: Celery
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.celery = get_celery_binding()
|
||||
def __init__(self, broker_url: str) -> None:
|
||||
self.celery = get_celery_binding(broker_url=broker_url)
|
||||
|
||||
def queue_task(self, job: models.Job):
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from asyncio.log import logger
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
@@ -7,11 +6,16 @@ from sqlalchemy.orm import Session
|
||||
|
||||
import app.shared.db.models as models
|
||||
from app.shared.celery import get_celery_binding
|
||||
from app.shared.db.base import SessionLocal
|
||||
from app.shared.settings import settings
|
||||
from app.shared.db.base import make_engine, make_session_local
|
||||
from app.shared.logger import logger
|
||||
from app.shared.settings import Settings
|
||||
from app.worker.strategies.local import LocalStrategy
|
||||
|
||||
celery = get_celery_binding()
|
||||
# TODO: refactor to be part of a Task instance.
|
||||
settings = Settings() # type: ignore
|
||||
celery = get_celery_binding(settings.BROKER_URL)
|
||||
engine = make_engine(settings.DATABASE_URI)
|
||||
SessionLocal = make_session_local(engine)
|
||||
|
||||
|
||||
class TranscribeTask(Task):
|
||||
@@ -43,14 +47,24 @@ class TranscribeTask(Task):
|
||||
task_acks_on_failure_or_timeout=True,
|
||||
task_reject_on_worker_lost=True,
|
||||
)
|
||||
def transcribe(self: Task, job_id: UUID) -> None:
|
||||
def transcribe(self: TranscribeTask, job_id: UUID) -> None:
|
||||
session: Session | None = None
|
||||
job: models.Job | None = None
|
||||
|
||||
try:
|
||||
if not self.strategy:
|
||||
raise Exception("expected a transcription strategy to be defined.")
|
||||
|
||||
# runs in a separate thread => requires sqlite's WAL mode to be enabled.
|
||||
db: Session = SessionLocal()
|
||||
session = SessionLocal()
|
||||
|
||||
# work around mypy not inferring the sum type correctly.
|
||||
if not session:
|
||||
raise Exception("failed to acquire a session.")
|
||||
|
||||
# check if passed job should be processed.
|
||||
|
||||
job = db.query(models.Job).filter(models.Job.id == job_id).one_or_none()
|
||||
job = session.query(models.Job).filter(models.Job.id == job_id).one_or_none()
|
||||
|
||||
if job is None:
|
||||
logger.warn("[{job.id}]: Received unknown job, abort.")
|
||||
@@ -62,7 +76,7 @@ def transcribe(self: Task, job_id: UUID) -> None:
|
||||
|
||||
logger.debug(f"[{job.id}]: start processing {job.type} job.")
|
||||
|
||||
if job.meta:
|
||||
if job.meta is not None:
|
||||
attempts = 1 + (job.meta.get("attempts") or 0)
|
||||
else:
|
||||
attempts = 1
|
||||
@@ -77,7 +91,7 @@ def transcribe(self: Task, job_id: UUID) -> None:
|
||||
job.meta = {"task_id": self.request.id, "attempts": attempts}
|
||||
|
||||
job.status = models.JobStatus.processing
|
||||
db.commit()
|
||||
session.commit()
|
||||
|
||||
logger.debug(f"[{job.id}]: finished setting task to {job.status}.")
|
||||
|
||||
@@ -86,25 +100,27 @@ def transcribe(self: Task, job_id: UUID) -> None:
|
||||
logger.debug(f"[{job.id}]: successfully processed audio.")
|
||||
|
||||
artifact = models.Artifact(job_id=str(job.id), data=result, type=result_type)
|
||||
db.add(artifact)
|
||||
session.add(artifact)
|
||||
|
||||
job.status = models.JobStatus.success
|
||||
db.commit()
|
||||
session.commit()
|
||||
|
||||
logger.debug(f"[{job.id}]: successfully stored artifact.")
|
||||
|
||||
except Exception as e:
|
||||
if job and db:
|
||||
if db.in_transaction():
|
||||
db.rollback()
|
||||
if job.meta:
|
||||
job.meta = {**job.meta, "error": str(e)} # type: ignore
|
||||
if job and session:
|
||||
if session.in_transaction():
|
||||
session.rollback()
|
||||
if job.meta is not None:
|
||||
job.meta = {**job.meta, "error": str(e)}
|
||||
else:
|
||||
job.meta = {"error": str(e)}
|
||||
|
||||
job.status = models.JobStatus.error
|
||||
db.commit()
|
||||
session.commit()
|
||||
raise
|
||||
finally:
|
||||
if self.strategy:
|
||||
self.strategy.cleanup(job_id)
|
||||
db.close()
|
||||
if session:
|
||||
session.close()
|
||||
|
||||
@@ -12,7 +12,7 @@ services:
|
||||
- "--entrypoints.web.address=:80"
|
||||
|
||||
web:
|
||||
command: bash -c "alembic upgrade head && uvicorn app.web:app --reload --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info"
|
||||
command: bash -c "alembic upgrade head && uvicorn app.web:app --reload --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info --factory"
|
||||
# NOTE: the docker on mac mount adapter (virtioFS) does not support flock.
|
||||
# this can cause the sqlite database to corrupt when written from worker <> api simultaneously.
|
||||
volumes:
|
||||
|
||||
1
mypy.ini
1
mypy.ini
@@ -2,3 +2,4 @@
|
||||
plugins = sqlalchemy.ext.mypy.plugin
|
||||
ignore_missing_imports = True
|
||||
disallow_untyped_defs = False
|
||||
check_untyped_defs = True
|
||||
|
||||
@@ -20,4 +20,4 @@ COPY alembic.ini .
|
||||
ENV VIRTUAL_ENV /opt/venv
|
||||
ENV PATH /opt/venv/bin:$PATH
|
||||
|
||||
CMD alembic upgrade head && uvicorn app.web:app --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info --workers 4 --proxy-headers
|
||||
CMD alembic upgrade head && uvicorn app.web:app --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info --workers 4 --proxy-headers --factory
|
||||
|
||||
Reference in New Issue
Block a user