feat: add whisper tasks

This commit is contained in:
Felix Spöttel
2023-01-28 12:30:02 +01:00
parent 8669a18110
commit a7ce71ed33
17 changed files with 209 additions and 88 deletions

View File

@@ -1,37 +1,55 @@
from time import sleep
from asyncio.log import logger
from uuid import UUID
from celery import Celery
from sqlalchemy.orm import Session
import app.shared.db.dtos as dtos
import app.shared.db.models as models
from app.shared.config import settings
from app.shared.celery import get_celery_binding
from app.shared.db.base import SessionLocal
from app.worker.strategies.local import LocalStrategy
celery = Celery(__name__)
celery.conf.broker_url = settings.BROKER_URI
celery = get_celery_binding()
def update_job_status(db: Session, job_id: UUID, status: dtos.JobStatus) -> None:
db.begin()
job = db.query(models.Job).filter(models.Job.id == job_id).one()
def update_job_status(db: Session, job: models.Job, status: dtos.JobStatus) -> None:
job.status = status
db.commit()
@celery.task()
def transcribe(job_id: UUID) -> int:
def transcribe(job_id: UUID) -> None:
try:
db: Session = SessionLocal()
update_job_status(db, job_id, dtos.JobStatus.processing)
sleep(60)
update_job_status(db, job_id, dtos.JobStatus.success)
job = db.query(models.Job).filter(models.Job.id == job_id).one()
update_job_status(db, job, dtos.JobStatus.processing)
# pick a transcription strategy.
# currently only `local` is supported.
job_record = dtos.Job.from_orm(job)
strategy = LocalStrategy(
db=db, job_id=job.id, url=job_record.url, config=job_record.config
)
# process selected task.
# currently only `transcribe` is supported.
if job.type == dtos.JobType.transcript:
result = strategy.transcribe()
elif job.type == dtos.JobType.translation:
result = strategy.translate()
else:
result = strategy.detect_language()
artifact = models.Artifact(
job_id=job.id, data=result, type=dtos.ArtifactType.raw_transcript
)
db.add(artifact)
db.commit()
except Exception:
update_job_status(db, job_id, dtos.JobStatus.error)
update_job_status(db, job, dtos.JobStatus.success)
except Exception as e:
logger.error(e)
update_job_status(db, job, dtos.JobStatus.error)
finally:
db.close()
return 0

View File

@@ -0,0 +1,81 @@
import os
import shutil
import tempfile
from os import path
from typing import Any, List, Literal, Optional
from uuid import UUID
import requests
from pydantic import BaseModel
from sqlalchemy.orm import Session
from whisper import load_model
import app.shared.db.dtos as dtos
class DecodeOptions(BaseModel):
language: Optional[str]
task: Literal["translate", "transcribe"]
class LocalStrategy:
def __init__(
self, db: Session, job_id: UUID, url: str, config: Optional[dtos.JobConfig]
):
self.db = db
self.job_id = job_id
self.url = url
self.config = config
def transcribe(self) -> List[Any]:
result = self.run_whisper(self._download(), "transcribe")
self._cleanup()
return result
def translate(self) -> List[Any]:
result = self.run_whisper(self._download(), "translate")
self._cleanup()
return result
def detect_language(self) -> List[Any]:
raise NotImplementedError("detect_language has not been implemented yet.")
def _download(self) -> str:
dirname = self._get_tmp_dir()
filename = path.join(dirname, "media.mp3")
# re-create folder.
shutil.rmtree(dirname, ignore_errors=True)
os.makedirs(dirname)
# stream media to disk.
with requests.get(self.url, stream=True) as r:
r.raise_for_status()
with open(filename, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
return filename
def run_whisper(self, filepath: str, task: str) -> List[Any]:
language = self.config.language if self.config else None
decode_opts = DecodeOptions(task=task, language=language)
model = load_model("small", download_root="/models")
result = model.transcribe(
filepath, condition_on_previous_text=False, **decode_opts.dict()
)
return result["segments"]
def _get_tmp_dir(self) -> str:
return path.join(tempfile.gettempdir(), str(self.job_id))
def _cleanup(self) -> None:
shutil.rmtree(self._get_tmp_dir(), ignore_errors=True)
def _convert(self) -> None:
pass
def _transcribe(self) -> None:
pass