feat: add language detection task

This commit is contained in:
Felix Spöttel
2023-06-29 09:13:11 +02:00
parent d2223206be
commit 908bd48170
15 changed files with 267 additions and 191 deletions

View File

@@ -1,4 +1,4 @@
API_SECRET="test_secret" API_SECRET="test_secret"
BROKER_URL="memory://" BROKER_URL="memory://"
DATABASE_URI="sqlite:///memory" DATABASE_URI="sqlite://"
ENVIRONMENT="test" ENVIRONMENT="test"

View File

@@ -1,15 +1,15 @@
"""add_job_tables """add_tables
Revision ID: dc8582aea0bc Revision ID: 0eee2b7913b7
Revises: Revises:
Create Date: 2023-02-08 12:12:00.808816 Create Date: 2023-06-29 08:33:26.123728
""" """
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = "dc8582aea0bc" revision = "0eee2b7913b7"
down_revision = None down_revision = None
branch_labels = None branch_labels = None
depends_on = None depends_on = None
@@ -54,7 +54,7 @@ def upgrade() -> None:
sa.Column("data", sa.JSON(none_as_null=True), nullable=True), sa.Column("data", sa.JSON(none_as_null=True), nullable=True),
sa.Column( sa.Column(
"type", "type",
sa.Enum("raw_transcript", name="artifacttype"), sa.Enum("raw_transcript", "language_detection", name="artifacttype"),
nullable=False, nullable=False,
), ),
sa.Column( sa.Column(

View File

@@ -19,11 +19,8 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def get_session() -> Generator[Session, None, None]: def get_session() -> Generator[Session, None, None]:
db: Session = SessionLocal() session: Session = SessionLocal()
try: try:
yield db yield session
db.commit()
except Exception:
db.rollback()
finally: finally:
db.close() session.close()

View File

@@ -1,13 +1,39 @@
import enum
import uuid import uuid
from sqlalchemy import JSON, VARCHAR, Column, DateTime, Enum, ForeignKey, String, func from sqlalchemy import JSON, VARCHAR, Column, DateTime, Enum, ForeignKey, String, func
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, declarative_base, declarative_mixin, declared_attr from sqlalchemy.orm import Mapped, declarative_base, declarative_mixin, declared_attr
from .schemas import ArtifactType, JobStatus, JobType
Base = declarative_base() Base = declarative_base()
# Enums
class JobType(str, enum.Enum):
"""Requested type of a job."""
transcript = "transcribe"
translation = "translate"
language_detection = "detect_language"
class JobStatus(str, enum.Enum):
"""Processing status of a job."""
create = "create"
processing = "processing"
error = "error"
success = "success"
class ArtifactType(str, enum.Enum):
raw_transcript = "transcript_raw"
language_detection = "language_detection"
# SQLAlchemy models
@declarative_mixin @declarative_mixin
class WithStandardFields: class WithStandardFields:

View File

@@ -1,42 +1,16 @@
import enum
from datetime import datetime from datetime import datetime
from uuid import UUID from uuid import UUID
from pydantic import AnyHttpUrl, BaseModel, Field from pydantic import AnyHttpUrl, BaseModel, Field
from app.shared.db.models import ArtifactType, JobStatus, JobType
class WithDbFields(BaseModel): # JSON field types
id: UUID
created_at: datetime
updated_at: datetime | None
class Config:
orm_mode = True
class ArtifactType(str, enum.Enum):
raw_transcript = "raw_transcript"
class JobType(str, enum.Enum):
transcript = "transcript"
translation = "translation"
language_detection = "language_detection"
class JobStatus(str, enum.Enum):
"""Processing status of a job."""
create = "create"
processing = "processing"
error = "error"
success = "success"
class JobConfig(BaseModel): class JobConfig(BaseModel):
"""Configuration for a job.""" """Configuration for a job."""
# TODO: limit to locales selected by whisper.
language: str | None = Field( language: str | None = Field(
description=( description=(
"Spoken language in the media file. " "Spoken language in the media file. "
@@ -51,21 +25,12 @@ class JobMeta(BaseModel):
error: str | None = Field( error: str | None = Field(
description="Will contain a descriptive error message if processing failed." description="Will contain a descriptive error message if processing failed."
) )
task_id: UUID | None = Field( task_id: UUID | None = Field(
description="Internal celery id of this job submission." description="Internal celery id of this job submission."
) )
class Job(WithDbFields):
"""A transcription job for one media file."""
status: JobStatus
type: JobType
url: AnyHttpUrl
meta: JobMeta | None
config: JobConfig | None
class RawTranscript(BaseModel): class RawTranscript(BaseModel):
"""A single transcript passage returned by whisper.""" """A single transcript passage returned by whisper."""
@@ -81,9 +46,35 @@ class RawTranscript(BaseModel):
no_speech_prob: float no_speech_prob: float
class Artifact(WithDbFields): class LanguageDetection(BaseModel):
"""whisper output for one job.""" """A language detection"""
data: list[RawTranscript] | None code: str
# DB objects
class WithDbFields(BaseModel):
id: UUID
created_at: datetime
updated_at: datetime | None
class Config:
orm_mode = True
class Job(WithDbFields):
"""A transcription job for one media file."""
status: JobStatus
type: JobType
url: AnyHttpUrl
meta: JobMeta | None
config: JobConfig | None
class Artifact(WithDbFields):
job_id: UUID job_id: UUID
data: LanguageDetection | RawTranscript | None
type: ArtifactType type: ArtifactType

View File

@@ -16,8 +16,6 @@ class Settings(BaseSettings):
if "pytest" in sys.modules: if "pytest" in sys.modules:
settings = Settings( settings = Settings(_env_file=".env.test") # type: ignore
_env_file=".env.test", _env_file_encoding="utf-8"
) # type: ignore
else: else:
settings = Settings() # type: ignore settings = Settings() # type: ignore

View File

@@ -13,7 +13,6 @@ from app.web.main import app
def pytest_configure() -> None: def pytest_configure() -> None:
if not database_exists(engine.url): if not database_exists(engine.url):
create_database(engine.url) create_database(engine.url)
models.Base.metadata.create_all(engine)
def pytest_unconfigure() -> None: def pytest_unconfigure() -> None:
@@ -21,19 +20,21 @@ def pytest_unconfigure() -> None:
drop_database(engine.url) drop_database(engine.url)
@pytest.fixture(name="auth_headers", scope="function") @pytest.fixture(scope="function")
def auth_header() -> dict[str, str]: def auth_headers() -> dict[str, str]:
return {"Authorization": f"Bearer {settings.API_SECRET}"} return {"Authorization": f"Bearer {settings.API_SECRET}"}
@pytest.fixture(name="db_session", scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True)
def db_session() -> Generator[Session, None, None]: def db_session() -> Generator[Session, None, None]:
models.Base.metadata.create_all(engine)
connection = engine.connect() connection = engine.connect()
transaction = connection.begin()
with SessionLocal(bind=connection) as session: with SessionLocal(bind=connection) as session:
app.dependency_overrides[get_session] = lambda: session app.dependency_overrides[get_session] = lambda: session
yield session yield session
app.dependency_overrides.clear() app.dependency_overrides.clear()
transaction.rollback()
connection.close() connection.close()
models.Base.metadata.drop_all(bind=engine)

View File

@@ -3,8 +3,6 @@ from fastapi.testclient import TestClient
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
import app.shared.db.models as models import app.shared.db.models as models
import app.shared.db.schemas as schemas
from app.shared.db.schemas import JobStatus, JobType
from app.web.main import app from app.web.main import app
client = TestClient(app) client = TestClient(app)
@@ -13,7 +11,9 @@ client = TestClient(app)
@pytest.fixture(name="mock_job", scope="function", autouse=False) @pytest.fixture(name="mock_job", scope="function", autouse=False)
def mock_job(db_session: Session) -> models.Job: def mock_job(db_session: Session) -> models.Job:
job = models.Job( job = models.Job(
url="https://example.com", type=JobType.transcript, status=JobStatus.create url="https://example.com",
type=models.JobType.transcript,
status=models.JobStatus.create,
) )
db_session.add(job) db_session.add(job)
@@ -28,7 +28,7 @@ def test_create_job_pass(auth_headers: dict[str, str]) -> None:
res = client.post( res = client.post(
"/api/v1/jobs", "/api/v1/jobs",
headers=auth_headers, headers=auth_headers,
json={"url": "https://example.com", "type": JobType.transcript}, json={"url": "https://example.com", "type": models.JobType.transcript},
) )
assert res.status_code == 201 assert res.status_code == 201
assert isinstance(res.json()["id"], str) assert isinstance(res.json()["id"], str)
@@ -43,7 +43,7 @@ def test_create_job_malformed_url(auth_headers: dict[str, str]) -> None:
res = client.post( res = client.post(
"/api/v1/jobs", "/api/v1/jobs",
headers=auth_headers, headers=auth_headers,
json={"url": "example.com", "type": JobType.transcript}, json={"url": "example.com", "type": models.JobType.transcript},
) )
assert res.status_code == 422 assert res.status_code == 422
@@ -52,7 +52,7 @@ def test_create_job_malformed_url(auth_headers: dict[str, str]) -> None:
# --- # ---
def test_get_jobs_pass(auth_headers: dict[str, str], mock_job: models.Job) -> None: def test_get_jobs_pass(auth_headers: dict[str, str], mock_job: models.Job) -> None:
res = client.get( res = client.get(
"/api/v1/jobs?type=transcript", "/api/v1/jobs?type=transcribe",
headers=auth_headers, headers=auth_headers,
) )
assert len(res.json()) == 1 assert len(res.json()) == 1
@@ -85,7 +85,7 @@ def test_get_artifacts_pass(
auth_headers: dict[str, str], db_session: Session, mock_job: models.Job auth_headers: dict[str, str], db_session: Session, mock_job: models.Job
) -> None: ) -> None:
artifact = models.Artifact( artifact = models.Artifact(
data=[], job_id=str(mock_job.id), type=schemas.ArtifactType.raw_transcript data=None, job_id=str(mock_job.id), type=models.ArtifactType.raw_transcript
) )
db_session.add(artifact) db_session.add(artifact)

View File

@@ -1,17 +1,6 @@
from typing import Any
from pydantic import AnyHttpUrl, BaseModel, Field from pydantic import AnyHttpUrl, BaseModel, Field
import app.shared.db.schemas as schemas import app.shared.db.models as models
class DetailResponse(BaseModel):
detail: str
DEFAULT_RESPONSES: dict[int | str, dict[str, Any]] = {
401: {"model": DetailResponse, "description": "Not authenticated"}
}
class PostJobPayload(BaseModel): class PostJobPayload(BaseModel):
@@ -21,7 +10,7 @@ class PostJobPayload(BaseModel):
) )
) )
type: schemas.JobType = Field( type: models.JobType = Field(
description="""Type of this job. description="""Type of this job.
`transcript` uses the original language of the audio. `transcript` uses the original language of the audio.
`translation` creates an automatic translation to english. `translation` creates an automatic translation to english.

View File

@@ -1,37 +1,37 @@
from asyncio.log import logger from contextlib import asynccontextmanager
from typing import Annotated
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path
from sqlalchemy import or_
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
import app.shared.db.models as models import app.shared.db.models as models
import app.shared.db.schemas as schemas import app.shared.db.schemas as schemas
from app.shared.celery import get_celery_binding from app.shared.db.base import SessionLocal, get_session
from app.shared.db.base import get_session from app.web.dtos import PostJobPayload
from app.web.dtos import DEFAULT_RESPONSES, DetailResponse, PostJobPayload
from app.web.security import authenticate_api_key from app.web.security import authenticate_api_key
from app.web.task_queue import task_queue
DatabaseSession = Annotated[Session, Depends(get_session)]
@asynccontextmanager
async def lifespan(_: FastAPI):
with SessionLocal() as session:
task_queue.rehydrate(session)
yield
app = FastAPI( app = FastAPI(
description="whisperbox-transcribe is an async HTTP wrapper for openai/whisper.", description="whisperbox-transcribe is an async HTTP wrapper for openai/whisper.",
lifespan=lifespan,
title="whisperbox-transcribe", title="whisperbox-transcribe",
) )
celery = get_celery_binding()
def queue_task(job: models.Job) -> None:
# queue an async transcription task.
# we use a signature here to allow full separation of
# worker processes and dependencies.
transcribe = celery.signature("app.worker.main.transcribe")
# TODO: catch delivery errors.
transcribe.delay(job.id)
api_router = APIRouter( api_router = APIRouter(
prefix="/api/v1", prefix="/api/v1",
dependencies=[Depends(authenticate_api_key)], dependencies=[Depends(authenticate_api_key)],
responses={**DEFAULT_RESPONSES},
) )
@@ -48,7 +48,7 @@ def api_root() -> None:
) )
def create_job( def create_job(
payload: PostJobPayload, payload: PostJobPayload,
session: Session = Depends(get_session), session: DatabaseSession,
) -> models.Job: ) -> models.Job:
""" """
Enqueue a new whisper job for processing. Enqueue a new whisper job for processing.
@@ -62,6 +62,7 @@ def create_job(
consume considerable resources while active. consume considerable resources while active.
* Once a job is created, you can query its status by its id. * Once a job is created, you can query its status by its id.
""" """
# create a job with status "create" and save it to the database. # create a job with status "create" and save it to the database.
job = models.Job( job = models.Job(
url=payload.url, url=payload.url,
@@ -73,12 +74,7 @@ def create_job(
session.add(job) session.add(job)
session.commit() session.commit()
# queue an async transcription task. task_queue.queue_task(job)
# we use a signature here to allow full separation of
# worker processes and dependencies.
transcribe = celery.signature("app.worker.main.transcribe")
# TODO: catch delivery errors.
transcribe.delay(job.id)
return job return job
@@ -87,7 +83,8 @@ def create_job(
"/jobs", response_model=list[schemas.Job], summary="Get metadata for all jobs" "/jobs", response_model=list[schemas.Job], summary="Get metadata for all jobs"
) )
def get_transcripts( def get_transcripts(
type: schemas.JobType | None = None, session: Session = Depends(get_session) session: DatabaseSession,
type: schemas.JobType | None = None,
) -> list[models.Job]: ) -> list[models.Job]:
"""Get metadata for all jobs.""" """Get metadata for all jobs."""
query = session.query(models.Job) query = session.query(models.Job)
@@ -101,18 +98,20 @@ def get_transcripts(
@api_router.get( @api_router.get(
"/jobs/{id}", "/jobs/{id}",
response_model=schemas.Job, response_model=schemas.Job,
responses={404: {"model": DetailResponse, "description": "Not found"}},
summary="Get metadata for one job", summary="Get metadata for one job",
) )
def get_transcript( def get_transcript(
id: UUID = Path(), session: Session = Depends(get_session) session: DatabaseSession,
id: UUID = Path(),
) -> models.Job | None: ) -> models.Job | None:
""" """
Use this route to check transcription status of any given job. Use this route to check transcription status of any given job.
""" """
job = session.query(models.Job).filter(models.Job.id == str(id)).one_or_none() job = session.query(models.Job).filter(models.Job.id == str(id)).one_or_none()
if not job: if not job:
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
return job return job
@@ -122,10 +121,12 @@ def get_transcript(
summary="Get all artifacts for one job", summary="Get all artifacts for one job",
) )
def get_artifacts_for_job( def get_artifacts_for_job(
id: UUID = Path(), session: Session = Depends(get_session) session: DatabaseSession,
id: UUID = Path(),
) -> list[models.Artifact]: ) -> list[models.Artifact]:
""" """
Right now, there is only one type of artifact (`raw_transcript`). Returns all artifacts for one job.
See the type of `data` for possible data types.
Returns an empty array for unfinished or non-existant jobs. Returns an empty array for unfinished or non-existant jobs.
""" """
artifacts = ( artifacts = (
@@ -139,7 +140,8 @@ def get_artifacts_for_job(
"/jobs/{id}", status_code=204, summary="Delete a job with all artifacts" "/jobs/{id}", status_code=204, summary="Delete a job with all artifacts"
) )
def delete_transcript( def delete_transcript(
id: UUID = Path(), session: Session = Depends(get_session) session: DatabaseSession,
id: UUID = Path(),
) -> None: ) -> None:
"""Remove metadata and artifacts for a single job.""" """Remove metadata and artifacts for a single job."""
session.query(models.Job).filter(models.Job.id == str(id)).delete() session.query(models.Job).filter(models.Job.id == str(id)).delete()
@@ -147,28 +149,3 @@ def delete_transcript(
app.include_router(api_router) app.include_router(api_router)
# TODO: we could use `acks_late` to handle this scenario within celery itself.
# the reason this does not work well in our case is that `visibility_timeout`
# needs to be very high since whisper workers can be long running.
# doing this application-side bears the risk of poison pilling the worker though,
# implement a workaround with an acceptable trade-off. (=> retry only once?)
@app.on_event("startup")
def on_startup() -> None:
session = get_session().__next__()
jobs = (
session.query(models.Job)
.filter(
or_(
models.Job.status == schemas.JobStatus.processing,
models.Job.status == schemas.JobStatus.create,
)
)
.order_by(models.Job.created_at)
).all()
logger.info(f"Requeueing {len(jobs)} jobs.")
for job in jobs:
queue_task(job)

47
app/web/task_queue.py Normal file
View File

@@ -0,0 +1,47 @@
from asyncio.log import logger
from celery import Celery
from sqlalchemy import or_
from sqlalchemy.orm import Session
import app.shared.db.models as models
from app.shared.celery import get_celery_binding
class TaskQueue:
celery: Celery
def __init__(self) -> None:
self.celery = get_celery_binding()
def queue_task(self, job: models.Job):
# queue an async transcription task.
# we use a signature here to allow full separation of
# worker processes and dependencies.
transcribe = self.celery.signature("app.worker.main.transcribe")
transcribe.delay(job.id)
def rehydrate(self, session: Session):
# TODO: we could use `acks_late` to handle this scenario within celery itself.
# the reason this does not work well in our case is that `visibility_timeout`
# needs to be very high since whisper workers can be long running.
# doing this app-side bears the risk of poison pilling the worker though,
# implement a workaround with an acceptable trade-off. (=> retry only once?)
jobs = (
session.query(models.Job)
.filter(
or_(
models.Job.status == models.JobStatus.processing,
models.Job.status == models.JobStatus.create,
)
)
.order_by(models.Job.created_at)
).all()
logger.info(f"Requeueing {len(jobs)} jobs.")
for job in jobs:
self.queue_task(job)
task_queue = TaskQueue()

View File

@@ -10,6 +10,7 @@ import app.shared.db.schemas as schemas
from app.shared.celery import get_celery_binding from app.shared.celery import get_celery_binding
from app.shared.db.base import SessionLocal from app.shared.db.base import SessionLocal
from app.shared.settings import settings from app.shared.settings import settings
from app.worker.strategies.base import TaskProtocol
from app.worker.strategies.local import LocalStrategy from app.worker.strategies.local import LocalStrategy
celery = get_celery_binding() celery = get_celery_binding()
@@ -30,10 +31,10 @@ class TranscribeTask(Task):
return self.run(*args, **kwargs) return self.run(*args, **kwargs)
def select_strategy(task: Task, job: schemas.Job) -> Any: def select_task_processor(task: Task, job: schemas.Job) -> TaskProtocol:
if job.type == schemas.JobType.transcript: if job.type == models.JobType.transcript:
return task.strategy.transcribe return task.strategy.transcribe
elif job.type == schemas.JobType.translation: elif job.type == models.JobType.translation:
return task.strategy.translate return task.strategy.translate
else: else:
return task.strategy.detect_language return task.strategy.detect_language
@@ -50,49 +51,50 @@ def transcribe(self: Task, job_id: UUID) -> None:
# runs in a separate thread => requires sqlite's WAL mode to be enabled. # runs in a separate thread => requires sqlite's WAL mode to be enabled.
db: Session = SessionLocal() db: Session = SessionLocal()
job = db.query(models.Job).filter(models.Job.id == job_id).one() # unit of work: set task status to processing.
if ( job = db.query(models.Job).filter(models.Job.id == job_id).one_or_none()
job.status == schemas.JobStatus.error
or job.status == schemas.JobStatus.success if job is None:
): logger.warn("[{job.id}]: Received unknown job, abort.")
logger.warn(
"[{job.id}]: Received job that has already been processed, abort."
)
return return
logger.info(f"[{job.id}]: worker received task.") if job.status in [models.JobStatus.error, models.JobStatus.success]:
logger.warn("[{job.id}]: job has already been processed, abort.")
return
logger.info(f"[{job.id}]: received eligible job.")
job.meta = {"task_id": self.request.id} job.meta = {"task_id": self.request.id}
job.status = schemas.JobStatus.processing job.status = models.JobStatus.processing
db.commit()
logger.info(f"[{job.id}]: set task to status processing.")
db.commit()
logger.info(f"[{job.id}]: finished setting task to status processing.")
# unit of work: process job with whisper.
job_record = schemas.Job.from_orm(job) job_record = schemas.Job.from_orm(job)
strategy = select_strategy(self, job_record) processor = select_task_processor(self, job_record)
result = strategy(
result_type, result = processor(
url=job_record.url, job_id=job_record.id, config=job_record.config url=job_record.url, job_id=job_record.id, config=job_record.config
) )
logger.info(f"[{job.id}]: successfully processed audio.") logger.info(f"[{job.id}]: successfully processed audio.")
artifact = models.Artifact( artifact = models.Artifact(job_id=str(job.id), data=result, type=result_type)
job_id=str(job.id), data=result, type=schemas.ArtifactType.raw_transcript
)
db.add(artifact) db.add(artifact)
job.status = models.JobStatus.success
db.commit() db.commit()
logger.info(f"[{job.id}]: successfully stored artifact.") logger.info(f"[{job.id}]: successfully stored artifact.")
job.status = schemas.JobStatus.success
db.commit()
logger.info(f"[{job.id}]: set task to status success.")
except Exception as e: except Exception as e:
if job and db: if job and db:
db.rollback()
job.meta = {**job.meta, "error": str(e)} # type: ignore job.meta = {**job.meta, "error": str(e)} # type: ignore
job.status = schemas.JobStatus.error job.status = models.JobStatus.error
db.commit() db.commit()
raise raise
finally: finally:

View File

@@ -0,0 +1,35 @@
from abc import ABC
from typing import Any, Protocol, Tuple
from uuid import UUID
import app.shared.db.models as models
import app.shared.db.schemas as schemas
TaskReturnValue = Tuple[models.ArtifactType, Any]
class TaskProtocol(Protocol):
def __call__(
self, url: str, job_id: UUID, config: schemas.JobConfig | None
) -> TaskReturnValue:
...
class BaseStrategy(ABC):
def transcribe(
self, url: str, job_id: UUID, config: schemas.JobConfig | None
) -> TaskReturnValue:
raise NotImplementedError()
def translate(
self, url: str, job_id: UUID, config: schemas.JobConfig | None
) -> TaskReturnValue:
raise NotImplementedError()
def detect_language(
self, url: str, job_id: UUID, config: schemas.JobConfig | None
) -> TaskReturnValue:
raise NotImplementedError()
def cleanup(self, job_id: UUID) -> None:
raise NotImplementedError()

View File

@@ -7,10 +7,11 @@ from uuid import UUID
import requests import requests
import torch import torch
import whisper
from pydantic import BaseModel from pydantic import BaseModel
from whisper import load_model
import app.shared.db.schemas as schemas import app.shared.db.schemas as schemas
from app.worker.strategies.base import BaseStrategy, TaskReturnValue
class DecodeOptions(BaseModel): class DecodeOptions(BaseModel):
@@ -18,40 +19,58 @@ class DecodeOptions(BaseModel):
task: Literal["translate", "transcribe"] task: Literal["translate", "transcribe"]
class LocalStrategy: class LocalStrategy(BaseStrategy):
def __init__(self) -> None: def __init__(self) -> None:
if torch.cuda.is_available(): if torch.cuda.is_available():
logger.info("initializing GPU model.") logger.info("initializing GPU model.")
self.model = load_model( self.model = whisper.load_model(
os.environ["WHISPER_MODEL"], download_root="/models" os.environ["WHISPER_MODEL"], download_root="/models"
).cuda() ).cuda()
else: else:
logger.info("initializing CPU model.") logger.info("initializing CPU model.")
self.model = load_model( self.model = whisper.load_model(
os.environ["WHISPER_MODEL"], download_root="/models" os.environ["WHISPER_MODEL"], download_root="/models"
) )
logger.info("initialized local strategy.") logger.info("initialized local strategy.")
def transcribe( def cleanup(self, job_id) -> None:
self, url: str, job_id: UUID, config: schemas.JobConfig | None try:
) -> list[Any]: os.remove(self._get_tmp_file(job_id))
return self.run_whisper( except OSError:
self._download(url, job_id), "transcribe", config, job_id pass
def transcribe(self, url, job_id, config):
return (
schemas.ArtifactType.raw_transcript,
self._run_whisper(
self._download(url, job_id), "transcribe", config, job_id
),
) )
def translate( def translate(self, url, job_id, config) -> TaskReturnValue:
self, url: str, job_id: UUID, config: schemas.JobConfig | None return (
) -> list[Any]: schemas.ArtifactType.raw_transcript,
return self.run_whisper( self._run_whisper(
self._download(url, job_id), self._download(url, job_id),
"translate", "translate",
config, config,
job_id, job_id,
),
) )
def detect_language(self, url: str, config: schemas.JobConfig | None) -> list[Any]: def detect_language(self, url, job_id, config) -> TaskReturnValue:
raise NotImplementedError("detect_language has not been implemented yet.") file = self._download(url, job_id)
audio = whisper.pad_or_trim(whisper.load_audio(file))
mel = whisper.log_mel_spectrogram(audio).to(self.model.device)
_, probs = self.model.detect_language(mel)
return (
schemas.ArtifactType.language_detection,
{"code": max(probs, key=probs.get)},
)
def _download(self, url: str, job_id: UUID) -> str: def _download(self, url: str, job_id: UUID) -> str:
# re-create folder. # re-create folder.
@@ -67,7 +86,7 @@ class LocalStrategy:
return filename return filename
def run_whisper( def _run_whisper(
self, self,
filepath: str, filepath: str,
task: Literal["translate", "transcribe"], task: Literal["translate", "transcribe"],
@@ -90,9 +109,3 @@ class LocalStrategy:
def _get_tmp_file(self, job_id: UUID) -> str: def _get_tmp_file(self, job_id: UUID) -> str:
tmp = tempfile.gettempdir() tmp = tempfile.gettempdir()
return path.join(tmp, str(job_id)) return path.join(tmp, str(job_id))
def cleanup(self, job_id: UUID) -> None:
try:
os.remove(self._get_tmp_file(job_id))
except OSError:
pass

View File

@@ -1,4 +1,4 @@
[mypy] [mypy]
plugins = sqlalchemy.ext.mypy.plugin plugins = sqlalchemy.ext.mypy.plugin
ignore_missing_imports = True ignore_missing_imports = True
disallow_untyped_defs = True disallow_untyped_defs = False