mirror of
https://github.com/bellingcat/whisperbox-transcribe.git
synced 2026-06-13 05:58:35 +03:00
feat: add whisper tasks
This commit is contained in:
@@ -1,37 +1,55 @@
|
||||
from time import sleep
|
||||
from asyncio.log import logger
|
||||
from uuid import UUID
|
||||
|
||||
from celery import Celery
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import app.shared.db.dtos as dtos
|
||||
import app.shared.db.models as models
|
||||
from app.shared.config import settings
|
||||
from app.shared.celery import get_celery_binding
|
||||
from app.shared.db.base import SessionLocal
|
||||
from app.worker.strategies.local import LocalStrategy
|
||||
|
||||
celery = Celery(__name__)
|
||||
|
||||
celery.conf.broker_url = settings.BROKER_URI
|
||||
celery = get_celery_binding()
|
||||
|
||||
|
||||
def update_job_status(db: Session, job_id: UUID, status: dtos.JobStatus) -> None:
|
||||
db.begin()
|
||||
job = db.query(models.Job).filter(models.Job.id == job_id).one()
|
||||
def update_job_status(db: Session, job: models.Job, status: dtos.JobStatus) -> None:
|
||||
job.status = status
|
||||
db.commit()
|
||||
|
||||
|
||||
@celery.task()
|
||||
def transcribe(job_id: UUID) -> int:
|
||||
def transcribe(job_id: UUID) -> None:
|
||||
try:
|
||||
db: Session = SessionLocal()
|
||||
update_job_status(db, job_id, dtos.JobStatus.processing)
|
||||
sleep(60)
|
||||
update_job_status(db, job_id, dtos.JobStatus.success)
|
||||
job = db.query(models.Job).filter(models.Job.id == job_id).one()
|
||||
|
||||
update_job_status(db, job, dtos.JobStatus.processing)
|
||||
|
||||
# pick a transcription strategy.
|
||||
# currently only `local` is supported.
|
||||
job_record = dtos.Job.from_orm(job)
|
||||
strategy = LocalStrategy(
|
||||
db=db, job_id=job.id, url=job_record.url, config=job_record.config
|
||||
)
|
||||
|
||||
# process selected task.
|
||||
# currently only `transcribe` is supported.
|
||||
if job.type == dtos.JobType.transcript:
|
||||
result = strategy.transcribe()
|
||||
elif job.type == dtos.JobType.translation:
|
||||
result = strategy.translate()
|
||||
else:
|
||||
result = strategy.detect_language()
|
||||
|
||||
artifact = models.Artifact(
|
||||
job_id=job.id, data=result, type=dtos.ArtifactType.raw_transcript
|
||||
)
|
||||
db.add(artifact)
|
||||
db.commit()
|
||||
except Exception:
|
||||
update_job_status(db, job_id, dtos.JobStatus.error)
|
||||
|
||||
update_job_status(db, job, dtos.JobStatus.success)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
update_job_status(db, job, dtos.JobStatus.error)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return 0
|
||||
|
||||
81
app/worker/strategies/local.py
Normal file
81
app/worker/strategies/local.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from os import path
|
||||
from typing import Any, List, Literal, Optional
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from whisper import load_model
|
||||
|
||||
import app.shared.db.dtos as dtos
|
||||
|
||||
|
||||
class DecodeOptions(BaseModel):
|
||||
language: Optional[str]
|
||||
task: Literal["translate", "transcribe"]
|
||||
|
||||
|
||||
class LocalStrategy:
|
||||
def __init__(
|
||||
self, db: Session, job_id: UUID, url: str, config: Optional[dtos.JobConfig]
|
||||
):
|
||||
self.db = db
|
||||
self.job_id = job_id
|
||||
self.url = url
|
||||
self.config = config
|
||||
|
||||
def transcribe(self) -> List[Any]:
|
||||
result = self.run_whisper(self._download(), "transcribe")
|
||||
self._cleanup()
|
||||
return result
|
||||
|
||||
def translate(self) -> List[Any]:
|
||||
result = self.run_whisper(self._download(), "translate")
|
||||
self._cleanup()
|
||||
return result
|
||||
|
||||
def detect_language(self) -> List[Any]:
|
||||
raise NotImplementedError("detect_language has not been implemented yet.")
|
||||
|
||||
def _download(self) -> str:
|
||||
dirname = self._get_tmp_dir()
|
||||
filename = path.join(dirname, "media.mp3")
|
||||
|
||||
# re-create folder.
|
||||
shutil.rmtree(dirname, ignore_errors=True)
|
||||
os.makedirs(dirname)
|
||||
|
||||
# stream media to disk.
|
||||
with requests.get(self.url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(filename, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
return filename
|
||||
|
||||
def run_whisper(self, filepath: str, task: str) -> List[Any]:
|
||||
language = self.config.language if self.config else None
|
||||
decode_opts = DecodeOptions(task=task, language=language)
|
||||
model = load_model("small", download_root="/models")
|
||||
|
||||
result = model.transcribe(
|
||||
filepath, condition_on_previous_text=False, **decode_opts.dict()
|
||||
)
|
||||
|
||||
return result["segments"]
|
||||
|
||||
def _get_tmp_dir(self) -> str:
|
||||
return path.join(tempfile.gettempdir(), str(self.job_id))
|
||||
|
||||
def _cleanup(self) -> None:
|
||||
shutil.rmtree(self._get_tmp_dir(), ignore_errors=True)
|
||||
|
||||
def _convert(self) -> None:
|
||||
pass
|
||||
|
||||
def _transcribe(self) -> None:
|
||||
pass
|
||||
Reference in New Issue
Block a user