From 692a2adb6af77623f0a207dee146125a7206ff7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20Sp=C3=B6ttel?= <1682504+fspoettel@users.noreply.github.com> Date: Wed, 28 Jun 2023 10:01:16 +0200 Subject: [PATCH] fix: always clean up temp files in worker closes #27 --- app/worker/main.py | 32 +++++++++++++++++--------------- app/worker/strategies/local.py | 6 +++--- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/app/worker/main.py b/app/worker/main.py index c15fb57..391ad66 100644 --- a/app/worker/main.py +++ b/app/worker/main.py @@ -30,6 +30,15 @@ class TranscribeTask(Task): return self.run(*args, **kwargs) +def select_strategy(task: Task, job: schemas.Job) -> Any: + if job.type == schemas.JobType.transcript: + return task.strategy.transcribe + elif job.type == schemas.JobType.translation: + return task.strategy.translate + else: + return task.strategy.detect_language + + @celery.task( base=TranscribeTask, bind=True, @@ -52,27 +61,19 @@ def transcribe(self: Task, job_id: UUID) -> None: ) return + logger.info(f"[{job.id}]: worker received task.") + job.meta = {"task_id": self.request.id} job.status = schemas.JobStatus.processing db.commit() - logger.info(f"[{job.id}]: set task to status processing.") job_record = schemas.Job.from_orm(job) - # process selected task. - if job.type == schemas.JobType.transcript: - result = self.strategy.transcribe( - url=job_record.url, job_id=job_record.id, config=job_record.config - ) - elif job.type == schemas.JobType.translation: - result = self.strategy.translate( - url=job_record.url, job_id=job_record.id, config=job_record.config - ) - else: - result = self.strategy.detect_language( - url=job_record.url, job_id=job_record.id, config=job_record.config - ) + strategy = select_strategy(self, job_record) + result = strategy( + url=job_record.url, job_id=job_record.id, config=job_record.config + ) logger.info(f"[{job.id}]: successfully processed audio.") @@ -93,6 +94,7 @@ def transcribe(self: Task, job_id: UUID) -> None: job.meta = {**job.meta, "error": str(e)} # type: ignore job.status = schemas.JobStatus.error db.commit() - raise (e) + raise (e) finally: + self.strategy.cleanup(job_id=job_id) db.close() diff --git a/app/worker/strategies/local.py b/app/worker/strategies/local.py index 09fa0fc..15a8ec2 100644 --- a/app/worker/strategies/local.py +++ b/app/worker/strategies/local.py @@ -58,7 +58,7 @@ class LocalStrategy: def _download(self, url: str, job_id: UUID) -> str: # re-create folder. filename = self._get_tmp_file(job_id) - self._cleanup(job_id) + self.cleanup(job_id) # stream media to disk. with requests.get(url, stream=True) as r: @@ -87,13 +87,13 @@ class LocalStrategy: return result["segments"] finally: - self._cleanup(job_id) + self.cleanup(job_id) def _get_tmp_file(self, job_id: UUID) -> str: tmp = tempfile.gettempdir() return path.join(tmp, str(job_id)) - def _cleanup(self, job_id: UUID) -> None: + def cleanup(self, job_id: UUID) -> None: try: os.remove(self._get_tmp_file(job_id)) except OSError: