mirror of
https://github.com/bellingcat/whisperbox-transcribe.git
synced 2026-06-11 13:08:35 +03:00
feat: postgres => sqlite
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
DATABASE_URI="postgresql://postgres:postgres@localhost:5432/whisperbox_test"
|
||||
ENVIRONMENT="development"
|
||||
API_SECRET="foo"
|
||||
BROKER_URL="redis://localhost:6379/0"
|
||||
DATABASE_URI="sqlite:///memory"
|
||||
ENVIRONMENT="test"
|
||||
API_SECRET="test_secret"
|
||||
BROKER_URL="memory://"
|
||||
|
||||
22
.github/workflows/ci.yml
vendored
22
.github/workflows/ci.yml
vendored
@@ -16,14 +16,14 @@ jobs:
|
||||
- isort --check app
|
||||
- flake8 app
|
||||
- mypy app
|
||||
# test:
|
||||
# runs-on: ubuntu-latest
|
||||
# name: Test
|
||||
# steps:
|
||||
# - uses: actions/checkout@v3
|
||||
# - uses: actions/setup-python@v4
|
||||
# with:
|
||||
# python-version: '3.11'
|
||||
# cache: 'pip'
|
||||
# - pip install -e .[test]
|
||||
# - pytest
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
name: Test
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
- pip install -e .[test]
|
||||
- pytest
|
||||
|
||||
11
Makefile
11
Makefile
@@ -1,8 +1,9 @@
|
||||
clean:
|
||||
docker-compose -f docker/dev.docker-compose.yml down --volumes --remove-orphans
|
||||
docker-compose -f docker/dev/docker-compose.yml down --volumes --remove-orphans
|
||||
|
||||
dev:
|
||||
docker-compose -f docker/dev.docker-compose.yml build --progress tty
|
||||
docker-compose -f docker/dev.docker-compose.yml up --remove-orphans
|
||||
docker-compose -f docker/dev/docker-compose.yml build --progress tty
|
||||
docker-compose -f docker/dev/docker-compose.yml up --remove-orphans
|
||||
|
||||
fmt:
|
||||
black app
|
||||
@@ -14,3 +15,7 @@ lint:
|
||||
|
||||
test:
|
||||
pytest
|
||||
|
||||
run:
|
||||
docker-compose -f docker/prod/docker-compose.yml build --progress tty
|
||||
docker-compose -f docker/prod/docker-compose.yml up --remove-orphans
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
"""add_job_tables
|
||||
|
||||
Revision ID: 426b6bdc3360
|
||||
Revision ID: dc8582aea0bc
|
||||
Revises:
|
||||
Create Date: 2023-01-27 17:55:21.758828
|
||||
Create Date: 2023-02-08 12:12:00.808816
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "426b6bdc3360"
|
||||
revision = "dc8582aea0bc"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
@@ -30,23 +29,28 @@ def upgrade() -> None:
|
||||
sa.Column("meta", sa.JSON(none_as_null=True), nullable=True),
|
||||
sa.Column(
|
||||
"type",
|
||||
sa.Enum("transcript", "translation", "language_detection", name="jobtype"),
|
||||
sa.Enum(
|
||||
"transcript",
|
||||
"translation",
|
||||
"language_detection",
|
||||
name="jobtype",
|
||||
),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(),
|
||||
server_default=sa.text("now()"),
|
||||
server_default=sa.text("(CURRENT_TIMESTAMP)"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("updated_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("id", sa.VARCHAR(length=36), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("ix_jobs_id"), "jobs", ["id"], unique=False)
|
||||
op.create_table(
|
||||
"artifacts",
|
||||
sa.Column("job_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("job_id", sa.VARCHAR(length=36), nullable=False),
|
||||
sa.Column("data", sa.JSON(none_as_null=True), nullable=True),
|
||||
sa.Column(
|
||||
"type",
|
||||
@@ -56,11 +60,11 @@ def upgrade() -> None:
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(),
|
||||
server_default=sa.text("now()"),
|
||||
server_default=sa.text("(CURRENT_TIMESTAMP)"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("updated_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("id", sa.VARCHAR(length=36), nullable=False),
|
||||
sa.ForeignKeyConstraint(["job_id"], ["jobs.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
@@ -5,7 +5,7 @@ from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from app.shared.config import settings
|
||||
|
||||
engine = create_engine(settings.DATABASE_URI)
|
||||
engine = create_engine(settings.DATABASE_URI, connect_args={"check_same_thread": False})
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import JSON, Column, DateTime, Enum, ForeignKey, String, func
|
||||
from sqlalchemy import JSON, VARCHAR, Column, DateTime, Enum, ForeignKey, String, func
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.ext.declarative import declarative_base, declared_attr
|
||||
from sqlalchemy.orm import Mapped, declarative_mixin # type: ignore
|
||||
@@ -26,7 +26,7 @@ class WithStandardFields:
|
||||
@declared_attr
|
||||
def id(cls) -> Mapped[UUID]:
|
||||
return Column(
|
||||
UUID(as_uuid=True), primary_key=True, index=True, default=uuid.uuid4
|
||||
VARCHAR(36), primary_key=True, index=True, default=lambda: str(uuid.uuid4())
|
||||
)
|
||||
|
||||
|
||||
@@ -44,7 +44,9 @@ class Artifact(Base, WithStandardFields):
|
||||
__tablename__ = "artifacts"
|
||||
|
||||
job_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("jobs.id", ondelete="CASCADE"), nullable=False
|
||||
VARCHAR(36),
|
||||
ForeignKey("jobs.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
data = Column(JSON(none_as_null=True))
|
||||
|
||||
@@ -7,11 +7,10 @@ from sqlalchemy.orm import Session
|
||||
import app.shared.db.dtos as dtos
|
||||
import app.shared.db.models as models
|
||||
from app.shared.db.dtos import JobStatus, JobType
|
||||
from app.web.main import app
|
||||
from app.web.main import app, celery
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(name="mock_job", scope="function", autouse=False)
|
||||
def mock_job(db_session: Session) -> models.Job:
|
||||
job = models.Job(
|
||||
|
||||
@@ -81,7 +81,7 @@ def get_transcripts(
|
||||
def get_transcript(
|
||||
id: UUID = Path(), session: Session = Depends(get_session)
|
||||
) -> Optional[models.Job]:
|
||||
job = session.query(models.Job).filter(models.Job.id == id).one_or_none()
|
||||
job = session.query(models.Job).filter(models.Job.id == str(id)).one_or_none()
|
||||
if not job:
|
||||
raise HTTPException(status_code=404)
|
||||
return job
|
||||
@@ -92,7 +92,7 @@ def get_artifacts_for_job(
|
||||
id: UUID = Path(), session: Session = Depends(get_session)
|
||||
) -> List[models.Artifact]:
|
||||
artifacts = (
|
||||
session.query(models.Artifact).filter(models.Artifact.job_id == id)
|
||||
session.query(models.Artifact).filter(models.Artifact.job_id == str(id))
|
||||
).all()
|
||||
|
||||
if not len(artifacts):
|
||||
@@ -105,14 +105,14 @@ def get_artifacts_for_job(
|
||||
def delete_transcript(
|
||||
id: UUID = Path(), session: Session = Depends(get_session)
|
||||
) -> None:
|
||||
session.query(models.Job).filter(models.Job.id == id).delete()
|
||||
session.query(models.Job).filter(models.Job.id == str(id)).delete()
|
||||
return None
|
||||
|
||||
|
||||
app.include_router(api_router)
|
||||
|
||||
# TODO:
|
||||
# we could use `acks_late` to handle this scenario within celery itself.
|
||||
|
||||
# TODO: we could use `acks_late` to handle this scenario within celery itself.
|
||||
# the reason this does not work well in our case is that `visibility_timeout`
|
||||
# needs to be very high since whisper workers can be long running.
|
||||
# doing this application-side bears the risk of poison pilling the worker though,
|
||||
@@ -123,10 +123,15 @@ def on_startup() -> None:
|
||||
|
||||
jobs = (
|
||||
session.query(models.Job)
|
||||
.filter(or_(models.Job.status == dtos.JobStatus.processing, models.Job.status == dtos.JobStatus.create))
|
||||
.order_by(models.Job.created_at)
|
||||
.filter(
|
||||
or_(
|
||||
models.Job.status == dtos.JobStatus.processing,
|
||||
models.Job.status == dtos.JobStatus.create,
|
||||
)
|
||||
)
|
||||
.order_by(models.Job.created_at)
|
||||
).all()
|
||||
|
||||
logger.info(f"Re-queueing {len(jobs)} jobs.")
|
||||
logger.info(f"Requeueing {len(jobs)} jobs.")
|
||||
for job in jobs:
|
||||
queue_task(job)
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
# run migrations
|
||||
alembic upgrade head
|
||||
|
||||
# start app
|
||||
uvicorn app.web.main:app --reload --host ${HOST:-0.0.0.0} --port ${PORT:-80} --log-level info
|
||||
@@ -13,17 +13,16 @@ from app.worker.strategies.local import LocalStrategy
|
||||
celery = get_celery_binding()
|
||||
|
||||
|
||||
@celery.task(
|
||||
bind=True,
|
||||
soft_time_limit=2 * 60 * 60 # TODO: make configurable
|
||||
)
|
||||
@celery.task(bind=True, soft_time_limit=2 * 60 * 60) # TODO: make configurable
|
||||
def transcribe(self: Task, job_id: UUID) -> None:
|
||||
try:
|
||||
db: Session = SessionLocal()
|
||||
job = db.query(models.Job).filter(models.Job.id == job_id).one()
|
||||
|
||||
if job.status == dtos.JobStatus.error or job.status == dtos.JobStatus.success:
|
||||
logger.warn("[{job.id}]: Received job that has already been processed, abort.")
|
||||
logger.warn(
|
||||
"[{job.id}]: Received job that has already been processed, abort."
|
||||
)
|
||||
return
|
||||
|
||||
job.meta = {"task_id": self.request.id}
|
||||
@@ -51,7 +50,7 @@ def transcribe(self: Task, job_id: UUID) -> None:
|
||||
result = strategy.detect_language()
|
||||
|
||||
artifact = models.Artifact(
|
||||
job_id=job.id, data=result, type=dtos.ArtifactType.raw_transcript
|
||||
job_id=str(job.id), data=result, type=dtos.ArtifactType.raw_transcript
|
||||
)
|
||||
|
||||
db.add(artifact)
|
||||
@@ -64,9 +63,9 @@ def transcribe(self: Task, job_id: UUID) -> None:
|
||||
logger.info(f"[{job.id}]: set task to status success.")
|
||||
except Exception as e:
|
||||
if job and db:
|
||||
job.meta = { **job.meta, "error": str(e) }
|
||||
job.meta = {**job.meta.__dict__, "error": str(e)}
|
||||
job.status = dtos.JobStatus.error
|
||||
db.commit()
|
||||
raise(e)
|
||||
raise (e)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
version: "3.8"
|
||||
|
||||
x-app-variables: &app-variables
|
||||
API_SECRET: a_very_secret_token
|
||||
DATABASE_URI: postgresql://postgres:postgres@postgres/whisperbox
|
||||
ENVIRONMENT: development
|
||||
BROKER_URL: redis://redis:6379/0
|
||||
|
||||
services:
|
||||
postgres:
|
||||
container_name: whisperbox_postgres
|
||||
image: postgres:15-alpine
|
||||
environment:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: whisperbox
|
||||
ports:
|
||||
- "5432:5432"
|
||||
networks:
|
||||
- app
|
||||
volumes:
|
||||
- postgres-data:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U postgres"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
redis:
|
||||
container_name: whisperbox_redis
|
||||
image: redis:7-alpine
|
||||
ports:
|
||||
- 6379:6379
|
||||
networks:
|
||||
- app
|
||||
|
||||
app:
|
||||
container_name: whisperbox_app
|
||||
build:
|
||||
context: ../
|
||||
dockerfile: docker/app.dev.Dockerfile
|
||||
environment: *app-variables
|
||||
ports:
|
||||
- "8000:80"
|
||||
networks:
|
||||
- app
|
||||
volumes:
|
||||
- ../:/code
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_started
|
||||
|
||||
worker:
|
||||
build:
|
||||
context: ../
|
||||
dockerfile: docker/worker.dev.Dockerfile
|
||||
container_name: whisperbox_worker
|
||||
volumes:
|
||||
- ../:/code
|
||||
environment: *app-variables
|
||||
depends_on:
|
||||
- redis
|
||||
networks:
|
||||
- app
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "celery inspect ping -A app.worker.main.transcribe -d celery@$$HOSTNAME"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
flower:
|
||||
container_name: whisperbox_flower
|
||||
image: mher/flower
|
||||
command: celery --broker redis://redis:6379/0 flower --port=5555
|
||||
ports:
|
||||
- 5555:5555
|
||||
depends_on:
|
||||
- redis
|
||||
networks:
|
||||
- app
|
||||
|
||||
volumes:
|
||||
postgres-data:
|
||||
|
||||
networks:
|
||||
app:
|
||||
driver: bridge
|
||||
71
docker/dev/docker-compose.yml
Normal file
71
docker/dev/docker-compose.yml
Normal file
@@ -0,0 +1,71 @@
|
||||
version: "3.8"
|
||||
|
||||
x-app-variables: &app-variables
|
||||
API_SECRET: a_very_secret_token
|
||||
DATABASE_URI: sqlite:////whisperbox.sqlite
|
||||
ENVIRONMENT: development
|
||||
BROKER_URL: redis://redis:6379/0
|
||||
WHISPER_MODEL: small
|
||||
|
||||
services:
|
||||
redis:
|
||||
container_name: whisperbox_redis_dev
|
||||
image: redis:7-alpine
|
||||
ports:
|
||||
- 6379:6379
|
||||
networks:
|
||||
- app
|
||||
resources:
|
||||
limits:
|
||||
memory: 128M
|
||||
|
||||
web:
|
||||
container_name: whisperbox_web_dev
|
||||
build:
|
||||
context: ../../
|
||||
dockerfile: docker/dev/web.Dockerfile
|
||||
environment: *app-variables
|
||||
ports:
|
||||
- "8000:80"
|
||||
networks:
|
||||
- app
|
||||
volumes:
|
||||
- ../../:/code
|
||||
depends_on:
|
||||
redis:
|
||||
condition: service_started
|
||||
|
||||
worker:
|
||||
build:
|
||||
context: ../../
|
||||
dockerfile: docker/dev/worker.Dockerfile
|
||||
args:
|
||||
WHISPER_MODEL: small
|
||||
container_name: whisperbox_worker_dev
|
||||
volumes:
|
||||
- ../../:/code
|
||||
environment: *app-variables
|
||||
depends_on:
|
||||
- redis
|
||||
networks:
|
||||
- app
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "celery inspect ping -A app.worker.main.transcribe -d celery@$$HOSTNAME"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
flower:
|
||||
container_name: whisperbox_flower_dev
|
||||
image: mher/flower
|
||||
command: celery --broker redis://redis:6379/0 flower --port=5555
|
||||
ports:
|
||||
- 5555:5555
|
||||
depends_on:
|
||||
- redis
|
||||
networks:
|
||||
- app
|
||||
|
||||
networks:
|
||||
app:
|
||||
driver: bridge
|
||||
@@ -5,8 +5,9 @@ WORKDIR /code
|
||||
COPY pyproject.toml .
|
||||
RUN pip install --no-cache-dir --user .[web]
|
||||
|
||||
ENV PYTHONIOENCODING=utf-8
|
||||
ENV PYTHONDONTWRITEBYTECODE 1
|
||||
ENV PYTHONUNBUFFERED 1
|
||||
ENV PATH=/root/.local/bin:$PATH
|
||||
|
||||
ENTRYPOINT ["bash", "./app/web/start.sh"]
|
||||
CMD alembic upgrade head && uvicorn app.web.main:app --reload --host ${HOST:-0.0.0.0} --port ${PORT:-80} --log-level info
|
||||
@@ -1,14 +1,17 @@
|
||||
FROM python:3.10 AS compile-image
|
||||
|
||||
ARG WHISPER_MODEL
|
||||
|
||||
WORKDIR /code
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends ffmpeg
|
||||
COPY --from=mwader/static-ffmpeg:5.1.2 /ffmpeg /usr/local/bin/
|
||||
COPY --from=mwader/static-ffmpeg:5.1.2 /ffprobe /usr/local/bin/
|
||||
|
||||
COPY pyproject.toml .
|
||||
RUN pip install --no-cache-dir --user .[worker,worker_dev]
|
||||
|
||||
COPY scripts/download_model.py .
|
||||
RUN chmod +x download_model.py && python download_model.py small small.en
|
||||
RUN chmod +x download_model.py && python download_model.py ${WHISPER_MODEL}
|
||||
|
||||
ENV PYTHONIOENCODING=utf-8
|
||||
ENV PYTHONDONTWRITEBYTECODE 1
|
||||
@@ -5,7 +5,6 @@ version = "0.0.1"
|
||||
|
||||
dependencies=[
|
||||
"celery[redis] ==5.2.7",
|
||||
"psycopg2 ==2.9.5",
|
||||
"sqlalchemy[mypy] == 1.4.45",
|
||||
"pydantic ==1.10.4"
|
||||
]
|
||||
@@ -34,9 +33,9 @@ lint = [
|
||||
test = [
|
||||
"httpx",
|
||||
"sqlalchemy-stubs",
|
||||
"sqlalchemy-utils",
|
||||
"sqlalchemy-utils ==0.39.0",
|
||||
"python-dotenv",
|
||||
"pytest"
|
||||
"pytest",
|
||||
]
|
||||
|
||||
worker_dev = [
|
||||
|
||||
@@ -2,6 +2,7 @@ import sys
|
||||
from whisper import _download, _MODELS # type: ignore
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = sys.argv[1:]
|
||||
for name in args:
|
||||
_download(_MODELS[name], "/models/", False)
|
||||
model_name = sys.argv[1]
|
||||
_download(_MODELS[model_name], "/models/", False)
|
||||
if model_name != "large":
|
||||
_download(_MODELS[f"{model_name}.en"], "/models/", False)
|
||||
|
||||
Reference in New Issue
Block a user