mirror of
https://github.com/bellingcat/whisperbox-transcribe.git
synced 2026-06-13 05:58:35 +03:00
feat: add language detection task
This commit is contained in:
@@ -10,6 +10,7 @@ import app.shared.db.schemas as schemas
|
||||
from app.shared.celery import get_celery_binding
|
||||
from app.shared.db.base import SessionLocal
|
||||
from app.shared.settings import settings
|
||||
from app.worker.strategies.base import TaskProtocol
|
||||
from app.worker.strategies.local import LocalStrategy
|
||||
|
||||
celery = get_celery_binding()
|
||||
@@ -30,10 +31,10 @@ class TranscribeTask(Task):
|
||||
return self.run(*args, **kwargs)
|
||||
|
||||
|
||||
def select_strategy(task: Task, job: schemas.Job) -> Any:
|
||||
if job.type == schemas.JobType.transcript:
|
||||
def select_task_processor(task: Task, job: schemas.Job) -> TaskProtocol:
|
||||
if job.type == models.JobType.transcript:
|
||||
return task.strategy.transcribe
|
||||
elif job.type == schemas.JobType.translation:
|
||||
elif job.type == models.JobType.translation:
|
||||
return task.strategy.translate
|
||||
else:
|
||||
return task.strategy.detect_language
|
||||
@@ -50,49 +51,50 @@ def transcribe(self: Task, job_id: UUID) -> None:
|
||||
# runs in a separate thread => requires sqlite's WAL mode to be enabled.
|
||||
db: Session = SessionLocal()
|
||||
|
||||
job = db.query(models.Job).filter(models.Job.id == job_id).one()
|
||||
# unit of work: set task status to processing.
|
||||
|
||||
if (
|
||||
job.status == schemas.JobStatus.error
|
||||
or job.status == schemas.JobStatus.success
|
||||
):
|
||||
logger.warn(
|
||||
"[{job.id}]: Received job that has already been processed, abort."
|
||||
)
|
||||
job = db.query(models.Job).filter(models.Job.id == job_id).one_or_none()
|
||||
|
||||
if job is None:
|
||||
logger.warn("[{job.id}]: Received unknown job, abort.")
|
||||
return
|
||||
|
||||
logger.info(f"[{job.id}]: worker received task.")
|
||||
if job.status in [models.JobStatus.error, models.JobStatus.success]:
|
||||
logger.warn("[{job.id}]: job has already been processed, abort.")
|
||||
return
|
||||
|
||||
logger.info(f"[{job.id}]: received eligible job.")
|
||||
|
||||
job.meta = {"task_id": self.request.id}
|
||||
job.status = schemas.JobStatus.processing
|
||||
db.commit()
|
||||
logger.info(f"[{job.id}]: set task to status processing.")
|
||||
job.status = models.JobStatus.processing
|
||||
|
||||
db.commit()
|
||||
logger.info(f"[{job.id}]: finished setting task to status processing.")
|
||||
|
||||
# unit of work: process job with whisper.
|
||||
job_record = schemas.Job.from_orm(job)
|
||||
|
||||
strategy = select_strategy(self, job_record)
|
||||
result = strategy(
|
||||
processor = select_task_processor(self, job_record)
|
||||
|
||||
result_type, result = processor(
|
||||
url=job_record.url, job_id=job_record.id, config=job_record.config
|
||||
)
|
||||
|
||||
logger.info(f"[{job.id}]: successfully processed audio.")
|
||||
|
||||
artifact = models.Artifact(
|
||||
job_id=str(job.id), data=result, type=schemas.ArtifactType.raw_transcript
|
||||
)
|
||||
artifact = models.Artifact(job_id=str(job.id), data=result, type=result_type)
|
||||
|
||||
db.add(artifact)
|
||||
|
||||
job.status = models.JobStatus.success
|
||||
|
||||
db.commit()
|
||||
logger.info(f"[{job.id}]: successfully stored artifact.")
|
||||
|
||||
job.status = schemas.JobStatus.success
|
||||
db.commit()
|
||||
|
||||
logger.info(f"[{job.id}]: set task to status success.")
|
||||
except Exception as e:
|
||||
if job and db:
|
||||
db.rollback()
|
||||
job.meta = {**job.meta, "error": str(e)} # type: ignore
|
||||
job.status = schemas.JobStatus.error
|
||||
job.status = models.JobStatus.error
|
||||
db.commit()
|
||||
raise
|
||||
finally:
|
||||
|
||||
35
app/worker/strategies/base.py
Normal file
35
app/worker/strategies/base.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from abc import ABC
|
||||
from typing import Any, Protocol, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
import app.shared.db.models as models
|
||||
import app.shared.db.schemas as schemas
|
||||
|
||||
TaskReturnValue = Tuple[models.ArtifactType, Any]
|
||||
|
||||
|
||||
class TaskProtocol(Protocol):
|
||||
def __call__(
|
||||
self, url: str, job_id: UUID, config: schemas.JobConfig | None
|
||||
) -> TaskReturnValue:
|
||||
...
|
||||
|
||||
|
||||
class BaseStrategy(ABC):
|
||||
def transcribe(
|
||||
self, url: str, job_id: UUID, config: schemas.JobConfig | None
|
||||
) -> TaskReturnValue:
|
||||
raise NotImplementedError()
|
||||
|
||||
def translate(
|
||||
self, url: str, job_id: UUID, config: schemas.JobConfig | None
|
||||
) -> TaskReturnValue:
|
||||
raise NotImplementedError()
|
||||
|
||||
def detect_language(
|
||||
self, url: str, job_id: UUID, config: schemas.JobConfig | None
|
||||
) -> TaskReturnValue:
|
||||
raise NotImplementedError()
|
||||
|
||||
def cleanup(self, job_id: UUID) -> None:
|
||||
raise NotImplementedError()
|
||||
@@ -7,10 +7,11 @@ from uuid import UUID
|
||||
|
||||
import requests
|
||||
import torch
|
||||
import whisper
|
||||
from pydantic import BaseModel
|
||||
from whisper import load_model
|
||||
|
||||
import app.shared.db.schemas as schemas
|
||||
from app.worker.strategies.base import BaseStrategy, TaskReturnValue
|
||||
|
||||
|
||||
class DecodeOptions(BaseModel):
|
||||
@@ -18,40 +19,58 @@ class DecodeOptions(BaseModel):
|
||||
task: Literal["translate", "transcribe"]
|
||||
|
||||
|
||||
class LocalStrategy:
|
||||
class LocalStrategy(BaseStrategy):
|
||||
def __init__(self) -> None:
|
||||
if torch.cuda.is_available():
|
||||
logger.info("initializing GPU model.")
|
||||
self.model = load_model(
|
||||
self.model = whisper.load_model(
|
||||
os.environ["WHISPER_MODEL"], download_root="/models"
|
||||
).cuda()
|
||||
else:
|
||||
logger.info("initializing CPU model.")
|
||||
self.model = load_model(
|
||||
self.model = whisper.load_model(
|
||||
os.environ["WHISPER_MODEL"], download_root="/models"
|
||||
)
|
||||
|
||||
logger.info("initialized local strategy.")
|
||||
|
||||
def transcribe(
|
||||
self, url: str, job_id: UUID, config: schemas.JobConfig | None
|
||||
) -> list[Any]:
|
||||
return self.run_whisper(
|
||||
self._download(url, job_id), "transcribe", config, job_id
|
||||
def cleanup(self, job_id) -> None:
|
||||
try:
|
||||
os.remove(self._get_tmp_file(job_id))
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def transcribe(self, url, job_id, config):
|
||||
return (
|
||||
schemas.ArtifactType.raw_transcript,
|
||||
self._run_whisper(
|
||||
self._download(url, job_id), "transcribe", config, job_id
|
||||
),
|
||||
)
|
||||
|
||||
def translate(
|
||||
self, url: str, job_id: UUID, config: schemas.JobConfig | None
|
||||
) -> list[Any]:
|
||||
return self.run_whisper(
|
||||
self._download(url, job_id),
|
||||
"translate",
|
||||
config,
|
||||
job_id,
|
||||
def translate(self, url, job_id, config) -> TaskReturnValue:
|
||||
return (
|
||||
schemas.ArtifactType.raw_transcript,
|
||||
self._run_whisper(
|
||||
self._download(url, job_id),
|
||||
"translate",
|
||||
config,
|
||||
job_id,
|
||||
),
|
||||
)
|
||||
|
||||
def detect_language(self, url: str, config: schemas.JobConfig | None) -> list[Any]:
|
||||
raise NotImplementedError("detect_language has not been implemented yet.")
|
||||
def detect_language(self, url, job_id, config) -> TaskReturnValue:
|
||||
file = self._download(url, job_id)
|
||||
|
||||
audio = whisper.pad_or_trim(whisper.load_audio(file))
|
||||
|
||||
mel = whisper.log_mel_spectrogram(audio).to(self.model.device)
|
||||
_, probs = self.model.detect_language(mel)
|
||||
|
||||
return (
|
||||
schemas.ArtifactType.language_detection,
|
||||
{"code": max(probs, key=probs.get)},
|
||||
)
|
||||
|
||||
def _download(self, url: str, job_id: UUID) -> str:
|
||||
# re-create folder.
|
||||
@@ -67,7 +86,7 @@ class LocalStrategy:
|
||||
|
||||
return filename
|
||||
|
||||
def run_whisper(
|
||||
def _run_whisper(
|
||||
self,
|
||||
filepath: str,
|
||||
task: Literal["translate", "transcribe"],
|
||||
@@ -90,9 +109,3 @@ class LocalStrategy:
|
||||
def _get_tmp_file(self, job_id: UUID) -> str:
|
||||
tmp = tempfile.gettempdir()
|
||||
return path.join(tmp, str(job_id))
|
||||
|
||||
def cleanup(self, job_id: UUID) -> None:
|
||||
try:
|
||||
os.remove(self._get_tmp_file(job_id))
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user