From 238a694f72c0a997b7b6db8d8ffcf522b7a180bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20Sp=C3=B6ttel?= <1682504+fspoettel@users.noreply.github.com> Date: Thu, 29 Jun 2023 12:56:33 +0200 Subject: [PATCH] refactor: move file handling logic to base strategy --- app/worker/strategies/base.py | 27 ++++++++++++++- app/worker/strategies/local.py | 62 +++++++++------------------------- 2 files changed, 42 insertions(+), 47 deletions(-) diff --git a/app/worker/strategies/base.py b/app/worker/strategies/base.py index 8060a89..259cc10 100644 --- a/app/worker/strategies/base.py +++ b/app/worker/strategies/base.py @@ -1,7 +1,11 @@ +import os +import tempfile from abc import ABC from typing import Any, Protocol, Tuple from uuid import UUID +import requests + import app.shared.db.models as models TaskReturnValue = Tuple[models.ArtifactType, Any] @@ -23,4 +27,25 @@ class BaseStrategy(ABC): raise NotImplementedError() def cleanup(self, job_id: UUID) -> None: - raise NotImplementedError() + try: + os.remove(self._get_tmp_file(job_id)) + except OSError: + pass + + def _get_tmp_file(self, job_id: UUID) -> str: + tmp = tempfile.gettempdir() + return os.path.join(tmp, str(job_id)) + + def _download(self, url: str, job_id: UUID) -> str: + # re-create folder. + filename = self._get_tmp_file(job_id) + self.cleanup(job_id) + + # stream media to disk. + with requests.get(url, stream=True) as r: + r.raise_for_status() + with open(filename, "wb") as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) + + return filename diff --git a/app/worker/strategies/local.py b/app/worker/strategies/local.py index 68f6621..a9abb9c 100644 --- a/app/worker/strategies/local.py +++ b/app/worker/strategies/local.py @@ -1,15 +1,11 @@ import os -import tempfile from asyncio.log import logger -from os import path from typing import Any, Literal from uuid import UUID -import requests import torch import whisper from pydantic import BaseModel -from sqlalchemy import JSON, Column import app.shared.db.models as models from app.worker.strategies.base import BaseStrategy, TaskReturnValue @@ -28,23 +24,17 @@ class DecodingOptions(BaseModel): class LocalStrategy(BaseStrategy): def __init__(self) -> None: if torch.cuda.is_available(): - logger.info("initializing GPU model.") + logger.debug("initializing GPU model.") self.model = whisper.load_model( os.environ["WHISPER_MODEL"], download_root="/models" ).cuda() else: - logger.info("initializing CPU model.") + logger.debug("initializing CPU model.") self.model = whisper.load_model( os.environ["WHISPER_MODEL"], download_root="/models" ) - logger.info("initialized local strategy.") - - def cleanup(self, job_id) -> None: - try: - os.remove(self._get_tmp_file(job_id)) - except OSError: - pass + logger.debug("initialized local strategy.") def transcribe(self, job): result = self._run_whisper( @@ -65,8 +55,8 @@ class LocalStrategy(BaseStrategy): def detect_language(self, job) -> TaskReturnValue: file = self._download(job.url, job.id) + # see: https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/README.md?plain=1#L114 audio = whisper.pad_or_trim(whisper.load_audio(file)) - mel = whisper.log_mel_spectrogram(audio).to(self.model.device) _, probs = self.model.detect_language(mel) @@ -75,42 +65,22 @@ class LocalStrategy(BaseStrategy): {"code": max(probs, key=probs.get)}, ) - def _download(self, url: str, job_id: UUID) -> str: - # re-create folder. - filename = self._get_tmp_file(job_id) - self.cleanup(job_id) - - # stream media to disk. - with requests.get(url, stream=True) as r: - r.raise_for_status() - with open(filename, "wb") as f: - for chunk in r.iter_content(chunk_size=8192): - f.write(chunk) - - return filename - def _run_whisper( self, filepath: str, task: Literal["translate", "transcribe"], - config: Column[JSON], + config: dict[str, Any], job_id: UUID, ) -> list[Any]: - try: - result = self.model.transcribe( - filepath, - # turning this off might make the transcription less accurate, - # but significantly reduces amount of model halucinations. - condition_on_previous_text=False, - **DecodingOptions( - task=task, language=config.language if config else None - ).dict(), - ) + result = self.model.transcribe( + filepath, + # turning this off might make the transcription less accurate, + # but significantly reduces amount of model halucinations. + condition_on_previous_text=False, + **DecodingOptions( + task=task, + language=models.JobConfig(**config).language if config else None, + ).dict(), + ) - return result["segments"] - finally: - self.cleanup(job_id) - - def _get_tmp_file(self, job_id: UUID) -> str: - tmp = tempfile.gettempdir() - return path.join(tmp, str(job_id)) + return result["segments"]