From a7ce71ed336708098a5079e0e12e6e40c6bdb8d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20Sp=C3=B6ttel?= <1682504+fspoettel@users.noreply.github.com> Date: Sat, 28 Jan 2023 12:30:02 +0100 Subject: [PATCH] feat: add whisper tasks --- .env.test | 2 +- .flake8 | 2 +- app/shared/celery.py | 9 +++ app/shared/config.py | 6 +- ...bles.py => 426b6bdc3360_add_job_tables.py} | 30 +++++-- .../684a5e546314_add_job_meta_field.py | 27 ------- app/shared/db/dtos.py | 13 ++- app/shared/db/models.py | 1 + app/tests/test_api.py | 2 +- app/web/main.py | 27 +++++-- app/worker/main.py | 52 ++++++++---- app/worker/strategies/local.py | 81 +++++++++++++++++++ docker/app.dev.Dockerfile | 11 +-- docker/dev.docker-compose.yml | 2 +- docker/worker.dev.Dockerfile | 22 +++-- pyproject.toml | 3 +- scripts/download_model.py | 7 ++ 17 files changed, 209 insertions(+), 88 deletions(-) create mode 100644 app/shared/celery.py rename app/shared/db/alembic/versions/{bb249ed79907_add_job_tables.py => 426b6bdc3360_add_job_tables.py} (67%) delete mode 100644 app/shared/db/alembic/versions/684a5e546314_add_job_meta_field.py create mode 100644 app/worker/strategies/local.py create mode 100644 scripts/download_model.py diff --git a/.env.test b/.env.test index b166cd1..08a3511 100644 --- a/.env.test +++ b/.env.test @@ -1,4 +1,4 @@ DATABASE_URI="postgresql://postgres:postgres@localhost:5432/whisperbox_test" ENVIRONMENT="development" API_SECRET="foo" -BROKER_URI="redis://localhost:6379/0" +BROKER_URL="redis://localhost:6379/0" diff --git a/.flake8 b/.flake8 index 5579ce6..c7c7565 100644 --- a/.flake8 +++ b/.flake8 @@ -1,4 +1,4 @@ [flake8] -max-line-length = 90 +max-line-length = 88 extend-ignore = E203 exclude = .git,__pycache__,__init__.py,.mypy_cache,.pytest_cache,app/alembic/versions diff --git a/app/shared/celery.py b/app/shared/celery.py new file mode 100644 index 0000000..415f3dc --- /dev/null +++ b/app/shared/celery.py @@ -0,0 +1,9 @@ +from celery import Celery + +from app.shared.config import settings + + +def get_celery_binding() -> Celery: + celery = Celery("tasks") + celery.conf.broker_url = settings.BROKER_URL + return celery diff --git a/app/shared/config.py b/app/shared/config.py index 214d6ac..70300ea 100644 --- a/app/shared/config.py +++ b/app/shared/config.py @@ -9,10 +9,12 @@ class Settings(BaseSettings): ENVIRONMENT: str # derived settings - BROKER_URI: str + BROKER_URL: str if "pytest" in sys.modules: - settings = Settings(_env_file=".env.test", _env_file_encoding="utf-8") # type: ignore + settings = Settings( + _env_file=".env.test", _env_file_encoding="utf-8" + ) # type: ignore else: settings = Settings() diff --git a/app/shared/db/alembic/versions/bb249ed79907_add_job_tables.py b/app/shared/db/alembic/versions/426b6bdc3360_add_job_tables.py similarity index 67% rename from app/shared/db/alembic/versions/bb249ed79907_add_job_tables.py rename to app/shared/db/alembic/versions/426b6bdc3360_add_job_tables.py index 7fb73e3..0c31a99 100644 --- a/app/shared/db/alembic/versions/bb249ed79907_add_job_tables.py +++ b/app/shared/db/alembic/versions/426b6bdc3360_add_job_tables.py @@ -1,8 +1,8 @@ """add_job_tables -Revision ID: bb249ed79907 +Revision ID: 426b6bdc3360 Revises: -Create Date: 2023-01-17 14:30:30.920466 +Create Date: 2023-01-27 17:55:21.758828 """ import sqlalchemy as sa @@ -10,7 +10,7 @@ from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. -revision = "bb249ed79907" +revision = "426b6bdc3360" down_revision = None branch_labels = None depends_on = None @@ -23,12 +23,21 @@ def upgrade() -> None: sa.Column("url", sa.String(length=2048), nullable=True), sa.Column( "status", - sa.Enum("create", "error", "processing", "success", name="jobstatus"), + sa.Enum("create", "processing", "error", "success", name="jobstatus"), nullable=False, ), - sa.Column("type", sa.Enum("transcript", name="jobtype"), nullable=False), + sa.Column("config", sa.JSON(none_as_null=True), nullable=True), + sa.Column("meta", sa.JSON(none_as_null=True), nullable=True), sa.Column( - "created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=False + "type", + sa.Enum("transcript", "translation", "language_detection", name="jobtype"), + nullable=False, + ), + sa.Column( + "created_at", + sa.DateTime(), + server_default=sa.text("now()"), + nullable=False, ), sa.Column("updated_at", sa.DateTime(), nullable=True), sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False), @@ -40,10 +49,15 @@ def upgrade() -> None: sa.Column("job_id", postgresql.UUID(as_uuid=True), nullable=False), sa.Column("data", sa.JSON(none_as_null=True), nullable=True), sa.Column( - "type", sa.Enum("raw_transcript", name="artifacttype"), nullable=False + "type", + sa.Enum("raw_transcript", name="artifacttype"), + nullable=False, ), sa.Column( - "created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=False + "created_at", + sa.DateTime(), + server_default=sa.text("now()"), + nullable=False, ), sa.Column("updated_at", sa.DateTime(), nullable=True), sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False), diff --git a/app/shared/db/alembic/versions/684a5e546314_add_job_meta_field.py b/app/shared/db/alembic/versions/684a5e546314_add_job_meta_field.py deleted file mode 100644 index 4337f57..0000000 --- a/app/shared/db/alembic/versions/684a5e546314_add_job_meta_field.py +++ /dev/null @@ -1,27 +0,0 @@ -"""add_job_meta_field - -Revision ID: 684a5e546314 -Revises: bb249ed79907 -Create Date: 2023-01-18 13:38:07.692830 - -""" -import sqlalchemy as sa -from alembic import op - -# revision identifiers, used by Alembic. -revision = "684a5e546314" -down_revision = "bb249ed79907" -branch_labels = None -depends_on = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.add_column("jobs", sa.Column("meta", sa.JSON(none_as_null=True), nullable=True)) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_column("jobs", "meta") - # ### end Alembic commands ### diff --git a/app/shared/db/dtos.py b/app/shared/db/dtos.py index 0213d92..8bd5197 100644 --- a/app/shared/db/dtos.py +++ b/app/shared/db/dtos.py @@ -1,6 +1,6 @@ import enum from datetime import datetime -from typing import Any, Optional +from typing import Any, List, Optional from uuid import UUID from pydantic import AnyHttpUrl, BaseModel @@ -21,6 +21,8 @@ class ArtifactType(str, enum.Enum): class JobType(str, enum.Enum): transcript = "transcript" + translation = "translation" + language_detection = "language_detection" class JobStatus(str, enum.Enum): @@ -30,8 +32,12 @@ class JobStatus(str, enum.Enum): success = "success" -class JobMeta(BaseModel): +class JobConfig(BaseModel): language: Optional[str] + + +class JobMeta(BaseModel): + error: Optional[str] task_id: Optional[UUID] @@ -40,10 +46,11 @@ class Job(WithDbFields): type: JobType url: AnyHttpUrl meta: Optional[JobMeta] + config: Optional[JobConfig] class Artifact(WithDbFields): # TODO: narrow type - data: Optional[Any] + data: Optional[List[Any]] job_id: UUID type: ArtifactType diff --git a/app/shared/db/models.py b/app/shared/db/models.py index 1ac52c6..166a579 100644 --- a/app/shared/db/models.py +++ b/app/shared/db/models.py @@ -35,6 +35,7 @@ class Job(Base, WithStandardFields): url = Column(String(length=2048)) status = Column(Enum(JobStatus), nullable=False) + config = Column(JSON(none_as_null=True)) meta = Column(JSON(none_as_null=True)) type = Column(Enum(JobType), nullable=False) diff --git a/app/tests/test_api.py b/app/tests/test_api.py index 8d2d93e..8bdc560 100644 --- a/app/tests/test_api.py +++ b/app/tests/test_api.py @@ -86,7 +86,7 @@ def test_get_artifact_pass( auth_headers: Dict[str, str], db_session: Session, mock_job: models.Job ) -> None: artifact = models.Artifact( - data={}, job_id=mock_job.id, type=dtos.ArtifactType.raw_transcript + data=[], job_id=mock_job.id, type=dtos.ArtifactType.raw_transcript ) db_session.add(artifact) diff --git a/app/web/main.py b/app/web/main.py index 6403fc8..11b563a 100644 --- a/app/web/main.py +++ b/app/web/main.py @@ -7,11 +7,12 @@ from sqlalchemy.orm import Session import app.shared.db.dtos as dtos import app.shared.db.models as models +from app.shared.celery import get_celery_binding from app.shared.db.base import get_session from app.web.security import authenticate_api_key -from app.worker.main import transcribe app = FastAPI() +celery = get_celery_binding() api_router = APIRouter(prefix="/api/v1", dependencies=[Depends(authenticate_api_key)]) @@ -21,19 +22,33 @@ def api_root() -> Dict: return {} -class TranscriptPayload(BaseModel): +class PostJobPayload(BaseModel): url: AnyHttpUrl type: dtos.JobType + language: Optional[str] @api_router.post("/jobs", response_model=dtos.Job, status_code=201) def create_job( - payload: TranscriptPayload, session: Session = Depends(get_session) + payload: PostJobPayload, + session: Session = Depends(get_session), ) -> models.Job: - job = models.Job(url=payload.url, status=dtos.JobStatus.create, type=payload.type) - session.add(job) - session.flush() + # create a job with status "create" and save it to the database. + job = models.Job( + url=payload.url, + status=dtos.JobStatus.create, + type=payload.type, + config={"language": payload.language} if payload.language else None, + ) + session.add(job) + session.commit() + + # queue an async transcription task. + # we use a signature here to allow full separation of + # worker processes and dependencies. + transcribe = celery.signature("app.worker.main.transcribe") + # TODO: catch delivery errors. transcribe.delay(job.id) return job diff --git a/app/worker/main.py b/app/worker/main.py index 7ce4791..be70bdb 100644 --- a/app/worker/main.py +++ b/app/worker/main.py @@ -1,37 +1,55 @@ -from time import sleep +from asyncio.log import logger from uuid import UUID -from celery import Celery from sqlalchemy.orm import Session import app.shared.db.dtos as dtos import app.shared.db.models as models -from app.shared.config import settings +from app.shared.celery import get_celery_binding from app.shared.db.base import SessionLocal +from app.worker.strategies.local import LocalStrategy -celery = Celery(__name__) - -celery.conf.broker_url = settings.BROKER_URI +celery = get_celery_binding() -def update_job_status(db: Session, job_id: UUID, status: dtos.JobStatus) -> None: - db.begin() - job = db.query(models.Job).filter(models.Job.id == job_id).one() +def update_job_status(db: Session, job: models.Job, status: dtos.JobStatus) -> None: job.status = status db.commit() @celery.task() -def transcribe(job_id: UUID) -> int: +def transcribe(job_id: UUID) -> None: try: db: Session = SessionLocal() - update_job_status(db, job_id, dtos.JobStatus.processing) - sleep(60) - update_job_status(db, job_id, dtos.JobStatus.success) + job = db.query(models.Job).filter(models.Job.id == job_id).one() + + update_job_status(db, job, dtos.JobStatus.processing) + + # pick a transcription strategy. + # currently only `local` is supported. + job_record = dtos.Job.from_orm(job) + strategy = LocalStrategy( + db=db, job_id=job.id, url=job_record.url, config=job_record.config + ) + + # process selected task. + # currently only `transcribe` is supported. + if job.type == dtos.JobType.transcript: + result = strategy.transcribe() + elif job.type == dtos.JobType.translation: + result = strategy.translate() + else: + result = strategy.detect_language() + + artifact = models.Artifact( + job_id=job.id, data=result, type=dtos.ArtifactType.raw_transcript + ) + db.add(artifact) db.commit() - except Exception: - update_job_status(db, job_id, dtos.JobStatus.error) + + update_job_status(db, job, dtos.JobStatus.success) + except Exception as e: + logger.error(e) + update_job_status(db, job, dtos.JobStatus.error) finally: db.close() - - return 0 diff --git a/app/worker/strategies/local.py b/app/worker/strategies/local.py new file mode 100644 index 0000000..21b9c0b --- /dev/null +++ b/app/worker/strategies/local.py @@ -0,0 +1,81 @@ +import os +import shutil +import tempfile +from os import path +from typing import Any, List, Literal, Optional +from uuid import UUID + +import requests +from pydantic import BaseModel +from sqlalchemy.orm import Session +from whisper import load_model + +import app.shared.db.dtos as dtos + + +class DecodeOptions(BaseModel): + language: Optional[str] + task: Literal["translate", "transcribe"] + + +class LocalStrategy: + def __init__( + self, db: Session, job_id: UUID, url: str, config: Optional[dtos.JobConfig] + ): + self.db = db + self.job_id = job_id + self.url = url + self.config = config + + def transcribe(self) -> List[Any]: + result = self.run_whisper(self._download(), "transcribe") + self._cleanup() + return result + + def translate(self) -> List[Any]: + result = self.run_whisper(self._download(), "translate") + self._cleanup() + return result + + def detect_language(self) -> List[Any]: + raise NotImplementedError("detect_language has not been implemented yet.") + + def _download(self) -> str: + dirname = self._get_tmp_dir() + filename = path.join(dirname, "media.mp3") + + # re-create folder. + shutil.rmtree(dirname, ignore_errors=True) + os.makedirs(dirname) + + # stream media to disk. + with requests.get(self.url, stream=True) as r: + r.raise_for_status() + with open(filename, "wb") as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) + + return filename + + def run_whisper(self, filepath: str, task: str) -> List[Any]: + language = self.config.language if self.config else None + decode_opts = DecodeOptions(task=task, language=language) + model = load_model("small", download_root="/models") + + result = model.transcribe( + filepath, condition_on_previous_text=False, **decode_opts.dict() + ) + + return result["segments"] + + def _get_tmp_dir(self) -> str: + return path.join(tempfile.gettempdir(), str(self.job_id)) + + def _cleanup(self) -> None: + shutil.rmtree(self._get_tmp_dir(), ignore_errors=True) + + def _convert(self) -> None: + pass + + def _transcribe(self) -> None: + pass diff --git a/docker/app.dev.Dockerfile b/docker/app.dev.Dockerfile index 11c2f03..f2ed19b 100644 --- a/docker/app.dev.Dockerfile +++ b/docker/app.dev.Dockerfile @@ -1,17 +1,12 @@ FROM python:3.10 AS compile-image -COPY pyproject.toml . -RUN --mount=type=cache,target=/root/.cache \ - pip install --user .[web] - -FROM python:3.10 AS build-image - WORKDIR /code +COPY pyproject.toml . +RUN pip install --no-cache-dir --user .[web] + ENV PYTHONDONTWRITEBYTECODE 1 ENV PYTHONUNBUFFERED 1 - -COPY --from=compile-image /root/.local /root/.local ENV PATH=/root/.local/bin:$PATH ENTRYPOINT ["bash", "./app/web/start.sh"] diff --git a/docker/dev.docker-compose.yml b/docker/dev.docker-compose.yml index c65f7af..5aeadfd 100644 --- a/docker/dev.docker-compose.yml +++ b/docker/dev.docker-compose.yml @@ -4,7 +4,7 @@ x-app-variables: &app-variables API_SECRET: a_very_secret_token DATABASE_URI: postgresql://postgres:postgres@postgres/whisperbox ENVIRONMENT: development - BROKER_URI: redis://redis:6379/0 + BROKER_URL: redis://redis:6379/0 services: postgres: diff --git a/docker/worker.dev.Dockerfile b/docker/worker.dev.Dockerfile index 4d31289..9af2b5d 100644 --- a/docker/worker.dev.Dockerfile +++ b/docker/worker.dev.Dockerfile @@ -1,20 +1,18 @@ FROM python:3.10 AS compile-image -RUN --mount=type=cache,target=/var/cache/apt \ - apt-get update && apt-get install -y --no-install-recommends ffmpeg - -COPY pyproject.toml . -RUN --mount=type=cache,target=/root/.cache \ - pip install --user .[worker,worker_dev] - -FROM python:3.10 AS build-image - WORKDIR /code +RUN apt-get update && apt-get install -y --no-install-recommends ffmpeg + +COPY pyproject.toml . +RUN pip install --no-cache-dir --user .[worker,worker_dev] + +COPY scripts/download_model.py . +RUN chmod +x download_model.py && python download_model.py small small.en + +ENV PYTHONIOENCODING=utf-8 ENV PYTHONDONTWRITEBYTECODE 1 ENV PYTHONUNBUFFERED 1 - -COPY --from=compile-image /root/.local /root/.local ENV PATH=/root/.local/bin:$PATH -ENTRYPOINT ["watchmedo", "auto-restart", "-d" , "app/worker", "-p", "*.py", "celery", "--", "--app=app.worker.main.celery", "worker", "--loglevel=info"] +ENTRYPOINT ["watchmedo", "auto-restart", "-d" , "app/worker", "-p", "*.py", "--recursive", "celery", "--", "--app=app.worker.main.celery", "worker", "--loglevel=info", "--concurrency=1"] diff --git a/pyproject.toml b/pyproject.toml index 362a854..8263358 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,8 @@ web=[ ] worker=[ - "whisper-openai ==1.0.0" + "whisper-openai ==1.0.0", + "requests ==2.28.2" ] lint = [ diff --git a/scripts/download_model.py b/scripts/download_model.py new file mode 100644 index 0000000..8ba6d13 --- /dev/null +++ b/scripts/download_model.py @@ -0,0 +1,7 @@ +import sys +from whisper import _download, _MODELS # type: ignore + +if __name__ == "__main__": + args = sys.argv[1:] + for name in args: + _download(_MODELS[name], "/models/", False)