feat: retry lost jobs on startup

This commit is contained in:
Felix Spöttel
2023-01-28 13:41:38 +01:00
parent a7ce71ed33
commit e995b1f2ff
6 changed files with 92 additions and 40 deletions

View File

@@ -4,6 +4,6 @@ from app.shared.config import settings
def get_celery_binding() -> Celery:
celery = Celery("tasks")
celery = Celery()
celery.conf.broker_url = settings.BROKER_URL
return celery

View File

@@ -1,8 +1,10 @@
from asyncio.log import logger
from typing import Dict, List, Optional
from uuid import UUID
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path
from pydantic import AnyHttpUrl, BaseModel
from sqlalchemy import or_
from sqlalchemy.orm import Session
import app.shared.db.dtos as dtos
@@ -17,6 +19,15 @@ celery = get_celery_binding()
api_router = APIRouter(prefix="/api/v1", dependencies=[Depends(authenticate_api_key)])
def queue_task(job: models.Job) -> None:
# queue an async transcription task.
# we use a signature here to allow full separation of
# worker processes and dependencies.
transcribe = celery.signature("app.worker.main.transcribe")
# TODO: catch delivery errors.
transcribe.delay(job.id)
@api_router.get("/")
def api_root() -> Dict:
return {}
@@ -99,3 +110,23 @@ def delete_transcript(
app.include_router(api_router)
# TODO:
# we could use `acks_late` to handle this scenario within celery itself.
# the reason this does not work well in our case is that `visibility_timeout`
# needs to be very high since whisper workers can be long running.
# doing this application-side bears the risk of poison pilling the worker though,
# implement a workaround with an acceptable trade-off. (=> retry only once?)
@app.on_event("startup")
def on_startup() -> None:
session = get_session().__next__()
jobs = (
session.query(models.Job)
.filter(or_(models.Job.status == dtos.JobStatus.processing, models.Job.status == dtos.JobStatus.create))
.order_by(models.Job.created_at)
).all()
logger.info(f"Re-queueing {len(jobs)} jobs.")
for job in jobs:
queue_task(job)

View File

@@ -6,4 +6,4 @@ set -e
alembic upgrade head
# start app
uvicorn app.web.main:app --reload --host ${HOST:-0.0.0.0} --port ${PORT:-80}
uvicorn app.web.main:app --reload --host ${HOST:-0.0.0.0} --port ${PORT:-80} --log-level info

View File

@@ -1,6 +1,7 @@
from asyncio.log import logger
from uuid import UUID
from celery import Task
from sqlalchemy.orm import Session
import app.shared.db.dtos as dtos
@@ -12,18 +13,24 @@ from app.worker.strategies.local import LocalStrategy
celery = get_celery_binding()
def update_job_status(db: Session, job: models.Job, status: dtos.JobStatus) -> None:
job.status = status
db.commit()
@celery.task()
def transcribe(job_id: UUID) -> None:
@celery.task(
bind=True,
soft_time_limit=2 * 60 * 60 # TODO: make configurable
)
def transcribe(self: Task, job_id: UUID) -> None:
try:
db: Session = SessionLocal()
job = db.query(models.Job).filter(models.Job.id == job_id).one()
update_job_status(db, job, dtos.JobStatus.processing)
if job.status == dtos.JobStatus.error or job.status == dtos.JobStatus.success:
logger.warn("[{job.id}]: Received job that has already been processed, abort.")
return
job.meta = {"task_id": self.request.id}
job.status = dtos.JobStatus.processing
db.commit()
logger.info(f"[{job.id}]: set task to status processing.")
# pick a transcription strategy.
# currently only `local` is supported.
@@ -36,20 +43,30 @@ def transcribe(job_id: UUID) -> None:
# currently only `transcribe` is supported.
if job.type == dtos.JobType.transcript:
result = strategy.transcribe()
logger.info(f"[{job.id}]: successfully transcribed audio.")
elif job.type == dtos.JobType.translation:
result = strategy.translate()
logger.info(f"[{job.id}]: successfully translated audio.")
else:
result = strategy.detect_language()
artifact = models.Artifact(
job_id=job.id, data=result, type=dtos.ArtifactType.raw_transcript
)
db.add(artifact)
db.commit()
logger.info(f"[{job.id}]: successfully stored artifact.")
update_job_status(db, job, dtos.JobStatus.success)
job.status = dtos.JobStatus.success
db.commit()
logger.info(f"[{job.id}]: set task to status success.")
except Exception as e:
logger.error(e)
update_job_status(db, job, dtos.JobStatus.error)
if job and db:
job.meta = { **job.meta, "error": str(e) }
job.status = dtos.JobStatus.error
db.commit()
raise(e)
finally:
db.close()

View File

@@ -1,6 +1,6 @@
import os
import shutil
import tempfile
from asyncio.log import logger
from os import path
from typing import Any, List, Literal, Optional
from uuid import UUID
@@ -26,27 +26,21 @@ class LocalStrategy:
self.job_id = job_id
self.url = url
self.config = config
logger.info(f"[{self.job_id}]: initialized local strategy.")
def transcribe(self) -> List[Any]:
result = self.run_whisper(self._download(), "transcribe")
self._cleanup()
return result
return self.run_whisper(self._download(), "transcribe")
def translate(self) -> List[Any]:
result = self.run_whisper(self._download(), "translate")
self._cleanup()
return result
return self.run_whisper(self._download(), "translate")
def detect_language(self) -> List[Any]:
raise NotImplementedError("detect_language has not been implemented yet.")
def _download(self) -> str:
dirname = self._get_tmp_dir()
filename = path.join(dirname, "media.mp3")
# re-create folder.
shutil.rmtree(dirname, ignore_errors=True)
os.makedirs(dirname)
filename = self._get_tmp_file()
self._cleanup()
# stream media to disk.
with requests.get(self.url, stream=True) as r:
@@ -58,21 +52,29 @@ class LocalStrategy:
return filename
def run_whisper(self, filepath: str, task: str) -> List[Any]:
language = self.config.language if self.config else None
decode_opts = DecodeOptions(task=task, language=language)
model = load_model("small", download_root="/models")
try:
language = self.config.language if self.config else None
model = load_model("small", download_root="/models")
result = model.transcribe(
filepath, condition_on_previous_text=False, **decode_opts.dict()
)
result = model.transcribe(
filepath,
condition_on_previous_text=False,
**DecodeOptions(task=task, language=language).dict(),
)
return result["segments"]
return result["segments"]
finally:
self._cleanup()
def _get_tmp_dir(self) -> str:
return path.join(tempfile.gettempdir(), str(self.job_id))
def _get_tmp_file(self) -> str:
tmp = tempfile.gettempdir()
return path.join(tmp, str(self.job_id))
def _cleanup(self) -> None:
shutil.rmtree(self._get_tmp_dir(), ignore_errors=True)
try:
os.remove(self._get_tmp_file())
except OSError:
pass
def _convert(self) -> None:
pass

View File

@@ -29,13 +29,10 @@ services:
redis:
container_name: whisperbox_redis
image: redis:7-alpine
command: ["redis-server", "--save", "60 1"]
ports:
- 6379:6379
networks:
- app
volumes:
- redis-data:/data
app:
container_name: whisperbox_app
@@ -50,6 +47,8 @@ services:
volumes:
- ../:/code
depends_on:
postgres:
condition: service_healthy
postgres:
condition: service_healthy
redis:
@@ -64,10 +63,14 @@ services:
- ../:/code
environment: *app-variables
depends_on:
- app
- redis
networks:
- app
healthcheck:
test: ["CMD-SHELL", "celery inspect ping -A app.worker.main.transcribe -d celery@$$HOSTNAME"]
interval: 5s
timeout: 5s
retries: 5
flower:
container_name: whisperbox_flower
@@ -82,7 +85,6 @@ services:
volumes:
postgres-data:
redis-data:
networks:
app: