From 85c3f1fc4491e595eba8e4dd25daec6962e9ced5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20Sp=C3=B6ttel?= <1682504+fspoettel@users.noreply.github.com> Date: Thu, 29 Jun 2023 16:31:57 +0200 Subject: [PATCH] refactor: add `strategy.process()` to BaseStrategy --- app/web/task_queue.py | 1 + app/worker/main.py | 29 ++++++++++------------------- app/worker/strategies/base.py | 20 ++++++++++++++------ 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/app/web/task_queue.py b/app/web/task_queue.py index 03dd01d..589d465 100644 --- a/app/web/task_queue.py +++ b/app/web/task_queue.py @@ -20,6 +20,7 @@ class TaskQueue: allow for full separation of worker processes and dependencies. """ transcribe = self.celery.signature("app.worker.main.transcribe") + # TODO: catch delivery errors? transcribe.delay(job.id) def rehydrate(self, session: Session): diff --git a/app/worker/main.py b/app/worker/main.py index 13eebc4..7957971 100644 --- a/app/worker/main.py +++ b/app/worker/main.py @@ -9,13 +9,17 @@ import app.shared.db.models as models from app.shared.celery import get_celery_binding from app.shared.db.base import SessionLocal from app.shared.settings import settings -from app.worker.strategies.base import TaskProtocol from app.worker.strategies.local import LocalStrategy celery = get_celery_binding() class TranscribeTask(Task): + """ + Decorate the transcribe task with an instance of the transcription strategy. + This is important for the local strategy, where loading the model is expensive. + """ + abstract = True def __init__(self) -> None: @@ -30,15 +34,6 @@ class TranscribeTask(Task): return self.run(*args, **kwargs) -def select_task_processor(task: Task, job: models.Job) -> TaskProtocol: - if job.type == models.JobType.transcript: - return task.strategy.transcribe - elif job.type == models.JobType.translation: - return task.strategy.translate - else: - return task.strategy.detect_language - - @celery.task( base=TranscribeTask, bind=True, @@ -73,29 +68,25 @@ def transcribe(self: Task, job_id: UUID) -> None: logger.debug(f"[{job.id}]: finished setting task to {job.status}.") # unit of work: process job with whisper. - - processor = select_task_processor(self, job) - - result_type, result = processor(job) - + result_type, result = self.strategy.process(job) logger.debug(f"[{job.id}]: successfully processed audio.") artifact = models.Artifact(job_id=str(job.id), data=result, type=result_type) - db.add(artifact) job.status = models.JobStatus.success - db.commit() + logger.debug(f"[{job.id}]: successfully stored artifact.") except Exception as e: if job and db: - db.rollback() + if db.in_transaction(): + db.rollback() job.meta = {**job.meta, "error": str(e)} # type: ignore job.status = models.JobStatus.error db.commit() raise finally: - self.strategy.cleanup(job_id=job_id) + self.strategy.cleanup(job_id) db.close() diff --git a/app/worker/strategies/base.py b/app/worker/strategies/base.py index 259cc10..1243582 100644 --- a/app/worker/strategies/base.py +++ b/app/worker/strategies/base.py @@ -17,6 +17,20 @@ class TaskProtocol(Protocol): class BaseStrategy(ABC): + def process(self, job: models.Job) -> TaskReturnValue: + if job.type == models.JobType.transcript: + return self.transcribe(job) + elif job.type == models.JobType.translation: + return self.translate(job) + else: + return self.detect_language(job) + + def cleanup(self, job_id: UUID) -> None: + try: + os.remove(self._get_tmp_file(job_id)) + except OSError: + ... + def transcribe(self, job: models.Job) -> TaskReturnValue: raise NotImplementedError() @@ -26,12 +40,6 @@ class BaseStrategy(ABC): def detect_language(self, job: models.Job) -> TaskReturnValue: raise NotImplementedError() - def cleanup(self, job_id: UUID) -> None: - try: - os.remove(self._get_tmp_file(job_id)) - except OSError: - pass - def _get_tmp_file(self, job_id: UUID) -> str: tmp = tempfile.gettempdir() return os.path.join(tmp, str(job_id))