Files
2024-04-19 21:51:19 +02:00

127 lines
4.0 KiB
Python

from typing import Any
from uuid import UUID
from celery import Task
from sqlalchemy.orm import Session
import app.shared.db.models as models
from app.shared.celery import get_celery_binding
from app.shared.db.base import make_engine, make_session_local
from app.shared.logger import logger
from app.shared.settings import Settings
from app.worker.strategies.local import LocalStrategy
# TODO: refactor to be part of a Task instance.
settings = Settings() # type: ignore
celery = get_celery_binding(settings.BROKER_URL)
engine = make_engine(settings.DATABASE_URI)
SessionLocal = make_session_local(engine)
class TranscribeTask(Task):
"""
Decorate the transcribe task with an instance of the transcription strategy.
This is important for the local strategy, where loading the model is expensive.
"""
abstract = True
def __init__(self) -> None:
super().__init__()
# currently only `LocalStrategy` is implemented.
self.strategy: LocalStrategy | None = None
def __call__(self, *args: Any, **kwargs: Any) -> Any:
# load model into memory once when the first task is processed.
if not self.strategy:
self.strategy = LocalStrategy()
return self.run(*args, **kwargs)
@celery.task(
base=TranscribeTask,
bind=True,
soft_time_limit=settings.TASK_SOFT_TIME_LIMIT,
time_limit=settings.TASK_HARD_TIME_LIMIT,
task_acks_late=True,
task_acks_on_failure_or_timeout=True,
task_reject_on_worker_lost=True,
)
def transcribe(self: TranscribeTask, job_id: UUID) -> None:
session: Session | None = None
job: models.Job | None = None
try:
if not self.strategy:
raise Exception("expected a transcription strategy to be defined.")
# runs in a separate thread => requires sqlite's WAL mode to be enabled.
session = SessionLocal()
# work around mypy not inferring the sum type correctly.
if not session:
raise Exception("failed to acquire a session.")
# check if passed job should be processed.
job = session.query(models.Job).filter(models.Job.id == job_id).one_or_none()
if job is None:
logger.warn("[unknown]: Received unknown job, abort.")
return
if job.status in [models.JobStatus.error, models.JobStatus.success]:
logger.warn(f"[{job.id}]: job has already been processed, abort.")
return
logger.debug(f"[{job.id}]: start processing {job.type} job.")
if job.meta is not None:
attempts = 1 + (job.meta.get("attempts") or 0)
else:
attempts = 1
# SAFEGUARD: celery's retry policies do not handle lost workers, retry once.
# @see https://github.com/celery/celery/pull/6103
if attempts > 2:
raise Exception("Maximum number of retries exceeded for killed worker.")
# unit of work: set task status to processing.
job.meta = {"task_id": self.request.id, "attempts": attempts}
job.status = models.JobStatus.processing
session.commit()
logger.debug(f"[{job.id}]: finished setting task to {job.status}.")
# unit of work: process job with whisper.
result_type, result = self.strategy.process(job)
logger.debug(f"[{job.id}]: successfully processed audio.")
artifact = models.Artifact(job_id=str(job.id), data=result, type=result_type)
session.add(artifact)
job.status = models.JobStatus.success
session.commit()
logger.debug(f"[{job.id}]: successfully stored artifact.")
except Exception as e:
if job and session:
if session.in_transaction():
session.rollback()
if job.meta is not None:
job.meta = {**job.meta, "error": str(e)}
else:
job.meta = {"error": str(e)}
job.status = models.JobStatus.error
session.commit()
raise
finally:
if self.strategy:
self.strategy.cleanup(job_id)
if session:
session.close()