Files
whisperbox-transcribe/app/web/main.py
2023-06-29 09:13:11 +02:00

152 lines
3.9 KiB
Python

from contextlib import asynccontextmanager
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path
from sqlalchemy.orm import Session
import app.shared.db.models as models
import app.shared.db.schemas as schemas
from app.shared.db.base import SessionLocal, get_session
from app.web.dtos import PostJobPayload
from app.web.security import authenticate_api_key
from app.web.task_queue import task_queue
DatabaseSession = Annotated[Session, Depends(get_session)]
@asynccontextmanager
async def lifespan(_: FastAPI):
with SessionLocal() as session:
task_queue.rehydrate(session)
yield
app = FastAPI(
description="whisperbox-transcribe is an async HTTP wrapper for openai/whisper.",
lifespan=lifespan,
title="whisperbox-transcribe",
)
api_router = APIRouter(
prefix="/api/v1",
dependencies=[Depends(authenticate_api_key)],
)
@api_router.get("/", response_model=None, status_code=204)
def api_root() -> None:
return None
@api_router.post(
"/jobs",
response_model=schemas.Job,
status_code=201,
summary="Enqueue a new job",
)
def create_job(
payload: PostJobPayload,
session: DatabaseSession,
) -> models.Job:
"""
Enqueue a new whisper job for processing.
Notes:
* Jobs are processed one-by-one in order of creation.
* `payload.url` needs to point directly to a media file.
* The media file is downloaded to a tmp file for the duration of processing.
enough free space needs to be available on disk.
* Media files ideally are audio files with a sampling rate of 16kHz.
other files will be transcoded automatically via ffmpeg which might
consume considerable resources while active.
* Once a job is created, you can query its status by its id.
"""
# create a job with status "create" and save it to the database.
job = models.Job(
url=payload.url,
status=schemas.JobStatus.create,
type=payload.type,
config={"language": payload.language} if payload.language else None,
)
session.add(job)
session.commit()
task_queue.queue_task(job)
return job
@api_router.get(
"/jobs", response_model=list[schemas.Job], summary="Get metadata for all jobs"
)
def get_transcripts(
session: DatabaseSession,
type: schemas.JobType | None = None,
) -> list[models.Job]:
"""Get metadata for all jobs."""
query = session.query(models.Job)
if type:
query = query.filter(models.Job.type == type)
return query.all()
@api_router.get(
"/jobs/{id}",
response_model=schemas.Job,
summary="Get metadata for one job",
)
def get_transcript(
session: DatabaseSession,
id: UUID = Path(),
) -> models.Job | None:
"""
Use this route to check transcription status of any given job.
"""
job = session.query(models.Job).filter(models.Job.id == str(id)).one_or_none()
if not job:
raise HTTPException(status_code=404)
return job
@api_router.get(
"/jobs/{id}/artifacts",
response_model=list[schemas.Artifact],
summary="Get all artifacts for one job",
)
def get_artifacts_for_job(
session: DatabaseSession,
id: UUID = Path(),
) -> list[models.Artifact]:
"""
Returns all artifacts for one job.
See the type of `data` for possible data types.
Returns an empty array for unfinished or non-existant jobs.
"""
artifacts = (
session.query(models.Artifact).filter(models.Artifact.job_id == str(id))
).all()
return artifacts
@api_router.delete(
"/jobs/{id}", status_code=204, summary="Delete a job with all artifacts"
)
def delete_transcript(
session: DatabaseSession,
id: UUID = Path(),
) -> None:
"""Remove metadata and artifacts for a single job."""
session.query(models.Job).filter(models.Job.id == str(id)).delete()
return None
app.include_router(api_router)