From b3f8d5c82a207809105146429ebf4430351a3d3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20Sp=C3=B6ttel?= <1682504+fspoettel@users.noreply.github.com> Date: Thu, 29 Jun 2023 12:38:56 +0200 Subject: [PATCH] refactor: remove shared schemas --- app/shared/db/models.py | 52 ++++++++++++++- app/shared/db/schemas.py | 80 ----------------------- app/web/dtos.py | 55 ++++++++++------ app/web/main.py | 114 ++++++++++++++++++++------------- app/web/task_queue.py | 7 +- app/worker/main.py | 25 ++++---- app/worker/strategies/base.py | 17 ++--- app/worker/strategies/local.py | 55 ++++++++-------- 8 files changed, 204 insertions(+), 201 deletions(-) delete mode 100644 app/shared/db/schemas.py diff --git a/app/shared/db/models.py b/app/shared/db/models.py index e4b6f02..761e251 100644 --- a/app/shared/db/models.py +++ b/app/shared/db/models.py @@ -1,12 +1,14 @@ import enum import uuid +from pydantic import BaseModel, Field from sqlalchemy import JSON, VARCHAR, Column, DateTime, Enum, ForeignKey, String, func from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import Mapped, declarative_base, declarative_mixin, declared_attr Base = declarative_base() + # Enums @@ -32,7 +34,55 @@ class ArtifactType(str, enum.Enum): language_detection = "language_detection" -# SQLAlchemy models +# JSON field types + + +class JobConfig(BaseModel): + """(JSON) Configuration for a job.""" + + language: str | None = Field( + description=( + "Spoken language in the media file. " + "While optional, this can improve output." + ) + ) + + +class JobMeta(BaseModel): + """(JSON) Metadata relating to a job's execution.""" + + error: str | None = Field( + description="Will contain a descriptive error message if processing failed." + ) + + task_id: uuid.UUID | None = Field( + description="Internal celery id of this job submission." + ) + + +class RawTranscript(BaseModel): + """(JSON) A single transcript passage returned by whisper.""" + + id: int + seek: int + start: float + end: float + text: str + tokens: list[int] + temperature: float + avg_logprob: float + compression_ratio: float + no_speech_prob: float + + +class LanguageDetection(BaseModel): + """A language detection""" + + language_code: str + + +# Sum type for all possible artifact data values +ArtifactData = list[RawTranscript] | LanguageDetection | None @declarative_mixin diff --git a/app/shared/db/schemas.py b/app/shared/db/schemas.py deleted file mode 100644 index 839cddc..0000000 --- a/app/shared/db/schemas.py +++ /dev/null @@ -1,80 +0,0 @@ -from datetime import datetime -from uuid import UUID - -from pydantic import AnyHttpUrl, BaseModel, Field - -from app.shared.db.models import ArtifactType, JobStatus, JobType - -# JSON field types - - -class JobConfig(BaseModel): - """Configuration for a job.""" - - language: str | None = Field( - description=( - "Spoken language in the media file. " - "While optional, this can improve output." - ) - ) - - -class JobMeta(BaseModel): - """Metadata relating to a job's execution.""" - - error: str | None = Field( - description="Will contain a descriptive error message if processing failed." - ) - - task_id: UUID | None = Field( - description="Internal celery id of this job submission." - ) - - -class RawTranscript(BaseModel): - """A single transcript passage returned by whisper.""" - - id: int - seek: int - start: float - end: float - text: str - tokens: list[int] - temperature: float - avg_logprob: float - compression_ratio: float - no_speech_prob: float - - -class LanguageDetection(BaseModel): - """A language detection""" - - code: str - - -# DB objects - - -class WithDbFields(BaseModel): - id: UUID - created_at: datetime - updated_at: datetime | None - - class Config: - orm_mode = True - - -class Job(WithDbFields): - """A transcription job for one media file.""" - - status: JobStatus - type: JobType - url: AnyHttpUrl - meta: JobMeta | None - config: JobConfig | None - - -class Artifact(WithDbFields): - job_id: UUID - data: LanguageDetection | RawTranscript | None - type: ArtifactType diff --git a/app/web/dtos.py b/app/web/dtos.py index 480ee93..735bd4d 100644 --- a/app/web/dtos.py +++ b/app/web/dtos.py @@ -1,25 +1,40 @@ -from pydantic import AnyHttpUrl, BaseModel, Field +from datetime import datetime +from uuid import UUID -import app.shared.db.models as models +from pydantic import AnyHttpUrl, BaseModel + +from app.shared.db.models import ( + ArtifactData, + ArtifactType, + JobConfig, + JobMeta, + JobStatus, + JobType, +) + +# DB objects -class PostJobPayload(BaseModel): - url: AnyHttpUrl = Field( - description=( - "URL where the media file is available. This needs to be a direct link." - ) - ) +class WithDbFields(BaseModel): + id: UUID + created_at: datetime + updated_at: datetime | None - type: models.JobType = Field( - description="""Type of this job. - `transcript` uses the original language of the audio. - `translation` creates an automatic translation to english. - `language_detection` detects language from the first 30 seconds of audio.""" - ) + class Config: + orm_mode = True - language: str | None = Field( - description=( - "Spoken language in the media file. " - "While optional, this can improve output when set." - ) - ) + +class Job(WithDbFields): + """A transcription job for one media file.""" + + status: JobStatus + type: JobType + url: AnyHttpUrl + meta: JobMeta | None + config: JobConfig | None + + +class Artifact(WithDbFields): + job_id: UUID + data: ArtifactData + type: ArtifactType diff --git a/app/web/main.py b/app/web/main.py index 2d1fd53..7aba541 100644 --- a/app/web/main.py +++ b/app/web/main.py @@ -3,12 +3,12 @@ from typing import Annotated from uuid import UUID from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path +from pydantic import AnyHttpUrl, BaseModel, Field from sqlalchemy.orm import Session import app.shared.db.models as models -import app.shared.db.schemas as schemas +import app.web.dtos as dtos from app.shared.db.base import SessionLocal, get_session -from app.web.dtos import PostJobPayload from app.web.security import authenticate_api_key from app.web.task_queue import task_queue @@ -40,54 +40,15 @@ def api_root() -> None: return None -@api_router.post( - "/jobs", - response_model=schemas.Job, - status_code=201, - summary="Enqueue a new job", -) -def create_job( - payload: PostJobPayload, - session: DatabaseSession, -) -> models.Job: - """ - Enqueue a new whisper job for processing. - Notes: - * Jobs are processed one-by-one in order of creation. - * `payload.url` needs to point directly to a media file. - * The media file is downloaded to a tmp file for the duration of processing. - enough free space needs to be available on disk. - * Media files ideally are audio files with a sampling rate of 16kHz. - other files will be transcoded automatically via ffmpeg which might - consume considerable resources while active. - * Once a job is created, you can query its status by its id. - """ - - # create a job with status "create" and save it to the database. - job = models.Job( - url=payload.url, - status=schemas.JobStatus.create, - type=payload.type, - config={"language": payload.language} if payload.language else None, - ) - - session.add(job) - session.commit() - - task_queue.queue_task(job) - - return job - - @api_router.get( - "/jobs", response_model=list[schemas.Job], summary="Get metadata for all jobs" + "/jobs", response_model=list[dtos.Job], summary="Get metadata for all jobs" ) def get_transcripts( session: DatabaseSession, - type: schemas.JobType | None = None, + type: dtos.JobType | None = None, ) -> list[models.Job]: """Get metadata for all jobs.""" - query = session.query(models.Job) + query = session.query(models.Job).order_by(models.Job.created_at.desc()) if type: query = query.filter(models.Job.type == type) @@ -97,7 +58,7 @@ def get_transcripts( @api_router.get( "/jobs/{id}", - response_model=schemas.Job, + response_model=dtos.Job, summary="Get metadata for one job", ) def get_transcript( @@ -117,7 +78,7 @@ def get_transcript( @api_router.get( "/jobs/{id}/artifacts", - response_model=list[schemas.Artifact], + response_model=list[dtos.Artifact], summary="Get all artifacts for one job", ) def get_artifacts_for_job( @@ -148,4 +109,65 @@ def delete_transcript( return None +class PostJobPayload(BaseModel): + url: AnyHttpUrl = Field( + description=( + "URL where the media file is available. This needs to be a direct link." + ) + ) + + type: models.JobType = Field( + description="""Type of this job. + `transcript` uses the original language of the audio. + `translation` creates an automatic translation to english. + `language_detection` detects language from the first 30 seconds of audio.""" + ) + + language: str | None = Field( + description=( + "Spoken language in the media file. " + "While optional, this can improve output when set." + ) + ) + + +@api_router.post( + "/jobs", + response_model=dtos.Job, + status_code=201, + summary="Enqueue a new job", +) +def create_job( + payload: PostJobPayload, + session: DatabaseSession, +) -> models.Job: + """ + Enqueue a new whisper job for processing. + Notes: + * Jobs are processed one-by-one in order of creation. + * `payload.url` needs to point directly to a media file. + * The media file is downloaded to a tmp file for the duration of processing. + enough free space needs to be available on disk. + * Media files ideally are audio files with a sampling rate of 16kHz. + other files will be transcoded automatically via ffmpeg which might + consume considerable resources while active. + * Once a job is created, you can query its status by its id. + """ + + # create a job with status "create" and save it to the database. + job = models.Job( + url=payload.url, + status=dtos.JobStatus.create, + type=payload.type, + config={"language": payload.language} if payload.language else None, + ) + + session.add(job) + session.commit() + + task_queue.queue_task(job) + + return job + + app.include_router(api_router) diff --git a/app/web/task_queue.py b/app/web/task_queue.py index ae4d7f8..582f47f 100644 --- a/app/web/task_queue.py +++ b/app/web/task_queue.py @@ -15,9 +15,10 @@ class TaskQueue: self.celery = get_celery_binding() def queue_task(self, job: models.Job): - # queue an async transcription task. - # we use a signature here to allow full separation of - # worker processes and dependencies. + """ + Queues an async transcription job. We use a celery signature here to + allow for full separation of worker processes and dependencies. + """ transcribe = self.celery.signature("app.worker.main.transcribe") transcribe.delay(job.id) diff --git a/app/worker/main.py b/app/worker/main.py index 26d9170..13eebc4 100644 --- a/app/worker/main.py +++ b/app/worker/main.py @@ -6,7 +6,6 @@ from celery import Task from sqlalchemy.orm import Session import app.shared.db.models as models -import app.shared.db.schemas as schemas from app.shared.celery import get_celery_binding from app.shared.db.base import SessionLocal from app.shared.settings import settings @@ -31,7 +30,7 @@ class TranscribeTask(Task): return self.run(*args, **kwargs) -def select_task_processor(task: Task, job: schemas.Job) -> TaskProtocol: +def select_task_processor(task: Task, job: models.Job) -> TaskProtocol: if job.type == models.JobType.transcript: return task.strategy.transcribe elif job.type == models.JobType.translation: @@ -51,7 +50,7 @@ def transcribe(self: Task, job_id: UUID) -> None: # runs in a separate thread => requires sqlite's WAL mode to be enabled. db: Session = SessionLocal() - # unit of work: set task status to processing. + # check if passed job should be processed. job = db.query(models.Job).filter(models.Job.id == job_id).one_or_none() @@ -63,24 +62,23 @@ def transcribe(self: Task, job_id: UUID) -> None: logger.warn("[{job.id}]: job has already been processed, abort.") return - logger.info(f"[{job.id}]: received eligible job.") + logger.debug(f"[{job.id}]: start processing {job.type} job.") + + # unit of work: set task status to processing. job.meta = {"task_id": self.request.id} job.status = models.JobStatus.processing - db.commit() - logger.info(f"[{job.id}]: finished setting task to status processing.") + + logger.debug(f"[{job.id}]: finished setting task to {job.status}.") # unit of work: process job with whisper. - job_record = schemas.Job.from_orm(job) - processor = select_task_processor(self, job_record) + processor = select_task_processor(self, job) - result_type, result = processor( - url=job_record.url, job_id=job_record.id, config=job_record.config - ) + result_type, result = processor(job) - logger.info(f"[{job.id}]: successfully processed audio.") + logger.debug(f"[{job.id}]: successfully processed audio.") artifact = models.Artifact(job_id=str(job.id), data=result, type=result_type) @@ -89,7 +87,8 @@ def transcribe(self: Task, job_id: UUID) -> None: job.status = models.JobStatus.success db.commit() - logger.info(f"[{job.id}]: successfully stored artifact.") + logger.debug(f"[{job.id}]: successfully stored artifact.") + except Exception as e: if job and db: db.rollback() diff --git a/app/worker/strategies/base.py b/app/worker/strategies/base.py index e38c139..8060a89 100644 --- a/app/worker/strategies/base.py +++ b/app/worker/strategies/base.py @@ -3,32 +3,23 @@ from typing import Any, Protocol, Tuple from uuid import UUID import app.shared.db.models as models -import app.shared.db.schemas as schemas TaskReturnValue = Tuple[models.ArtifactType, Any] class TaskProtocol(Protocol): - def __call__( - self, url: str, job_id: UUID, config: schemas.JobConfig | None - ) -> TaskReturnValue: + def __call__(self, job: models.Job) -> TaskReturnValue: ... class BaseStrategy(ABC): - def transcribe( - self, url: str, job_id: UUID, config: schemas.JobConfig | None - ) -> TaskReturnValue: + def transcribe(self, job: models.Job) -> TaskReturnValue: raise NotImplementedError() - def translate( - self, url: str, job_id: UUID, config: schemas.JobConfig | None - ) -> TaskReturnValue: + def translate(self, job: models.Job) -> TaskReturnValue: raise NotImplementedError() - def detect_language( - self, url: str, job_id: UUID, config: schemas.JobConfig | None - ) -> TaskReturnValue: + def detect_language(self, job: models.Job) -> TaskReturnValue: raise NotImplementedError() def cleanup(self, job_id: UUID) -> None: diff --git a/app/worker/strategies/local.py b/app/worker/strategies/local.py index ebf4a97..68f6621 100644 --- a/app/worker/strategies/local.py +++ b/app/worker/strategies/local.py @@ -9,12 +9,18 @@ import requests import torch import whisper from pydantic import BaseModel +from sqlalchemy import JSON, Column -import app.shared.db.schemas as schemas +import app.shared.db.models as models from app.worker.strategies.base import BaseStrategy, TaskReturnValue -class DecodeOptions(BaseModel): +class DecodingOptions(BaseModel): + """ + Options passed to the whipser model. + This mirrors private type `whisper.DecodingOptions`. + """ + language: str | None task: Literal["translate", "transcribe"] @@ -40,27 +46,24 @@ class LocalStrategy(BaseStrategy): except OSError: pass - def transcribe(self, url, job_id, config): - return ( - schemas.ArtifactType.raw_transcript, - self._run_whisper( - self._download(url, job_id), "transcribe", config, job_id - ), + def transcribe(self, job): + result = self._run_whisper( + self._download(job.url, job.id), "transcribe", job.config, job.id ) - def translate(self, url, job_id, config) -> TaskReturnValue: - return ( - schemas.ArtifactType.raw_transcript, - self._run_whisper( - self._download(url, job_id), - "translate", - config, - job_id, - ), - ) + return (models.ArtifactType.raw_transcript, result) - def detect_language(self, url, job_id, config) -> TaskReturnValue: - file = self._download(url, job_id) + def translate(self, job) -> TaskReturnValue: + result = self._run_whisper( + self._download(job.url, job.id), + "translate", + job.config, + job.id, + ) + return (models.ArtifactType.raw_transcript, result) + + def detect_language(self, job) -> TaskReturnValue: + file = self._download(job.url, job.id) audio = whisper.pad_or_trim(whisper.load_audio(file)) @@ -68,7 +71,7 @@ class LocalStrategy(BaseStrategy): _, probs = self.model.detect_language(mel) return ( - schemas.ArtifactType.language_detection, + models.ArtifactType.language_detection, {"code": max(probs, key=probs.get)}, ) @@ -90,16 +93,18 @@ class LocalStrategy(BaseStrategy): self, filepath: str, task: Literal["translate", "transcribe"], - config: schemas.JobConfig | None, + config: Column[JSON], job_id: UUID, ) -> list[Any]: try: - language = config.language if config else None - result = self.model.transcribe( filepath, + # turning this off might make the transcription less accurate, + # but significantly reduces amount of model halucinations. condition_on_previous_text=False, - **DecodeOptions(task=task, language=language).dict(), + **DecodingOptions( + task=task, language=config.language if config else None + ).dict(), ) return result["segments"]