feat: retry lost jobs on startup

This commit is contained in:
Felix Spöttel
2023-01-28 13:41:38 +01:00
parent a7ce71ed33
commit e995b1f2ff
6 changed files with 92 additions and 40 deletions

View File

@@ -1,6 +1,7 @@
from asyncio.log import logger
from uuid import UUID
from celery import Task
from sqlalchemy.orm import Session
import app.shared.db.dtos as dtos
@@ -12,18 +13,24 @@ from app.worker.strategies.local import LocalStrategy
celery = get_celery_binding()
def update_job_status(db: Session, job: models.Job, status: dtos.JobStatus) -> None:
job.status = status
db.commit()
@celery.task()
def transcribe(job_id: UUID) -> None:
@celery.task(
bind=True,
soft_time_limit=2 * 60 * 60 # TODO: make configurable
)
def transcribe(self: Task, job_id: UUID) -> None:
try:
db: Session = SessionLocal()
job = db.query(models.Job).filter(models.Job.id == job_id).one()
update_job_status(db, job, dtos.JobStatus.processing)
if job.status == dtos.JobStatus.error or job.status == dtos.JobStatus.success:
logger.warn("[{job.id}]: Received job that has already been processed, abort.")
return
job.meta = {"task_id": self.request.id}
job.status = dtos.JobStatus.processing
db.commit()
logger.info(f"[{job.id}]: set task to status processing.")
# pick a transcription strategy.
# currently only `local` is supported.
@@ -36,20 +43,30 @@ def transcribe(job_id: UUID) -> None:
# currently only `transcribe` is supported.
if job.type == dtos.JobType.transcript:
result = strategy.transcribe()
logger.info(f"[{job.id}]: successfully transcribed audio.")
elif job.type == dtos.JobType.translation:
result = strategy.translate()
logger.info(f"[{job.id}]: successfully translated audio.")
else:
result = strategy.detect_language()
artifact = models.Artifact(
job_id=job.id, data=result, type=dtos.ArtifactType.raw_transcript
)
db.add(artifact)
db.commit()
logger.info(f"[{job.id}]: successfully stored artifact.")
update_job_status(db, job, dtos.JobStatus.success)
job.status = dtos.JobStatus.success
db.commit()
logger.info(f"[{job.id}]: set task to status success.")
except Exception as e:
logger.error(e)
update_job_status(db, job, dtos.JobStatus.error)
if job and db:
job.meta = { **job.meta, "error": str(e) }
job.status = dtos.JobStatus.error
db.commit()
raise(e)
finally:
db.close()

View File

@@ -1,6 +1,6 @@
import os
import shutil
import tempfile
from asyncio.log import logger
from os import path
from typing import Any, List, Literal, Optional
from uuid import UUID
@@ -26,27 +26,21 @@ class LocalStrategy:
self.job_id = job_id
self.url = url
self.config = config
logger.info(f"[{self.job_id}]: initialized local strategy.")
def transcribe(self) -> List[Any]:
result = self.run_whisper(self._download(), "transcribe")
self._cleanup()
return result
return self.run_whisper(self._download(), "transcribe")
def translate(self) -> List[Any]:
result = self.run_whisper(self._download(), "translate")
self._cleanup()
return result
return self.run_whisper(self._download(), "translate")
def detect_language(self) -> List[Any]:
raise NotImplementedError("detect_language has not been implemented yet.")
def _download(self) -> str:
dirname = self._get_tmp_dir()
filename = path.join(dirname, "media.mp3")
# re-create folder.
shutil.rmtree(dirname, ignore_errors=True)
os.makedirs(dirname)
filename = self._get_tmp_file()
self._cleanup()
# stream media to disk.
with requests.get(self.url, stream=True) as r:
@@ -58,21 +52,29 @@ class LocalStrategy:
return filename
def run_whisper(self, filepath: str, task: str) -> List[Any]:
language = self.config.language if self.config else None
decode_opts = DecodeOptions(task=task, language=language)
model = load_model("small", download_root="/models")
try:
language = self.config.language if self.config else None
model = load_model("small", download_root="/models")
result = model.transcribe(
filepath, condition_on_previous_text=False, **decode_opts.dict()
)
result = model.transcribe(
filepath,
condition_on_previous_text=False,
**DecodeOptions(task=task, language=language).dict(),
)
return result["segments"]
return result["segments"]
finally:
self._cleanup()
def _get_tmp_dir(self) -> str:
return path.join(tempfile.gettempdir(), str(self.job_id))
def _get_tmp_file(self) -> str:
tmp = tempfile.gettempdir()
return path.join(tmp, str(self.job_id))
def _cleanup(self) -> None:
shutil.rmtree(self._get_tmp_dir(), ignore_errors=True)
try:
os.remove(self._get_tmp_file())
except OSError:
pass
def _convert(self) -> None:
pass