feat: add language detection task

This commit is contained in:
Felix Spöttel
2023-06-29 09:13:11 +02:00
parent d2223206be
commit 908bd48170
15 changed files with 267 additions and 191 deletions

View File

@@ -1,17 +1,6 @@
from typing import Any
from pydantic import AnyHttpUrl, BaseModel, Field
import app.shared.db.schemas as schemas
class DetailResponse(BaseModel):
detail: str
DEFAULT_RESPONSES: dict[int | str, dict[str, Any]] = {
401: {"model": DetailResponse, "description": "Not authenticated"}
}
import app.shared.db.models as models
class PostJobPayload(BaseModel):
@@ -21,7 +10,7 @@ class PostJobPayload(BaseModel):
)
)
type: schemas.JobType = Field(
type: models.JobType = Field(
description="""Type of this job.
`transcript` uses the original language of the audio.
`translation` creates an automatic translation to english.

View File

@@ -1,37 +1,37 @@
from asyncio.log import logger
from contextlib import asynccontextmanager
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path
from sqlalchemy import or_
from sqlalchemy.orm import Session
import app.shared.db.models as models
import app.shared.db.schemas as schemas
from app.shared.celery import get_celery_binding
from app.shared.db.base import get_session
from app.web.dtos import DEFAULT_RESPONSES, DetailResponse, PostJobPayload
from app.shared.db.base import SessionLocal, get_session
from app.web.dtos import PostJobPayload
from app.web.security import authenticate_api_key
from app.web.task_queue import task_queue
DatabaseSession = Annotated[Session, Depends(get_session)]
@asynccontextmanager
async def lifespan(_: FastAPI):
with SessionLocal() as session:
task_queue.rehydrate(session)
yield
app = FastAPI(
description="whisperbox-transcribe is an async HTTP wrapper for openai/whisper.",
lifespan=lifespan,
title="whisperbox-transcribe",
)
celery = get_celery_binding()
def queue_task(job: models.Job) -> None:
# queue an async transcription task.
# we use a signature here to allow full separation of
# worker processes and dependencies.
transcribe = celery.signature("app.worker.main.transcribe")
# TODO: catch delivery errors.
transcribe.delay(job.id)
api_router = APIRouter(
prefix="/api/v1",
dependencies=[Depends(authenticate_api_key)],
responses={**DEFAULT_RESPONSES},
)
@@ -48,7 +48,7 @@ def api_root() -> None:
)
def create_job(
payload: PostJobPayload,
session: Session = Depends(get_session),
session: DatabaseSession,
) -> models.Job:
"""
Enqueue a new whisper job for processing.
@@ -62,6 +62,7 @@ def create_job(
consume considerable resources while active.
* Once a job is created, you can query its status by its id.
"""
# create a job with status "create" and save it to the database.
job = models.Job(
url=payload.url,
@@ -73,12 +74,7 @@ def create_job(
session.add(job)
session.commit()
# queue an async transcription task.
# we use a signature here to allow full separation of
# worker processes and dependencies.
transcribe = celery.signature("app.worker.main.transcribe")
# TODO: catch delivery errors.
transcribe.delay(job.id)
task_queue.queue_task(job)
return job
@@ -87,7 +83,8 @@ def create_job(
"/jobs", response_model=list[schemas.Job], summary="Get metadata for all jobs"
)
def get_transcripts(
type: schemas.JobType | None = None, session: Session = Depends(get_session)
session: DatabaseSession,
type: schemas.JobType | None = None,
) -> list[models.Job]:
"""Get metadata for all jobs."""
query = session.query(models.Job)
@@ -101,18 +98,20 @@ def get_transcripts(
@api_router.get(
"/jobs/{id}",
response_model=schemas.Job,
responses={404: {"model": DetailResponse, "description": "Not found"}},
summary="Get metadata for one job",
)
def get_transcript(
id: UUID = Path(), session: Session = Depends(get_session)
session: DatabaseSession,
id: UUID = Path(),
) -> models.Job | None:
"""
Use this route to check transcription status of any given job.
"""
job = session.query(models.Job).filter(models.Job.id == str(id)).one_or_none()
if not job:
raise HTTPException(status_code=404)
return job
@@ -122,10 +121,12 @@ def get_transcript(
summary="Get all artifacts for one job",
)
def get_artifacts_for_job(
id: UUID = Path(), session: Session = Depends(get_session)
session: DatabaseSession,
id: UUID = Path(),
) -> list[models.Artifact]:
"""
Right now, there is only one type of artifact (`raw_transcript`).
Returns all artifacts for one job.
See the type of `data` for possible data types.
Returns an empty array for unfinished or non-existant jobs.
"""
artifacts = (
@@ -139,7 +140,8 @@ def get_artifacts_for_job(
"/jobs/{id}", status_code=204, summary="Delete a job with all artifacts"
)
def delete_transcript(
id: UUID = Path(), session: Session = Depends(get_session)
session: DatabaseSession,
id: UUID = Path(),
) -> None:
"""Remove metadata and artifacts for a single job."""
session.query(models.Job).filter(models.Job.id == str(id)).delete()
@@ -147,28 +149,3 @@ def delete_transcript(
app.include_router(api_router)
# TODO: we could use `acks_late` to handle this scenario within celery itself.
# the reason this does not work well in our case is that `visibility_timeout`
# needs to be very high since whisper workers can be long running.
# doing this application-side bears the risk of poison pilling the worker though,
# implement a workaround with an acceptable trade-off. (=> retry only once?)
@app.on_event("startup")
def on_startup() -> None:
session = get_session().__next__()
jobs = (
session.query(models.Job)
.filter(
or_(
models.Job.status == schemas.JobStatus.processing,
models.Job.status == schemas.JobStatus.create,
)
)
.order_by(models.Job.created_at)
).all()
logger.info(f"Requeueing {len(jobs)} jobs.")
for job in jobs:
queue_task(job)

47
app/web/task_queue.py Normal file
View File

@@ -0,0 +1,47 @@
from asyncio.log import logger
from celery import Celery
from sqlalchemy import or_
from sqlalchemy.orm import Session
import app.shared.db.models as models
from app.shared.celery import get_celery_binding
class TaskQueue:
celery: Celery
def __init__(self) -> None:
self.celery = get_celery_binding()
def queue_task(self, job: models.Job):
# queue an async transcription task.
# we use a signature here to allow full separation of
# worker processes and dependencies.
transcribe = self.celery.signature("app.worker.main.transcribe")
transcribe.delay(job.id)
def rehydrate(self, session: Session):
# TODO: we could use `acks_late` to handle this scenario within celery itself.
# the reason this does not work well in our case is that `visibility_timeout`
# needs to be very high since whisper workers can be long running.
# doing this app-side bears the risk of poison pilling the worker though,
# implement a workaround with an acceptable trade-off. (=> retry only once?)
jobs = (
session.query(models.Job)
.filter(
or_(
models.Job.status == models.JobStatus.processing,
models.Job.status == models.JobStatus.create,
)
)
.order_by(models.Job.created_at)
).all()
logger.info(f"Requeueing {len(jobs)} jobs.")
for job in jobs:
self.queue_task(job)
task_queue = TaskQueue()