feat: add gpu support

This commit is contained in:
Felix Spöttel
2023-03-01 16:04:05 +01:00
parent f27fe02958
commit 7ece7944bf
6 changed files with 125 additions and 55 deletions

View File

@@ -1,3 +1,4 @@
API_SECRET="change_me"
WHISPER_MODEL="small"
DOMAIN="whisperbox.localhost"
DATABASE_URI="sqlite:///etc/whisperbox/data/whisperbox.sqlite"

View File

@@ -1,4 +1,5 @@
from asyncio.log import logger
from typing import Any, Optional
from uuid import UUID
from celery import Task
@@ -12,11 +13,30 @@ from app.worker.strategies.local import LocalStrategy
celery = get_celery_binding()
class TranscribeTask(Task):
abstract = True
@celery.task(bind=True, soft_time_limit=2 * 60 * 60) # TODO: make configurable
def __init__(self) -> None:
super().__init__()
# currently only `LocalStrategy` is implemented.
# TODO: implement remote processing strategy.
self.strategy: Optional[LocalStrategy] = None
def __call__(self, *args: Any, **kwargs: Any) -> Any:
# load model into memory once when the first task is processed.
if not self.strategy:
self.strategy = LocalStrategy()
return self.run(*args, **kwargs)
@celery.task(
base=TranscribeTask, bind=True, soft_time_limit=2 * 60 * 60
)
def transcribe(self: Task, job_id: UUID) -> None:
try:
# runs in a separate thread => requires sqlite's WAL mode to be enabled.
db: Session = SessionLocal()
job = db.query(models.Job).filter(models.Job.id == job_id).one()
if (
@@ -34,23 +54,23 @@ def transcribe(self: Task, job_id: UUID) -> None:
logger.info(f"[{job.id}]: set task to status processing.")
# pick a transcription strategy.
# currently only `local` is supported.
job_record = schemas.Job.from_orm(job)
strategy = LocalStrategy(
db=db, job_id=job.id, url=job_record.url, config=job_record.config
)
# process selected task.
# currently only `transcribe` is supported.
if job.type == schemas.JobType.transcript:
result = strategy.transcribe()
logger.info(f"[{job.id}]: successfully transcribed audio.")
result = self.strategy.transcribe(
url=job_record.url, job_id=job_record.id, config=job_record.config
)
elif job.type == schemas.JobType.translation:
result = strategy.translate()
logger.info(f"[{job.id}]: successfully translated audio.")
result = self.strategy.translate(
url=job_record.url, job_id=job_record.id, config=job_record.config
)
else:
result = strategy.detect_language()
result = self.strategy.detect_language(
url=job_record.url, job_id=job_record.id, config=job_record.config
)
logger.info(f"[{job.id}]: successfully processed audio.")
artifact = models.Artifact(
job_id=str(job.id), data=result, type=schemas.ArtifactType.raw_transcript
@@ -66,7 +86,7 @@ def transcribe(self: Task, job_id: UUID) -> None:
logger.info(f"[{job.id}]: set task to status success.")
except Exception as e:
if job and db:
job.meta = {**job.meta.__dict__, "error": str(e)}
job.meta = {**job.meta, "error": str(e)} # type: ignore
job.status = schemas.JobStatus.error
db.commit()
raise (e)

View File

@@ -6,9 +6,8 @@ from typing import Any, List, Literal, Optional
from uuid import UUID
import requests
from pydantic import BaseModel
from sqlalchemy.orm import Session
import torch
from pydantic import BaseModel
from whisper import load_model
import app.shared.db.schemas as schemas
@@ -20,43 +19,49 @@ class DecodeOptions(BaseModel):
class LocalStrategy:
def __init__(
self, db: Session, job_id: UUID, url: str, config: Optional[schemas.JobConfig]
):
self.db = db
self.job_id = job_id
self.url = url
self.config = config
def __init__(self) -> None:
if torch.cuda.is_available():
logger.info("initializing GPU model.")
self.model = load_model(
os.environ["WHISPER_MODEL"],
download_root="/models"
os.environ["WHISPER_MODEL"], download_root="/models"
).cuda()
else:
logger.info("initializing CPU model.")
self.model = load_model(
os.environ["WHISPER_MODEL"],
download_root="/models"
os.environ["WHISPER_MODEL"], download_root="/models"
)
logger.info(f"[{self.job_id}]: initialized local strategy.")
logger.info("initialized local strategy.")
def transcribe(self) -> List[Any]:
return self.run_whisper(self._download(), "transcribe")
def transcribe(
self, url: str, job_id: UUID, config: Optional[schemas.JobConfig]
) -> List[Any]:
return self.run_whisper(
self._download(url, job_id), "transcribe", config, job_id
)
def translate(self) -> List[Any]:
return self.run_whisper(self._download(), "translate")
def translate(
self, url: str, job_id: UUID, config: Optional[schemas.JobConfig]
) -> List[Any]:
return self.run_whisper(
self._download(url, job_id),
"translate",
config,
job_id,
)
def detect_language(self) -> List[Any]:
def detect_language(
self, url: str, config: Optional[schemas.JobConfig]
) -> List[Any]:
raise NotImplementedError("detect_language has not been implemented yet.")
def _download(self) -> str:
def _download(self, url: str, job_id: UUID) -> str:
# re-create folder.
filename = self._get_tmp_file()
self._cleanup()
filename = self._get_tmp_file(job_id)
self._cleanup(job_id)
# stream media to disk.
with requests.get(self.url, stream=True) as r:
with requests.get(url, stream=True) as r:
r.raise_for_status()
with open(filename, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
@@ -64,11 +69,17 @@ class LocalStrategy:
return filename
def run_whisper(self, filepath: str, task: str) -> List[Any]:
def run_whisper(
self,
filepath: str,
task: str,
config: Optional[schemas.JobConfig],
job_id: UUID,
) -> List[Any]:
try:
language = self.config.language if self.config else None
language = config.language if config else None
result = model.transcribe(
result = self.model.transcribe(
filepath,
condition_on_previous_text=False,
**DecodeOptions(task=task, language=language).dict(),
@@ -76,20 +87,14 @@ class LocalStrategy:
return result["segments"]
finally:
self._cleanup()
self._cleanup(job_id)
def _get_tmp_file(self) -> str:
def _get_tmp_file(self, job_id: UUID) -> str:
tmp = tempfile.gettempdir()
return path.join(tmp, str(self.job_id))
return path.join(tmp, str(job_id))
def _cleanup(self) -> None:
def _cleanup(self, job_id: UUID) -> None:
try:
os.remove(self._get_tmp_file())
os.remove(self._get_tmp_file(job_id))
except OSError:
pass
def _convert(self) -> None:
pass
def _transcribe(self) -> None:
pass

View File

@@ -64,6 +64,3 @@ networks:
driver: bridge
traefik:
driver: bridge
volumes:
whisperbox-data:

View File

@@ -7,9 +7,12 @@ services:
worker:
container_name: whisperbox_worker
env_file: .env
# <GPU SUPPORT>
# build:
# dockerfile: worker.gpu.Dockerfile
volumes:
- whisperbox-data:/etc/whisperbox/data
# <ENABLE GPU SUPPORT>
# <GPU SUPPORT>
# deploy:
# resources:
# reservations:
@@ -23,3 +26,8 @@ services:
env_file: .env
volumes:
- whisperbox-data:/etc/whisperbox/data
labels:
- "traefik.http.routers.web.entrypoints=web"
volumes:
whisperbox-data:

39
worker.gpu.Dockerfile Normal file
View File

@@ -0,0 +1,39 @@
# TODO: clean up lol
FROM nvidia/cuda:11.8.0-base-ubuntu22.04 AS python-deploy
ENV PYTHON_VERSION=3.10
ARG WHISPER_MODEL
WORKDIR /etc/whisperbox
RUN export DEBIAN_FRONTEND=noninteractive \
&& apt-get -qq update \
&& apt-get -qq install --no-install-recommends \
python${PYTHON_VERSION} \
python${PYTHON_VERSION}-venv \
python3-pip \
&& rm -rf /var/lib/apt/lists/*
RUN ln -s -f /usr/bin/python${PYTHON_VERSION} /usr/bin/python3 && \
ln -s -f /usr/bin/python${PYTHON_VERSION} /usr/bin/python && \
ln -s -f /usr/bin/pip3 /usr/bin/pip
COPY pyproject.toml .
RUN python -m venv /opt/venv && \
/opt/venv/bin/pip install -U pip wheel && \
/opt/venv/bin/pip install -U .[worker]
COPY --from=mwader/static-ffmpeg:latest /ffmpeg /usr/local/bin/
COPY --from=mwader/static-ffmpeg:latest /ffprobe /usr/local/bin/
COPY app ./app
ENV VIRTUAL_ENV /opt/venv
ENV PATH /opt/venv/bin:$PATH
COPY scripts/download_models.py .
RUN python download_models.py ${WHISPER_MODEL}
CMD celery --app=app.worker.main.celery worker --loglevel=info --concurrency=1 --pool=solo