feat: add gpu support

This commit is contained in:
Felix Spöttel
2023-03-01 16:04:05 +01:00
parent f27fe02958
commit 7ece7944bf
6 changed files with 125 additions and 55 deletions

View File

@@ -1,3 +1,4 @@
API_SECRET="change_me" API_SECRET="change_me"
WHISPER_MODEL="small" WHISPER_MODEL="small"
DOMAIN="whisperbox.localhost" DOMAIN="whisperbox.localhost"
DATABASE_URI="sqlite:///etc/whisperbox/data/whisperbox.sqlite"

View File

@@ -1,4 +1,5 @@
from asyncio.log import logger from asyncio.log import logger
from typing import Any, Optional
from uuid import UUID from uuid import UUID
from celery import Task from celery import Task
@@ -12,11 +13,30 @@ from app.worker.strategies.local import LocalStrategy
celery = get_celery_binding() celery = get_celery_binding()
class TranscribeTask(Task):
abstract = True
@celery.task(bind=True, soft_time_limit=2 * 60 * 60) # TODO: make configurable def __init__(self) -> None:
super().__init__()
# currently only `LocalStrategy` is implemented.
# TODO: implement remote processing strategy.
self.strategy: Optional[LocalStrategy] = None
def __call__(self, *args: Any, **kwargs: Any) -> Any:
# load model into memory once when the first task is processed.
if not self.strategy:
self.strategy = LocalStrategy()
return self.run(*args, **kwargs)
@celery.task(
base=TranscribeTask, bind=True, soft_time_limit=2 * 60 * 60
)
def transcribe(self: Task, job_id: UUID) -> None: def transcribe(self: Task, job_id: UUID) -> None:
try: try:
# runs in a separate thread => requires sqlite's WAL mode to be enabled.
db: Session = SessionLocal() db: Session = SessionLocal()
job = db.query(models.Job).filter(models.Job.id == job_id).one() job = db.query(models.Job).filter(models.Job.id == job_id).one()
if ( if (
@@ -34,23 +54,23 @@ def transcribe(self: Task, job_id: UUID) -> None:
logger.info(f"[{job.id}]: set task to status processing.") logger.info(f"[{job.id}]: set task to status processing.")
# pick a transcription strategy.
# currently only `local` is supported.
job_record = schemas.Job.from_orm(job) job_record = schemas.Job.from_orm(job)
strategy = LocalStrategy(
db=db, job_id=job.id, url=job_record.url, config=job_record.config
)
# process selected task. # process selected task.
# currently only `transcribe` is supported.
if job.type == schemas.JobType.transcript: if job.type == schemas.JobType.transcript:
result = strategy.transcribe() result = self.strategy.transcribe(
logger.info(f"[{job.id}]: successfully transcribed audio.") url=job_record.url, job_id=job_record.id, config=job_record.config
)
elif job.type == schemas.JobType.translation: elif job.type == schemas.JobType.translation:
result = strategy.translate() result = self.strategy.translate(
logger.info(f"[{job.id}]: successfully translated audio.") url=job_record.url, job_id=job_record.id, config=job_record.config
)
else: else:
result = strategy.detect_language() result = self.strategy.detect_language(
url=job_record.url, job_id=job_record.id, config=job_record.config
)
logger.info(f"[{job.id}]: successfully processed audio.")
artifact = models.Artifact( artifact = models.Artifact(
job_id=str(job.id), data=result, type=schemas.ArtifactType.raw_transcript job_id=str(job.id), data=result, type=schemas.ArtifactType.raw_transcript
@@ -66,7 +86,7 @@ def transcribe(self: Task, job_id: UUID) -> None:
logger.info(f"[{job.id}]: set task to status success.") logger.info(f"[{job.id}]: set task to status success.")
except Exception as e: except Exception as e:
if job and db: if job and db:
job.meta = {**job.meta.__dict__, "error": str(e)} job.meta = {**job.meta, "error": str(e)} # type: ignore
job.status = schemas.JobStatus.error job.status = schemas.JobStatus.error
db.commit() db.commit()
raise (e) raise (e)

View File

@@ -6,9 +6,8 @@ from typing import Any, List, Literal, Optional
from uuid import UUID from uuid import UUID
import requests import requests
from pydantic import BaseModel
from sqlalchemy.orm import Session
import torch import torch
from pydantic import BaseModel
from whisper import load_model from whisper import load_model
import app.shared.db.schemas as schemas import app.shared.db.schemas as schemas
@@ -20,43 +19,49 @@ class DecodeOptions(BaseModel):
class LocalStrategy: class LocalStrategy:
def __init__( def __init__(self) -> None:
self, db: Session, job_id: UUID, url: str, config: Optional[schemas.JobConfig]
):
self.db = db
self.job_id = job_id
self.url = url
self.config = config
if torch.cuda.is_available(): if torch.cuda.is_available():
logger.info("initializing GPU model.")
self.model = load_model( self.model = load_model(
os.environ["WHISPER_MODEL"], os.environ["WHISPER_MODEL"], download_root="/models"
download_root="/models"
).cuda() ).cuda()
else: else:
logger.info("initializing CPU model.")
self.model = load_model( self.model = load_model(
os.environ["WHISPER_MODEL"], os.environ["WHISPER_MODEL"], download_root="/models"
download_root="/models"
) )
logger.info(f"[{self.job_id}]: initialized local strategy.") logger.info("initialized local strategy.")
def transcribe(self) -> List[Any]: def transcribe(
return self.run_whisper(self._download(), "transcribe") self, url: str, job_id: UUID, config: Optional[schemas.JobConfig]
) -> List[Any]:
return self.run_whisper(
self._download(url, job_id), "transcribe", config, job_id
)
def translate(self) -> List[Any]: def translate(
return self.run_whisper(self._download(), "translate") self, url: str, job_id: UUID, config: Optional[schemas.JobConfig]
) -> List[Any]:
return self.run_whisper(
self._download(url, job_id),
"translate",
config,
job_id,
)
def detect_language(self) -> List[Any]: def detect_language(
self, url: str, config: Optional[schemas.JobConfig]
) -> List[Any]:
raise NotImplementedError("detect_language has not been implemented yet.") raise NotImplementedError("detect_language has not been implemented yet.")
def _download(self) -> str: def _download(self, url: str, job_id: UUID) -> str:
# re-create folder. # re-create folder.
filename = self._get_tmp_file() filename = self._get_tmp_file(job_id)
self._cleanup() self._cleanup(job_id)
# stream media to disk. # stream media to disk.
with requests.get(self.url, stream=True) as r: with requests.get(url, stream=True) as r:
r.raise_for_status() r.raise_for_status()
with open(filename, "wb") as f: with open(filename, "wb") as f:
for chunk in r.iter_content(chunk_size=8192): for chunk in r.iter_content(chunk_size=8192):
@@ -64,11 +69,17 @@ class LocalStrategy:
return filename return filename
def run_whisper(self, filepath: str, task: str) -> List[Any]: def run_whisper(
self,
filepath: str,
task: str,
config: Optional[schemas.JobConfig],
job_id: UUID,
) -> List[Any]:
try: try:
language = self.config.language if self.config else None language = config.language if config else None
result = model.transcribe( result = self.model.transcribe(
filepath, filepath,
condition_on_previous_text=False, condition_on_previous_text=False,
**DecodeOptions(task=task, language=language).dict(), **DecodeOptions(task=task, language=language).dict(),
@@ -76,20 +87,14 @@ class LocalStrategy:
return result["segments"] return result["segments"]
finally: finally:
self._cleanup() self._cleanup(job_id)
def _get_tmp_file(self) -> str: def _get_tmp_file(self, job_id: UUID) -> str:
tmp = tempfile.gettempdir() tmp = tempfile.gettempdir()
return path.join(tmp, str(self.job_id)) return path.join(tmp, str(job_id))
def _cleanup(self) -> None: def _cleanup(self, job_id: UUID) -> None:
try: try:
os.remove(self._get_tmp_file()) os.remove(self._get_tmp_file(job_id))
except OSError: except OSError:
pass pass
def _convert(self) -> None:
pass
def _transcribe(self) -> None:
pass

View File

@@ -64,6 +64,3 @@ networks:
driver: bridge driver: bridge
traefik: traefik:
driver: bridge driver: bridge
volumes:
whisperbox-data:

View File

@@ -7,9 +7,12 @@ services:
worker: worker:
container_name: whisperbox_worker container_name: whisperbox_worker
env_file: .env env_file: .env
# <GPU SUPPORT>
# build:
# dockerfile: worker.gpu.Dockerfile
volumes: volumes:
- whisperbox-data:/etc/whisperbox/data - whisperbox-data:/etc/whisperbox/data
# <ENABLE GPU SUPPORT> # <GPU SUPPORT>
# deploy: # deploy:
# resources: # resources:
# reservations: # reservations:
@@ -23,3 +26,8 @@ services:
env_file: .env env_file: .env
volumes: volumes:
- whisperbox-data:/etc/whisperbox/data - whisperbox-data:/etc/whisperbox/data
labels:
- "traefik.http.routers.web.entrypoints=web"
volumes:
whisperbox-data:

39
worker.gpu.Dockerfile Normal file
View File

@@ -0,0 +1,39 @@
# TODO: clean up lol
FROM nvidia/cuda:11.8.0-base-ubuntu22.04 AS python-deploy
ENV PYTHON_VERSION=3.10
ARG WHISPER_MODEL
WORKDIR /etc/whisperbox
RUN export DEBIAN_FRONTEND=noninteractive \
&& apt-get -qq update \
&& apt-get -qq install --no-install-recommends \
python${PYTHON_VERSION} \
python${PYTHON_VERSION}-venv \
python3-pip \
&& rm -rf /var/lib/apt/lists/*
RUN ln -s -f /usr/bin/python${PYTHON_VERSION} /usr/bin/python3 && \
ln -s -f /usr/bin/python${PYTHON_VERSION} /usr/bin/python && \
ln -s -f /usr/bin/pip3 /usr/bin/pip
COPY pyproject.toml .
RUN python -m venv /opt/venv && \
/opt/venv/bin/pip install -U pip wheel && \
/opt/venv/bin/pip install -U .[worker]
COPY --from=mwader/static-ffmpeg:latest /ffmpeg /usr/local/bin/
COPY --from=mwader/static-ffmpeg:latest /ffprobe /usr/local/bin/
COPY app ./app
ENV VIRTUAL_ENV /opt/venv
ENV PATH /opt/venv/bin:$PATH
COPY scripts/download_models.py .
RUN python download_models.py ${WHISPER_MODEL}
CMD celery --app=app.worker.main.celery worker --loglevel=info --concurrency=1 --pool=solo