mirror of
https://github.com/bellingcat/whisperbox-transcribe.git
synced 2026-06-07 19:18:35 +03:00
refactor: move file handling logic to base strategy
This commit is contained in:
@@ -1,7 +1,11 @@
|
||||
import os
|
||||
import tempfile
|
||||
from abc import ABC
|
||||
from typing import Any, Protocol, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
|
||||
import app.shared.db.models as models
|
||||
|
||||
TaskReturnValue = Tuple[models.ArtifactType, Any]
|
||||
@@ -23,4 +27,25 @@ class BaseStrategy(ABC):
|
||||
raise NotImplementedError()
|
||||
|
||||
def cleanup(self, job_id: UUID) -> None:
|
||||
raise NotImplementedError()
|
||||
try:
|
||||
os.remove(self._get_tmp_file(job_id))
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _get_tmp_file(self, job_id: UUID) -> str:
|
||||
tmp = tempfile.gettempdir()
|
||||
return os.path.join(tmp, str(job_id))
|
||||
|
||||
def _download(self, url: str, job_id: UUID) -> str:
|
||||
# re-create folder.
|
||||
filename = self._get_tmp_file(job_id)
|
||||
self.cleanup(job_id)
|
||||
|
||||
# stream media to disk.
|
||||
with requests.get(url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(filename, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
return filename
|
||||
|
||||
@@ -1,15 +1,11 @@
|
||||
import os
|
||||
import tempfile
|
||||
from asyncio.log import logger
|
||||
from os import path
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
import torch
|
||||
import whisper
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import JSON, Column
|
||||
|
||||
import app.shared.db.models as models
|
||||
from app.worker.strategies.base import BaseStrategy, TaskReturnValue
|
||||
@@ -28,23 +24,17 @@ class DecodingOptions(BaseModel):
|
||||
class LocalStrategy(BaseStrategy):
|
||||
def __init__(self) -> None:
|
||||
if torch.cuda.is_available():
|
||||
logger.info("initializing GPU model.")
|
||||
logger.debug("initializing GPU model.")
|
||||
self.model = whisper.load_model(
|
||||
os.environ["WHISPER_MODEL"], download_root="/models"
|
||||
).cuda()
|
||||
else:
|
||||
logger.info("initializing CPU model.")
|
||||
logger.debug("initializing CPU model.")
|
||||
self.model = whisper.load_model(
|
||||
os.environ["WHISPER_MODEL"], download_root="/models"
|
||||
)
|
||||
|
||||
logger.info("initialized local strategy.")
|
||||
|
||||
def cleanup(self, job_id) -> None:
|
||||
try:
|
||||
os.remove(self._get_tmp_file(job_id))
|
||||
except OSError:
|
||||
pass
|
||||
logger.debug("initialized local strategy.")
|
||||
|
||||
def transcribe(self, job):
|
||||
result = self._run_whisper(
|
||||
@@ -65,8 +55,8 @@ class LocalStrategy(BaseStrategy):
|
||||
def detect_language(self, job) -> TaskReturnValue:
|
||||
file = self._download(job.url, job.id)
|
||||
|
||||
# see: https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/README.md?plain=1#L114
|
||||
audio = whisper.pad_or_trim(whisper.load_audio(file))
|
||||
|
||||
mel = whisper.log_mel_spectrogram(audio).to(self.model.device)
|
||||
_, probs = self.model.detect_language(mel)
|
||||
|
||||
@@ -75,42 +65,22 @@ class LocalStrategy(BaseStrategy):
|
||||
{"code": max(probs, key=probs.get)},
|
||||
)
|
||||
|
||||
def _download(self, url: str, job_id: UUID) -> str:
|
||||
# re-create folder.
|
||||
filename = self._get_tmp_file(job_id)
|
||||
self.cleanup(job_id)
|
||||
|
||||
# stream media to disk.
|
||||
with requests.get(url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(filename, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
return filename
|
||||
|
||||
def _run_whisper(
|
||||
self,
|
||||
filepath: str,
|
||||
task: Literal["translate", "transcribe"],
|
||||
config: Column[JSON],
|
||||
config: dict[str, Any],
|
||||
job_id: UUID,
|
||||
) -> list[Any]:
|
||||
try:
|
||||
result = self.model.transcribe(
|
||||
filepath,
|
||||
# turning this off might make the transcription less accurate,
|
||||
# but significantly reduces amount of model halucinations.
|
||||
condition_on_previous_text=False,
|
||||
**DecodingOptions(
|
||||
task=task, language=config.language if config else None
|
||||
).dict(),
|
||||
)
|
||||
result = self.model.transcribe(
|
||||
filepath,
|
||||
# turning this off might make the transcription less accurate,
|
||||
# but significantly reduces amount of model halucinations.
|
||||
condition_on_previous_text=False,
|
||||
**DecodingOptions(
|
||||
task=task,
|
||||
language=models.JobConfig(**config).language if config else None,
|
||||
).dict(),
|
||||
)
|
||||
|
||||
return result["segments"]
|
||||
finally:
|
||||
self.cleanup(job_id)
|
||||
|
||||
def _get_tmp_file(self, job_id: UUID) -> str:
|
||||
tmp = tempfile.gettempdir()
|
||||
return path.join(tmp, str(job_id))
|
||||
return result["segments"]
|
||||
|
||||
Reference in New Issue
Block a user