refactor: remove shared schemas

This commit is contained in:
Felix Spöttel
2023-06-29 12:38:56 +02:00
parent 908bd48170
commit b3f8d5c82a
8 changed files with 204 additions and 201 deletions

View File

@@ -1,12 +1,14 @@
import enum
import uuid
from pydantic import BaseModel, Field
from sqlalchemy import JSON, VARCHAR, Column, DateTime, Enum, ForeignKey, String, func
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, declarative_base, declarative_mixin, declared_attr
Base = declarative_base()
# Enums
@@ -32,7 +34,55 @@ class ArtifactType(str, enum.Enum):
language_detection = "language_detection"
# SQLAlchemy models
# JSON field types
class JobConfig(BaseModel):
"""(JSON) Configuration for a job."""
language: str | None = Field(
description=(
"Spoken language in the media file. "
"While optional, this can improve output."
)
)
class JobMeta(BaseModel):
"""(JSON) Metadata relating to a job's execution."""
error: str | None = Field(
description="Will contain a descriptive error message if processing failed."
)
task_id: uuid.UUID | None = Field(
description="Internal celery id of this job submission."
)
class RawTranscript(BaseModel):
"""(JSON) A single transcript passage returned by whisper."""
id: int
seek: int
start: float
end: float
text: str
tokens: list[int]
temperature: float
avg_logprob: float
compression_ratio: float
no_speech_prob: float
class LanguageDetection(BaseModel):
"""A language detection"""
language_code: str
# Sum type for all possible artifact data values
ArtifactData = list[RawTranscript] | LanguageDetection | None
@declarative_mixin

View File

@@ -1,80 +0,0 @@
from datetime import datetime
from uuid import UUID
from pydantic import AnyHttpUrl, BaseModel, Field
from app.shared.db.models import ArtifactType, JobStatus, JobType
# JSON field types
class JobConfig(BaseModel):
"""Configuration for a job."""
language: str | None = Field(
description=(
"Spoken language in the media file. "
"While optional, this can improve output."
)
)
class JobMeta(BaseModel):
"""Metadata relating to a job's execution."""
error: str | None = Field(
description="Will contain a descriptive error message if processing failed."
)
task_id: UUID | None = Field(
description="Internal celery id of this job submission."
)
class RawTranscript(BaseModel):
"""A single transcript passage returned by whisper."""
id: int
seek: int
start: float
end: float
text: str
tokens: list[int]
temperature: float
avg_logprob: float
compression_ratio: float
no_speech_prob: float
class LanguageDetection(BaseModel):
"""A language detection"""
code: str
# DB objects
class WithDbFields(BaseModel):
id: UUID
created_at: datetime
updated_at: datetime | None
class Config:
orm_mode = True
class Job(WithDbFields):
"""A transcription job for one media file."""
status: JobStatus
type: JobType
url: AnyHttpUrl
meta: JobMeta | None
config: JobConfig | None
class Artifact(WithDbFields):
job_id: UUID
data: LanguageDetection | RawTranscript | None
type: ArtifactType

View File

@@ -1,25 +1,40 @@
from pydantic import AnyHttpUrl, BaseModel, Field
from datetime import datetime
from uuid import UUID
import app.shared.db.models as models
from pydantic import AnyHttpUrl, BaseModel
from app.shared.db.models import (
ArtifactData,
ArtifactType,
JobConfig,
JobMeta,
JobStatus,
JobType,
)
# DB objects
class PostJobPayload(BaseModel):
url: AnyHttpUrl = Field(
description=(
"URL where the media file is available. This needs to be a direct link."
)
)
class WithDbFields(BaseModel):
id: UUID
created_at: datetime
updated_at: datetime | None
type: models.JobType = Field(
description="""Type of this job.
`transcript` uses the original language of the audio.
`translation` creates an automatic translation to english.
`language_detection` detects language from the first 30 seconds of audio."""
)
class Config:
orm_mode = True
language: str | None = Field(
description=(
"Spoken language in the media file. "
"While optional, this can improve output when set."
)
)
class Job(WithDbFields):
"""A transcription job for one media file."""
status: JobStatus
type: JobType
url: AnyHttpUrl
meta: JobMeta | None
config: JobConfig | None
class Artifact(WithDbFields):
job_id: UUID
data: ArtifactData
type: ArtifactType

View File

@@ -3,12 +3,12 @@ from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path
from pydantic import AnyHttpUrl, BaseModel, Field
from sqlalchemy.orm import Session
import app.shared.db.models as models
import app.shared.db.schemas as schemas
import app.web.dtos as dtos
from app.shared.db.base import SessionLocal, get_session
from app.web.dtos import PostJobPayload
from app.web.security import authenticate_api_key
from app.web.task_queue import task_queue
@@ -40,54 +40,15 @@ def api_root() -> None:
return None
@api_router.post(
"/jobs",
response_model=schemas.Job,
status_code=201,
summary="Enqueue a new job",
)
def create_job(
payload: PostJobPayload,
session: DatabaseSession,
) -> models.Job:
"""
Enqueue a new whisper job for processing.
Notes:
* Jobs are processed one-by-one in order of creation.
* `payload.url` needs to point directly to a media file.
* The media file is downloaded to a tmp file for the duration of processing.
enough free space needs to be available on disk.
* Media files ideally are audio files with a sampling rate of 16kHz.
other files will be transcoded automatically via ffmpeg which might
consume considerable resources while active.
* Once a job is created, you can query its status by its id.
"""
# create a job with status "create" and save it to the database.
job = models.Job(
url=payload.url,
status=schemas.JobStatus.create,
type=payload.type,
config={"language": payload.language} if payload.language else None,
)
session.add(job)
session.commit()
task_queue.queue_task(job)
return job
@api_router.get(
"/jobs", response_model=list[schemas.Job], summary="Get metadata for all jobs"
"/jobs", response_model=list[dtos.Job], summary="Get metadata for all jobs"
)
def get_transcripts(
session: DatabaseSession,
type: schemas.JobType | None = None,
type: dtos.JobType | None = None,
) -> list[models.Job]:
"""Get metadata for all jobs."""
query = session.query(models.Job)
query = session.query(models.Job).order_by(models.Job.created_at.desc())
if type:
query = query.filter(models.Job.type == type)
@@ -97,7 +58,7 @@ def get_transcripts(
@api_router.get(
"/jobs/{id}",
response_model=schemas.Job,
response_model=dtos.Job,
summary="Get metadata for one job",
)
def get_transcript(
@@ -117,7 +78,7 @@ def get_transcript(
@api_router.get(
"/jobs/{id}/artifacts",
response_model=list[schemas.Artifact],
response_model=list[dtos.Artifact],
summary="Get all artifacts for one job",
)
def get_artifacts_for_job(
@@ -148,4 +109,65 @@ def delete_transcript(
return None
class PostJobPayload(BaseModel):
url: AnyHttpUrl = Field(
description=(
"URL where the media file is available. This needs to be a direct link."
)
)
type: models.JobType = Field(
description="""Type of this job.
`transcript` uses the original language of the audio.
`translation` creates an automatic translation to english.
`language_detection` detects language from the first 30 seconds of audio."""
)
language: str | None = Field(
description=(
"Spoken language in the media file. "
"While optional, this can improve output when set."
)
)
@api_router.post(
"/jobs",
response_model=dtos.Job,
status_code=201,
summary="Enqueue a new job",
)
def create_job(
payload: PostJobPayload,
session: DatabaseSession,
) -> models.Job:
"""
Enqueue a new whisper job for processing.
Notes:
* Jobs are processed one-by-one in order of creation.
* `payload.url` needs to point directly to a media file.
* The media file is downloaded to a tmp file for the duration of processing.
enough free space needs to be available on disk.
* Media files ideally are audio files with a sampling rate of 16kHz.
other files will be transcoded automatically via ffmpeg which might
consume considerable resources while active.
* Once a job is created, you can query its status by its id.
"""
# create a job with status "create" and save it to the database.
job = models.Job(
url=payload.url,
status=dtos.JobStatus.create,
type=payload.type,
config={"language": payload.language} if payload.language else None,
)
session.add(job)
session.commit()
task_queue.queue_task(job)
return job
app.include_router(api_router)

View File

@@ -15,9 +15,10 @@ class TaskQueue:
self.celery = get_celery_binding()
def queue_task(self, job: models.Job):
# queue an async transcription task.
# we use a signature here to allow full separation of
# worker processes and dependencies.
"""
Queues an async transcription job. We use a celery signature here to
allow for full separation of worker processes and dependencies.
"""
transcribe = self.celery.signature("app.worker.main.transcribe")
transcribe.delay(job.id)

View File

@@ -6,7 +6,6 @@ from celery import Task
from sqlalchemy.orm import Session
import app.shared.db.models as models
import app.shared.db.schemas as schemas
from app.shared.celery import get_celery_binding
from app.shared.db.base import SessionLocal
from app.shared.settings import settings
@@ -31,7 +30,7 @@ class TranscribeTask(Task):
return self.run(*args, **kwargs)
def select_task_processor(task: Task, job: schemas.Job) -> TaskProtocol:
def select_task_processor(task: Task, job: models.Job) -> TaskProtocol:
if job.type == models.JobType.transcript:
return task.strategy.transcribe
elif job.type == models.JobType.translation:
@@ -51,7 +50,7 @@ def transcribe(self: Task, job_id: UUID) -> None:
# runs in a separate thread => requires sqlite's WAL mode to be enabled.
db: Session = SessionLocal()
# unit of work: set task status to processing.
# check if passed job should be processed.
job = db.query(models.Job).filter(models.Job.id == job_id).one_or_none()
@@ -63,24 +62,23 @@ def transcribe(self: Task, job_id: UUID) -> None:
logger.warn("[{job.id}]: job has already been processed, abort.")
return
logger.info(f"[{job.id}]: received eligible job.")
logger.debug(f"[{job.id}]: start processing {job.type} job.")
# unit of work: set task status to processing.
job.meta = {"task_id": self.request.id}
job.status = models.JobStatus.processing
db.commit()
logger.info(f"[{job.id}]: finished setting task to status processing.")
logger.debug(f"[{job.id}]: finished setting task to {job.status}.")
# unit of work: process job with whisper.
job_record = schemas.Job.from_orm(job)
processor = select_task_processor(self, job_record)
processor = select_task_processor(self, job)
result_type, result = processor(
url=job_record.url, job_id=job_record.id, config=job_record.config
)
result_type, result = processor(job)
logger.info(f"[{job.id}]: successfully processed audio.")
logger.debug(f"[{job.id}]: successfully processed audio.")
artifact = models.Artifact(job_id=str(job.id), data=result, type=result_type)
@@ -89,7 +87,8 @@ def transcribe(self: Task, job_id: UUID) -> None:
job.status = models.JobStatus.success
db.commit()
logger.info(f"[{job.id}]: successfully stored artifact.")
logger.debug(f"[{job.id}]: successfully stored artifact.")
except Exception as e:
if job and db:
db.rollback()

View File

@@ -3,32 +3,23 @@ from typing import Any, Protocol, Tuple
from uuid import UUID
import app.shared.db.models as models
import app.shared.db.schemas as schemas
TaskReturnValue = Tuple[models.ArtifactType, Any]
class TaskProtocol(Protocol):
def __call__(
self, url: str, job_id: UUID, config: schemas.JobConfig | None
) -> TaskReturnValue:
def __call__(self, job: models.Job) -> TaskReturnValue:
...
class BaseStrategy(ABC):
def transcribe(
self, url: str, job_id: UUID, config: schemas.JobConfig | None
) -> TaskReturnValue:
def transcribe(self, job: models.Job) -> TaskReturnValue:
raise NotImplementedError()
def translate(
self, url: str, job_id: UUID, config: schemas.JobConfig | None
) -> TaskReturnValue:
def translate(self, job: models.Job) -> TaskReturnValue:
raise NotImplementedError()
def detect_language(
self, url: str, job_id: UUID, config: schemas.JobConfig | None
) -> TaskReturnValue:
def detect_language(self, job: models.Job) -> TaskReturnValue:
raise NotImplementedError()
def cleanup(self, job_id: UUID) -> None:

View File

@@ -9,12 +9,18 @@ import requests
import torch
import whisper
from pydantic import BaseModel
from sqlalchemy import JSON, Column
import app.shared.db.schemas as schemas
import app.shared.db.models as models
from app.worker.strategies.base import BaseStrategy, TaskReturnValue
class DecodeOptions(BaseModel):
class DecodingOptions(BaseModel):
"""
Options passed to the whipser model.
This mirrors private type `whisper.DecodingOptions`.
"""
language: str | None
task: Literal["translate", "transcribe"]
@@ -40,27 +46,24 @@ class LocalStrategy(BaseStrategy):
except OSError:
pass
def transcribe(self, url, job_id, config):
return (
schemas.ArtifactType.raw_transcript,
self._run_whisper(
self._download(url, job_id), "transcribe", config, job_id
),
def transcribe(self, job):
result = self._run_whisper(
self._download(job.url, job.id), "transcribe", job.config, job.id
)
def translate(self, url, job_id, config) -> TaskReturnValue:
return (
schemas.ArtifactType.raw_transcript,
self._run_whisper(
self._download(url, job_id),
"translate",
config,
job_id,
),
)
return (models.ArtifactType.raw_transcript, result)
def detect_language(self, url, job_id, config) -> TaskReturnValue:
file = self._download(url, job_id)
def translate(self, job) -> TaskReturnValue:
result = self._run_whisper(
self._download(job.url, job.id),
"translate",
job.config,
job.id,
)
return (models.ArtifactType.raw_transcript, result)
def detect_language(self, job) -> TaskReturnValue:
file = self._download(job.url, job.id)
audio = whisper.pad_or_trim(whisper.load_audio(file))
@@ -68,7 +71,7 @@ class LocalStrategy(BaseStrategy):
_, probs = self.model.detect_language(mel)
return (
schemas.ArtifactType.language_detection,
models.ArtifactType.language_detection,
{"code": max(probs, key=probs.get)},
)
@@ -90,16 +93,18 @@ class LocalStrategy(BaseStrategy):
self,
filepath: str,
task: Literal["translate", "transcribe"],
config: schemas.JobConfig | None,
config: Column[JSON],
job_id: UUID,
) -> list[Any]:
try:
language = config.language if config else None
result = self.model.transcribe(
filepath,
# turning this off might make the transcription less accurate,
# but significantly reduces amount of model halucinations.
condition_on_previous_text=False,
**DecodeOptions(task=task, language=language).dict(),
**DecodingOptions(
task=task, language=config.language if config else None
).dict(),
)
return result["segments"]