refactor: move file handling logic to base strategy

This commit is contained in:
Felix Spöttel
2023-06-29 12:56:33 +02:00
parent b3f8d5c82a
commit 238a694f72
2 changed files with 42 additions and 47 deletions

View File

@@ -1,7 +1,11 @@
import os
import tempfile
from abc import ABC
from typing import Any, Protocol, Tuple
from uuid import UUID
import requests
import app.shared.db.models as models
TaskReturnValue = Tuple[models.ArtifactType, Any]
@@ -23,4 +27,25 @@ class BaseStrategy(ABC):
raise NotImplementedError()
def cleanup(self, job_id: UUID) -> None:
raise NotImplementedError()
try:
os.remove(self._get_tmp_file(job_id))
except OSError:
pass
def _get_tmp_file(self, job_id: UUID) -> str:
tmp = tempfile.gettempdir()
return os.path.join(tmp, str(job_id))
def _download(self, url: str, job_id: UUID) -> str:
# re-create folder.
filename = self._get_tmp_file(job_id)
self.cleanup(job_id)
# stream media to disk.
with requests.get(url, stream=True) as r:
r.raise_for_status()
with open(filename, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
return filename

View File

@@ -1,15 +1,11 @@
import os
import tempfile
from asyncio.log import logger
from os import path
from typing import Any, Literal
from uuid import UUID
import requests
import torch
import whisper
from pydantic import BaseModel
from sqlalchemy import JSON, Column
import app.shared.db.models as models
from app.worker.strategies.base import BaseStrategy, TaskReturnValue
@@ -28,23 +24,17 @@ class DecodingOptions(BaseModel):
class LocalStrategy(BaseStrategy):
def __init__(self) -> None:
if torch.cuda.is_available():
logger.info("initializing GPU model.")
logger.debug("initializing GPU model.")
self.model = whisper.load_model(
os.environ["WHISPER_MODEL"], download_root="/models"
).cuda()
else:
logger.info("initializing CPU model.")
logger.debug("initializing CPU model.")
self.model = whisper.load_model(
os.environ["WHISPER_MODEL"], download_root="/models"
)
logger.info("initialized local strategy.")
def cleanup(self, job_id) -> None:
try:
os.remove(self._get_tmp_file(job_id))
except OSError:
pass
logger.debug("initialized local strategy.")
def transcribe(self, job):
result = self._run_whisper(
@@ -65,8 +55,8 @@ class LocalStrategy(BaseStrategy):
def detect_language(self, job) -> TaskReturnValue:
file = self._download(job.url, job.id)
# see: https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/README.md?plain=1#L114
audio = whisper.pad_or_trim(whisper.load_audio(file))
mel = whisper.log_mel_spectrogram(audio).to(self.model.device)
_, probs = self.model.detect_language(mel)
@@ -75,42 +65,22 @@ class LocalStrategy(BaseStrategy):
{"code": max(probs, key=probs.get)},
)
def _download(self, url: str, job_id: UUID) -> str:
# re-create folder.
filename = self._get_tmp_file(job_id)
self.cleanup(job_id)
# stream media to disk.
with requests.get(url, stream=True) as r:
r.raise_for_status()
with open(filename, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
return filename
def _run_whisper(
self,
filepath: str,
task: Literal["translate", "transcribe"],
config: Column[JSON],
config: dict[str, Any],
job_id: UUID,
) -> list[Any]:
try:
result = self.model.transcribe(
filepath,
# turning this off might make the transcription less accurate,
# but significantly reduces amount of model halucinations.
condition_on_previous_text=False,
**DecodingOptions(
task=task, language=config.language if config else None
).dict(),
)
result = self.model.transcribe(
filepath,
# turning this off might make the transcription less accurate,
# but significantly reduces amount of model halucinations.
condition_on_previous_text=False,
**DecodingOptions(
task=task,
language=models.JobConfig(**config).language if config else None,
).dict(),
)
return result["segments"]
finally:
self.cleanup(job_id)
def _get_tmp_file(self, job_id: UUID) -> str:
tmp = tempfile.gettempdir()
return path.join(tmp, str(job_id))
return result["segments"]