mirror of
https://github.com/bellingcat/whisperbox-transcribe.git
synced 2026-06-13 05:58:35 +03:00
feat: add gpu support
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
API_SECRET="change_me"
|
API_SECRET="change_me"
|
||||||
WHISPER_MODEL="small"
|
WHISPER_MODEL="small"
|
||||||
DOMAIN="whisperbox.localhost"
|
DOMAIN="whisperbox.localhost"
|
||||||
|
DATABASE_URI="sqlite:///etc/whisperbox/data/whisperbox.sqlite"
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from asyncio.log import logger
|
from asyncio.log import logger
|
||||||
|
from typing import Any, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from celery import Task
|
from celery import Task
|
||||||
@@ -12,11 +13,30 @@ from app.worker.strategies.local import LocalStrategy
|
|||||||
|
|
||||||
celery = get_celery_binding()
|
celery = get_celery_binding()
|
||||||
|
|
||||||
|
class TranscribeTask(Task):
|
||||||
|
abstract = True
|
||||||
|
|
||||||
@celery.task(bind=True, soft_time_limit=2 * 60 * 60) # TODO: make configurable
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
# currently only `LocalStrategy` is implemented.
|
||||||
|
# TODO: implement remote processing strategy.
|
||||||
|
self.strategy: Optional[LocalStrategy] = None
|
||||||
|
|
||||||
|
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
# load model into memory once when the first task is processed.
|
||||||
|
if not self.strategy:
|
||||||
|
self.strategy = LocalStrategy()
|
||||||
|
return self.run(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@celery.task(
|
||||||
|
base=TranscribeTask, bind=True, soft_time_limit=2 * 60 * 60
|
||||||
|
)
|
||||||
def transcribe(self: Task, job_id: UUID) -> None:
|
def transcribe(self: Task, job_id: UUID) -> None:
|
||||||
try:
|
try:
|
||||||
|
# runs in a separate thread => requires sqlite's WAL mode to be enabled.
|
||||||
db: Session = SessionLocal()
|
db: Session = SessionLocal()
|
||||||
|
|
||||||
job = db.query(models.Job).filter(models.Job.id == job_id).one()
|
job = db.query(models.Job).filter(models.Job.id == job_id).one()
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@@ -34,23 +54,23 @@ def transcribe(self: Task, job_id: UUID) -> None:
|
|||||||
|
|
||||||
logger.info(f"[{job.id}]: set task to status processing.")
|
logger.info(f"[{job.id}]: set task to status processing.")
|
||||||
|
|
||||||
# pick a transcription strategy.
|
|
||||||
# currently only `local` is supported.
|
|
||||||
job_record = schemas.Job.from_orm(job)
|
job_record = schemas.Job.from_orm(job)
|
||||||
strategy = LocalStrategy(
|
|
||||||
db=db, job_id=job.id, url=job_record.url, config=job_record.config
|
|
||||||
)
|
|
||||||
|
|
||||||
# process selected task.
|
# process selected task.
|
||||||
# currently only `transcribe` is supported.
|
|
||||||
if job.type == schemas.JobType.transcript:
|
if job.type == schemas.JobType.transcript:
|
||||||
result = strategy.transcribe()
|
result = self.strategy.transcribe(
|
||||||
logger.info(f"[{job.id}]: successfully transcribed audio.")
|
url=job_record.url, job_id=job_record.id, config=job_record.config
|
||||||
|
)
|
||||||
elif job.type == schemas.JobType.translation:
|
elif job.type == schemas.JobType.translation:
|
||||||
result = strategy.translate()
|
result = self.strategy.translate(
|
||||||
logger.info(f"[{job.id}]: successfully translated audio.")
|
url=job_record.url, job_id=job_record.id, config=job_record.config
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
result = strategy.detect_language()
|
result = self.strategy.detect_language(
|
||||||
|
url=job_record.url, job_id=job_record.id, config=job_record.config
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"[{job.id}]: successfully processed audio.")
|
||||||
|
|
||||||
artifact = models.Artifact(
|
artifact = models.Artifact(
|
||||||
job_id=str(job.id), data=result, type=schemas.ArtifactType.raw_transcript
|
job_id=str(job.id), data=result, type=schemas.ArtifactType.raw_transcript
|
||||||
@@ -66,7 +86,7 @@ def transcribe(self: Task, job_id: UUID) -> None:
|
|||||||
logger.info(f"[{job.id}]: set task to status success.")
|
logger.info(f"[{job.id}]: set task to status success.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if job and db:
|
if job and db:
|
||||||
job.meta = {**job.meta.__dict__, "error": str(e)}
|
job.meta = {**job.meta, "error": str(e)} # type: ignore
|
||||||
job.status = schemas.JobStatus.error
|
job.status = schemas.JobStatus.error
|
||||||
db.commit()
|
db.commit()
|
||||||
raise (e)
|
raise (e)
|
||||||
|
|||||||
@@ -6,9 +6,8 @@ from typing import Any, List, Literal, Optional
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from pydantic import BaseModel
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
import torch
|
import torch
|
||||||
|
from pydantic import BaseModel
|
||||||
from whisper import load_model
|
from whisper import load_model
|
||||||
|
|
||||||
import app.shared.db.schemas as schemas
|
import app.shared.db.schemas as schemas
|
||||||
@@ -20,43 +19,49 @@ class DecodeOptions(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class LocalStrategy:
|
class LocalStrategy:
|
||||||
def __init__(
|
def __init__(self) -> None:
|
||||||
self, db: Session, job_id: UUID, url: str, config: Optional[schemas.JobConfig]
|
|
||||||
):
|
|
||||||
self.db = db
|
|
||||||
self.job_id = job_id
|
|
||||||
self.url = url
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
logger.info("initializing GPU model.")
|
||||||
self.model = load_model(
|
self.model = load_model(
|
||||||
os.environ["WHISPER_MODEL"],
|
os.environ["WHISPER_MODEL"], download_root="/models"
|
||||||
download_root="/models"
|
|
||||||
).cuda()
|
).cuda()
|
||||||
else:
|
else:
|
||||||
|
logger.info("initializing CPU model.")
|
||||||
self.model = load_model(
|
self.model = load_model(
|
||||||
os.environ["WHISPER_MODEL"],
|
os.environ["WHISPER_MODEL"], download_root="/models"
|
||||||
download_root="/models"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"[{self.job_id}]: initialized local strategy.")
|
logger.info("initialized local strategy.")
|
||||||
|
|
||||||
def transcribe(self) -> List[Any]:
|
def transcribe(
|
||||||
return self.run_whisper(self._download(), "transcribe")
|
self, url: str, job_id: UUID, config: Optional[schemas.JobConfig]
|
||||||
|
) -> List[Any]:
|
||||||
|
return self.run_whisper(
|
||||||
|
self._download(url, job_id), "transcribe", config, job_id
|
||||||
|
)
|
||||||
|
|
||||||
def translate(self) -> List[Any]:
|
def translate(
|
||||||
return self.run_whisper(self._download(), "translate")
|
self, url: str, job_id: UUID, config: Optional[schemas.JobConfig]
|
||||||
|
) -> List[Any]:
|
||||||
|
return self.run_whisper(
|
||||||
|
self._download(url, job_id),
|
||||||
|
"translate",
|
||||||
|
config,
|
||||||
|
job_id,
|
||||||
|
)
|
||||||
|
|
||||||
def detect_language(self) -> List[Any]:
|
def detect_language(
|
||||||
|
self, url: str, config: Optional[schemas.JobConfig]
|
||||||
|
) -> List[Any]:
|
||||||
raise NotImplementedError("detect_language has not been implemented yet.")
|
raise NotImplementedError("detect_language has not been implemented yet.")
|
||||||
|
|
||||||
def _download(self) -> str:
|
def _download(self, url: str, job_id: UUID) -> str:
|
||||||
# re-create folder.
|
# re-create folder.
|
||||||
filename = self._get_tmp_file()
|
filename = self._get_tmp_file(job_id)
|
||||||
self._cleanup()
|
self._cleanup(job_id)
|
||||||
|
|
||||||
# stream media to disk.
|
# stream media to disk.
|
||||||
with requests.get(self.url, stream=True) as r:
|
with requests.get(url, stream=True) as r:
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
with open(filename, "wb") as f:
|
with open(filename, "wb") as f:
|
||||||
for chunk in r.iter_content(chunk_size=8192):
|
for chunk in r.iter_content(chunk_size=8192):
|
||||||
@@ -64,11 +69,17 @@ class LocalStrategy:
|
|||||||
|
|
||||||
return filename
|
return filename
|
||||||
|
|
||||||
def run_whisper(self, filepath: str, task: str) -> List[Any]:
|
def run_whisper(
|
||||||
|
self,
|
||||||
|
filepath: str,
|
||||||
|
task: str,
|
||||||
|
config: Optional[schemas.JobConfig],
|
||||||
|
job_id: UUID,
|
||||||
|
) -> List[Any]:
|
||||||
try:
|
try:
|
||||||
language = self.config.language if self.config else None
|
language = config.language if config else None
|
||||||
|
|
||||||
result = model.transcribe(
|
result = self.model.transcribe(
|
||||||
filepath,
|
filepath,
|
||||||
condition_on_previous_text=False,
|
condition_on_previous_text=False,
|
||||||
**DecodeOptions(task=task, language=language).dict(),
|
**DecodeOptions(task=task, language=language).dict(),
|
||||||
@@ -76,20 +87,14 @@ class LocalStrategy:
|
|||||||
|
|
||||||
return result["segments"]
|
return result["segments"]
|
||||||
finally:
|
finally:
|
||||||
self._cleanup()
|
self._cleanup(job_id)
|
||||||
|
|
||||||
def _get_tmp_file(self) -> str:
|
def _get_tmp_file(self, job_id: UUID) -> str:
|
||||||
tmp = tempfile.gettempdir()
|
tmp = tempfile.gettempdir()
|
||||||
return path.join(tmp, str(self.job_id))
|
return path.join(tmp, str(job_id))
|
||||||
|
|
||||||
def _cleanup(self) -> None:
|
def _cleanup(self, job_id: UUID) -> None:
|
||||||
try:
|
try:
|
||||||
os.remove(self._get_tmp_file())
|
os.remove(self._get_tmp_file(job_id))
|
||||||
except OSError:
|
except OSError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _convert(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _transcribe(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|||||||
@@ -64,6 +64,3 @@ networks:
|
|||||||
driver: bridge
|
driver: bridge
|
||||||
traefik:
|
traefik:
|
||||||
driver: bridge
|
driver: bridge
|
||||||
|
|
||||||
volumes:
|
|
||||||
whisperbox-data:
|
|
||||||
|
|||||||
@@ -7,9 +7,12 @@ services:
|
|||||||
worker:
|
worker:
|
||||||
container_name: whisperbox_worker
|
container_name: whisperbox_worker
|
||||||
env_file: .env
|
env_file: .env
|
||||||
|
# <GPU SUPPORT>
|
||||||
|
# build:
|
||||||
|
# dockerfile: worker.gpu.Dockerfile
|
||||||
volumes:
|
volumes:
|
||||||
- whisperbox-data:/etc/whisperbox/data
|
- whisperbox-data:/etc/whisperbox/data
|
||||||
# <ENABLE GPU SUPPORT>
|
# <GPU SUPPORT>
|
||||||
# deploy:
|
# deploy:
|
||||||
# resources:
|
# resources:
|
||||||
# reservations:
|
# reservations:
|
||||||
@@ -23,3 +26,8 @@ services:
|
|||||||
env_file: .env
|
env_file: .env
|
||||||
volumes:
|
volumes:
|
||||||
- whisperbox-data:/etc/whisperbox/data
|
- whisperbox-data:/etc/whisperbox/data
|
||||||
|
labels:
|
||||||
|
- "traefik.http.routers.web.entrypoints=web"
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
whisperbox-data:
|
||||||
|
|||||||
39
worker.gpu.Dockerfile
Normal file
39
worker.gpu.Dockerfile
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
# TODO: clean up lol
|
||||||
|
FROM nvidia/cuda:11.8.0-base-ubuntu22.04 AS python-deploy
|
||||||
|
|
||||||
|
ENV PYTHON_VERSION=3.10
|
||||||
|
|
||||||
|
ARG WHISPER_MODEL
|
||||||
|
|
||||||
|
WORKDIR /etc/whisperbox
|
||||||
|
|
||||||
|
RUN export DEBIAN_FRONTEND=noninteractive \
|
||||||
|
&& apt-get -qq update \
|
||||||
|
&& apt-get -qq install --no-install-recommends \
|
||||||
|
python${PYTHON_VERSION} \
|
||||||
|
python${PYTHON_VERSION}-venv \
|
||||||
|
python3-pip \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
RUN ln -s -f /usr/bin/python${PYTHON_VERSION} /usr/bin/python3 && \
|
||||||
|
ln -s -f /usr/bin/python${PYTHON_VERSION} /usr/bin/python && \
|
||||||
|
ln -s -f /usr/bin/pip3 /usr/bin/pip
|
||||||
|
|
||||||
|
COPY pyproject.toml .
|
||||||
|
|
||||||
|
RUN python -m venv /opt/venv && \
|
||||||
|
/opt/venv/bin/pip install -U pip wheel && \
|
||||||
|
/opt/venv/bin/pip install -U .[worker]
|
||||||
|
|
||||||
|
COPY --from=mwader/static-ffmpeg:latest /ffmpeg /usr/local/bin/
|
||||||
|
COPY --from=mwader/static-ffmpeg:latest /ffprobe /usr/local/bin/
|
||||||
|
|
||||||
|
COPY app ./app
|
||||||
|
|
||||||
|
ENV VIRTUAL_ENV /opt/venv
|
||||||
|
ENV PATH /opt/venv/bin:$PATH
|
||||||
|
|
||||||
|
COPY scripts/download_models.py .
|
||||||
|
RUN python download_models.py ${WHISPER_MODEL}
|
||||||
|
|
||||||
|
CMD celery --app=app.worker.main.celery worker --loglevel=info --concurrency=1 --pool=solo
|
||||||
Reference in New Issue
Block a user