mirror of
https://github.com/bellingcat/whisperbox-transcribe.git
synced 2026-06-12 05:28:34 +03:00
feat: add whisper tasks
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
DATABASE_URI="postgresql://postgres:postgres@localhost:5432/whisperbox_test"
|
||||
ENVIRONMENT="development"
|
||||
API_SECRET="foo"
|
||||
BROKER_URI="redis://localhost:6379/0"
|
||||
BROKER_URL="redis://localhost:6379/0"
|
||||
|
||||
2
.flake8
2
.flake8
@@ -1,4 +1,4 @@
|
||||
[flake8]
|
||||
max-line-length = 90
|
||||
max-line-length = 88
|
||||
extend-ignore = E203
|
||||
exclude = .git,__pycache__,__init__.py,.mypy_cache,.pytest_cache,app/alembic/versions
|
||||
|
||||
9
app/shared/celery.py
Normal file
9
app/shared/celery.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from celery import Celery
|
||||
|
||||
from app.shared.config import settings
|
||||
|
||||
|
||||
def get_celery_binding() -> Celery:
|
||||
celery = Celery("tasks")
|
||||
celery.conf.broker_url = settings.BROKER_URL
|
||||
return celery
|
||||
@@ -9,10 +9,12 @@ class Settings(BaseSettings):
|
||||
ENVIRONMENT: str
|
||||
|
||||
# derived settings
|
||||
BROKER_URI: str
|
||||
BROKER_URL: str
|
||||
|
||||
|
||||
if "pytest" in sys.modules:
|
||||
settings = Settings(_env_file=".env.test", _env_file_encoding="utf-8") # type: ignore
|
||||
settings = Settings(
|
||||
_env_file=".env.test", _env_file_encoding="utf-8"
|
||||
) # type: ignore
|
||||
else:
|
||||
settings = Settings()
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""add_job_tables
|
||||
|
||||
Revision ID: bb249ed79907
|
||||
Revision ID: 426b6bdc3360
|
||||
Revises:
|
||||
Create Date: 2023-01-17 14:30:30.920466
|
||||
Create Date: 2023-01-27 17:55:21.758828
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
@@ -10,7 +10,7 @@ from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "bb249ed79907"
|
||||
revision = "426b6bdc3360"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
@@ -23,12 +23,21 @@ def upgrade() -> None:
|
||||
sa.Column("url", sa.String(length=2048), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.Enum("create", "error", "processing", "success", name="jobstatus"),
|
||||
sa.Enum("create", "processing", "error", "success", name="jobstatus"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("type", sa.Enum("transcript", name="jobtype"), nullable=False),
|
||||
sa.Column("config", sa.JSON(none_as_null=True), nullable=True),
|
||||
sa.Column("meta", sa.JSON(none_as_null=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=False
|
||||
"type",
|
||||
sa.Enum("transcript", "translation", "language_detection", name="jobtype"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("updated_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
@@ -40,10 +49,15 @@ def upgrade() -> None:
|
||||
sa.Column("job_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("data", sa.JSON(none_as_null=True), nullable=True),
|
||||
sa.Column(
|
||||
"type", sa.Enum("raw_transcript", name="artifacttype"), nullable=False
|
||||
"type",
|
||||
sa.Enum("raw_transcript", name="artifacttype"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=False
|
||||
"created_at",
|
||||
sa.DateTime(),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("updated_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
@@ -1,27 +0,0 @@
|
||||
"""add_job_meta_field
|
||||
|
||||
Revision ID: 684a5e546314
|
||||
Revises: bb249ed79907
|
||||
Create Date: 2023-01-18 13:38:07.692830
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "684a5e546314"
|
||||
down_revision = "bb249ed79907"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("jobs", sa.Column("meta", sa.JSON(none_as_null=True), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("jobs", "meta")
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,6 +1,6 @@
|
||||
import enum
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from typing import Any, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import AnyHttpUrl, BaseModel
|
||||
@@ -21,6 +21,8 @@ class ArtifactType(str, enum.Enum):
|
||||
|
||||
class JobType(str, enum.Enum):
|
||||
transcript = "transcript"
|
||||
translation = "translation"
|
||||
language_detection = "language_detection"
|
||||
|
||||
|
||||
class JobStatus(str, enum.Enum):
|
||||
@@ -30,8 +32,12 @@ class JobStatus(str, enum.Enum):
|
||||
success = "success"
|
||||
|
||||
|
||||
class JobMeta(BaseModel):
|
||||
class JobConfig(BaseModel):
|
||||
language: Optional[str]
|
||||
|
||||
|
||||
class JobMeta(BaseModel):
|
||||
error: Optional[str]
|
||||
task_id: Optional[UUID]
|
||||
|
||||
|
||||
@@ -40,10 +46,11 @@ class Job(WithDbFields):
|
||||
type: JobType
|
||||
url: AnyHttpUrl
|
||||
meta: Optional[JobMeta]
|
||||
config: Optional[JobConfig]
|
||||
|
||||
|
||||
class Artifact(WithDbFields):
|
||||
# TODO: narrow type
|
||||
data: Optional[Any]
|
||||
data: Optional[List[Any]]
|
||||
job_id: UUID
|
||||
type: ArtifactType
|
||||
|
||||
@@ -35,6 +35,7 @@ class Job(Base, WithStandardFields):
|
||||
|
||||
url = Column(String(length=2048))
|
||||
status = Column(Enum(JobStatus), nullable=False)
|
||||
config = Column(JSON(none_as_null=True))
|
||||
meta = Column(JSON(none_as_null=True))
|
||||
type = Column(Enum(JobType), nullable=False)
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ def test_get_artifact_pass(
|
||||
auth_headers: Dict[str, str], db_session: Session, mock_job: models.Job
|
||||
) -> None:
|
||||
artifact = models.Artifact(
|
||||
data={}, job_id=mock_job.id, type=dtos.ArtifactType.raw_transcript
|
||||
data=[], job_id=mock_job.id, type=dtos.ArtifactType.raw_transcript
|
||||
)
|
||||
|
||||
db_session.add(artifact)
|
||||
|
||||
@@ -7,11 +7,12 @@ from sqlalchemy.orm import Session
|
||||
|
||||
import app.shared.db.dtos as dtos
|
||||
import app.shared.db.models as models
|
||||
from app.shared.celery import get_celery_binding
|
||||
from app.shared.db.base import get_session
|
||||
from app.web.security import authenticate_api_key
|
||||
from app.worker.main import transcribe
|
||||
|
||||
app = FastAPI()
|
||||
celery = get_celery_binding()
|
||||
|
||||
api_router = APIRouter(prefix="/api/v1", dependencies=[Depends(authenticate_api_key)])
|
||||
|
||||
@@ -21,19 +22,33 @@ def api_root() -> Dict:
|
||||
return {}
|
||||
|
||||
|
||||
class TranscriptPayload(BaseModel):
|
||||
class PostJobPayload(BaseModel):
|
||||
url: AnyHttpUrl
|
||||
type: dtos.JobType
|
||||
language: Optional[str]
|
||||
|
||||
|
||||
@api_router.post("/jobs", response_model=dtos.Job, status_code=201)
|
||||
def create_job(
|
||||
payload: TranscriptPayload, session: Session = Depends(get_session)
|
||||
payload: PostJobPayload,
|
||||
session: Session = Depends(get_session),
|
||||
) -> models.Job:
|
||||
job = models.Job(url=payload.url, status=dtos.JobStatus.create, type=payload.type)
|
||||
session.add(job)
|
||||
session.flush()
|
||||
# create a job with status "create" and save it to the database.
|
||||
job = models.Job(
|
||||
url=payload.url,
|
||||
status=dtos.JobStatus.create,
|
||||
type=payload.type,
|
||||
config={"language": payload.language} if payload.language else None,
|
||||
)
|
||||
|
||||
session.add(job)
|
||||
session.commit()
|
||||
|
||||
# queue an async transcription task.
|
||||
# we use a signature here to allow full separation of
|
||||
# worker processes and dependencies.
|
||||
transcribe = celery.signature("app.worker.main.transcribe")
|
||||
# TODO: catch delivery errors.
|
||||
transcribe.delay(job.id)
|
||||
|
||||
return job
|
||||
|
||||
@@ -1,37 +1,55 @@
|
||||
from time import sleep
|
||||
from asyncio.log import logger
|
||||
from uuid import UUID
|
||||
|
||||
from celery import Celery
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import app.shared.db.dtos as dtos
|
||||
import app.shared.db.models as models
|
||||
from app.shared.config import settings
|
||||
from app.shared.celery import get_celery_binding
|
||||
from app.shared.db.base import SessionLocal
|
||||
from app.worker.strategies.local import LocalStrategy
|
||||
|
||||
celery = Celery(__name__)
|
||||
|
||||
celery.conf.broker_url = settings.BROKER_URI
|
||||
celery = get_celery_binding()
|
||||
|
||||
|
||||
def update_job_status(db: Session, job_id: UUID, status: dtos.JobStatus) -> None:
|
||||
db.begin()
|
||||
job = db.query(models.Job).filter(models.Job.id == job_id).one()
|
||||
def update_job_status(db: Session, job: models.Job, status: dtos.JobStatus) -> None:
|
||||
job.status = status
|
||||
db.commit()
|
||||
|
||||
|
||||
@celery.task()
|
||||
def transcribe(job_id: UUID) -> int:
|
||||
def transcribe(job_id: UUID) -> None:
|
||||
try:
|
||||
db: Session = SessionLocal()
|
||||
update_job_status(db, job_id, dtos.JobStatus.processing)
|
||||
sleep(60)
|
||||
update_job_status(db, job_id, dtos.JobStatus.success)
|
||||
job = db.query(models.Job).filter(models.Job.id == job_id).one()
|
||||
|
||||
update_job_status(db, job, dtos.JobStatus.processing)
|
||||
|
||||
# pick a transcription strategy.
|
||||
# currently only `local` is supported.
|
||||
job_record = dtos.Job.from_orm(job)
|
||||
strategy = LocalStrategy(
|
||||
db=db, job_id=job.id, url=job_record.url, config=job_record.config
|
||||
)
|
||||
|
||||
# process selected task.
|
||||
# currently only `transcribe` is supported.
|
||||
if job.type == dtos.JobType.transcript:
|
||||
result = strategy.transcribe()
|
||||
elif job.type == dtos.JobType.translation:
|
||||
result = strategy.translate()
|
||||
else:
|
||||
result = strategy.detect_language()
|
||||
|
||||
artifact = models.Artifact(
|
||||
job_id=job.id, data=result, type=dtos.ArtifactType.raw_transcript
|
||||
)
|
||||
db.add(artifact)
|
||||
db.commit()
|
||||
except Exception:
|
||||
update_job_status(db, job_id, dtos.JobStatus.error)
|
||||
|
||||
update_job_status(db, job, dtos.JobStatus.success)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
update_job_status(db, job, dtos.JobStatus.error)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return 0
|
||||
|
||||
81
app/worker/strategies/local.py
Normal file
81
app/worker/strategies/local.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from os import path
|
||||
from typing import Any, List, Literal, Optional
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from whisper import load_model
|
||||
|
||||
import app.shared.db.dtos as dtos
|
||||
|
||||
|
||||
class DecodeOptions(BaseModel):
|
||||
language: Optional[str]
|
||||
task: Literal["translate", "transcribe"]
|
||||
|
||||
|
||||
class LocalStrategy:
|
||||
def __init__(
|
||||
self, db: Session, job_id: UUID, url: str, config: Optional[dtos.JobConfig]
|
||||
):
|
||||
self.db = db
|
||||
self.job_id = job_id
|
||||
self.url = url
|
||||
self.config = config
|
||||
|
||||
def transcribe(self) -> List[Any]:
|
||||
result = self.run_whisper(self._download(), "transcribe")
|
||||
self._cleanup()
|
||||
return result
|
||||
|
||||
def translate(self) -> List[Any]:
|
||||
result = self.run_whisper(self._download(), "translate")
|
||||
self._cleanup()
|
||||
return result
|
||||
|
||||
def detect_language(self) -> List[Any]:
|
||||
raise NotImplementedError("detect_language has not been implemented yet.")
|
||||
|
||||
def _download(self) -> str:
|
||||
dirname = self._get_tmp_dir()
|
||||
filename = path.join(dirname, "media.mp3")
|
||||
|
||||
# re-create folder.
|
||||
shutil.rmtree(dirname, ignore_errors=True)
|
||||
os.makedirs(dirname)
|
||||
|
||||
# stream media to disk.
|
||||
with requests.get(self.url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(filename, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
return filename
|
||||
|
||||
def run_whisper(self, filepath: str, task: str) -> List[Any]:
|
||||
language = self.config.language if self.config else None
|
||||
decode_opts = DecodeOptions(task=task, language=language)
|
||||
model = load_model("small", download_root="/models")
|
||||
|
||||
result = model.transcribe(
|
||||
filepath, condition_on_previous_text=False, **decode_opts.dict()
|
||||
)
|
||||
|
||||
return result["segments"]
|
||||
|
||||
def _get_tmp_dir(self) -> str:
|
||||
return path.join(tempfile.gettempdir(), str(self.job_id))
|
||||
|
||||
def _cleanup(self) -> None:
|
||||
shutil.rmtree(self._get_tmp_dir(), ignore_errors=True)
|
||||
|
||||
def _convert(self) -> None:
|
||||
pass
|
||||
|
||||
def _transcribe(self) -> None:
|
||||
pass
|
||||
@@ -1,17 +1,12 @@
|
||||
FROM python:3.10 AS compile-image
|
||||
|
||||
COPY pyproject.toml .
|
||||
RUN --mount=type=cache,target=/root/.cache \
|
||||
pip install --user .[web]
|
||||
|
||||
FROM python:3.10 AS build-image
|
||||
|
||||
WORKDIR /code
|
||||
|
||||
COPY pyproject.toml .
|
||||
RUN pip install --no-cache-dir --user .[web]
|
||||
|
||||
ENV PYTHONDONTWRITEBYTECODE 1
|
||||
ENV PYTHONUNBUFFERED 1
|
||||
|
||||
COPY --from=compile-image /root/.local /root/.local
|
||||
ENV PATH=/root/.local/bin:$PATH
|
||||
|
||||
ENTRYPOINT ["bash", "./app/web/start.sh"]
|
||||
|
||||
@@ -4,7 +4,7 @@ x-app-variables: &app-variables
|
||||
API_SECRET: a_very_secret_token
|
||||
DATABASE_URI: postgresql://postgres:postgres@postgres/whisperbox
|
||||
ENVIRONMENT: development
|
||||
BROKER_URI: redis://redis:6379/0
|
||||
BROKER_URL: redis://redis:6379/0
|
||||
|
||||
services:
|
||||
postgres:
|
||||
|
||||
@@ -1,20 +1,18 @@
|
||||
FROM python:3.10 AS compile-image
|
||||
|
||||
RUN --mount=type=cache,target=/var/cache/apt \
|
||||
apt-get update && apt-get install -y --no-install-recommends ffmpeg
|
||||
|
||||
COPY pyproject.toml .
|
||||
RUN --mount=type=cache,target=/root/.cache \
|
||||
pip install --user .[worker,worker_dev]
|
||||
|
||||
FROM python:3.10 AS build-image
|
||||
|
||||
WORKDIR /code
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends ffmpeg
|
||||
|
||||
COPY pyproject.toml .
|
||||
RUN pip install --no-cache-dir --user .[worker,worker_dev]
|
||||
|
||||
COPY scripts/download_model.py .
|
||||
RUN chmod +x download_model.py && python download_model.py small small.en
|
||||
|
||||
ENV PYTHONIOENCODING=utf-8
|
||||
ENV PYTHONDONTWRITEBYTECODE 1
|
||||
ENV PYTHONUNBUFFERED 1
|
||||
|
||||
COPY --from=compile-image /root/.local /root/.local
|
||||
ENV PATH=/root/.local/bin:$PATH
|
||||
|
||||
ENTRYPOINT ["watchmedo", "auto-restart", "-d" , "app/worker", "-p", "*.py", "celery", "--", "--app=app.worker.main.celery", "worker", "--loglevel=info"]
|
||||
ENTRYPOINT ["watchmedo", "auto-restart", "-d" , "app/worker", "-p", "*.py", "--recursive", "celery", "--", "--app=app.worker.main.celery", "worker", "--loglevel=info", "--concurrency=1"]
|
||||
|
||||
@@ -18,7 +18,8 @@ web=[
|
||||
]
|
||||
|
||||
worker=[
|
||||
"whisper-openai ==1.0.0"
|
||||
"whisper-openai ==1.0.0",
|
||||
"requests ==2.28.2"
|
||||
]
|
||||
|
||||
lint = [
|
||||
|
||||
7
scripts/download_model.py
Normal file
7
scripts/download_model.py
Normal file
@@ -0,0 +1,7 @@
|
||||
import sys
|
||||
from whisper import _download, _MODELS # type: ignore
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = sys.argv[1:]
|
||||
for name in args:
|
||||
_download(_MODELS[name], "/models/", False)
|
||||
Reference in New Issue
Block a user