diff --git a/.env.example b/.env.example index becac1f..76ceb1d 100644 --- a/.env.example +++ b/.env.example @@ -1,3 +1,4 @@ API_SECRET="change_me" WHISPER_MODEL="small" DOMAIN="whisperbox.localhost" +DATABASE_URI="sqlite:///etc/whisperbox/data/whisperbox.sqlite" diff --git a/app/worker/main.py b/app/worker/main.py index c6acc09..5c7536d 100644 --- a/app/worker/main.py +++ b/app/worker/main.py @@ -1,4 +1,5 @@ from asyncio.log import logger +from typing import Any, Optional from uuid import UUID from celery import Task @@ -12,11 +13,30 @@ from app.worker.strategies.local import LocalStrategy celery = get_celery_binding() +class TranscribeTask(Task): + abstract = True -@celery.task(bind=True, soft_time_limit=2 * 60 * 60) # TODO: make configurable + def __init__(self) -> None: + super().__init__() + # currently only `LocalStrategy` is implemented. + # TODO: implement remote processing strategy. + self.strategy: Optional[LocalStrategy] = None + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + # load model into memory once when the first task is processed. + if not self.strategy: + self.strategy = LocalStrategy() + return self.run(*args, **kwargs) + + +@celery.task( + base=TranscribeTask, bind=True, soft_time_limit=2 * 60 * 60 +) def transcribe(self: Task, job_id: UUID) -> None: try: + # runs in a separate thread => requires sqlite's WAL mode to be enabled. db: Session = SessionLocal() + job = db.query(models.Job).filter(models.Job.id == job_id).one() if ( @@ -34,23 +54,23 @@ def transcribe(self: Task, job_id: UUID) -> None: logger.info(f"[{job.id}]: set task to status processing.") - # pick a transcription strategy. - # currently only `local` is supported. job_record = schemas.Job.from_orm(job) - strategy = LocalStrategy( - db=db, job_id=job.id, url=job_record.url, config=job_record.config - ) # process selected task. - # currently only `transcribe` is supported. if job.type == schemas.JobType.transcript: - result = strategy.transcribe() - logger.info(f"[{job.id}]: successfully transcribed audio.") + result = self.strategy.transcribe( + url=job_record.url, job_id=job_record.id, config=job_record.config + ) elif job.type == schemas.JobType.translation: - result = strategy.translate() - logger.info(f"[{job.id}]: successfully translated audio.") + result = self.strategy.translate( + url=job_record.url, job_id=job_record.id, config=job_record.config + ) else: - result = strategy.detect_language() + result = self.strategy.detect_language( + url=job_record.url, job_id=job_record.id, config=job_record.config + ) + + logger.info(f"[{job.id}]: successfully processed audio.") artifact = models.Artifact( job_id=str(job.id), data=result, type=schemas.ArtifactType.raw_transcript @@ -66,7 +86,7 @@ def transcribe(self: Task, job_id: UUID) -> None: logger.info(f"[{job.id}]: set task to status success.") except Exception as e: if job and db: - job.meta = {**job.meta.__dict__, "error": str(e)} + job.meta = {**job.meta, "error": str(e)} # type: ignore job.status = schemas.JobStatus.error db.commit() raise (e) diff --git a/app/worker/strategies/local.py b/app/worker/strategies/local.py index 42cfbf1..09fa0fc 100644 --- a/app/worker/strategies/local.py +++ b/app/worker/strategies/local.py @@ -6,9 +6,8 @@ from typing import Any, List, Literal, Optional from uuid import UUID import requests -from pydantic import BaseModel -from sqlalchemy.orm import Session import torch +from pydantic import BaseModel from whisper import load_model import app.shared.db.schemas as schemas @@ -20,43 +19,49 @@ class DecodeOptions(BaseModel): class LocalStrategy: - def __init__( - self, db: Session, job_id: UUID, url: str, config: Optional[schemas.JobConfig] - ): - self.db = db - self.job_id = job_id - self.url = url - self.config = config - + def __init__(self) -> None: if torch.cuda.is_available(): + logger.info("initializing GPU model.") self.model = load_model( - os.environ["WHISPER_MODEL"], - download_root="/models" + os.environ["WHISPER_MODEL"], download_root="/models" ).cuda() else: + logger.info("initializing CPU model.") self.model = load_model( - os.environ["WHISPER_MODEL"], - download_root="/models" + os.environ["WHISPER_MODEL"], download_root="/models" ) - logger.info(f"[{self.job_id}]: initialized local strategy.") + logger.info("initialized local strategy.") - def transcribe(self) -> List[Any]: - return self.run_whisper(self._download(), "transcribe") + def transcribe( + self, url: str, job_id: UUID, config: Optional[schemas.JobConfig] + ) -> List[Any]: + return self.run_whisper( + self._download(url, job_id), "transcribe", config, job_id + ) - def translate(self) -> List[Any]: - return self.run_whisper(self._download(), "translate") + def translate( + self, url: str, job_id: UUID, config: Optional[schemas.JobConfig] + ) -> List[Any]: + return self.run_whisper( + self._download(url, job_id), + "translate", + config, + job_id, + ) - def detect_language(self) -> List[Any]: + def detect_language( + self, url: str, config: Optional[schemas.JobConfig] + ) -> List[Any]: raise NotImplementedError("detect_language has not been implemented yet.") - def _download(self) -> str: + def _download(self, url: str, job_id: UUID) -> str: # re-create folder. - filename = self._get_tmp_file() - self._cleanup() + filename = self._get_tmp_file(job_id) + self._cleanup(job_id) # stream media to disk. - with requests.get(self.url, stream=True) as r: + with requests.get(url, stream=True) as r: r.raise_for_status() with open(filename, "wb") as f: for chunk in r.iter_content(chunk_size=8192): @@ -64,11 +69,17 @@ class LocalStrategy: return filename - def run_whisper(self, filepath: str, task: str) -> List[Any]: + def run_whisper( + self, + filepath: str, + task: str, + config: Optional[schemas.JobConfig], + job_id: UUID, + ) -> List[Any]: try: - language = self.config.language if self.config else None + language = config.language if config else None - result = model.transcribe( + result = self.model.transcribe( filepath, condition_on_previous_text=False, **DecodeOptions(task=task, language=language).dict(), @@ -76,20 +87,14 @@ class LocalStrategy: return result["segments"] finally: - self._cleanup() + self._cleanup(job_id) - def _get_tmp_file(self) -> str: + def _get_tmp_file(self, job_id: UUID) -> str: tmp = tempfile.gettempdir() - return path.join(tmp, str(self.job_id)) + return path.join(tmp, str(job_id)) - def _cleanup(self) -> None: + def _cleanup(self, job_id: UUID) -> None: try: - os.remove(self._get_tmp_file()) + os.remove(self._get_tmp_file(job_id)) except OSError: pass - - def _convert(self) -> None: - pass - - def _transcribe(self) -> None: - pass diff --git a/docker-compose.base.yml b/docker-compose.base.yml index b9a15ff..ee43414 100644 --- a/docker-compose.base.yml +++ b/docker-compose.base.yml @@ -64,6 +64,3 @@ networks: driver: bridge traefik: driver: bridge - -volumes: - whisperbox-data: diff --git a/docker-compose.prod.yml b/docker-compose.prod.yml index 35c6e41..2b6e793 100644 --- a/docker-compose.prod.yml +++ b/docker-compose.prod.yml @@ -7,9 +7,12 @@ services: worker: container_name: whisperbox_worker env_file: .env + # + # build: + # dockerfile: worker.gpu.Dockerfile volumes: - whisperbox-data:/etc/whisperbox/data - # + # # deploy: # resources: # reservations: @@ -23,3 +26,8 @@ services: env_file: .env volumes: - whisperbox-data:/etc/whisperbox/data + labels: + - "traefik.http.routers.web.entrypoints=web" + +volumes: + whisperbox-data: diff --git a/worker.gpu.Dockerfile b/worker.gpu.Dockerfile new file mode 100644 index 0000000..ddf4764 --- /dev/null +++ b/worker.gpu.Dockerfile @@ -0,0 +1,39 @@ +# TODO: clean up lol +FROM nvidia/cuda:11.8.0-base-ubuntu22.04 AS python-deploy + +ENV PYTHON_VERSION=3.10 + +ARG WHISPER_MODEL + +WORKDIR /etc/whisperbox + +RUN export DEBIAN_FRONTEND=noninteractive \ + && apt-get -qq update \ + && apt-get -qq install --no-install-recommends \ + python${PYTHON_VERSION} \ + python${PYTHON_VERSION}-venv \ + python3-pip \ + && rm -rf /var/lib/apt/lists/* + +RUN ln -s -f /usr/bin/python${PYTHON_VERSION} /usr/bin/python3 && \ + ln -s -f /usr/bin/python${PYTHON_VERSION} /usr/bin/python && \ + ln -s -f /usr/bin/pip3 /usr/bin/pip + +COPY pyproject.toml . + +RUN python -m venv /opt/venv && \ + /opt/venv/bin/pip install -U pip wheel && \ + /opt/venv/bin/pip install -U .[worker] + +COPY --from=mwader/static-ffmpeg:latest /ffmpeg /usr/local/bin/ +COPY --from=mwader/static-ffmpeg:latest /ffprobe /usr/local/bin/ + +COPY app ./app + +ENV VIRTUAL_ENV /opt/venv +ENV PATH /opt/venv/bin:$PATH + +COPY scripts/download_models.py . +RUN python download_models.py ${WHISPER_MODEL} + +CMD celery --app=app.worker.main.celery worker --loglevel=info --concurrency=1 --pool=solo