mirror of
https://github.com/bellingcat/whisperbox-transcribe.git
synced 2026-06-08 03:28:35 +03:00
138 lines
4.0 KiB
Python
138 lines
4.0 KiB
Python
from asyncio.log import logger
|
|
from typing import List, Optional
|
|
from uuid import UUID
|
|
|
|
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path
|
|
from sqlalchemy import or_
|
|
from sqlalchemy.orm import Session
|
|
|
|
import app.shared.db.models as models
|
|
import app.shared.db.schemas as schemas
|
|
from app.shared.celery import get_celery_binding
|
|
from app.shared.db.base import get_session
|
|
from app.web.dtos import DEFAULT_RESPONSES, DetailResponse, PostJobPayload
|
|
from app.web.security import authenticate_api_key
|
|
|
|
app = FastAPI()
|
|
celery = get_celery_binding()
|
|
|
|
|
|
def queue_task(job: models.Job) -> None:
|
|
# queue an async transcription task.
|
|
# we use a signature here to allow full separation of
|
|
# worker processes and dependencies.
|
|
transcribe = celery.signature("app.worker.main.transcribe")
|
|
# TODO: catch delivery errors.
|
|
transcribe.delay(job.id)
|
|
|
|
|
|
api_router = APIRouter(
|
|
prefix="/api/v1",
|
|
dependencies=[Depends(authenticate_api_key)],
|
|
responses={**DEFAULT_RESPONSES},
|
|
)
|
|
|
|
|
|
@api_router.get("/", response_model=None, status_code=204)
|
|
def api_root() -> None:
|
|
return None
|
|
|
|
|
|
@api_router.post("/jobs", response_model=schemas.Job, status_code=201)
|
|
def create_job(
|
|
payload: PostJobPayload,
|
|
session: Session = Depends(get_session),
|
|
) -> models.Job:
|
|
# create a job with status "create" and save it to the database.
|
|
job = models.Job(
|
|
url=payload.url,
|
|
status=schemas.JobStatus.create,
|
|
type=payload.type,
|
|
config={"language": payload.language} if payload.language else None,
|
|
)
|
|
|
|
session.add(job)
|
|
session.commit()
|
|
|
|
# queue an async transcription task.
|
|
# we use a signature here to allow full separation of
|
|
# worker processes and dependencies.
|
|
transcribe = celery.signature("app.worker.main.transcribe")
|
|
# TODO: catch delivery errors.
|
|
transcribe.delay(job.id)
|
|
|
|
return job
|
|
|
|
|
|
@api_router.get("/jobs", response_model=List[schemas.Job])
|
|
def get_transcripts(
|
|
type: Optional[schemas.JobType] = None, session: Session = Depends(get_session)
|
|
) -> List[models.Job]:
|
|
query = session.query(models.Job)
|
|
|
|
if type:
|
|
query = query.filter(models.Job.type == type)
|
|
|
|
return query.all()
|
|
|
|
|
|
@api_router.get(
|
|
"/jobs/{id}",
|
|
response_model=schemas.Job,
|
|
responses={404: {"model": DetailResponse, "description": "Not authenticated"}},
|
|
)
|
|
def get_transcript(
|
|
id: UUID = Path(), session: Session = Depends(get_session)
|
|
) -> Optional[models.Job]:
|
|
job = session.query(models.Job).filter(models.Job.id == str(id)).one_or_none()
|
|
if not job:
|
|
raise HTTPException(status_code=404)
|
|
return job
|
|
|
|
|
|
@api_router.get("/jobs/{id}/artifacts", response_model=List[schemas.Artifact])
|
|
def get_artifacts_for_job(
|
|
id: UUID = Path(), session: Session = Depends(get_session)
|
|
) -> List[models.Artifact]:
|
|
artifacts = (
|
|
session.query(models.Artifact).filter(models.Artifact.job_id == str(id))
|
|
).all()
|
|
|
|
return artifacts
|
|
|
|
|
|
@api_router.delete("/jobs/{id}", status_code=204)
|
|
def delete_transcript(
|
|
id: UUID = Path(), session: Session = Depends(get_session)
|
|
) -> None:
|
|
session.query(models.Job).filter(models.Job.id == str(id)).delete()
|
|
return None
|
|
|
|
|
|
app.include_router(api_router)
|
|
|
|
|
|
# TODO: we could use `acks_late` to handle this scenario within celery itself.
|
|
# the reason this does not work well in our case is that `visibility_timeout`
|
|
# needs to be very high since whisper workers can be long running.
|
|
# doing this application-side bears the risk of poison pilling the worker though,
|
|
# implement a workaround with an acceptable trade-off. (=> retry only once?)
|
|
@app.on_event("startup")
|
|
def on_startup() -> None:
|
|
session = get_session().__next__()
|
|
|
|
jobs = (
|
|
session.query(models.Job)
|
|
.filter(
|
|
or_(
|
|
models.Job.status == schemas.JobStatus.processing,
|
|
models.Job.status == schemas.JobStatus.create,
|
|
)
|
|
)
|
|
.order_by(models.Job.created_at)
|
|
).all()
|
|
|
|
logger.info(f"Requeueing {len(jobs)} jobs.")
|
|
for job in jobs:
|
|
queue_task(job)
|