mirror of
https://github.com/bellingcat/whisperbox-transcribe.git
synced 2026-06-11 21:18:36 +03:00
feat: retry lost jobs on startup
This commit is contained in:
@@ -4,6 +4,6 @@ from app.shared.config import settings
|
||||
|
||||
|
||||
def get_celery_binding() -> Celery:
|
||||
celery = Celery("tasks")
|
||||
celery = Celery()
|
||||
celery.conf.broker_url = settings.BROKER_URL
|
||||
return celery
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from asyncio.log import logger
|
||||
from typing import Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path
|
||||
from pydantic import AnyHttpUrl, BaseModel
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import app.shared.db.dtos as dtos
|
||||
@@ -17,6 +19,15 @@ celery = get_celery_binding()
|
||||
api_router = APIRouter(prefix="/api/v1", dependencies=[Depends(authenticate_api_key)])
|
||||
|
||||
|
||||
def queue_task(job: models.Job) -> None:
|
||||
# queue an async transcription task.
|
||||
# we use a signature here to allow full separation of
|
||||
# worker processes and dependencies.
|
||||
transcribe = celery.signature("app.worker.main.transcribe")
|
||||
# TODO: catch delivery errors.
|
||||
transcribe.delay(job.id)
|
||||
|
||||
|
||||
@api_router.get("/")
|
||||
def api_root() -> Dict:
|
||||
return {}
|
||||
@@ -99,3 +110,23 @@ def delete_transcript(
|
||||
|
||||
|
||||
app.include_router(api_router)
|
||||
|
||||
# TODO:
|
||||
# we could use `acks_late` to handle this scenario within celery itself.
|
||||
# the reason this does not work well in our case is that `visibility_timeout`
|
||||
# needs to be very high since whisper workers can be long running.
|
||||
# doing this application-side bears the risk of poison pilling the worker though,
|
||||
# implement a workaround with an acceptable trade-off. (=> retry only once?)
|
||||
@app.on_event("startup")
|
||||
def on_startup() -> None:
|
||||
session = get_session().__next__()
|
||||
|
||||
jobs = (
|
||||
session.query(models.Job)
|
||||
.filter(or_(models.Job.status == dtos.JobStatus.processing, models.Job.status == dtos.JobStatus.create))
|
||||
.order_by(models.Job.created_at)
|
||||
).all()
|
||||
|
||||
logger.info(f"Re-queueing {len(jobs)} jobs.")
|
||||
for job in jobs:
|
||||
queue_task(job)
|
||||
|
||||
@@ -6,4 +6,4 @@ set -e
|
||||
alembic upgrade head
|
||||
|
||||
# start app
|
||||
uvicorn app.web.main:app --reload --host ${HOST:-0.0.0.0} --port ${PORT:-80}
|
||||
uvicorn app.web.main:app --reload --host ${HOST:-0.0.0.0} --port ${PORT:-80} --log-level info
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from asyncio.log import logger
|
||||
from uuid import UUID
|
||||
|
||||
from celery import Task
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import app.shared.db.dtos as dtos
|
||||
@@ -12,18 +13,24 @@ from app.worker.strategies.local import LocalStrategy
|
||||
celery = get_celery_binding()
|
||||
|
||||
|
||||
def update_job_status(db: Session, job: models.Job, status: dtos.JobStatus) -> None:
|
||||
job.status = status
|
||||
db.commit()
|
||||
|
||||
|
||||
@celery.task()
|
||||
def transcribe(job_id: UUID) -> None:
|
||||
@celery.task(
|
||||
bind=True,
|
||||
soft_time_limit=2 * 60 * 60 # TODO: make configurable
|
||||
)
|
||||
def transcribe(self: Task, job_id: UUID) -> None:
|
||||
try:
|
||||
db: Session = SessionLocal()
|
||||
job = db.query(models.Job).filter(models.Job.id == job_id).one()
|
||||
|
||||
update_job_status(db, job, dtos.JobStatus.processing)
|
||||
if job.status == dtos.JobStatus.error or job.status == dtos.JobStatus.success:
|
||||
logger.warn("[{job.id}]: Received job that has already been processed, abort.")
|
||||
return
|
||||
|
||||
job.meta = {"task_id": self.request.id}
|
||||
job.status = dtos.JobStatus.processing
|
||||
db.commit()
|
||||
|
||||
logger.info(f"[{job.id}]: set task to status processing.")
|
||||
|
||||
# pick a transcription strategy.
|
||||
# currently only `local` is supported.
|
||||
@@ -36,20 +43,30 @@ def transcribe(job_id: UUID) -> None:
|
||||
# currently only `transcribe` is supported.
|
||||
if job.type == dtos.JobType.transcript:
|
||||
result = strategy.transcribe()
|
||||
logger.info(f"[{job.id}]: successfully transcribed audio.")
|
||||
elif job.type == dtos.JobType.translation:
|
||||
result = strategy.translate()
|
||||
logger.info(f"[{job.id}]: successfully translated audio.")
|
||||
else:
|
||||
result = strategy.detect_language()
|
||||
|
||||
artifact = models.Artifact(
|
||||
job_id=job.id, data=result, type=dtos.ArtifactType.raw_transcript
|
||||
)
|
||||
|
||||
db.add(artifact)
|
||||
db.commit()
|
||||
logger.info(f"[{job.id}]: successfully stored artifact.")
|
||||
|
||||
update_job_status(db, job, dtos.JobStatus.success)
|
||||
job.status = dtos.JobStatus.success
|
||||
db.commit()
|
||||
|
||||
logger.info(f"[{job.id}]: set task to status success.")
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
update_job_status(db, job, dtos.JobStatus.error)
|
||||
if job and db:
|
||||
job.meta = { **job.meta, "error": str(e) }
|
||||
job.status = dtos.JobStatus.error
|
||||
db.commit()
|
||||
raise(e)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from asyncio.log import logger
|
||||
from os import path
|
||||
from typing import Any, List, Literal, Optional
|
||||
from uuid import UUID
|
||||
@@ -26,27 +26,21 @@ class LocalStrategy:
|
||||
self.job_id = job_id
|
||||
self.url = url
|
||||
self.config = config
|
||||
logger.info(f"[{self.job_id}]: initialized local strategy.")
|
||||
|
||||
def transcribe(self) -> List[Any]:
|
||||
result = self.run_whisper(self._download(), "transcribe")
|
||||
self._cleanup()
|
||||
return result
|
||||
return self.run_whisper(self._download(), "transcribe")
|
||||
|
||||
def translate(self) -> List[Any]:
|
||||
result = self.run_whisper(self._download(), "translate")
|
||||
self._cleanup()
|
||||
return result
|
||||
return self.run_whisper(self._download(), "translate")
|
||||
|
||||
def detect_language(self) -> List[Any]:
|
||||
raise NotImplementedError("detect_language has not been implemented yet.")
|
||||
|
||||
def _download(self) -> str:
|
||||
dirname = self._get_tmp_dir()
|
||||
filename = path.join(dirname, "media.mp3")
|
||||
|
||||
# re-create folder.
|
||||
shutil.rmtree(dirname, ignore_errors=True)
|
||||
os.makedirs(dirname)
|
||||
filename = self._get_tmp_file()
|
||||
self._cleanup()
|
||||
|
||||
# stream media to disk.
|
||||
with requests.get(self.url, stream=True) as r:
|
||||
@@ -58,21 +52,29 @@ class LocalStrategy:
|
||||
return filename
|
||||
|
||||
def run_whisper(self, filepath: str, task: str) -> List[Any]:
|
||||
language = self.config.language if self.config else None
|
||||
decode_opts = DecodeOptions(task=task, language=language)
|
||||
model = load_model("small", download_root="/models")
|
||||
try:
|
||||
language = self.config.language if self.config else None
|
||||
model = load_model("small", download_root="/models")
|
||||
|
||||
result = model.transcribe(
|
||||
filepath, condition_on_previous_text=False, **decode_opts.dict()
|
||||
)
|
||||
result = model.transcribe(
|
||||
filepath,
|
||||
condition_on_previous_text=False,
|
||||
**DecodeOptions(task=task, language=language).dict(),
|
||||
)
|
||||
|
||||
return result["segments"]
|
||||
return result["segments"]
|
||||
finally:
|
||||
self._cleanup()
|
||||
|
||||
def _get_tmp_dir(self) -> str:
|
||||
return path.join(tempfile.gettempdir(), str(self.job_id))
|
||||
def _get_tmp_file(self) -> str:
|
||||
tmp = tempfile.gettempdir()
|
||||
return path.join(tmp, str(self.job_id))
|
||||
|
||||
def _cleanup(self) -> None:
|
||||
shutil.rmtree(self._get_tmp_dir(), ignore_errors=True)
|
||||
try:
|
||||
os.remove(self._get_tmp_file())
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _convert(self) -> None:
|
||||
pass
|
||||
|
||||
@@ -29,13 +29,10 @@ services:
|
||||
redis:
|
||||
container_name: whisperbox_redis
|
||||
image: redis:7-alpine
|
||||
command: ["redis-server", "--save", "60 1"]
|
||||
ports:
|
||||
- 6379:6379
|
||||
networks:
|
||||
- app
|
||||
volumes:
|
||||
- redis-data:/data
|
||||
|
||||
app:
|
||||
container_name: whisperbox_app
|
||||
@@ -50,6 +47,8 @@ services:
|
||||
volumes:
|
||||
- ../:/code
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
@@ -64,10 +63,14 @@ services:
|
||||
- ../:/code
|
||||
environment: *app-variables
|
||||
depends_on:
|
||||
- app
|
||||
- redis
|
||||
networks:
|
||||
- app
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "celery inspect ping -A app.worker.main.transcribe -d celery@$$HOSTNAME"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
flower:
|
||||
container_name: whisperbox_flower
|
||||
@@ -82,7 +85,6 @@ services:
|
||||
|
||||
volumes:
|
||||
postgres-data:
|
||||
redis-data:
|
||||
|
||||
networks:
|
||||
app:
|
||||
|
||||
Reference in New Issue
Block a user