mirror of
https://github.com/bellingcat/whisperbox-transcribe.git
synced 2026-06-07 19:18:35 +03:00
feat: add language detection task
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
API_SECRET="test_secret"
|
API_SECRET="test_secret"
|
||||||
BROKER_URL="memory://"
|
BROKER_URL="memory://"
|
||||||
DATABASE_URI="sqlite:///memory"
|
DATABASE_URI="sqlite://"
|
||||||
ENVIRONMENT="test"
|
ENVIRONMENT="test"
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
"""add_job_tables
|
"""add_tables
|
||||||
|
|
||||||
Revision ID: dc8582aea0bc
|
Revision ID: 0eee2b7913b7
|
||||||
Revises:
|
Revises:
|
||||||
Create Date: 2023-02-08 12:12:00.808816
|
Create Date: 2023-06-29 08:33:26.123728
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from alembic import op
|
from alembic import op
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision = "dc8582aea0bc"
|
revision = "0eee2b7913b7"
|
||||||
down_revision = None
|
down_revision = None
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
@@ -54,7 +54,7 @@ def upgrade() -> None:
|
|||||||
sa.Column("data", sa.JSON(none_as_null=True), nullable=True),
|
sa.Column("data", sa.JSON(none_as_null=True), nullable=True),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"type",
|
"type",
|
||||||
sa.Enum("raw_transcript", name="artifacttype"),
|
sa.Enum("raw_transcript", "language_detection", name="artifacttype"),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
),
|
),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
@@ -19,11 +19,8 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|||||||
|
|
||||||
|
|
||||||
def get_session() -> Generator[Session, None, None]:
|
def get_session() -> Generator[Session, None, None]:
|
||||||
db: Session = SessionLocal()
|
session: Session = SessionLocal()
|
||||||
try:
|
try:
|
||||||
yield db
|
yield session
|
||||||
db.commit()
|
|
||||||
except Exception:
|
|
||||||
db.rollback()
|
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
session.close()
|
||||||
|
|||||||
@@ -1,13 +1,39 @@
|
|||||||
|
import enum
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from sqlalchemy import JSON, VARCHAR, Column, DateTime, Enum, ForeignKey, String, func
|
from sqlalchemy import JSON, VARCHAR, Column, DateTime, Enum, ForeignKey, String, func
|
||||||
from sqlalchemy.dialects.postgresql import UUID
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
from sqlalchemy.orm import Mapped, declarative_base, declarative_mixin, declared_attr
|
from sqlalchemy.orm import Mapped, declarative_base, declarative_mixin, declared_attr
|
||||||
|
|
||||||
from .schemas import ArtifactType, JobStatus, JobType
|
|
||||||
|
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
|
# Enums
|
||||||
|
|
||||||
|
|
||||||
|
class JobType(str, enum.Enum):
|
||||||
|
"""Requested type of a job."""
|
||||||
|
|
||||||
|
transcript = "transcribe"
|
||||||
|
translation = "translate"
|
||||||
|
language_detection = "detect_language"
|
||||||
|
|
||||||
|
|
||||||
|
class JobStatus(str, enum.Enum):
|
||||||
|
"""Processing status of a job."""
|
||||||
|
|
||||||
|
create = "create"
|
||||||
|
processing = "processing"
|
||||||
|
error = "error"
|
||||||
|
success = "success"
|
||||||
|
|
||||||
|
|
||||||
|
class ArtifactType(str, enum.Enum):
|
||||||
|
raw_transcript = "transcript_raw"
|
||||||
|
language_detection = "language_detection"
|
||||||
|
|
||||||
|
|
||||||
|
# SQLAlchemy models
|
||||||
|
|
||||||
|
|
||||||
@declarative_mixin
|
@declarative_mixin
|
||||||
class WithStandardFields:
|
class WithStandardFields:
|
||||||
|
|||||||
@@ -1,42 +1,16 @@
|
|||||||
import enum
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from pydantic import AnyHttpUrl, BaseModel, Field
|
from pydantic import AnyHttpUrl, BaseModel, Field
|
||||||
|
|
||||||
|
from app.shared.db.models import ArtifactType, JobStatus, JobType
|
||||||
|
|
||||||
class WithDbFields(BaseModel):
|
# JSON field types
|
||||||
id: UUID
|
|
||||||
created_at: datetime
|
|
||||||
updated_at: datetime | None
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
orm_mode = True
|
|
||||||
|
|
||||||
|
|
||||||
class ArtifactType(str, enum.Enum):
|
|
||||||
raw_transcript = "raw_transcript"
|
|
||||||
|
|
||||||
|
|
||||||
class JobType(str, enum.Enum):
|
|
||||||
transcript = "transcript"
|
|
||||||
translation = "translation"
|
|
||||||
language_detection = "language_detection"
|
|
||||||
|
|
||||||
|
|
||||||
class JobStatus(str, enum.Enum):
|
|
||||||
"""Processing status of a job."""
|
|
||||||
|
|
||||||
create = "create"
|
|
||||||
processing = "processing"
|
|
||||||
error = "error"
|
|
||||||
success = "success"
|
|
||||||
|
|
||||||
|
|
||||||
class JobConfig(BaseModel):
|
class JobConfig(BaseModel):
|
||||||
"""Configuration for a job."""
|
"""Configuration for a job."""
|
||||||
|
|
||||||
# TODO: limit to locales selected by whisper.
|
|
||||||
language: str | None = Field(
|
language: str | None = Field(
|
||||||
description=(
|
description=(
|
||||||
"Spoken language in the media file. "
|
"Spoken language in the media file. "
|
||||||
@@ -51,21 +25,12 @@ class JobMeta(BaseModel):
|
|||||||
error: str | None = Field(
|
error: str | None = Field(
|
||||||
description="Will contain a descriptive error message if processing failed."
|
description="Will contain a descriptive error message if processing failed."
|
||||||
)
|
)
|
||||||
|
|
||||||
task_id: UUID | None = Field(
|
task_id: UUID | None = Field(
|
||||||
description="Internal celery id of this job submission."
|
description="Internal celery id of this job submission."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Job(WithDbFields):
|
|
||||||
"""A transcription job for one media file."""
|
|
||||||
|
|
||||||
status: JobStatus
|
|
||||||
type: JobType
|
|
||||||
url: AnyHttpUrl
|
|
||||||
meta: JobMeta | None
|
|
||||||
config: JobConfig | None
|
|
||||||
|
|
||||||
|
|
||||||
class RawTranscript(BaseModel):
|
class RawTranscript(BaseModel):
|
||||||
"""A single transcript passage returned by whisper."""
|
"""A single transcript passage returned by whisper."""
|
||||||
|
|
||||||
@@ -81,9 +46,35 @@ class RawTranscript(BaseModel):
|
|||||||
no_speech_prob: float
|
no_speech_prob: float
|
||||||
|
|
||||||
|
|
||||||
class Artifact(WithDbFields):
|
class LanguageDetection(BaseModel):
|
||||||
"""whisper output for one job."""
|
"""A language detection"""
|
||||||
|
|
||||||
data: list[RawTranscript] | None
|
code: str
|
||||||
|
|
||||||
|
|
||||||
|
# DB objects
|
||||||
|
|
||||||
|
|
||||||
|
class WithDbFields(BaseModel):
|
||||||
|
id: UUID
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime | None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
orm_mode = True
|
||||||
|
|
||||||
|
|
||||||
|
class Job(WithDbFields):
|
||||||
|
"""A transcription job for one media file."""
|
||||||
|
|
||||||
|
status: JobStatus
|
||||||
|
type: JobType
|
||||||
|
url: AnyHttpUrl
|
||||||
|
meta: JobMeta | None
|
||||||
|
config: JobConfig | None
|
||||||
|
|
||||||
|
|
||||||
|
class Artifact(WithDbFields):
|
||||||
job_id: UUID
|
job_id: UUID
|
||||||
|
data: LanguageDetection | RawTranscript | None
|
||||||
type: ArtifactType
|
type: ArtifactType
|
||||||
|
|||||||
@@ -16,8 +16,6 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
|
|
||||||
if "pytest" in sys.modules:
|
if "pytest" in sys.modules:
|
||||||
settings = Settings(
|
settings = Settings(_env_file=".env.test") # type: ignore
|
||||||
_env_file=".env.test", _env_file_encoding="utf-8"
|
|
||||||
) # type: ignore
|
|
||||||
else:
|
else:
|
||||||
settings = Settings() # type: ignore
|
settings = Settings() # type: ignore
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from app.web.main import app
|
|||||||
def pytest_configure() -> None:
|
def pytest_configure() -> None:
|
||||||
if not database_exists(engine.url):
|
if not database_exists(engine.url):
|
||||||
create_database(engine.url)
|
create_database(engine.url)
|
||||||
models.Base.metadata.create_all(engine)
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_unconfigure() -> None:
|
def pytest_unconfigure() -> None:
|
||||||
@@ -21,19 +20,21 @@ def pytest_unconfigure() -> None:
|
|||||||
drop_database(engine.url)
|
drop_database(engine.url)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="auth_headers", scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def auth_header() -> dict[str, str]:
|
def auth_headers() -> dict[str, str]:
|
||||||
return {"Authorization": f"Bearer {settings.API_SECRET}"}
|
return {"Authorization": f"Bearer {settings.API_SECRET}"}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="db_session", scope="function", autouse=True)
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
def db_session() -> Generator[Session, None, None]:
|
def db_session() -> Generator[Session, None, None]:
|
||||||
|
models.Base.metadata.create_all(engine)
|
||||||
|
|
||||||
connection = engine.connect()
|
connection = engine.connect()
|
||||||
transaction = connection.begin()
|
|
||||||
|
|
||||||
with SessionLocal(bind=connection) as session:
|
with SessionLocal(bind=connection) as session:
|
||||||
app.dependency_overrides[get_session] = lambda: session
|
app.dependency_overrides[get_session] = lambda: session
|
||||||
yield session
|
yield session
|
||||||
app.dependency_overrides.clear()
|
app.dependency_overrides.clear()
|
||||||
transaction.rollback()
|
|
||||||
connection.close()
|
connection.close()
|
||||||
|
|
||||||
|
models.Base.metadata.drop_all(bind=engine)
|
||||||
|
|||||||
@@ -3,8 +3,6 @@ from fastapi.testclient import TestClient
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
import app.shared.db.models as models
|
import app.shared.db.models as models
|
||||||
import app.shared.db.schemas as schemas
|
|
||||||
from app.shared.db.schemas import JobStatus, JobType
|
|
||||||
from app.web.main import app
|
from app.web.main import app
|
||||||
|
|
||||||
client = TestClient(app)
|
client = TestClient(app)
|
||||||
@@ -13,7 +11,9 @@ client = TestClient(app)
|
|||||||
@pytest.fixture(name="mock_job", scope="function", autouse=False)
|
@pytest.fixture(name="mock_job", scope="function", autouse=False)
|
||||||
def mock_job(db_session: Session) -> models.Job:
|
def mock_job(db_session: Session) -> models.Job:
|
||||||
job = models.Job(
|
job = models.Job(
|
||||||
url="https://example.com", type=JobType.transcript, status=JobStatus.create
|
url="https://example.com",
|
||||||
|
type=models.JobType.transcript,
|
||||||
|
status=models.JobStatus.create,
|
||||||
)
|
)
|
||||||
|
|
||||||
db_session.add(job)
|
db_session.add(job)
|
||||||
@@ -28,7 +28,7 @@ def test_create_job_pass(auth_headers: dict[str, str]) -> None:
|
|||||||
res = client.post(
|
res = client.post(
|
||||||
"/api/v1/jobs",
|
"/api/v1/jobs",
|
||||||
headers=auth_headers,
|
headers=auth_headers,
|
||||||
json={"url": "https://example.com", "type": JobType.transcript},
|
json={"url": "https://example.com", "type": models.JobType.transcript},
|
||||||
)
|
)
|
||||||
assert res.status_code == 201
|
assert res.status_code == 201
|
||||||
assert isinstance(res.json()["id"], str)
|
assert isinstance(res.json()["id"], str)
|
||||||
@@ -43,7 +43,7 @@ def test_create_job_malformed_url(auth_headers: dict[str, str]) -> None:
|
|||||||
res = client.post(
|
res = client.post(
|
||||||
"/api/v1/jobs",
|
"/api/v1/jobs",
|
||||||
headers=auth_headers,
|
headers=auth_headers,
|
||||||
json={"url": "example.com", "type": JobType.transcript},
|
json={"url": "example.com", "type": models.JobType.transcript},
|
||||||
)
|
)
|
||||||
assert res.status_code == 422
|
assert res.status_code == 422
|
||||||
|
|
||||||
@@ -52,7 +52,7 @@ def test_create_job_malformed_url(auth_headers: dict[str, str]) -> None:
|
|||||||
# ---
|
# ---
|
||||||
def test_get_jobs_pass(auth_headers: dict[str, str], mock_job: models.Job) -> None:
|
def test_get_jobs_pass(auth_headers: dict[str, str], mock_job: models.Job) -> None:
|
||||||
res = client.get(
|
res = client.get(
|
||||||
"/api/v1/jobs?type=transcript",
|
"/api/v1/jobs?type=transcribe",
|
||||||
headers=auth_headers,
|
headers=auth_headers,
|
||||||
)
|
)
|
||||||
assert len(res.json()) == 1
|
assert len(res.json()) == 1
|
||||||
@@ -85,7 +85,7 @@ def test_get_artifacts_pass(
|
|||||||
auth_headers: dict[str, str], db_session: Session, mock_job: models.Job
|
auth_headers: dict[str, str], db_session: Session, mock_job: models.Job
|
||||||
) -> None:
|
) -> None:
|
||||||
artifact = models.Artifact(
|
artifact = models.Artifact(
|
||||||
data=[], job_id=str(mock_job.id), type=schemas.ArtifactType.raw_transcript
|
data=None, job_id=str(mock_job.id), type=models.ArtifactType.raw_transcript
|
||||||
)
|
)
|
||||||
|
|
||||||
db_session.add(artifact)
|
db_session.add(artifact)
|
||||||
|
|||||||
@@ -1,17 +1,6 @@
|
|||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import AnyHttpUrl, BaseModel, Field
|
from pydantic import AnyHttpUrl, BaseModel, Field
|
||||||
|
|
||||||
import app.shared.db.schemas as schemas
|
import app.shared.db.models as models
|
||||||
|
|
||||||
|
|
||||||
class DetailResponse(BaseModel):
|
|
||||||
detail: str
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_RESPONSES: dict[int | str, dict[str, Any]] = {
|
|
||||||
401: {"model": DetailResponse, "description": "Not authenticated"}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class PostJobPayload(BaseModel):
|
class PostJobPayload(BaseModel):
|
||||||
@@ -21,7 +10,7 @@ class PostJobPayload(BaseModel):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
type: schemas.JobType = Field(
|
type: models.JobType = Field(
|
||||||
description="""Type of this job.
|
description="""Type of this job.
|
||||||
`transcript` uses the original language of the audio.
|
`transcript` uses the original language of the audio.
|
||||||
`translation` creates an automatic translation to english.
|
`translation` creates an automatic translation to english.
|
||||||
|
|||||||
@@ -1,37 +1,37 @@
|
|||||||
from asyncio.log import logger
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Annotated
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path
|
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path
|
||||||
from sqlalchemy import or_
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
import app.shared.db.models as models
|
import app.shared.db.models as models
|
||||||
import app.shared.db.schemas as schemas
|
import app.shared.db.schemas as schemas
|
||||||
from app.shared.celery import get_celery_binding
|
from app.shared.db.base import SessionLocal, get_session
|
||||||
from app.shared.db.base import get_session
|
from app.web.dtos import PostJobPayload
|
||||||
from app.web.dtos import DEFAULT_RESPONSES, DetailResponse, PostJobPayload
|
|
||||||
from app.web.security import authenticate_api_key
|
from app.web.security import authenticate_api_key
|
||||||
|
from app.web.task_queue import task_queue
|
||||||
|
|
||||||
|
DatabaseSession = Annotated[Session, Depends(get_session)]
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(_: FastAPI):
|
||||||
|
with SessionLocal() as session:
|
||||||
|
task_queue.rehydrate(session)
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
description="whisperbox-transcribe is an async HTTP wrapper for openai/whisper.",
|
description="whisperbox-transcribe is an async HTTP wrapper for openai/whisper.",
|
||||||
|
lifespan=lifespan,
|
||||||
title="whisperbox-transcribe",
|
title="whisperbox-transcribe",
|
||||||
)
|
)
|
||||||
celery = get_celery_binding()
|
|
||||||
|
|
||||||
|
|
||||||
def queue_task(job: models.Job) -> None:
|
|
||||||
# queue an async transcription task.
|
|
||||||
# we use a signature here to allow full separation of
|
|
||||||
# worker processes and dependencies.
|
|
||||||
transcribe = celery.signature("app.worker.main.transcribe")
|
|
||||||
# TODO: catch delivery errors.
|
|
||||||
transcribe.delay(job.id)
|
|
||||||
|
|
||||||
|
|
||||||
api_router = APIRouter(
|
api_router = APIRouter(
|
||||||
prefix="/api/v1",
|
prefix="/api/v1",
|
||||||
dependencies=[Depends(authenticate_api_key)],
|
dependencies=[Depends(authenticate_api_key)],
|
||||||
responses={**DEFAULT_RESPONSES},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -48,7 +48,7 @@ def api_root() -> None:
|
|||||||
)
|
)
|
||||||
def create_job(
|
def create_job(
|
||||||
payload: PostJobPayload,
|
payload: PostJobPayload,
|
||||||
session: Session = Depends(get_session),
|
session: DatabaseSession,
|
||||||
) -> models.Job:
|
) -> models.Job:
|
||||||
"""
|
"""
|
||||||
Enqueue a new whisper job for processing.
|
Enqueue a new whisper job for processing.
|
||||||
@@ -62,6 +62,7 @@ def create_job(
|
|||||||
consume considerable resources while active.
|
consume considerable resources while active.
|
||||||
* Once a job is created, you can query its status by its id.
|
* Once a job is created, you can query its status by its id.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# create a job with status "create" and save it to the database.
|
# create a job with status "create" and save it to the database.
|
||||||
job = models.Job(
|
job = models.Job(
|
||||||
url=payload.url,
|
url=payload.url,
|
||||||
@@ -73,12 +74,7 @@ def create_job(
|
|||||||
session.add(job)
|
session.add(job)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
# queue an async transcription task.
|
task_queue.queue_task(job)
|
||||||
# we use a signature here to allow full separation of
|
|
||||||
# worker processes and dependencies.
|
|
||||||
transcribe = celery.signature("app.worker.main.transcribe")
|
|
||||||
# TODO: catch delivery errors.
|
|
||||||
transcribe.delay(job.id)
|
|
||||||
|
|
||||||
return job
|
return job
|
||||||
|
|
||||||
@@ -87,7 +83,8 @@ def create_job(
|
|||||||
"/jobs", response_model=list[schemas.Job], summary="Get metadata for all jobs"
|
"/jobs", response_model=list[schemas.Job], summary="Get metadata for all jobs"
|
||||||
)
|
)
|
||||||
def get_transcripts(
|
def get_transcripts(
|
||||||
type: schemas.JobType | None = None, session: Session = Depends(get_session)
|
session: DatabaseSession,
|
||||||
|
type: schemas.JobType | None = None,
|
||||||
) -> list[models.Job]:
|
) -> list[models.Job]:
|
||||||
"""Get metadata for all jobs."""
|
"""Get metadata for all jobs."""
|
||||||
query = session.query(models.Job)
|
query = session.query(models.Job)
|
||||||
@@ -101,18 +98,20 @@ def get_transcripts(
|
|||||||
@api_router.get(
|
@api_router.get(
|
||||||
"/jobs/{id}",
|
"/jobs/{id}",
|
||||||
response_model=schemas.Job,
|
response_model=schemas.Job,
|
||||||
responses={404: {"model": DetailResponse, "description": "Not found"}},
|
|
||||||
summary="Get metadata for one job",
|
summary="Get metadata for one job",
|
||||||
)
|
)
|
||||||
def get_transcript(
|
def get_transcript(
|
||||||
id: UUID = Path(), session: Session = Depends(get_session)
|
session: DatabaseSession,
|
||||||
|
id: UUID = Path(),
|
||||||
) -> models.Job | None:
|
) -> models.Job | None:
|
||||||
"""
|
"""
|
||||||
Use this route to check transcription status of any given job.
|
Use this route to check transcription status of any given job.
|
||||||
"""
|
"""
|
||||||
job = session.query(models.Job).filter(models.Job.id == str(id)).one_or_none()
|
job = session.query(models.Job).filter(models.Job.id == str(id)).one_or_none()
|
||||||
|
|
||||||
if not job:
|
if not job:
|
||||||
raise HTTPException(status_code=404)
|
raise HTTPException(status_code=404)
|
||||||
|
|
||||||
return job
|
return job
|
||||||
|
|
||||||
|
|
||||||
@@ -122,10 +121,12 @@ def get_transcript(
|
|||||||
summary="Get all artifacts for one job",
|
summary="Get all artifacts for one job",
|
||||||
)
|
)
|
||||||
def get_artifacts_for_job(
|
def get_artifacts_for_job(
|
||||||
id: UUID = Path(), session: Session = Depends(get_session)
|
session: DatabaseSession,
|
||||||
|
id: UUID = Path(),
|
||||||
) -> list[models.Artifact]:
|
) -> list[models.Artifact]:
|
||||||
"""
|
"""
|
||||||
Right now, there is only one type of artifact (`raw_transcript`).
|
Returns all artifacts for one job.
|
||||||
|
See the type of `data` for possible data types.
|
||||||
Returns an empty array for unfinished or non-existant jobs.
|
Returns an empty array for unfinished or non-existant jobs.
|
||||||
"""
|
"""
|
||||||
artifacts = (
|
artifacts = (
|
||||||
@@ -139,7 +140,8 @@ def get_artifacts_for_job(
|
|||||||
"/jobs/{id}", status_code=204, summary="Delete a job with all artifacts"
|
"/jobs/{id}", status_code=204, summary="Delete a job with all artifacts"
|
||||||
)
|
)
|
||||||
def delete_transcript(
|
def delete_transcript(
|
||||||
id: UUID = Path(), session: Session = Depends(get_session)
|
session: DatabaseSession,
|
||||||
|
id: UUID = Path(),
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Remove metadata and artifacts for a single job."""
|
"""Remove metadata and artifacts for a single job."""
|
||||||
session.query(models.Job).filter(models.Job.id == str(id)).delete()
|
session.query(models.Job).filter(models.Job.id == str(id)).delete()
|
||||||
@@ -147,28 +149,3 @@ def delete_transcript(
|
|||||||
|
|
||||||
|
|
||||||
app.include_router(api_router)
|
app.include_router(api_router)
|
||||||
|
|
||||||
|
|
||||||
# TODO: we could use `acks_late` to handle this scenario within celery itself.
|
|
||||||
# the reason this does not work well in our case is that `visibility_timeout`
|
|
||||||
# needs to be very high since whisper workers can be long running.
|
|
||||||
# doing this application-side bears the risk of poison pilling the worker though,
|
|
||||||
# implement a workaround with an acceptable trade-off. (=> retry only once?)
|
|
||||||
@app.on_event("startup")
|
|
||||||
def on_startup() -> None:
|
|
||||||
session = get_session().__next__()
|
|
||||||
|
|
||||||
jobs = (
|
|
||||||
session.query(models.Job)
|
|
||||||
.filter(
|
|
||||||
or_(
|
|
||||||
models.Job.status == schemas.JobStatus.processing,
|
|
||||||
models.Job.status == schemas.JobStatus.create,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.order_by(models.Job.created_at)
|
|
||||||
).all()
|
|
||||||
|
|
||||||
logger.info(f"Requeueing {len(jobs)} jobs.")
|
|
||||||
for job in jobs:
|
|
||||||
queue_task(job)
|
|
||||||
|
|||||||
47
app/web/task_queue.py
Normal file
47
app/web/task_queue.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
from asyncio.log import logger
|
||||||
|
|
||||||
|
from celery import Celery
|
||||||
|
from sqlalchemy import or_
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
import app.shared.db.models as models
|
||||||
|
from app.shared.celery import get_celery_binding
|
||||||
|
|
||||||
|
|
||||||
|
class TaskQueue:
|
||||||
|
celery: Celery
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.celery = get_celery_binding()
|
||||||
|
|
||||||
|
def queue_task(self, job: models.Job):
|
||||||
|
# queue an async transcription task.
|
||||||
|
# we use a signature here to allow full separation of
|
||||||
|
# worker processes and dependencies.
|
||||||
|
transcribe = self.celery.signature("app.worker.main.transcribe")
|
||||||
|
transcribe.delay(job.id)
|
||||||
|
|
||||||
|
def rehydrate(self, session: Session):
|
||||||
|
# TODO: we could use `acks_late` to handle this scenario within celery itself.
|
||||||
|
# the reason this does not work well in our case is that `visibility_timeout`
|
||||||
|
# needs to be very high since whisper workers can be long running.
|
||||||
|
# doing this app-side bears the risk of poison pilling the worker though,
|
||||||
|
# implement a workaround with an acceptable trade-off. (=> retry only once?)
|
||||||
|
jobs = (
|
||||||
|
session.query(models.Job)
|
||||||
|
.filter(
|
||||||
|
or_(
|
||||||
|
models.Job.status == models.JobStatus.processing,
|
||||||
|
models.Job.status == models.JobStatus.create,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.order_by(models.Job.created_at)
|
||||||
|
).all()
|
||||||
|
|
||||||
|
logger.info(f"Requeueing {len(jobs)} jobs.")
|
||||||
|
|
||||||
|
for job in jobs:
|
||||||
|
self.queue_task(job)
|
||||||
|
|
||||||
|
|
||||||
|
task_queue = TaskQueue()
|
||||||
@@ -10,6 +10,7 @@ import app.shared.db.schemas as schemas
|
|||||||
from app.shared.celery import get_celery_binding
|
from app.shared.celery import get_celery_binding
|
||||||
from app.shared.db.base import SessionLocal
|
from app.shared.db.base import SessionLocal
|
||||||
from app.shared.settings import settings
|
from app.shared.settings import settings
|
||||||
|
from app.worker.strategies.base import TaskProtocol
|
||||||
from app.worker.strategies.local import LocalStrategy
|
from app.worker.strategies.local import LocalStrategy
|
||||||
|
|
||||||
celery = get_celery_binding()
|
celery = get_celery_binding()
|
||||||
@@ -30,10 +31,10 @@ class TranscribeTask(Task):
|
|||||||
return self.run(*args, **kwargs)
|
return self.run(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def select_strategy(task: Task, job: schemas.Job) -> Any:
|
def select_task_processor(task: Task, job: schemas.Job) -> TaskProtocol:
|
||||||
if job.type == schemas.JobType.transcript:
|
if job.type == models.JobType.transcript:
|
||||||
return task.strategy.transcribe
|
return task.strategy.transcribe
|
||||||
elif job.type == schemas.JobType.translation:
|
elif job.type == models.JobType.translation:
|
||||||
return task.strategy.translate
|
return task.strategy.translate
|
||||||
else:
|
else:
|
||||||
return task.strategy.detect_language
|
return task.strategy.detect_language
|
||||||
@@ -50,49 +51,50 @@ def transcribe(self: Task, job_id: UUID) -> None:
|
|||||||
# runs in a separate thread => requires sqlite's WAL mode to be enabled.
|
# runs in a separate thread => requires sqlite's WAL mode to be enabled.
|
||||||
db: Session = SessionLocal()
|
db: Session = SessionLocal()
|
||||||
|
|
||||||
job = db.query(models.Job).filter(models.Job.id == job_id).one()
|
# unit of work: set task status to processing.
|
||||||
|
|
||||||
if (
|
job = db.query(models.Job).filter(models.Job.id == job_id).one_or_none()
|
||||||
job.status == schemas.JobStatus.error
|
|
||||||
or job.status == schemas.JobStatus.success
|
if job is None:
|
||||||
):
|
logger.warn("[{job.id}]: Received unknown job, abort.")
|
||||||
logger.warn(
|
|
||||||
"[{job.id}]: Received job that has already been processed, abort."
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"[{job.id}]: worker received task.")
|
if job.status in [models.JobStatus.error, models.JobStatus.success]:
|
||||||
|
logger.warn("[{job.id}]: job has already been processed, abort.")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"[{job.id}]: received eligible job.")
|
||||||
|
|
||||||
job.meta = {"task_id": self.request.id}
|
job.meta = {"task_id": self.request.id}
|
||||||
job.status = schemas.JobStatus.processing
|
job.status = models.JobStatus.processing
|
||||||
db.commit()
|
|
||||||
logger.info(f"[{job.id}]: set task to status processing.")
|
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
logger.info(f"[{job.id}]: finished setting task to status processing.")
|
||||||
|
|
||||||
|
# unit of work: process job with whisper.
|
||||||
job_record = schemas.Job.from_orm(job)
|
job_record = schemas.Job.from_orm(job)
|
||||||
|
|
||||||
strategy = select_strategy(self, job_record)
|
processor = select_task_processor(self, job_record)
|
||||||
result = strategy(
|
|
||||||
|
result_type, result = processor(
|
||||||
url=job_record.url, job_id=job_record.id, config=job_record.config
|
url=job_record.url, job_id=job_record.id, config=job_record.config
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"[{job.id}]: successfully processed audio.")
|
logger.info(f"[{job.id}]: successfully processed audio.")
|
||||||
|
|
||||||
artifact = models.Artifact(
|
artifact = models.Artifact(job_id=str(job.id), data=result, type=result_type)
|
||||||
job_id=str(job.id), data=result, type=schemas.ArtifactType.raw_transcript
|
|
||||||
)
|
|
||||||
|
|
||||||
db.add(artifact)
|
db.add(artifact)
|
||||||
|
|
||||||
|
job.status = models.JobStatus.success
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
logger.info(f"[{job.id}]: successfully stored artifact.")
|
logger.info(f"[{job.id}]: successfully stored artifact.")
|
||||||
|
|
||||||
job.status = schemas.JobStatus.success
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
logger.info(f"[{job.id}]: set task to status success.")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if job and db:
|
if job and db:
|
||||||
|
db.rollback()
|
||||||
job.meta = {**job.meta, "error": str(e)} # type: ignore
|
job.meta = {**job.meta, "error": str(e)} # type: ignore
|
||||||
job.status = schemas.JobStatus.error
|
job.status = models.JobStatus.error
|
||||||
db.commit()
|
db.commit()
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
35
app/worker/strategies/base.py
Normal file
35
app/worker/strategies/base.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
from abc import ABC
|
||||||
|
from typing import Any, Protocol, Tuple
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import app.shared.db.models as models
|
||||||
|
import app.shared.db.schemas as schemas
|
||||||
|
|
||||||
|
TaskReturnValue = Tuple[models.ArtifactType, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class TaskProtocol(Protocol):
|
||||||
|
def __call__(
|
||||||
|
self, url: str, job_id: UUID, config: schemas.JobConfig | None
|
||||||
|
) -> TaskReturnValue:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class BaseStrategy(ABC):
|
||||||
|
def transcribe(
|
||||||
|
self, url: str, job_id: UUID, config: schemas.JobConfig | None
|
||||||
|
) -> TaskReturnValue:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def translate(
|
||||||
|
self, url: str, job_id: UUID, config: schemas.JobConfig | None
|
||||||
|
) -> TaskReturnValue:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def detect_language(
|
||||||
|
self, url: str, job_id: UUID, config: schemas.JobConfig | None
|
||||||
|
) -> TaskReturnValue:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def cleanup(self, job_id: UUID) -> None:
|
||||||
|
raise NotImplementedError()
|
||||||
@@ -7,10 +7,11 @@ from uuid import UUID
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
|
import whisper
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from whisper import load_model
|
|
||||||
|
|
||||||
import app.shared.db.schemas as schemas
|
import app.shared.db.schemas as schemas
|
||||||
|
from app.worker.strategies.base import BaseStrategy, TaskReturnValue
|
||||||
|
|
||||||
|
|
||||||
class DecodeOptions(BaseModel):
|
class DecodeOptions(BaseModel):
|
||||||
@@ -18,40 +19,58 @@ class DecodeOptions(BaseModel):
|
|||||||
task: Literal["translate", "transcribe"]
|
task: Literal["translate", "transcribe"]
|
||||||
|
|
||||||
|
|
||||||
class LocalStrategy:
|
class LocalStrategy(BaseStrategy):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
logger.info("initializing GPU model.")
|
logger.info("initializing GPU model.")
|
||||||
self.model = load_model(
|
self.model = whisper.load_model(
|
||||||
os.environ["WHISPER_MODEL"], download_root="/models"
|
os.environ["WHISPER_MODEL"], download_root="/models"
|
||||||
).cuda()
|
).cuda()
|
||||||
else:
|
else:
|
||||||
logger.info("initializing CPU model.")
|
logger.info("initializing CPU model.")
|
||||||
self.model = load_model(
|
self.model = whisper.load_model(
|
||||||
os.environ["WHISPER_MODEL"], download_root="/models"
|
os.environ["WHISPER_MODEL"], download_root="/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("initialized local strategy.")
|
logger.info("initialized local strategy.")
|
||||||
|
|
||||||
def transcribe(
|
def cleanup(self, job_id) -> None:
|
||||||
self, url: str, job_id: UUID, config: schemas.JobConfig | None
|
try:
|
||||||
) -> list[Any]:
|
os.remove(self._get_tmp_file(job_id))
|
||||||
return self.run_whisper(
|
except OSError:
|
||||||
self._download(url, job_id), "transcribe", config, job_id
|
pass
|
||||||
|
|
||||||
|
def transcribe(self, url, job_id, config):
|
||||||
|
return (
|
||||||
|
schemas.ArtifactType.raw_transcript,
|
||||||
|
self._run_whisper(
|
||||||
|
self._download(url, job_id), "transcribe", config, job_id
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def translate(
|
def translate(self, url, job_id, config) -> TaskReturnValue:
|
||||||
self, url: str, job_id: UUID, config: schemas.JobConfig | None
|
return (
|
||||||
) -> list[Any]:
|
schemas.ArtifactType.raw_transcript,
|
||||||
return self.run_whisper(
|
self._run_whisper(
|
||||||
self._download(url, job_id),
|
self._download(url, job_id),
|
||||||
"translate",
|
"translate",
|
||||||
config,
|
config,
|
||||||
job_id,
|
job_id,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def detect_language(self, url: str, config: schemas.JobConfig | None) -> list[Any]:
|
def detect_language(self, url, job_id, config) -> TaskReturnValue:
|
||||||
raise NotImplementedError("detect_language has not been implemented yet.")
|
file = self._download(url, job_id)
|
||||||
|
|
||||||
|
audio = whisper.pad_or_trim(whisper.load_audio(file))
|
||||||
|
|
||||||
|
mel = whisper.log_mel_spectrogram(audio).to(self.model.device)
|
||||||
|
_, probs = self.model.detect_language(mel)
|
||||||
|
|
||||||
|
return (
|
||||||
|
schemas.ArtifactType.language_detection,
|
||||||
|
{"code": max(probs, key=probs.get)},
|
||||||
|
)
|
||||||
|
|
||||||
def _download(self, url: str, job_id: UUID) -> str:
|
def _download(self, url: str, job_id: UUID) -> str:
|
||||||
# re-create folder.
|
# re-create folder.
|
||||||
@@ -67,7 +86,7 @@ class LocalStrategy:
|
|||||||
|
|
||||||
return filename
|
return filename
|
||||||
|
|
||||||
def run_whisper(
|
def _run_whisper(
|
||||||
self,
|
self,
|
||||||
filepath: str,
|
filepath: str,
|
||||||
task: Literal["translate", "transcribe"],
|
task: Literal["translate", "transcribe"],
|
||||||
@@ -90,9 +109,3 @@ class LocalStrategy:
|
|||||||
def _get_tmp_file(self, job_id: UUID) -> str:
|
def _get_tmp_file(self, job_id: UUID) -> str:
|
||||||
tmp = tempfile.gettempdir()
|
tmp = tempfile.gettempdir()
|
||||||
return path.join(tmp, str(job_id))
|
return path.join(tmp, str(job_id))
|
||||||
|
|
||||||
def cleanup(self, job_id: UUID) -> None:
|
|
||||||
try:
|
|
||||||
os.remove(self._get_tmp_file(job_id))
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|||||||
Reference in New Issue
Block a user