From e995b1f2ffc3437409882b30ba9b4a12dde51141 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20Sp=C3=B6ttel?= <1682504+fspoettel@users.noreply.github.com> Date: Sat, 28 Jan 2023 13:41:38 +0100 Subject: [PATCH] feat: retry lost jobs on startup --- app/shared/celery.py | 2 +- app/web/main.py | 31 +++++++++++++++++++++++ app/web/start.sh | 2 +- app/worker/main.py | 39 ++++++++++++++++++++-------- app/worker/strategies/local.py | 46 ++++++++++++++++++---------------- docker/dev.docker-compose.yml | 12 +++++---- 6 files changed, 92 insertions(+), 40 deletions(-) diff --git a/app/shared/celery.py b/app/shared/celery.py index 415f3dc..413700c 100644 --- a/app/shared/celery.py +++ b/app/shared/celery.py @@ -4,6 +4,6 @@ from app.shared.config import settings def get_celery_binding() -> Celery: - celery = Celery("tasks") + celery = Celery() celery.conf.broker_url = settings.BROKER_URL return celery diff --git a/app/web/main.py b/app/web/main.py index 11b563a..a689a00 100644 --- a/app/web/main.py +++ b/app/web/main.py @@ -1,8 +1,10 @@ +from asyncio.log import logger from typing import Dict, List, Optional from uuid import UUID from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path from pydantic import AnyHttpUrl, BaseModel +from sqlalchemy import or_ from sqlalchemy.orm import Session import app.shared.db.dtos as dtos @@ -17,6 +19,15 @@ celery = get_celery_binding() api_router = APIRouter(prefix="/api/v1", dependencies=[Depends(authenticate_api_key)]) +def queue_task(job: models.Job) -> None: + # queue an async transcription task. + # we use a signature here to allow full separation of + # worker processes and dependencies. + transcribe = celery.signature("app.worker.main.transcribe") + # TODO: catch delivery errors. + transcribe.delay(job.id) + + @api_router.get("/") def api_root() -> Dict: return {} @@ -99,3 +110,23 @@ def delete_transcript( app.include_router(api_router) + +# TODO: +# we could use `acks_late` to handle this scenario within celery itself. +# the reason this does not work well in our case is that `visibility_timeout` +# needs to be very high since whisper workers can be long running. +# doing this application-side bears the risk of poison pilling the worker though, +# implement a workaround with an acceptable trade-off. (=> retry only once?) +@app.on_event("startup") +def on_startup() -> None: + session = get_session().__next__() + + jobs = ( + session.query(models.Job) + .filter(or_(models.Job.status == dtos.JobStatus.processing, models.Job.status == dtos.JobStatus.create)) + .order_by(models.Job.created_at) + ).all() + + logger.info(f"Re-queueing {len(jobs)} jobs.") + for job in jobs: + queue_task(job) diff --git a/app/web/start.sh b/app/web/start.sh index 46b7684..9a23993 100755 --- a/app/web/start.sh +++ b/app/web/start.sh @@ -6,4 +6,4 @@ set -e alembic upgrade head # start app -uvicorn app.web.main:app --reload --host ${HOST:-0.0.0.0} --port ${PORT:-80} +uvicorn app.web.main:app --reload --host ${HOST:-0.0.0.0} --port ${PORT:-80} --log-level info diff --git a/app/worker/main.py b/app/worker/main.py index be70bdb..90933b0 100644 --- a/app/worker/main.py +++ b/app/worker/main.py @@ -1,6 +1,7 @@ from asyncio.log import logger from uuid import UUID +from celery import Task from sqlalchemy.orm import Session import app.shared.db.dtos as dtos @@ -12,18 +13,24 @@ from app.worker.strategies.local import LocalStrategy celery = get_celery_binding() -def update_job_status(db: Session, job: models.Job, status: dtos.JobStatus) -> None: - job.status = status - db.commit() - - -@celery.task() -def transcribe(job_id: UUID) -> None: +@celery.task( + bind=True, + soft_time_limit=2 * 60 * 60 # TODO: make configurable +) +def transcribe(self: Task, job_id: UUID) -> None: try: db: Session = SessionLocal() job = db.query(models.Job).filter(models.Job.id == job_id).one() - update_job_status(db, job, dtos.JobStatus.processing) + if job.status == dtos.JobStatus.error or job.status == dtos.JobStatus.success: + logger.warn("[{job.id}]: Received job that has already been processed, abort.") + return + + job.meta = {"task_id": self.request.id} + job.status = dtos.JobStatus.processing + db.commit() + + logger.info(f"[{job.id}]: set task to status processing.") # pick a transcription strategy. # currently only `local` is supported. @@ -36,20 +43,30 @@ def transcribe(job_id: UUID) -> None: # currently only `transcribe` is supported. if job.type == dtos.JobType.transcript: result = strategy.transcribe() + logger.info(f"[{job.id}]: successfully transcribed audio.") elif job.type == dtos.JobType.translation: result = strategy.translate() + logger.info(f"[{job.id}]: successfully translated audio.") else: result = strategy.detect_language() artifact = models.Artifact( job_id=job.id, data=result, type=dtos.ArtifactType.raw_transcript ) + db.add(artifact) db.commit() + logger.info(f"[{job.id}]: successfully stored artifact.") - update_job_status(db, job, dtos.JobStatus.success) + job.status = dtos.JobStatus.success + db.commit() + + logger.info(f"[{job.id}]: set task to status success.") except Exception as e: - logger.error(e) - update_job_status(db, job, dtos.JobStatus.error) + if job and db: + job.meta = { **job.meta, "error": str(e) } + job.status = dtos.JobStatus.error + db.commit() + raise(e) finally: db.close() diff --git a/app/worker/strategies/local.py b/app/worker/strategies/local.py index 21b9c0b..f1dd22a 100644 --- a/app/worker/strategies/local.py +++ b/app/worker/strategies/local.py @@ -1,6 +1,6 @@ import os -import shutil import tempfile +from asyncio.log import logger from os import path from typing import Any, List, Literal, Optional from uuid import UUID @@ -26,27 +26,21 @@ class LocalStrategy: self.job_id = job_id self.url = url self.config = config + logger.info(f"[{self.job_id}]: initialized local strategy.") def transcribe(self) -> List[Any]: - result = self.run_whisper(self._download(), "transcribe") - self._cleanup() - return result + return self.run_whisper(self._download(), "transcribe") def translate(self) -> List[Any]: - result = self.run_whisper(self._download(), "translate") - self._cleanup() - return result + return self.run_whisper(self._download(), "translate") def detect_language(self) -> List[Any]: raise NotImplementedError("detect_language has not been implemented yet.") def _download(self) -> str: - dirname = self._get_tmp_dir() - filename = path.join(dirname, "media.mp3") - # re-create folder. - shutil.rmtree(dirname, ignore_errors=True) - os.makedirs(dirname) + filename = self._get_tmp_file() + self._cleanup() # stream media to disk. with requests.get(self.url, stream=True) as r: @@ -58,21 +52,29 @@ class LocalStrategy: return filename def run_whisper(self, filepath: str, task: str) -> List[Any]: - language = self.config.language if self.config else None - decode_opts = DecodeOptions(task=task, language=language) - model = load_model("small", download_root="/models") + try: + language = self.config.language if self.config else None + model = load_model("small", download_root="/models") - result = model.transcribe( - filepath, condition_on_previous_text=False, **decode_opts.dict() - ) + result = model.transcribe( + filepath, + condition_on_previous_text=False, + **DecodeOptions(task=task, language=language).dict(), + ) - return result["segments"] + return result["segments"] + finally: + self._cleanup() - def _get_tmp_dir(self) -> str: - return path.join(tempfile.gettempdir(), str(self.job_id)) + def _get_tmp_file(self) -> str: + tmp = tempfile.gettempdir() + return path.join(tmp, str(self.job_id)) def _cleanup(self) -> None: - shutil.rmtree(self._get_tmp_dir(), ignore_errors=True) + try: + os.remove(self._get_tmp_file()) + except OSError: + pass def _convert(self) -> None: pass diff --git a/docker/dev.docker-compose.yml b/docker/dev.docker-compose.yml index 5aeadfd..e686eec 100644 --- a/docker/dev.docker-compose.yml +++ b/docker/dev.docker-compose.yml @@ -29,13 +29,10 @@ services: redis: container_name: whisperbox_redis image: redis:7-alpine - command: ["redis-server", "--save", "60 1"] ports: - 6379:6379 networks: - app - volumes: - - redis-data:/data app: container_name: whisperbox_app @@ -50,6 +47,8 @@ services: volumes: - ../:/code depends_on: + postgres: + condition: service_healthy postgres: condition: service_healthy redis: @@ -64,10 +63,14 @@ services: - ../:/code environment: *app-variables depends_on: - - app - redis networks: - app + healthcheck: + test: ["CMD-SHELL", "celery inspect ping -A app.worker.main.transcribe -d celery@$$HOSTNAME"] + interval: 5s + timeout: 5s + retries: 5 flower: container_name: whisperbox_flower @@ -82,7 +85,6 @@ services: volumes: postgres-data: - redis-data: networks: app: