mirror of
https://github.com/bellingcat/whisperbox-transcribe.git
synced 2026-06-12 21:48:35 +03:00
feat: add language detection task
This commit is contained in:
@@ -1,37 +1,37 @@
|
||||
from asyncio.log import logger
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import app.shared.db.models as models
|
||||
import app.shared.db.schemas as schemas
|
||||
from app.shared.celery import get_celery_binding
|
||||
from app.shared.db.base import get_session
|
||||
from app.web.dtos import DEFAULT_RESPONSES, DetailResponse, PostJobPayload
|
||||
from app.shared.db.base import SessionLocal, get_session
|
||||
from app.web.dtos import PostJobPayload
|
||||
from app.web.security import authenticate_api_key
|
||||
from app.web.task_queue import task_queue
|
||||
|
||||
DatabaseSession = Annotated[Session, Depends(get_session)]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_: FastAPI):
|
||||
with SessionLocal() as session:
|
||||
task_queue.rehydrate(session)
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
description="whisperbox-transcribe is an async HTTP wrapper for openai/whisper.",
|
||||
lifespan=lifespan,
|
||||
title="whisperbox-transcribe",
|
||||
)
|
||||
celery = get_celery_binding()
|
||||
|
||||
|
||||
def queue_task(job: models.Job) -> None:
|
||||
# queue an async transcription task.
|
||||
# we use a signature here to allow full separation of
|
||||
# worker processes and dependencies.
|
||||
transcribe = celery.signature("app.worker.main.transcribe")
|
||||
# TODO: catch delivery errors.
|
||||
transcribe.delay(job.id)
|
||||
|
||||
|
||||
api_router = APIRouter(
|
||||
prefix="/api/v1",
|
||||
dependencies=[Depends(authenticate_api_key)],
|
||||
responses={**DEFAULT_RESPONSES},
|
||||
)
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ def api_root() -> None:
|
||||
)
|
||||
def create_job(
|
||||
payload: PostJobPayload,
|
||||
session: Session = Depends(get_session),
|
||||
session: DatabaseSession,
|
||||
) -> models.Job:
|
||||
"""
|
||||
Enqueue a new whisper job for processing.
|
||||
@@ -62,6 +62,7 @@ def create_job(
|
||||
consume considerable resources while active.
|
||||
* Once a job is created, you can query its status by its id.
|
||||
"""
|
||||
|
||||
# create a job with status "create" and save it to the database.
|
||||
job = models.Job(
|
||||
url=payload.url,
|
||||
@@ -73,12 +74,7 @@ def create_job(
|
||||
session.add(job)
|
||||
session.commit()
|
||||
|
||||
# queue an async transcription task.
|
||||
# we use a signature here to allow full separation of
|
||||
# worker processes and dependencies.
|
||||
transcribe = celery.signature("app.worker.main.transcribe")
|
||||
# TODO: catch delivery errors.
|
||||
transcribe.delay(job.id)
|
||||
task_queue.queue_task(job)
|
||||
|
||||
return job
|
||||
|
||||
@@ -87,7 +83,8 @@ def create_job(
|
||||
"/jobs", response_model=list[schemas.Job], summary="Get metadata for all jobs"
|
||||
)
|
||||
def get_transcripts(
|
||||
type: schemas.JobType | None = None, session: Session = Depends(get_session)
|
||||
session: DatabaseSession,
|
||||
type: schemas.JobType | None = None,
|
||||
) -> list[models.Job]:
|
||||
"""Get metadata for all jobs."""
|
||||
query = session.query(models.Job)
|
||||
@@ -101,18 +98,20 @@ def get_transcripts(
|
||||
@api_router.get(
|
||||
"/jobs/{id}",
|
||||
response_model=schemas.Job,
|
||||
responses={404: {"model": DetailResponse, "description": "Not found"}},
|
||||
summary="Get metadata for one job",
|
||||
)
|
||||
def get_transcript(
|
||||
id: UUID = Path(), session: Session = Depends(get_session)
|
||||
session: DatabaseSession,
|
||||
id: UUID = Path(),
|
||||
) -> models.Job | None:
|
||||
"""
|
||||
Use this route to check transcription status of any given job.
|
||||
"""
|
||||
job = session.query(models.Job).filter(models.Job.id == str(id)).one_or_none()
|
||||
|
||||
if not job:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
return job
|
||||
|
||||
|
||||
@@ -122,10 +121,12 @@ def get_transcript(
|
||||
summary="Get all artifacts for one job",
|
||||
)
|
||||
def get_artifacts_for_job(
|
||||
id: UUID = Path(), session: Session = Depends(get_session)
|
||||
session: DatabaseSession,
|
||||
id: UUID = Path(),
|
||||
) -> list[models.Artifact]:
|
||||
"""
|
||||
Right now, there is only one type of artifact (`raw_transcript`).
|
||||
Returns all artifacts for one job.
|
||||
See the type of `data` for possible data types.
|
||||
Returns an empty array for unfinished or non-existant jobs.
|
||||
"""
|
||||
artifacts = (
|
||||
@@ -139,7 +140,8 @@ def get_artifacts_for_job(
|
||||
"/jobs/{id}", status_code=204, summary="Delete a job with all artifacts"
|
||||
)
|
||||
def delete_transcript(
|
||||
id: UUID = Path(), session: Session = Depends(get_session)
|
||||
session: DatabaseSession,
|
||||
id: UUID = Path(),
|
||||
) -> None:
|
||||
"""Remove metadata and artifacts for a single job."""
|
||||
session.query(models.Job).filter(models.Job.id == str(id)).delete()
|
||||
@@ -147,28 +149,3 @@ def delete_transcript(
|
||||
|
||||
|
||||
app.include_router(api_router)
|
||||
|
||||
|
||||
# TODO: we could use `acks_late` to handle this scenario within celery itself.
|
||||
# the reason this does not work well in our case is that `visibility_timeout`
|
||||
# needs to be very high since whisper workers can be long running.
|
||||
# doing this application-side bears the risk of poison pilling the worker though,
|
||||
# implement a workaround with an acceptable trade-off. (=> retry only once?)
|
||||
@app.on_event("startup")
|
||||
def on_startup() -> None:
|
||||
session = get_session().__next__()
|
||||
|
||||
jobs = (
|
||||
session.query(models.Job)
|
||||
.filter(
|
||||
or_(
|
||||
models.Job.status == schemas.JobStatus.processing,
|
||||
models.Job.status == schemas.JobStatus.create,
|
||||
)
|
||||
)
|
||||
.order_by(models.Job.created_at)
|
||||
).all()
|
||||
|
||||
logger.info(f"Requeueing {len(jobs)} jobs.")
|
||||
for job in jobs:
|
||||
queue_task(job)
|
||||
|
||||
Reference in New Issue
Block a user