diff --git a/.env.test b/.env.test index f67e6ff..615456b 100644 --- a/.env.test +++ b/.env.test @@ -1,4 +1,4 @@ API_SECRET="test_secret" BROKER_URL="memory://" -DATABASE_URI="sqlite:///memory" +DATABASE_URI="sqlite://" ENVIRONMENT="test" diff --git a/app/shared/db/alembic/versions/dc8582aea0bc_add_job_tables.py b/app/shared/db/alembic/versions/0eee2b7913b7_add_tables.py similarity index 92% rename from app/shared/db/alembic/versions/dc8582aea0bc_add_job_tables.py rename to app/shared/db/alembic/versions/0eee2b7913b7_add_tables.py index 95cacb5..afb9cc8 100644 --- a/app/shared/db/alembic/versions/dc8582aea0bc_add_job_tables.py +++ b/app/shared/db/alembic/versions/0eee2b7913b7_add_tables.py @@ -1,15 +1,15 @@ -"""add_job_tables +"""add_tables -Revision ID: dc8582aea0bc +Revision ID: 0eee2b7913b7 Revises: -Create Date: 2023-02-08 12:12:00.808816 +Create Date: 2023-06-29 08:33:26.123728 """ import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = "dc8582aea0bc" +revision = "0eee2b7913b7" down_revision = None branch_labels = None depends_on = None @@ -54,7 +54,7 @@ def upgrade() -> None: sa.Column("data", sa.JSON(none_as_null=True), nullable=True), sa.Column( "type", - sa.Enum("raw_transcript", name="artifacttype"), + sa.Enum("raw_transcript", "language_detection", name="artifacttype"), nullable=False, ), sa.Column( diff --git a/app/shared/db/base.py b/app/shared/db/base.py index e949340..a717f39 100644 --- a/app/shared/db/base.py +++ b/app/shared/db/base.py @@ -19,11 +19,8 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) def get_session() -> Generator[Session, None, None]: - db: Session = SessionLocal() + session: Session = SessionLocal() try: - yield db - db.commit() - except Exception: - db.rollback() + yield session finally: - db.close() + session.close() diff --git a/app/shared/db/models.py b/app/shared/db/models.py index 30c008a..e4b6f02 100644 --- a/app/shared/db/models.py +++ b/app/shared/db/models.py @@ -1,13 +1,39 @@ +import enum import uuid from sqlalchemy import JSON, VARCHAR, Column, DateTime, Enum, ForeignKey, String, func from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import Mapped, declarative_base, declarative_mixin, declared_attr -from .schemas import ArtifactType, JobStatus, JobType - Base = declarative_base() +# Enums + + +class JobType(str, enum.Enum): + """Requested type of a job.""" + + transcript = "transcribe" + translation = "translate" + language_detection = "detect_language" + + +class JobStatus(str, enum.Enum): + """Processing status of a job.""" + + create = "create" + processing = "processing" + error = "error" + success = "success" + + +class ArtifactType(str, enum.Enum): + raw_transcript = "transcript_raw" + language_detection = "language_detection" + + +# SQLAlchemy models + @declarative_mixin class WithStandardFields: diff --git a/app/shared/db/schemas.py b/app/shared/db/schemas.py index fd7266c..839cddc 100644 --- a/app/shared/db/schemas.py +++ b/app/shared/db/schemas.py @@ -1,42 +1,16 @@ -import enum from datetime import datetime from uuid import UUID from pydantic import AnyHttpUrl, BaseModel, Field +from app.shared.db.models import ArtifactType, JobStatus, JobType -class WithDbFields(BaseModel): - id: UUID - created_at: datetime - updated_at: datetime | None - - class Config: - orm_mode = True - - -class ArtifactType(str, enum.Enum): - raw_transcript = "raw_transcript" - - -class JobType(str, enum.Enum): - transcript = "transcript" - translation = "translation" - language_detection = "language_detection" - - -class JobStatus(str, enum.Enum): - """Processing status of a job.""" - - create = "create" - processing = "processing" - error = "error" - success = "success" +# JSON field types class JobConfig(BaseModel): """Configuration for a job.""" - # TODO: limit to locales selected by whisper. language: str | None = Field( description=( "Spoken language in the media file. " @@ -51,21 +25,12 @@ class JobMeta(BaseModel): error: str | None = Field( description="Will contain a descriptive error message if processing failed." ) + task_id: UUID | None = Field( description="Internal celery id of this job submission." ) -class Job(WithDbFields): - """A transcription job for one media file.""" - - status: JobStatus - type: JobType - url: AnyHttpUrl - meta: JobMeta | None - config: JobConfig | None - - class RawTranscript(BaseModel): """A single transcript passage returned by whisper.""" @@ -81,9 +46,35 @@ class RawTranscript(BaseModel): no_speech_prob: float -class Artifact(WithDbFields): - """whisper output for one job.""" +class LanguageDetection(BaseModel): + """A language detection""" - data: list[RawTranscript] | None + code: str + + +# DB objects + + +class WithDbFields(BaseModel): + id: UUID + created_at: datetime + updated_at: datetime | None + + class Config: + orm_mode = True + + +class Job(WithDbFields): + """A transcription job for one media file.""" + + status: JobStatus + type: JobType + url: AnyHttpUrl + meta: JobMeta | None + config: JobConfig | None + + +class Artifact(WithDbFields): job_id: UUID + data: LanguageDetection | RawTranscript | None type: ArtifactType diff --git a/app/shared/settings.py b/app/shared/settings.py index 39d11e2..3c8ee15 100644 --- a/app/shared/settings.py +++ b/app/shared/settings.py @@ -16,8 +16,6 @@ class Settings(BaseSettings): if "pytest" in sys.modules: - settings = Settings( - _env_file=".env.test", _env_file_encoding="utf-8" - ) # type: ignore + settings = Settings(_env_file=".env.test") # type: ignore else: settings = Settings() # type: ignore diff --git a/app/tests/conftest.py b/app/tests/conftest.py index d64f9e3..52bab7b 100644 --- a/app/tests/conftest.py +++ b/app/tests/conftest.py @@ -13,7 +13,6 @@ from app.web.main import app def pytest_configure() -> None: if not database_exists(engine.url): create_database(engine.url) - models.Base.metadata.create_all(engine) def pytest_unconfigure() -> None: @@ -21,19 +20,21 @@ def pytest_unconfigure() -> None: drop_database(engine.url) -@pytest.fixture(name="auth_headers", scope="function") -def auth_header() -> dict[str, str]: +@pytest.fixture(scope="function") +def auth_headers() -> dict[str, str]: return {"Authorization": f"Bearer {settings.API_SECRET}"} -@pytest.fixture(name="db_session", scope="function", autouse=True) +@pytest.fixture(scope="function", autouse=True) def db_session() -> Generator[Session, None, None]: + models.Base.metadata.create_all(engine) + connection = engine.connect() - transaction = connection.begin() with SessionLocal(bind=connection) as session: app.dependency_overrides[get_session] = lambda: session yield session app.dependency_overrides.clear() - transaction.rollback() connection.close() + + models.Base.metadata.drop_all(bind=engine) diff --git a/app/tests/test_api.py b/app/tests/test_api.py index 63fca22..3d7c0ad 100644 --- a/app/tests/test_api.py +++ b/app/tests/test_api.py @@ -3,8 +3,6 @@ from fastapi.testclient import TestClient from sqlalchemy.orm import Session import app.shared.db.models as models -import app.shared.db.schemas as schemas -from app.shared.db.schemas import JobStatus, JobType from app.web.main import app client = TestClient(app) @@ -13,7 +11,9 @@ client = TestClient(app) @pytest.fixture(name="mock_job", scope="function", autouse=False) def mock_job(db_session: Session) -> models.Job: job = models.Job( - url="https://example.com", type=JobType.transcript, status=JobStatus.create + url="https://example.com", + type=models.JobType.transcript, + status=models.JobStatus.create, ) db_session.add(job) @@ -28,7 +28,7 @@ def test_create_job_pass(auth_headers: dict[str, str]) -> None: res = client.post( "/api/v1/jobs", headers=auth_headers, - json={"url": "https://example.com", "type": JobType.transcript}, + json={"url": "https://example.com", "type": models.JobType.transcript}, ) assert res.status_code == 201 assert isinstance(res.json()["id"], str) @@ -43,7 +43,7 @@ def test_create_job_malformed_url(auth_headers: dict[str, str]) -> None: res = client.post( "/api/v1/jobs", headers=auth_headers, - json={"url": "example.com", "type": JobType.transcript}, + json={"url": "example.com", "type": models.JobType.transcript}, ) assert res.status_code == 422 @@ -52,7 +52,7 @@ def test_create_job_malformed_url(auth_headers: dict[str, str]) -> None: # --- def test_get_jobs_pass(auth_headers: dict[str, str], mock_job: models.Job) -> None: res = client.get( - "/api/v1/jobs?type=transcript", + "/api/v1/jobs?type=transcribe", headers=auth_headers, ) assert len(res.json()) == 1 @@ -85,7 +85,7 @@ def test_get_artifacts_pass( auth_headers: dict[str, str], db_session: Session, mock_job: models.Job ) -> None: artifact = models.Artifact( - data=[], job_id=str(mock_job.id), type=schemas.ArtifactType.raw_transcript + data=None, job_id=str(mock_job.id), type=models.ArtifactType.raw_transcript ) db_session.add(artifact) diff --git a/app/web/dtos.py b/app/web/dtos.py index 9ef63cd..480ee93 100644 --- a/app/web/dtos.py +++ b/app/web/dtos.py @@ -1,17 +1,6 @@ -from typing import Any - from pydantic import AnyHttpUrl, BaseModel, Field -import app.shared.db.schemas as schemas - - -class DetailResponse(BaseModel): - detail: str - - -DEFAULT_RESPONSES: dict[int | str, dict[str, Any]] = { - 401: {"model": DetailResponse, "description": "Not authenticated"} -} +import app.shared.db.models as models class PostJobPayload(BaseModel): @@ -21,7 +10,7 @@ class PostJobPayload(BaseModel): ) ) - type: schemas.JobType = Field( + type: models.JobType = Field( description="""Type of this job. `transcript` uses the original language of the audio. `translation` creates an automatic translation to english. diff --git a/app/web/main.py b/app/web/main.py index d440d52..2d1fd53 100644 --- a/app/web/main.py +++ b/app/web/main.py @@ -1,37 +1,37 @@ -from asyncio.log import logger +from contextlib import asynccontextmanager +from typing import Annotated from uuid import UUID from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path -from sqlalchemy import or_ from sqlalchemy.orm import Session import app.shared.db.models as models import app.shared.db.schemas as schemas -from app.shared.celery import get_celery_binding -from app.shared.db.base import get_session -from app.web.dtos import DEFAULT_RESPONSES, DetailResponse, PostJobPayload +from app.shared.db.base import SessionLocal, get_session +from app.web.dtos import PostJobPayload from app.web.security import authenticate_api_key +from app.web.task_queue import task_queue + +DatabaseSession = Annotated[Session, Depends(get_session)] + + +@asynccontextmanager +async def lifespan(_: FastAPI): + with SessionLocal() as session: + task_queue.rehydrate(session) + yield + app = FastAPI( description="whisperbox-transcribe is an async HTTP wrapper for openai/whisper.", + lifespan=lifespan, title="whisperbox-transcribe", ) -celery = get_celery_binding() - - -def queue_task(job: models.Job) -> None: - # queue an async transcription task. - # we use a signature here to allow full separation of - # worker processes and dependencies. - transcribe = celery.signature("app.worker.main.transcribe") - # TODO: catch delivery errors. - transcribe.delay(job.id) api_router = APIRouter( prefix="/api/v1", dependencies=[Depends(authenticate_api_key)], - responses={**DEFAULT_RESPONSES}, ) @@ -48,7 +48,7 @@ def api_root() -> None: ) def create_job( payload: PostJobPayload, - session: Session = Depends(get_session), + session: DatabaseSession, ) -> models.Job: """ Enqueue a new whisper job for processing. @@ -62,6 +62,7 @@ def create_job( consume considerable resources while active. * Once a job is created, you can query its status by its id. """ + # create a job with status "create" and save it to the database. job = models.Job( url=payload.url, @@ -73,12 +74,7 @@ def create_job( session.add(job) session.commit() - # queue an async transcription task. - # we use a signature here to allow full separation of - # worker processes and dependencies. - transcribe = celery.signature("app.worker.main.transcribe") - # TODO: catch delivery errors. - transcribe.delay(job.id) + task_queue.queue_task(job) return job @@ -87,7 +83,8 @@ def create_job( "/jobs", response_model=list[schemas.Job], summary="Get metadata for all jobs" ) def get_transcripts( - type: schemas.JobType | None = None, session: Session = Depends(get_session) + session: DatabaseSession, + type: schemas.JobType | None = None, ) -> list[models.Job]: """Get metadata for all jobs.""" query = session.query(models.Job) @@ -101,18 +98,20 @@ def get_transcripts( @api_router.get( "/jobs/{id}", response_model=schemas.Job, - responses={404: {"model": DetailResponse, "description": "Not found"}}, summary="Get metadata for one job", ) def get_transcript( - id: UUID = Path(), session: Session = Depends(get_session) + session: DatabaseSession, + id: UUID = Path(), ) -> models.Job | None: """ Use this route to check transcription status of any given job. """ job = session.query(models.Job).filter(models.Job.id == str(id)).one_or_none() + if not job: raise HTTPException(status_code=404) + return job @@ -122,10 +121,12 @@ def get_transcript( summary="Get all artifacts for one job", ) def get_artifacts_for_job( - id: UUID = Path(), session: Session = Depends(get_session) + session: DatabaseSession, + id: UUID = Path(), ) -> list[models.Artifact]: """ - Right now, there is only one type of artifact (`raw_transcript`). + Returns all artifacts for one job. + See the type of `data` for possible data types. Returns an empty array for unfinished or non-existant jobs. """ artifacts = ( @@ -139,7 +140,8 @@ def get_artifacts_for_job( "/jobs/{id}", status_code=204, summary="Delete a job with all artifacts" ) def delete_transcript( - id: UUID = Path(), session: Session = Depends(get_session) + session: DatabaseSession, + id: UUID = Path(), ) -> None: """Remove metadata and artifacts for a single job.""" session.query(models.Job).filter(models.Job.id == str(id)).delete() @@ -147,28 +149,3 @@ def delete_transcript( app.include_router(api_router) - - -# TODO: we could use `acks_late` to handle this scenario within celery itself. -# the reason this does not work well in our case is that `visibility_timeout` -# needs to be very high since whisper workers can be long running. -# doing this application-side bears the risk of poison pilling the worker though, -# implement a workaround with an acceptable trade-off. (=> retry only once?) -@app.on_event("startup") -def on_startup() -> None: - session = get_session().__next__() - - jobs = ( - session.query(models.Job) - .filter( - or_( - models.Job.status == schemas.JobStatus.processing, - models.Job.status == schemas.JobStatus.create, - ) - ) - .order_by(models.Job.created_at) - ).all() - - logger.info(f"Requeueing {len(jobs)} jobs.") - for job in jobs: - queue_task(job) diff --git a/app/web/task_queue.py b/app/web/task_queue.py new file mode 100644 index 0000000..ae4d7f8 --- /dev/null +++ b/app/web/task_queue.py @@ -0,0 +1,47 @@ +from asyncio.log import logger + +from celery import Celery +from sqlalchemy import or_ +from sqlalchemy.orm import Session + +import app.shared.db.models as models +from app.shared.celery import get_celery_binding + + +class TaskQueue: + celery: Celery + + def __init__(self) -> None: + self.celery = get_celery_binding() + + def queue_task(self, job: models.Job): + # queue an async transcription task. + # we use a signature here to allow full separation of + # worker processes and dependencies. + transcribe = self.celery.signature("app.worker.main.transcribe") + transcribe.delay(job.id) + + def rehydrate(self, session: Session): + # TODO: we could use `acks_late` to handle this scenario within celery itself. + # the reason this does not work well in our case is that `visibility_timeout` + # needs to be very high since whisper workers can be long running. + # doing this app-side bears the risk of poison pilling the worker though, + # implement a workaround with an acceptable trade-off. (=> retry only once?) + jobs = ( + session.query(models.Job) + .filter( + or_( + models.Job.status == models.JobStatus.processing, + models.Job.status == models.JobStatus.create, + ) + ) + .order_by(models.Job.created_at) + ).all() + + logger.info(f"Requeueing {len(jobs)} jobs.") + + for job in jobs: + self.queue_task(job) + + +task_queue = TaskQueue() diff --git a/app/worker/main.py b/app/worker/main.py index 9a46c76..26d9170 100644 --- a/app/worker/main.py +++ b/app/worker/main.py @@ -10,6 +10,7 @@ import app.shared.db.schemas as schemas from app.shared.celery import get_celery_binding from app.shared.db.base import SessionLocal from app.shared.settings import settings +from app.worker.strategies.base import TaskProtocol from app.worker.strategies.local import LocalStrategy celery = get_celery_binding() @@ -30,10 +31,10 @@ class TranscribeTask(Task): return self.run(*args, **kwargs) -def select_strategy(task: Task, job: schemas.Job) -> Any: - if job.type == schemas.JobType.transcript: +def select_task_processor(task: Task, job: schemas.Job) -> TaskProtocol: + if job.type == models.JobType.transcript: return task.strategy.transcribe - elif job.type == schemas.JobType.translation: + elif job.type == models.JobType.translation: return task.strategy.translate else: return task.strategy.detect_language @@ -50,49 +51,50 @@ def transcribe(self: Task, job_id: UUID) -> None: # runs in a separate thread => requires sqlite's WAL mode to be enabled. db: Session = SessionLocal() - job = db.query(models.Job).filter(models.Job.id == job_id).one() + # unit of work: set task status to processing. - if ( - job.status == schemas.JobStatus.error - or job.status == schemas.JobStatus.success - ): - logger.warn( - "[{job.id}]: Received job that has already been processed, abort." - ) + job = db.query(models.Job).filter(models.Job.id == job_id).one_or_none() + + if job is None: + logger.warn("[{job.id}]: Received unknown job, abort.") return - logger.info(f"[{job.id}]: worker received task.") + if job.status in [models.JobStatus.error, models.JobStatus.success]: + logger.warn("[{job.id}]: job has already been processed, abort.") + return + + logger.info(f"[{job.id}]: received eligible job.") job.meta = {"task_id": self.request.id} - job.status = schemas.JobStatus.processing - db.commit() - logger.info(f"[{job.id}]: set task to status processing.") + job.status = models.JobStatus.processing + db.commit() + logger.info(f"[{job.id}]: finished setting task to status processing.") + + # unit of work: process job with whisper. job_record = schemas.Job.from_orm(job) - strategy = select_strategy(self, job_record) - result = strategy( + processor = select_task_processor(self, job_record) + + result_type, result = processor( url=job_record.url, job_id=job_record.id, config=job_record.config ) logger.info(f"[{job.id}]: successfully processed audio.") - artifact = models.Artifact( - job_id=str(job.id), data=result, type=schemas.ArtifactType.raw_transcript - ) + artifact = models.Artifact(job_id=str(job.id), data=result, type=result_type) db.add(artifact) + + job.status = models.JobStatus.success + db.commit() logger.info(f"[{job.id}]: successfully stored artifact.") - - job.status = schemas.JobStatus.success - db.commit() - - logger.info(f"[{job.id}]: set task to status success.") except Exception as e: if job and db: + db.rollback() job.meta = {**job.meta, "error": str(e)} # type: ignore - job.status = schemas.JobStatus.error + job.status = models.JobStatus.error db.commit() raise finally: diff --git a/app/worker/strategies/base.py b/app/worker/strategies/base.py new file mode 100644 index 0000000..e38c139 --- /dev/null +++ b/app/worker/strategies/base.py @@ -0,0 +1,35 @@ +from abc import ABC +from typing import Any, Protocol, Tuple +from uuid import UUID + +import app.shared.db.models as models +import app.shared.db.schemas as schemas + +TaskReturnValue = Tuple[models.ArtifactType, Any] + + +class TaskProtocol(Protocol): + def __call__( + self, url: str, job_id: UUID, config: schemas.JobConfig | None + ) -> TaskReturnValue: + ... + + +class BaseStrategy(ABC): + def transcribe( + self, url: str, job_id: UUID, config: schemas.JobConfig | None + ) -> TaskReturnValue: + raise NotImplementedError() + + def translate( + self, url: str, job_id: UUID, config: schemas.JobConfig | None + ) -> TaskReturnValue: + raise NotImplementedError() + + def detect_language( + self, url: str, job_id: UUID, config: schemas.JobConfig | None + ) -> TaskReturnValue: + raise NotImplementedError() + + def cleanup(self, job_id: UUID) -> None: + raise NotImplementedError() diff --git a/app/worker/strategies/local.py b/app/worker/strategies/local.py index db3007d..ebf4a97 100644 --- a/app/worker/strategies/local.py +++ b/app/worker/strategies/local.py @@ -7,10 +7,11 @@ from uuid import UUID import requests import torch +import whisper from pydantic import BaseModel -from whisper import load_model import app.shared.db.schemas as schemas +from app.worker.strategies.base import BaseStrategy, TaskReturnValue class DecodeOptions(BaseModel): @@ -18,40 +19,58 @@ class DecodeOptions(BaseModel): task: Literal["translate", "transcribe"] -class LocalStrategy: +class LocalStrategy(BaseStrategy): def __init__(self) -> None: if torch.cuda.is_available(): logger.info("initializing GPU model.") - self.model = load_model( + self.model = whisper.load_model( os.environ["WHISPER_MODEL"], download_root="/models" ).cuda() else: logger.info("initializing CPU model.") - self.model = load_model( + self.model = whisper.load_model( os.environ["WHISPER_MODEL"], download_root="/models" ) logger.info("initialized local strategy.") - def transcribe( - self, url: str, job_id: UUID, config: schemas.JobConfig | None - ) -> list[Any]: - return self.run_whisper( - self._download(url, job_id), "transcribe", config, job_id + def cleanup(self, job_id) -> None: + try: + os.remove(self._get_tmp_file(job_id)) + except OSError: + pass + + def transcribe(self, url, job_id, config): + return ( + schemas.ArtifactType.raw_transcript, + self._run_whisper( + self._download(url, job_id), "transcribe", config, job_id + ), ) - def translate( - self, url: str, job_id: UUID, config: schemas.JobConfig | None - ) -> list[Any]: - return self.run_whisper( - self._download(url, job_id), - "translate", - config, - job_id, + def translate(self, url, job_id, config) -> TaskReturnValue: + return ( + schemas.ArtifactType.raw_transcript, + self._run_whisper( + self._download(url, job_id), + "translate", + config, + job_id, + ), ) - def detect_language(self, url: str, config: schemas.JobConfig | None) -> list[Any]: - raise NotImplementedError("detect_language has not been implemented yet.") + def detect_language(self, url, job_id, config) -> TaskReturnValue: + file = self._download(url, job_id) + + audio = whisper.pad_or_trim(whisper.load_audio(file)) + + mel = whisper.log_mel_spectrogram(audio).to(self.model.device) + _, probs = self.model.detect_language(mel) + + return ( + schemas.ArtifactType.language_detection, + {"code": max(probs, key=probs.get)}, + ) def _download(self, url: str, job_id: UUID) -> str: # re-create folder. @@ -67,7 +86,7 @@ class LocalStrategy: return filename - def run_whisper( + def _run_whisper( self, filepath: str, task: Literal["translate", "transcribe"], @@ -90,9 +109,3 @@ class LocalStrategy: def _get_tmp_file(self, job_id: UUID) -> str: tmp = tempfile.gettempdir() return path.join(tmp, str(job_id)) - - def cleanup(self, job_id: UUID) -> None: - try: - os.remove(self._get_tmp_file(job_id)) - except OSError: - pass diff --git a/mypy.ini b/mypy.ini index a3136a3..fa12c48 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,4 +1,4 @@ [mypy] plugins = sqlalchemy.ext.mypy.plugin ignore_missing_imports = True -disallow_untyped_defs = True +disallow_untyped_defs = False