mirror of
https://github.com/bellingcat/whisperbox-transcribe.git
synced 2026-06-08 03:28:35 +03:00
refactor: remove shared schemas
This commit is contained in:
@@ -1,12 +1,14 @@
|
||||
import enum
|
||||
import uuid
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import JSON, VARCHAR, Column, DateTime, Enum, ForeignKey, String, func
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, declarative_base, declarative_mixin, declared_attr
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
# Enums
|
||||
|
||||
|
||||
@@ -32,7 +34,55 @@ class ArtifactType(str, enum.Enum):
|
||||
language_detection = "language_detection"
|
||||
|
||||
|
||||
# SQLAlchemy models
|
||||
# JSON field types
|
||||
|
||||
|
||||
class JobConfig(BaseModel):
|
||||
"""(JSON) Configuration for a job."""
|
||||
|
||||
language: str | None = Field(
|
||||
description=(
|
||||
"Spoken language in the media file. "
|
||||
"While optional, this can improve output."
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class JobMeta(BaseModel):
|
||||
"""(JSON) Metadata relating to a job's execution."""
|
||||
|
||||
error: str | None = Field(
|
||||
description="Will contain a descriptive error message if processing failed."
|
||||
)
|
||||
|
||||
task_id: uuid.UUID | None = Field(
|
||||
description="Internal celery id of this job submission."
|
||||
)
|
||||
|
||||
|
||||
class RawTranscript(BaseModel):
|
||||
"""(JSON) A single transcript passage returned by whisper."""
|
||||
|
||||
id: int
|
||||
seek: int
|
||||
start: float
|
||||
end: float
|
||||
text: str
|
||||
tokens: list[int]
|
||||
temperature: float
|
||||
avg_logprob: float
|
||||
compression_ratio: float
|
||||
no_speech_prob: float
|
||||
|
||||
|
||||
class LanguageDetection(BaseModel):
|
||||
"""A language detection"""
|
||||
|
||||
language_code: str
|
||||
|
||||
|
||||
# Sum type for all possible artifact data values
|
||||
ArtifactData = list[RawTranscript] | LanguageDetection | None
|
||||
|
||||
|
||||
@declarative_mixin
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import AnyHttpUrl, BaseModel, Field
|
||||
|
||||
from app.shared.db.models import ArtifactType, JobStatus, JobType
|
||||
|
||||
# JSON field types
|
||||
|
||||
|
||||
class JobConfig(BaseModel):
|
||||
"""Configuration for a job."""
|
||||
|
||||
language: str | None = Field(
|
||||
description=(
|
||||
"Spoken language in the media file. "
|
||||
"While optional, this can improve output."
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class JobMeta(BaseModel):
|
||||
"""Metadata relating to a job's execution."""
|
||||
|
||||
error: str | None = Field(
|
||||
description="Will contain a descriptive error message if processing failed."
|
||||
)
|
||||
|
||||
task_id: UUID | None = Field(
|
||||
description="Internal celery id of this job submission."
|
||||
)
|
||||
|
||||
|
||||
class RawTranscript(BaseModel):
|
||||
"""A single transcript passage returned by whisper."""
|
||||
|
||||
id: int
|
||||
seek: int
|
||||
start: float
|
||||
end: float
|
||||
text: str
|
||||
tokens: list[int]
|
||||
temperature: float
|
||||
avg_logprob: float
|
||||
compression_ratio: float
|
||||
no_speech_prob: float
|
||||
|
||||
|
||||
class LanguageDetection(BaseModel):
|
||||
"""A language detection"""
|
||||
|
||||
code: str
|
||||
|
||||
|
||||
# DB objects
|
||||
|
||||
|
||||
class WithDbFields(BaseModel):
|
||||
id: UUID
|
||||
created_at: datetime
|
||||
updated_at: datetime | None
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
|
||||
class Job(WithDbFields):
|
||||
"""A transcription job for one media file."""
|
||||
|
||||
status: JobStatus
|
||||
type: JobType
|
||||
url: AnyHttpUrl
|
||||
meta: JobMeta | None
|
||||
config: JobConfig | None
|
||||
|
||||
|
||||
class Artifact(WithDbFields):
|
||||
job_id: UUID
|
||||
data: LanguageDetection | RawTranscript | None
|
||||
type: ArtifactType
|
||||
@@ -1,25 +1,40 @@
|
||||
from pydantic import AnyHttpUrl, BaseModel, Field
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
import app.shared.db.models as models
|
||||
from pydantic import AnyHttpUrl, BaseModel
|
||||
|
||||
from app.shared.db.models import (
|
||||
ArtifactData,
|
||||
ArtifactType,
|
||||
JobConfig,
|
||||
JobMeta,
|
||||
JobStatus,
|
||||
JobType,
|
||||
)
|
||||
|
||||
# DB objects
|
||||
|
||||
|
||||
class PostJobPayload(BaseModel):
|
||||
url: AnyHttpUrl = Field(
|
||||
description=(
|
||||
"URL where the media file is available. This needs to be a direct link."
|
||||
)
|
||||
)
|
||||
class WithDbFields(BaseModel):
|
||||
id: UUID
|
||||
created_at: datetime
|
||||
updated_at: datetime | None
|
||||
|
||||
type: models.JobType = Field(
|
||||
description="""Type of this job.
|
||||
`transcript` uses the original language of the audio.
|
||||
`translation` creates an automatic translation to english.
|
||||
`language_detection` detects language from the first 30 seconds of audio."""
|
||||
)
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
language: str | None = Field(
|
||||
description=(
|
||||
"Spoken language in the media file. "
|
||||
"While optional, this can improve output when set."
|
||||
)
|
||||
)
|
||||
|
||||
class Job(WithDbFields):
|
||||
"""A transcription job for one media file."""
|
||||
|
||||
status: JobStatus
|
||||
type: JobType
|
||||
url: AnyHttpUrl
|
||||
meta: JobMeta | None
|
||||
config: JobConfig | None
|
||||
|
||||
|
||||
class Artifact(WithDbFields):
|
||||
job_id: UUID
|
||||
data: ArtifactData
|
||||
type: ArtifactType
|
||||
|
||||
114
app/web/main.py
114
app/web/main.py
@@ -3,12 +3,12 @@ from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path
|
||||
from pydantic import AnyHttpUrl, BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import app.shared.db.models as models
|
||||
import app.shared.db.schemas as schemas
|
||||
import app.web.dtos as dtos
|
||||
from app.shared.db.base import SessionLocal, get_session
|
||||
from app.web.dtos import PostJobPayload
|
||||
from app.web.security import authenticate_api_key
|
||||
from app.web.task_queue import task_queue
|
||||
|
||||
@@ -40,54 +40,15 @@ def api_root() -> None:
|
||||
return None
|
||||
|
||||
|
||||
@api_router.post(
|
||||
"/jobs",
|
||||
response_model=schemas.Job,
|
||||
status_code=201,
|
||||
summary="Enqueue a new job",
|
||||
)
|
||||
def create_job(
|
||||
payload: PostJobPayload,
|
||||
session: DatabaseSession,
|
||||
) -> models.Job:
|
||||
"""
|
||||
Enqueue a new whisper job for processing.
|
||||
Notes:
|
||||
* Jobs are processed one-by-one in order of creation.
|
||||
* `payload.url` needs to point directly to a media file.
|
||||
* The media file is downloaded to a tmp file for the duration of processing.
|
||||
enough free space needs to be available on disk.
|
||||
* Media files ideally are audio files with a sampling rate of 16kHz.
|
||||
other files will be transcoded automatically via ffmpeg which might
|
||||
consume considerable resources while active.
|
||||
* Once a job is created, you can query its status by its id.
|
||||
"""
|
||||
|
||||
# create a job with status "create" and save it to the database.
|
||||
job = models.Job(
|
||||
url=payload.url,
|
||||
status=schemas.JobStatus.create,
|
||||
type=payload.type,
|
||||
config={"language": payload.language} if payload.language else None,
|
||||
)
|
||||
|
||||
session.add(job)
|
||||
session.commit()
|
||||
|
||||
task_queue.queue_task(job)
|
||||
|
||||
return job
|
||||
|
||||
|
||||
@api_router.get(
|
||||
"/jobs", response_model=list[schemas.Job], summary="Get metadata for all jobs"
|
||||
"/jobs", response_model=list[dtos.Job], summary="Get metadata for all jobs"
|
||||
)
|
||||
def get_transcripts(
|
||||
session: DatabaseSession,
|
||||
type: schemas.JobType | None = None,
|
||||
type: dtos.JobType | None = None,
|
||||
) -> list[models.Job]:
|
||||
"""Get metadata for all jobs."""
|
||||
query = session.query(models.Job)
|
||||
query = session.query(models.Job).order_by(models.Job.created_at.desc())
|
||||
|
||||
if type:
|
||||
query = query.filter(models.Job.type == type)
|
||||
@@ -97,7 +58,7 @@ def get_transcripts(
|
||||
|
||||
@api_router.get(
|
||||
"/jobs/{id}",
|
||||
response_model=schemas.Job,
|
||||
response_model=dtos.Job,
|
||||
summary="Get metadata for one job",
|
||||
)
|
||||
def get_transcript(
|
||||
@@ -117,7 +78,7 @@ def get_transcript(
|
||||
|
||||
@api_router.get(
|
||||
"/jobs/{id}/artifacts",
|
||||
response_model=list[schemas.Artifact],
|
||||
response_model=list[dtos.Artifact],
|
||||
summary="Get all artifacts for one job",
|
||||
)
|
||||
def get_artifacts_for_job(
|
||||
@@ -148,4 +109,65 @@ def delete_transcript(
|
||||
return None
|
||||
|
||||
|
||||
class PostJobPayload(BaseModel):
|
||||
url: AnyHttpUrl = Field(
|
||||
description=(
|
||||
"URL where the media file is available. This needs to be a direct link."
|
||||
)
|
||||
)
|
||||
|
||||
type: models.JobType = Field(
|
||||
description="""Type of this job.
|
||||
`transcript` uses the original language of the audio.
|
||||
`translation` creates an automatic translation to english.
|
||||
`language_detection` detects language from the first 30 seconds of audio."""
|
||||
)
|
||||
|
||||
language: str | None = Field(
|
||||
description=(
|
||||
"Spoken language in the media file. "
|
||||
"While optional, this can improve output when set."
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@api_router.post(
|
||||
"/jobs",
|
||||
response_model=dtos.Job,
|
||||
status_code=201,
|
||||
summary="Enqueue a new job",
|
||||
)
|
||||
def create_job(
|
||||
payload: PostJobPayload,
|
||||
session: DatabaseSession,
|
||||
) -> models.Job:
|
||||
"""
|
||||
Enqueue a new whisper job for processing.
|
||||
Notes:
|
||||
* Jobs are processed one-by-one in order of creation.
|
||||
* `payload.url` needs to point directly to a media file.
|
||||
* The media file is downloaded to a tmp file for the duration of processing.
|
||||
enough free space needs to be available on disk.
|
||||
* Media files ideally are audio files with a sampling rate of 16kHz.
|
||||
other files will be transcoded automatically via ffmpeg which might
|
||||
consume considerable resources while active.
|
||||
* Once a job is created, you can query its status by its id.
|
||||
"""
|
||||
|
||||
# create a job with status "create" and save it to the database.
|
||||
job = models.Job(
|
||||
url=payload.url,
|
||||
status=dtos.JobStatus.create,
|
||||
type=payload.type,
|
||||
config={"language": payload.language} if payload.language else None,
|
||||
)
|
||||
|
||||
session.add(job)
|
||||
session.commit()
|
||||
|
||||
task_queue.queue_task(job)
|
||||
|
||||
return job
|
||||
|
||||
|
||||
app.include_router(api_router)
|
||||
|
||||
@@ -15,9 +15,10 @@ class TaskQueue:
|
||||
self.celery = get_celery_binding()
|
||||
|
||||
def queue_task(self, job: models.Job):
|
||||
# queue an async transcription task.
|
||||
# we use a signature here to allow full separation of
|
||||
# worker processes and dependencies.
|
||||
"""
|
||||
Queues an async transcription job. We use a celery signature here to
|
||||
allow for full separation of worker processes and dependencies.
|
||||
"""
|
||||
transcribe = self.celery.signature("app.worker.main.transcribe")
|
||||
transcribe.delay(job.id)
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from celery import Task
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import app.shared.db.models as models
|
||||
import app.shared.db.schemas as schemas
|
||||
from app.shared.celery import get_celery_binding
|
||||
from app.shared.db.base import SessionLocal
|
||||
from app.shared.settings import settings
|
||||
@@ -31,7 +30,7 @@ class TranscribeTask(Task):
|
||||
return self.run(*args, **kwargs)
|
||||
|
||||
|
||||
def select_task_processor(task: Task, job: schemas.Job) -> TaskProtocol:
|
||||
def select_task_processor(task: Task, job: models.Job) -> TaskProtocol:
|
||||
if job.type == models.JobType.transcript:
|
||||
return task.strategy.transcribe
|
||||
elif job.type == models.JobType.translation:
|
||||
@@ -51,7 +50,7 @@ def transcribe(self: Task, job_id: UUID) -> None:
|
||||
# runs in a separate thread => requires sqlite's WAL mode to be enabled.
|
||||
db: Session = SessionLocal()
|
||||
|
||||
# unit of work: set task status to processing.
|
||||
# check if passed job should be processed.
|
||||
|
||||
job = db.query(models.Job).filter(models.Job.id == job_id).one_or_none()
|
||||
|
||||
@@ -63,24 +62,23 @@ def transcribe(self: Task, job_id: UUID) -> None:
|
||||
logger.warn("[{job.id}]: job has already been processed, abort.")
|
||||
return
|
||||
|
||||
logger.info(f"[{job.id}]: received eligible job.")
|
||||
logger.debug(f"[{job.id}]: start processing {job.type} job.")
|
||||
|
||||
# unit of work: set task status to processing.
|
||||
|
||||
job.meta = {"task_id": self.request.id}
|
||||
job.status = models.JobStatus.processing
|
||||
|
||||
db.commit()
|
||||
logger.info(f"[{job.id}]: finished setting task to status processing.")
|
||||
|
||||
logger.debug(f"[{job.id}]: finished setting task to {job.status}.")
|
||||
|
||||
# unit of work: process job with whisper.
|
||||
job_record = schemas.Job.from_orm(job)
|
||||
|
||||
processor = select_task_processor(self, job_record)
|
||||
processor = select_task_processor(self, job)
|
||||
|
||||
result_type, result = processor(
|
||||
url=job_record.url, job_id=job_record.id, config=job_record.config
|
||||
)
|
||||
result_type, result = processor(job)
|
||||
|
||||
logger.info(f"[{job.id}]: successfully processed audio.")
|
||||
logger.debug(f"[{job.id}]: successfully processed audio.")
|
||||
|
||||
artifact = models.Artifact(job_id=str(job.id), data=result, type=result_type)
|
||||
|
||||
@@ -89,7 +87,8 @@ def transcribe(self: Task, job_id: UUID) -> None:
|
||||
job.status = models.JobStatus.success
|
||||
|
||||
db.commit()
|
||||
logger.info(f"[{job.id}]: successfully stored artifact.")
|
||||
logger.debug(f"[{job.id}]: successfully stored artifact.")
|
||||
|
||||
except Exception as e:
|
||||
if job and db:
|
||||
db.rollback()
|
||||
|
||||
@@ -3,32 +3,23 @@ from typing import Any, Protocol, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
import app.shared.db.models as models
|
||||
import app.shared.db.schemas as schemas
|
||||
|
||||
TaskReturnValue = Tuple[models.ArtifactType, Any]
|
||||
|
||||
|
||||
class TaskProtocol(Protocol):
|
||||
def __call__(
|
||||
self, url: str, job_id: UUID, config: schemas.JobConfig | None
|
||||
) -> TaskReturnValue:
|
||||
def __call__(self, job: models.Job) -> TaskReturnValue:
|
||||
...
|
||||
|
||||
|
||||
class BaseStrategy(ABC):
|
||||
def transcribe(
|
||||
self, url: str, job_id: UUID, config: schemas.JobConfig | None
|
||||
) -> TaskReturnValue:
|
||||
def transcribe(self, job: models.Job) -> TaskReturnValue:
|
||||
raise NotImplementedError()
|
||||
|
||||
def translate(
|
||||
self, url: str, job_id: UUID, config: schemas.JobConfig | None
|
||||
) -> TaskReturnValue:
|
||||
def translate(self, job: models.Job) -> TaskReturnValue:
|
||||
raise NotImplementedError()
|
||||
|
||||
def detect_language(
|
||||
self, url: str, job_id: UUID, config: schemas.JobConfig | None
|
||||
) -> TaskReturnValue:
|
||||
def detect_language(self, job: models.Job) -> TaskReturnValue:
|
||||
raise NotImplementedError()
|
||||
|
||||
def cleanup(self, job_id: UUID) -> None:
|
||||
|
||||
@@ -9,12 +9,18 @@ import requests
|
||||
import torch
|
||||
import whisper
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import JSON, Column
|
||||
|
||||
import app.shared.db.schemas as schemas
|
||||
import app.shared.db.models as models
|
||||
from app.worker.strategies.base import BaseStrategy, TaskReturnValue
|
||||
|
||||
|
||||
class DecodeOptions(BaseModel):
|
||||
class DecodingOptions(BaseModel):
|
||||
"""
|
||||
Options passed to the whipser model.
|
||||
This mirrors private type `whisper.DecodingOptions`.
|
||||
"""
|
||||
|
||||
language: str | None
|
||||
task: Literal["translate", "transcribe"]
|
||||
|
||||
@@ -40,27 +46,24 @@ class LocalStrategy(BaseStrategy):
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def transcribe(self, url, job_id, config):
|
||||
return (
|
||||
schemas.ArtifactType.raw_transcript,
|
||||
self._run_whisper(
|
||||
self._download(url, job_id), "transcribe", config, job_id
|
||||
),
|
||||
def transcribe(self, job):
|
||||
result = self._run_whisper(
|
||||
self._download(job.url, job.id), "transcribe", job.config, job.id
|
||||
)
|
||||
|
||||
def translate(self, url, job_id, config) -> TaskReturnValue:
|
||||
return (
|
||||
schemas.ArtifactType.raw_transcript,
|
||||
self._run_whisper(
|
||||
self._download(url, job_id),
|
||||
"translate",
|
||||
config,
|
||||
job_id,
|
||||
),
|
||||
)
|
||||
return (models.ArtifactType.raw_transcript, result)
|
||||
|
||||
def detect_language(self, url, job_id, config) -> TaskReturnValue:
|
||||
file = self._download(url, job_id)
|
||||
def translate(self, job) -> TaskReturnValue:
|
||||
result = self._run_whisper(
|
||||
self._download(job.url, job.id),
|
||||
"translate",
|
||||
job.config,
|
||||
job.id,
|
||||
)
|
||||
return (models.ArtifactType.raw_transcript, result)
|
||||
|
||||
def detect_language(self, job) -> TaskReturnValue:
|
||||
file = self._download(job.url, job.id)
|
||||
|
||||
audio = whisper.pad_or_trim(whisper.load_audio(file))
|
||||
|
||||
@@ -68,7 +71,7 @@ class LocalStrategy(BaseStrategy):
|
||||
_, probs = self.model.detect_language(mel)
|
||||
|
||||
return (
|
||||
schemas.ArtifactType.language_detection,
|
||||
models.ArtifactType.language_detection,
|
||||
{"code": max(probs, key=probs.get)},
|
||||
)
|
||||
|
||||
@@ -90,16 +93,18 @@ class LocalStrategy(BaseStrategy):
|
||||
self,
|
||||
filepath: str,
|
||||
task: Literal["translate", "transcribe"],
|
||||
config: schemas.JobConfig | None,
|
||||
config: Column[JSON],
|
||||
job_id: UUID,
|
||||
) -> list[Any]:
|
||||
try:
|
||||
language = config.language if config else None
|
||||
|
||||
result = self.model.transcribe(
|
||||
filepath,
|
||||
# turning this off might make the transcription less accurate,
|
||||
# but significantly reduces amount of model halucinations.
|
||||
condition_on_previous_text=False,
|
||||
**DecodeOptions(task=task, language=language).dict(),
|
||||
**DecodingOptions(
|
||||
task=task, language=config.language if config else None
|
||||
).dict(),
|
||||
)
|
||||
|
||||
return result["segments"]
|
||||
|
||||
Reference in New Issue
Block a user