mirror of
https://github.com/bellingcat/whisperbox-transcribe.git
synced 2026-06-13 14:08:34 +03:00
refactor: remove shared schemas
This commit is contained in:
@@ -1,25 +1,40 @@
|
||||
from pydantic import AnyHttpUrl, BaseModel, Field
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
import app.shared.db.models as models
|
||||
from pydantic import AnyHttpUrl, BaseModel
|
||||
|
||||
from app.shared.db.models import (
|
||||
ArtifactData,
|
||||
ArtifactType,
|
||||
JobConfig,
|
||||
JobMeta,
|
||||
JobStatus,
|
||||
JobType,
|
||||
)
|
||||
|
||||
# DB objects
|
||||
|
||||
|
||||
class PostJobPayload(BaseModel):
|
||||
url: AnyHttpUrl = Field(
|
||||
description=(
|
||||
"URL where the media file is available. This needs to be a direct link."
|
||||
)
|
||||
)
|
||||
class WithDbFields(BaseModel):
|
||||
id: UUID
|
||||
created_at: datetime
|
||||
updated_at: datetime | None
|
||||
|
||||
type: models.JobType = Field(
|
||||
description="""Type of this job.
|
||||
`transcript` uses the original language of the audio.
|
||||
`translation` creates an automatic translation to english.
|
||||
`language_detection` detects language from the first 30 seconds of audio."""
|
||||
)
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
language: str | None = Field(
|
||||
description=(
|
||||
"Spoken language in the media file. "
|
||||
"While optional, this can improve output when set."
|
||||
)
|
||||
)
|
||||
|
||||
class Job(WithDbFields):
|
||||
"""A transcription job for one media file."""
|
||||
|
||||
status: JobStatus
|
||||
type: JobType
|
||||
url: AnyHttpUrl
|
||||
meta: JobMeta | None
|
||||
config: JobConfig | None
|
||||
|
||||
|
||||
class Artifact(WithDbFields):
|
||||
job_id: UUID
|
||||
data: ArtifactData
|
||||
type: ArtifactType
|
||||
|
||||
114
app/web/main.py
114
app/web/main.py
@@ -3,12 +3,12 @@ from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path
|
||||
from pydantic import AnyHttpUrl, BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import app.shared.db.models as models
|
||||
import app.shared.db.schemas as schemas
|
||||
import app.web.dtos as dtos
|
||||
from app.shared.db.base import SessionLocal, get_session
|
||||
from app.web.dtos import PostJobPayload
|
||||
from app.web.security import authenticate_api_key
|
||||
from app.web.task_queue import task_queue
|
||||
|
||||
@@ -40,54 +40,15 @@ def api_root() -> None:
|
||||
return None
|
||||
|
||||
|
||||
@api_router.post(
|
||||
"/jobs",
|
||||
response_model=schemas.Job,
|
||||
status_code=201,
|
||||
summary="Enqueue a new job",
|
||||
)
|
||||
def create_job(
|
||||
payload: PostJobPayload,
|
||||
session: DatabaseSession,
|
||||
) -> models.Job:
|
||||
"""
|
||||
Enqueue a new whisper job for processing.
|
||||
Notes:
|
||||
* Jobs are processed one-by-one in order of creation.
|
||||
* `payload.url` needs to point directly to a media file.
|
||||
* The media file is downloaded to a tmp file for the duration of processing.
|
||||
enough free space needs to be available on disk.
|
||||
* Media files ideally are audio files with a sampling rate of 16kHz.
|
||||
other files will be transcoded automatically via ffmpeg which might
|
||||
consume considerable resources while active.
|
||||
* Once a job is created, you can query its status by its id.
|
||||
"""
|
||||
|
||||
# create a job with status "create" and save it to the database.
|
||||
job = models.Job(
|
||||
url=payload.url,
|
||||
status=schemas.JobStatus.create,
|
||||
type=payload.type,
|
||||
config={"language": payload.language} if payload.language else None,
|
||||
)
|
||||
|
||||
session.add(job)
|
||||
session.commit()
|
||||
|
||||
task_queue.queue_task(job)
|
||||
|
||||
return job
|
||||
|
||||
|
||||
@api_router.get(
|
||||
"/jobs", response_model=list[schemas.Job], summary="Get metadata for all jobs"
|
||||
"/jobs", response_model=list[dtos.Job], summary="Get metadata for all jobs"
|
||||
)
|
||||
def get_transcripts(
|
||||
session: DatabaseSession,
|
||||
type: schemas.JobType | None = None,
|
||||
type: dtos.JobType | None = None,
|
||||
) -> list[models.Job]:
|
||||
"""Get metadata for all jobs."""
|
||||
query = session.query(models.Job)
|
||||
query = session.query(models.Job).order_by(models.Job.created_at.desc())
|
||||
|
||||
if type:
|
||||
query = query.filter(models.Job.type == type)
|
||||
@@ -97,7 +58,7 @@ def get_transcripts(
|
||||
|
||||
@api_router.get(
|
||||
"/jobs/{id}",
|
||||
response_model=schemas.Job,
|
||||
response_model=dtos.Job,
|
||||
summary="Get metadata for one job",
|
||||
)
|
||||
def get_transcript(
|
||||
@@ -117,7 +78,7 @@ def get_transcript(
|
||||
|
||||
@api_router.get(
|
||||
"/jobs/{id}/artifacts",
|
||||
response_model=list[schemas.Artifact],
|
||||
response_model=list[dtos.Artifact],
|
||||
summary="Get all artifacts for one job",
|
||||
)
|
||||
def get_artifacts_for_job(
|
||||
@@ -148,4 +109,65 @@ def delete_transcript(
|
||||
return None
|
||||
|
||||
|
||||
class PostJobPayload(BaseModel):
|
||||
url: AnyHttpUrl = Field(
|
||||
description=(
|
||||
"URL where the media file is available. This needs to be a direct link."
|
||||
)
|
||||
)
|
||||
|
||||
type: models.JobType = Field(
|
||||
description="""Type of this job.
|
||||
`transcript` uses the original language of the audio.
|
||||
`translation` creates an automatic translation to english.
|
||||
`language_detection` detects language from the first 30 seconds of audio."""
|
||||
)
|
||||
|
||||
language: str | None = Field(
|
||||
description=(
|
||||
"Spoken language in the media file. "
|
||||
"While optional, this can improve output when set."
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@api_router.post(
|
||||
"/jobs",
|
||||
response_model=dtos.Job,
|
||||
status_code=201,
|
||||
summary="Enqueue a new job",
|
||||
)
|
||||
def create_job(
|
||||
payload: PostJobPayload,
|
||||
session: DatabaseSession,
|
||||
) -> models.Job:
|
||||
"""
|
||||
Enqueue a new whisper job for processing.
|
||||
Notes:
|
||||
* Jobs are processed one-by-one in order of creation.
|
||||
* `payload.url` needs to point directly to a media file.
|
||||
* The media file is downloaded to a tmp file for the duration of processing.
|
||||
enough free space needs to be available on disk.
|
||||
* Media files ideally are audio files with a sampling rate of 16kHz.
|
||||
other files will be transcoded automatically via ffmpeg which might
|
||||
consume considerable resources while active.
|
||||
* Once a job is created, you can query its status by its id.
|
||||
"""
|
||||
|
||||
# create a job with status "create" and save it to the database.
|
||||
job = models.Job(
|
||||
url=payload.url,
|
||||
status=dtos.JobStatus.create,
|
||||
type=payload.type,
|
||||
config={"language": payload.language} if payload.language else None,
|
||||
)
|
||||
|
||||
session.add(job)
|
||||
session.commit()
|
||||
|
||||
task_queue.queue_task(job)
|
||||
|
||||
return job
|
||||
|
||||
|
||||
app.include_router(api_router)
|
||||
|
||||
@@ -15,9 +15,10 @@ class TaskQueue:
|
||||
self.celery = get_celery_binding()
|
||||
|
||||
def queue_task(self, job: models.Job):
|
||||
# queue an async transcription task.
|
||||
# we use a signature here to allow full separation of
|
||||
# worker processes and dependencies.
|
||||
"""
|
||||
Queues an async transcription job. We use a celery signature here to
|
||||
allow for full separation of worker processes and dependencies.
|
||||
"""
|
||||
transcribe = self.celery.signature("app.worker.main.transcribe")
|
||||
transcribe.delay(job.id)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user