feat: add language detection task

This commit is contained in:
Felix Spöttel
2023-06-29 09:13:11 +02:00
parent d2223206be
commit 908bd48170
15 changed files with 267 additions and 191 deletions

View File

@@ -1,4 +1,4 @@
API_SECRET="test_secret"
BROKER_URL="memory://"
DATABASE_URI="sqlite:///memory"
DATABASE_URI="sqlite://"
ENVIRONMENT="test"

View File

@@ -1,15 +1,15 @@
"""add_job_tables
"""add_tables
Revision ID: dc8582aea0bc
Revision ID: 0eee2b7913b7
Revises:
Create Date: 2023-02-08 12:12:00.808816
Create Date: 2023-06-29 08:33:26.123728
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "dc8582aea0bc"
revision = "0eee2b7913b7"
down_revision = None
branch_labels = None
depends_on = None
@@ -54,7 +54,7 @@ def upgrade() -> None:
sa.Column("data", sa.JSON(none_as_null=True), nullable=True),
sa.Column(
"type",
sa.Enum("raw_transcript", name="artifacttype"),
sa.Enum("raw_transcript", "language_detection", name="artifacttype"),
nullable=False,
),
sa.Column(

View File

@@ -19,11 +19,8 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def get_session() -> Generator[Session, None, None]:
db: Session = SessionLocal()
session: Session = SessionLocal()
try:
yield db
db.commit()
except Exception:
db.rollback()
yield session
finally:
db.close()
session.close()

View File

@@ -1,13 +1,39 @@
import enum
import uuid
from sqlalchemy import JSON, VARCHAR, Column, DateTime, Enum, ForeignKey, String, func
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, declarative_base, declarative_mixin, declared_attr
from .schemas import ArtifactType, JobStatus, JobType
Base = declarative_base()
# Enums
class JobType(str, enum.Enum):
"""Requested type of a job."""
transcript = "transcribe"
translation = "translate"
language_detection = "detect_language"
class JobStatus(str, enum.Enum):
"""Processing status of a job."""
create = "create"
processing = "processing"
error = "error"
success = "success"
class ArtifactType(str, enum.Enum):
raw_transcript = "transcript_raw"
language_detection = "language_detection"
# SQLAlchemy models
@declarative_mixin
class WithStandardFields:

View File

@@ -1,42 +1,16 @@
import enum
from datetime import datetime
from uuid import UUID
from pydantic import AnyHttpUrl, BaseModel, Field
from app.shared.db.models import ArtifactType, JobStatus, JobType
class WithDbFields(BaseModel):
id: UUID
created_at: datetime
updated_at: datetime | None
class Config:
orm_mode = True
class ArtifactType(str, enum.Enum):
raw_transcript = "raw_transcript"
class JobType(str, enum.Enum):
transcript = "transcript"
translation = "translation"
language_detection = "language_detection"
class JobStatus(str, enum.Enum):
"""Processing status of a job."""
create = "create"
processing = "processing"
error = "error"
success = "success"
# JSON field types
class JobConfig(BaseModel):
"""Configuration for a job."""
# TODO: limit to locales selected by whisper.
language: str | None = Field(
description=(
"Spoken language in the media file. "
@@ -51,21 +25,12 @@ class JobMeta(BaseModel):
error: str | None = Field(
description="Will contain a descriptive error message if processing failed."
)
task_id: UUID | None = Field(
description="Internal celery id of this job submission."
)
class Job(WithDbFields):
"""A transcription job for one media file."""
status: JobStatus
type: JobType
url: AnyHttpUrl
meta: JobMeta | None
config: JobConfig | None
class RawTranscript(BaseModel):
"""A single transcript passage returned by whisper."""
@@ -81,9 +46,35 @@ class RawTranscript(BaseModel):
no_speech_prob: float
class Artifact(WithDbFields):
"""whisper output for one job."""
class LanguageDetection(BaseModel):
"""A language detection"""
data: list[RawTranscript] | None
code: str
# DB objects
class WithDbFields(BaseModel):
id: UUID
created_at: datetime
updated_at: datetime | None
class Config:
orm_mode = True
class Job(WithDbFields):
"""A transcription job for one media file."""
status: JobStatus
type: JobType
url: AnyHttpUrl
meta: JobMeta | None
config: JobConfig | None
class Artifact(WithDbFields):
job_id: UUID
data: LanguageDetection | RawTranscript | None
type: ArtifactType

View File

@@ -16,8 +16,6 @@ class Settings(BaseSettings):
if "pytest" in sys.modules:
settings = Settings(
_env_file=".env.test", _env_file_encoding="utf-8"
) # type: ignore
settings = Settings(_env_file=".env.test") # type: ignore
else:
settings = Settings() # type: ignore

View File

@@ -13,7 +13,6 @@ from app.web.main import app
def pytest_configure() -> None:
if not database_exists(engine.url):
create_database(engine.url)
models.Base.metadata.create_all(engine)
def pytest_unconfigure() -> None:
@@ -21,19 +20,21 @@ def pytest_unconfigure() -> None:
drop_database(engine.url)
@pytest.fixture(name="auth_headers", scope="function")
def auth_header() -> dict[str, str]:
@pytest.fixture(scope="function")
def auth_headers() -> dict[str, str]:
return {"Authorization": f"Bearer {settings.API_SECRET}"}
@pytest.fixture(name="db_session", scope="function", autouse=True)
@pytest.fixture(scope="function", autouse=True)
def db_session() -> Generator[Session, None, None]:
models.Base.metadata.create_all(engine)
connection = engine.connect()
transaction = connection.begin()
with SessionLocal(bind=connection) as session:
app.dependency_overrides[get_session] = lambda: session
yield session
app.dependency_overrides.clear()
transaction.rollback()
connection.close()
models.Base.metadata.drop_all(bind=engine)

View File

@@ -3,8 +3,6 @@ from fastapi.testclient import TestClient
from sqlalchemy.orm import Session
import app.shared.db.models as models
import app.shared.db.schemas as schemas
from app.shared.db.schemas import JobStatus, JobType
from app.web.main import app
client = TestClient(app)
@@ -13,7 +11,9 @@ client = TestClient(app)
@pytest.fixture(name="mock_job", scope="function", autouse=False)
def mock_job(db_session: Session) -> models.Job:
job = models.Job(
url="https://example.com", type=JobType.transcript, status=JobStatus.create
url="https://example.com",
type=models.JobType.transcript,
status=models.JobStatus.create,
)
db_session.add(job)
@@ -28,7 +28,7 @@ def test_create_job_pass(auth_headers: dict[str, str]) -> None:
res = client.post(
"/api/v1/jobs",
headers=auth_headers,
json={"url": "https://example.com", "type": JobType.transcript},
json={"url": "https://example.com", "type": models.JobType.transcript},
)
assert res.status_code == 201
assert isinstance(res.json()["id"], str)
@@ -43,7 +43,7 @@ def test_create_job_malformed_url(auth_headers: dict[str, str]) -> None:
res = client.post(
"/api/v1/jobs",
headers=auth_headers,
json={"url": "example.com", "type": JobType.transcript},
json={"url": "example.com", "type": models.JobType.transcript},
)
assert res.status_code == 422
@@ -52,7 +52,7 @@ def test_create_job_malformed_url(auth_headers: dict[str, str]) -> None:
# ---
def test_get_jobs_pass(auth_headers: dict[str, str], mock_job: models.Job) -> None:
res = client.get(
"/api/v1/jobs?type=transcript",
"/api/v1/jobs?type=transcribe",
headers=auth_headers,
)
assert len(res.json()) == 1
@@ -85,7 +85,7 @@ def test_get_artifacts_pass(
auth_headers: dict[str, str], db_session: Session, mock_job: models.Job
) -> None:
artifact = models.Artifact(
data=[], job_id=str(mock_job.id), type=schemas.ArtifactType.raw_transcript
data=None, job_id=str(mock_job.id), type=models.ArtifactType.raw_transcript
)
db_session.add(artifact)

View File

@@ -1,17 +1,6 @@
from typing import Any
from pydantic import AnyHttpUrl, BaseModel, Field
import app.shared.db.schemas as schemas
class DetailResponse(BaseModel):
detail: str
DEFAULT_RESPONSES: dict[int | str, dict[str, Any]] = {
401: {"model": DetailResponse, "description": "Not authenticated"}
}
import app.shared.db.models as models
class PostJobPayload(BaseModel):
@@ -21,7 +10,7 @@ class PostJobPayload(BaseModel):
)
)
type: schemas.JobType = Field(
type: models.JobType = Field(
description="""Type of this job.
`transcript` uses the original language of the audio.
`translation` creates an automatic translation to english.

View File

@@ -1,37 +1,37 @@
from asyncio.log import logger
from contextlib import asynccontextmanager
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path
from sqlalchemy import or_
from sqlalchemy.orm import Session
import app.shared.db.models as models
import app.shared.db.schemas as schemas
from app.shared.celery import get_celery_binding
from app.shared.db.base import get_session
from app.web.dtos import DEFAULT_RESPONSES, DetailResponse, PostJobPayload
from app.shared.db.base import SessionLocal, get_session
from app.web.dtos import PostJobPayload
from app.web.security import authenticate_api_key
from app.web.task_queue import task_queue
DatabaseSession = Annotated[Session, Depends(get_session)]
@asynccontextmanager
async def lifespan(_: FastAPI):
with SessionLocal() as session:
task_queue.rehydrate(session)
yield
app = FastAPI(
description="whisperbox-transcribe is an async HTTP wrapper for openai/whisper.",
lifespan=lifespan,
title="whisperbox-transcribe",
)
celery = get_celery_binding()
def queue_task(job: models.Job) -> None:
# queue an async transcription task.
# we use a signature here to allow full separation of
# worker processes and dependencies.
transcribe = celery.signature("app.worker.main.transcribe")
# TODO: catch delivery errors.
transcribe.delay(job.id)
api_router = APIRouter(
prefix="/api/v1",
dependencies=[Depends(authenticate_api_key)],
responses={**DEFAULT_RESPONSES},
)
@@ -48,7 +48,7 @@ def api_root() -> None:
)
def create_job(
payload: PostJobPayload,
session: Session = Depends(get_session),
session: DatabaseSession,
) -> models.Job:
"""
Enqueue a new whisper job for processing.
@@ -62,6 +62,7 @@ def create_job(
consume considerable resources while active.
* Once a job is created, you can query its status by its id.
"""
# create a job with status "create" and save it to the database.
job = models.Job(
url=payload.url,
@@ -73,12 +74,7 @@ def create_job(
session.add(job)
session.commit()
# queue an async transcription task.
# we use a signature here to allow full separation of
# worker processes and dependencies.
transcribe = celery.signature("app.worker.main.transcribe")
# TODO: catch delivery errors.
transcribe.delay(job.id)
task_queue.queue_task(job)
return job
@@ -87,7 +83,8 @@ def create_job(
"/jobs", response_model=list[schemas.Job], summary="Get metadata for all jobs"
)
def get_transcripts(
type: schemas.JobType | None = None, session: Session = Depends(get_session)
session: DatabaseSession,
type: schemas.JobType | None = None,
) -> list[models.Job]:
"""Get metadata for all jobs."""
query = session.query(models.Job)
@@ -101,18 +98,20 @@ def get_transcripts(
@api_router.get(
"/jobs/{id}",
response_model=schemas.Job,
responses={404: {"model": DetailResponse, "description": "Not found"}},
summary="Get metadata for one job",
)
def get_transcript(
id: UUID = Path(), session: Session = Depends(get_session)
session: DatabaseSession,
id: UUID = Path(),
) -> models.Job | None:
"""
Use this route to check transcription status of any given job.
"""
job = session.query(models.Job).filter(models.Job.id == str(id)).one_or_none()
if not job:
raise HTTPException(status_code=404)
return job
@@ -122,10 +121,12 @@ def get_transcript(
summary="Get all artifacts for one job",
)
def get_artifacts_for_job(
id: UUID = Path(), session: Session = Depends(get_session)
session: DatabaseSession,
id: UUID = Path(),
) -> list[models.Artifact]:
"""
Right now, there is only one type of artifact (`raw_transcript`).
Returns all artifacts for one job.
See the type of `data` for possible data types.
Returns an empty array for unfinished or non-existant jobs.
"""
artifacts = (
@@ -139,7 +140,8 @@ def get_artifacts_for_job(
"/jobs/{id}", status_code=204, summary="Delete a job with all artifacts"
)
def delete_transcript(
id: UUID = Path(), session: Session = Depends(get_session)
session: DatabaseSession,
id: UUID = Path(),
) -> None:
"""Remove metadata and artifacts for a single job."""
session.query(models.Job).filter(models.Job.id == str(id)).delete()
@@ -147,28 +149,3 @@ def delete_transcript(
app.include_router(api_router)
# TODO: we could use `acks_late` to handle this scenario within celery itself.
# the reason this does not work well in our case is that `visibility_timeout`
# needs to be very high since whisper workers can be long running.
# doing this application-side bears the risk of poison pilling the worker though,
# implement a workaround with an acceptable trade-off. (=> retry only once?)
@app.on_event("startup")
def on_startup() -> None:
session = get_session().__next__()
jobs = (
session.query(models.Job)
.filter(
or_(
models.Job.status == schemas.JobStatus.processing,
models.Job.status == schemas.JobStatus.create,
)
)
.order_by(models.Job.created_at)
).all()
logger.info(f"Requeueing {len(jobs)} jobs.")
for job in jobs:
queue_task(job)

47
app/web/task_queue.py Normal file
View File

@@ -0,0 +1,47 @@
from asyncio.log import logger
from celery import Celery
from sqlalchemy import or_
from sqlalchemy.orm import Session
import app.shared.db.models as models
from app.shared.celery import get_celery_binding
class TaskQueue:
celery: Celery
def __init__(self) -> None:
self.celery = get_celery_binding()
def queue_task(self, job: models.Job):
# queue an async transcription task.
# we use a signature here to allow full separation of
# worker processes and dependencies.
transcribe = self.celery.signature("app.worker.main.transcribe")
transcribe.delay(job.id)
def rehydrate(self, session: Session):
# TODO: we could use `acks_late` to handle this scenario within celery itself.
# the reason this does not work well in our case is that `visibility_timeout`
# needs to be very high since whisper workers can be long running.
# doing this app-side bears the risk of poison pilling the worker though,
# implement a workaround with an acceptable trade-off. (=> retry only once?)
jobs = (
session.query(models.Job)
.filter(
or_(
models.Job.status == models.JobStatus.processing,
models.Job.status == models.JobStatus.create,
)
)
.order_by(models.Job.created_at)
).all()
logger.info(f"Requeueing {len(jobs)} jobs.")
for job in jobs:
self.queue_task(job)
task_queue = TaskQueue()

View File

@@ -10,6 +10,7 @@ import app.shared.db.schemas as schemas
from app.shared.celery import get_celery_binding
from app.shared.db.base import SessionLocal
from app.shared.settings import settings
from app.worker.strategies.base import TaskProtocol
from app.worker.strategies.local import LocalStrategy
celery = get_celery_binding()
@@ -30,10 +31,10 @@ class TranscribeTask(Task):
return self.run(*args, **kwargs)
def select_strategy(task: Task, job: schemas.Job) -> Any:
if job.type == schemas.JobType.transcript:
def select_task_processor(task: Task, job: schemas.Job) -> TaskProtocol:
if job.type == models.JobType.transcript:
return task.strategy.transcribe
elif job.type == schemas.JobType.translation:
elif job.type == models.JobType.translation:
return task.strategy.translate
else:
return task.strategy.detect_language
@@ -50,49 +51,50 @@ def transcribe(self: Task, job_id: UUID) -> None:
# runs in a separate thread => requires sqlite's WAL mode to be enabled.
db: Session = SessionLocal()
job = db.query(models.Job).filter(models.Job.id == job_id).one()
# unit of work: set task status to processing.
if (
job.status == schemas.JobStatus.error
or job.status == schemas.JobStatus.success
):
logger.warn(
"[{job.id}]: Received job that has already been processed, abort."
)
job = db.query(models.Job).filter(models.Job.id == job_id).one_or_none()
if job is None:
logger.warn("[{job.id}]: Received unknown job, abort.")
return
logger.info(f"[{job.id}]: worker received task.")
if job.status in [models.JobStatus.error, models.JobStatus.success]:
logger.warn("[{job.id}]: job has already been processed, abort.")
return
logger.info(f"[{job.id}]: received eligible job.")
job.meta = {"task_id": self.request.id}
job.status = schemas.JobStatus.processing
db.commit()
logger.info(f"[{job.id}]: set task to status processing.")
job.status = models.JobStatus.processing
db.commit()
logger.info(f"[{job.id}]: finished setting task to status processing.")
# unit of work: process job with whisper.
job_record = schemas.Job.from_orm(job)
strategy = select_strategy(self, job_record)
result = strategy(
processor = select_task_processor(self, job_record)
result_type, result = processor(
url=job_record.url, job_id=job_record.id, config=job_record.config
)
logger.info(f"[{job.id}]: successfully processed audio.")
artifact = models.Artifact(
job_id=str(job.id), data=result, type=schemas.ArtifactType.raw_transcript
)
artifact = models.Artifact(job_id=str(job.id), data=result, type=result_type)
db.add(artifact)
job.status = models.JobStatus.success
db.commit()
logger.info(f"[{job.id}]: successfully stored artifact.")
job.status = schemas.JobStatus.success
db.commit()
logger.info(f"[{job.id}]: set task to status success.")
except Exception as e:
if job and db:
db.rollback()
job.meta = {**job.meta, "error": str(e)} # type: ignore
job.status = schemas.JobStatus.error
job.status = models.JobStatus.error
db.commit()
raise
finally:

View File

@@ -0,0 +1,35 @@
from abc import ABC
from typing import Any, Protocol, Tuple
from uuid import UUID
import app.shared.db.models as models
import app.shared.db.schemas as schemas
TaskReturnValue = Tuple[models.ArtifactType, Any]
class TaskProtocol(Protocol):
def __call__(
self, url: str, job_id: UUID, config: schemas.JobConfig | None
) -> TaskReturnValue:
...
class BaseStrategy(ABC):
def transcribe(
self, url: str, job_id: UUID, config: schemas.JobConfig | None
) -> TaskReturnValue:
raise NotImplementedError()
def translate(
self, url: str, job_id: UUID, config: schemas.JobConfig | None
) -> TaskReturnValue:
raise NotImplementedError()
def detect_language(
self, url: str, job_id: UUID, config: schemas.JobConfig | None
) -> TaskReturnValue:
raise NotImplementedError()
def cleanup(self, job_id: UUID) -> None:
raise NotImplementedError()

View File

@@ -7,10 +7,11 @@ from uuid import UUID
import requests
import torch
import whisper
from pydantic import BaseModel
from whisper import load_model
import app.shared.db.schemas as schemas
from app.worker.strategies.base import BaseStrategy, TaskReturnValue
class DecodeOptions(BaseModel):
@@ -18,40 +19,58 @@ class DecodeOptions(BaseModel):
task: Literal["translate", "transcribe"]
class LocalStrategy:
class LocalStrategy(BaseStrategy):
def __init__(self) -> None:
if torch.cuda.is_available():
logger.info("initializing GPU model.")
self.model = load_model(
self.model = whisper.load_model(
os.environ["WHISPER_MODEL"], download_root="/models"
).cuda()
else:
logger.info("initializing CPU model.")
self.model = load_model(
self.model = whisper.load_model(
os.environ["WHISPER_MODEL"], download_root="/models"
)
logger.info("initialized local strategy.")
def transcribe(
self, url: str, job_id: UUID, config: schemas.JobConfig | None
) -> list[Any]:
return self.run_whisper(
def cleanup(self, job_id) -> None:
try:
os.remove(self._get_tmp_file(job_id))
except OSError:
pass
def transcribe(self, url, job_id, config):
return (
schemas.ArtifactType.raw_transcript,
self._run_whisper(
self._download(url, job_id), "transcribe", config, job_id
),
)
def translate(
self, url: str, job_id: UUID, config: schemas.JobConfig | None
) -> list[Any]:
return self.run_whisper(
def translate(self, url, job_id, config) -> TaskReturnValue:
return (
schemas.ArtifactType.raw_transcript,
self._run_whisper(
self._download(url, job_id),
"translate",
config,
job_id,
),
)
def detect_language(self, url: str, config: schemas.JobConfig | None) -> list[Any]:
raise NotImplementedError("detect_language has not been implemented yet.")
def detect_language(self, url, job_id, config) -> TaskReturnValue:
file = self._download(url, job_id)
audio = whisper.pad_or_trim(whisper.load_audio(file))
mel = whisper.log_mel_spectrogram(audio).to(self.model.device)
_, probs = self.model.detect_language(mel)
return (
schemas.ArtifactType.language_detection,
{"code": max(probs, key=probs.get)},
)
def _download(self, url: str, job_id: UUID) -> str:
# re-create folder.
@@ -67,7 +86,7 @@ class LocalStrategy:
return filename
def run_whisper(
def _run_whisper(
self,
filepath: str,
task: Literal["translate", "transcribe"],
@@ -90,9 +109,3 @@ class LocalStrategy:
def _get_tmp_file(self, job_id: UUID) -> str:
tmp = tempfile.gettempdir()
return path.join(tmp, str(job_id))
def cleanup(self, job_id: UUID) -> None:
try:
os.remove(self._get_tmp_file(job_id))
except OSError:
pass

View File

@@ -1,4 +1,4 @@
[mypy]
plugins = sqlalchemy.ext.mypy.plugin
ignore_missing_imports = True
disallow_untyped_defs = True
disallow_untyped_defs = False