feat: add whisper tasks

This commit is contained in:
Felix Spöttel
2023-01-28 12:30:02 +01:00
parent 8669a18110
commit a7ce71ed33
17 changed files with 209 additions and 88 deletions

View File

@@ -1,4 +1,4 @@
DATABASE_URI="postgresql://postgres:postgres@localhost:5432/whisperbox_test"
ENVIRONMENT="development"
API_SECRET="foo"
BROKER_URI="redis://localhost:6379/0"
BROKER_URL="redis://localhost:6379/0"

View File

@@ -1,4 +1,4 @@
[flake8]
max-line-length = 90
max-line-length = 88
extend-ignore = E203
exclude = .git,__pycache__,__init__.py,.mypy_cache,.pytest_cache,app/alembic/versions

9
app/shared/celery.py Normal file
View File

@@ -0,0 +1,9 @@
from celery import Celery
from app.shared.config import settings
def get_celery_binding() -> Celery:
celery = Celery("tasks")
celery.conf.broker_url = settings.BROKER_URL
return celery

View File

@@ -9,10 +9,12 @@ class Settings(BaseSettings):
ENVIRONMENT: str
# derived settings
BROKER_URI: str
BROKER_URL: str
if "pytest" in sys.modules:
settings = Settings(_env_file=".env.test", _env_file_encoding="utf-8") # type: ignore
settings = Settings(
_env_file=".env.test", _env_file_encoding="utf-8"
) # type: ignore
else:
settings = Settings()

View File

@@ -1,8 +1,8 @@
"""add_job_tables
Revision ID: bb249ed79907
Revision ID: 426b6bdc3360
Revises:
Create Date: 2023-01-17 14:30:30.920466
Create Date: 2023-01-27 17:55:21.758828
"""
import sqlalchemy as sa
@@ -10,7 +10,7 @@ from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "bb249ed79907"
revision = "426b6bdc3360"
down_revision = None
branch_labels = None
depends_on = None
@@ -23,12 +23,21 @@ def upgrade() -> None:
sa.Column("url", sa.String(length=2048), nullable=True),
sa.Column(
"status",
sa.Enum("create", "error", "processing", "success", name="jobstatus"),
sa.Enum("create", "processing", "error", "success", name="jobstatus"),
nullable=False,
),
sa.Column("type", sa.Enum("transcript", name="jobtype"), nullable=False),
sa.Column("config", sa.JSON(none_as_null=True), nullable=True),
sa.Column("meta", sa.JSON(none_as_null=True), nullable=True),
sa.Column(
"created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=False
"type",
sa.Enum("transcript", "translation", "language_detection", name="jobtype"),
nullable=False,
),
sa.Column(
"created_at",
sa.DateTime(),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("updated_at", sa.DateTime(), nullable=True),
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
@@ -40,10 +49,15 @@ def upgrade() -> None:
sa.Column("job_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("data", sa.JSON(none_as_null=True), nullable=True),
sa.Column(
"type", sa.Enum("raw_transcript", name="artifacttype"), nullable=False
"type",
sa.Enum("raw_transcript", name="artifacttype"),
nullable=False,
),
sa.Column(
"created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=False
"created_at",
sa.DateTime(),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("updated_at", sa.DateTime(), nullable=True),
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),

View File

@@ -1,27 +0,0 @@
"""add_job_meta_field
Revision ID: 684a5e546314
Revises: bb249ed79907
Create Date: 2023-01-18 13:38:07.692830
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "684a5e546314"
down_revision = "bb249ed79907"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("jobs", sa.Column("meta", sa.JSON(none_as_null=True), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("jobs", "meta")
# ### end Alembic commands ###

View File

@@ -1,6 +1,6 @@
import enum
from datetime import datetime
from typing import Any, Optional
from typing import Any, List, Optional
from uuid import UUID
from pydantic import AnyHttpUrl, BaseModel
@@ -21,6 +21,8 @@ class ArtifactType(str, enum.Enum):
class JobType(str, enum.Enum):
transcript = "transcript"
translation = "translation"
language_detection = "language_detection"
class JobStatus(str, enum.Enum):
@@ -30,8 +32,12 @@ class JobStatus(str, enum.Enum):
success = "success"
class JobMeta(BaseModel):
class JobConfig(BaseModel):
language: Optional[str]
class JobMeta(BaseModel):
error: Optional[str]
task_id: Optional[UUID]
@@ -40,10 +46,11 @@ class Job(WithDbFields):
type: JobType
url: AnyHttpUrl
meta: Optional[JobMeta]
config: Optional[JobConfig]
class Artifact(WithDbFields):
# TODO: narrow type
data: Optional[Any]
data: Optional[List[Any]]
job_id: UUID
type: ArtifactType

View File

@@ -35,6 +35,7 @@ class Job(Base, WithStandardFields):
url = Column(String(length=2048))
status = Column(Enum(JobStatus), nullable=False)
config = Column(JSON(none_as_null=True))
meta = Column(JSON(none_as_null=True))
type = Column(Enum(JobType), nullable=False)

View File

@@ -86,7 +86,7 @@ def test_get_artifact_pass(
auth_headers: Dict[str, str], db_session: Session, mock_job: models.Job
) -> None:
artifact = models.Artifact(
data={}, job_id=mock_job.id, type=dtos.ArtifactType.raw_transcript
data=[], job_id=mock_job.id, type=dtos.ArtifactType.raw_transcript
)
db_session.add(artifact)

View File

@@ -7,11 +7,12 @@ from sqlalchemy.orm import Session
import app.shared.db.dtos as dtos
import app.shared.db.models as models
from app.shared.celery import get_celery_binding
from app.shared.db.base import get_session
from app.web.security import authenticate_api_key
from app.worker.main import transcribe
app = FastAPI()
celery = get_celery_binding()
api_router = APIRouter(prefix="/api/v1", dependencies=[Depends(authenticate_api_key)])
@@ -21,19 +22,33 @@ def api_root() -> Dict:
return {}
class TranscriptPayload(BaseModel):
class PostJobPayload(BaseModel):
url: AnyHttpUrl
type: dtos.JobType
language: Optional[str]
@api_router.post("/jobs", response_model=dtos.Job, status_code=201)
def create_job(
payload: TranscriptPayload, session: Session = Depends(get_session)
payload: PostJobPayload,
session: Session = Depends(get_session),
) -> models.Job:
job = models.Job(url=payload.url, status=dtos.JobStatus.create, type=payload.type)
session.add(job)
session.flush()
# create a job with status "create" and save it to the database.
job = models.Job(
url=payload.url,
status=dtos.JobStatus.create,
type=payload.type,
config={"language": payload.language} if payload.language else None,
)
session.add(job)
session.commit()
# queue an async transcription task.
# we use a signature here to allow full separation of
# worker processes and dependencies.
transcribe = celery.signature("app.worker.main.transcribe")
# TODO: catch delivery errors.
transcribe.delay(job.id)
return job

View File

@@ -1,37 +1,55 @@
from time import sleep
from asyncio.log import logger
from uuid import UUID
from celery import Celery
from sqlalchemy.orm import Session
import app.shared.db.dtos as dtos
import app.shared.db.models as models
from app.shared.config import settings
from app.shared.celery import get_celery_binding
from app.shared.db.base import SessionLocal
from app.worker.strategies.local import LocalStrategy
celery = Celery(__name__)
celery.conf.broker_url = settings.BROKER_URI
celery = get_celery_binding()
def update_job_status(db: Session, job_id: UUID, status: dtos.JobStatus) -> None:
db.begin()
job = db.query(models.Job).filter(models.Job.id == job_id).one()
def update_job_status(db: Session, job: models.Job, status: dtos.JobStatus) -> None:
job.status = status
db.commit()
@celery.task()
def transcribe(job_id: UUID) -> int:
def transcribe(job_id: UUID) -> None:
try:
db: Session = SessionLocal()
update_job_status(db, job_id, dtos.JobStatus.processing)
sleep(60)
update_job_status(db, job_id, dtos.JobStatus.success)
job = db.query(models.Job).filter(models.Job.id == job_id).one()
update_job_status(db, job, dtos.JobStatus.processing)
# pick a transcription strategy.
# currently only `local` is supported.
job_record = dtos.Job.from_orm(job)
strategy = LocalStrategy(
db=db, job_id=job.id, url=job_record.url, config=job_record.config
)
# process selected task.
# currently only `transcribe` is supported.
if job.type == dtos.JobType.transcript:
result = strategy.transcribe()
elif job.type == dtos.JobType.translation:
result = strategy.translate()
else:
result = strategy.detect_language()
artifact = models.Artifact(
job_id=job.id, data=result, type=dtos.ArtifactType.raw_transcript
)
db.add(artifact)
db.commit()
except Exception:
update_job_status(db, job_id, dtos.JobStatus.error)
update_job_status(db, job, dtos.JobStatus.success)
except Exception as e:
logger.error(e)
update_job_status(db, job, dtos.JobStatus.error)
finally:
db.close()
return 0

View File

@@ -0,0 +1,81 @@
import os
import shutil
import tempfile
from os import path
from typing import Any, List, Literal, Optional
from uuid import UUID
import requests
from pydantic import BaseModel
from sqlalchemy.orm import Session
from whisper import load_model
import app.shared.db.dtos as dtos
class DecodeOptions(BaseModel):
language: Optional[str]
task: Literal["translate", "transcribe"]
class LocalStrategy:
def __init__(
self, db: Session, job_id: UUID, url: str, config: Optional[dtos.JobConfig]
):
self.db = db
self.job_id = job_id
self.url = url
self.config = config
def transcribe(self) -> List[Any]:
result = self.run_whisper(self._download(), "transcribe")
self._cleanup()
return result
def translate(self) -> List[Any]:
result = self.run_whisper(self._download(), "translate")
self._cleanup()
return result
def detect_language(self) -> List[Any]:
raise NotImplementedError("detect_language has not been implemented yet.")
def _download(self) -> str:
dirname = self._get_tmp_dir()
filename = path.join(dirname, "media.mp3")
# re-create folder.
shutil.rmtree(dirname, ignore_errors=True)
os.makedirs(dirname)
# stream media to disk.
with requests.get(self.url, stream=True) as r:
r.raise_for_status()
with open(filename, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
return filename
def run_whisper(self, filepath: str, task: str) -> List[Any]:
language = self.config.language if self.config else None
decode_opts = DecodeOptions(task=task, language=language)
model = load_model("small", download_root="/models")
result = model.transcribe(
filepath, condition_on_previous_text=False, **decode_opts.dict()
)
return result["segments"]
def _get_tmp_dir(self) -> str:
return path.join(tempfile.gettempdir(), str(self.job_id))
def _cleanup(self) -> None:
shutil.rmtree(self._get_tmp_dir(), ignore_errors=True)
def _convert(self) -> None:
pass
def _transcribe(self) -> None:
pass

View File

@@ -1,17 +1,12 @@
FROM python:3.10 AS compile-image
COPY pyproject.toml .
RUN --mount=type=cache,target=/root/.cache \
pip install --user .[web]
FROM python:3.10 AS build-image
WORKDIR /code
COPY pyproject.toml .
RUN pip install --no-cache-dir --user .[web]
ENV PYTHONDONTWRITEBYTECODE 1
ENV PYTHONUNBUFFERED 1
COPY --from=compile-image /root/.local /root/.local
ENV PATH=/root/.local/bin:$PATH
ENTRYPOINT ["bash", "./app/web/start.sh"]

View File

@@ -4,7 +4,7 @@ x-app-variables: &app-variables
API_SECRET: a_very_secret_token
DATABASE_URI: postgresql://postgres:postgres@postgres/whisperbox
ENVIRONMENT: development
BROKER_URI: redis://redis:6379/0
BROKER_URL: redis://redis:6379/0
services:
postgres:

View File

@@ -1,20 +1,18 @@
FROM python:3.10 AS compile-image
RUN --mount=type=cache,target=/var/cache/apt \
apt-get update && apt-get install -y --no-install-recommends ffmpeg
COPY pyproject.toml .
RUN --mount=type=cache,target=/root/.cache \
pip install --user .[worker,worker_dev]
FROM python:3.10 AS build-image
WORKDIR /code
RUN apt-get update && apt-get install -y --no-install-recommends ffmpeg
COPY pyproject.toml .
RUN pip install --no-cache-dir --user .[worker,worker_dev]
COPY scripts/download_model.py .
RUN chmod +x download_model.py && python download_model.py small small.en
ENV PYTHONIOENCODING=utf-8
ENV PYTHONDONTWRITEBYTECODE 1
ENV PYTHONUNBUFFERED 1
COPY --from=compile-image /root/.local /root/.local
ENV PATH=/root/.local/bin:$PATH
ENTRYPOINT ["watchmedo", "auto-restart", "-d" , "app/worker", "-p", "*.py", "celery", "--", "--app=app.worker.main.celery", "worker", "--loglevel=info"]
ENTRYPOINT ["watchmedo", "auto-restart", "-d" , "app/worker", "-p", "*.py", "--recursive", "celery", "--", "--app=app.worker.main.celery", "worker", "--loglevel=info", "--concurrency=1"]

View File

@@ -18,7 +18,8 @@ web=[
]
worker=[
"whisper-openai ==1.0.0"
"whisper-openai ==1.0.0",
"requests ==2.28.2"
]
lint = [

View File

@@ -0,0 +1,7 @@
import sys
from whisper import _download, _MODELS # type: ignore
if __name__ == "__main__":
args = sys.argv[1:]
for name in args:
_download(_MODELS[name], "/models/", False)