feat: postgres => sqlite

This commit is contained in:
Felix Spöttel
2023-02-08 14:09:16 +01:00
parent e995b1f2ff
commit 18921d34c6
16 changed files with 148 additions and 159 deletions

View File

@@ -1,4 +1,4 @@
DATABASE_URI="postgresql://postgres:postgres@localhost:5432/whisperbox_test"
ENVIRONMENT="development"
API_SECRET="foo"
BROKER_URL="redis://localhost:6379/0"
DATABASE_URI="sqlite:///memory"
ENVIRONMENT="test"
API_SECRET="test_secret"
BROKER_URL="memory://"

View File

@@ -16,14 +16,14 @@ jobs:
- isort --check app
- flake8 app
- mypy app
# test:
# runs-on: ubuntu-latest
# name: Test
# steps:
# - uses: actions/checkout@v3
# - uses: actions/setup-python@v4
# with:
# python-version: '3.11'
# cache: 'pip'
# - pip install -e .[test]
# - pytest
test:
runs-on: ubuntu-latest
name: Test
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.11'
cache: 'pip'
- pip install -e .[test]
- pytest

View File

@@ -1,8 +1,9 @@
clean:
docker-compose -f docker/dev.docker-compose.yml down --volumes --remove-orphans
docker-compose -f docker/dev/docker-compose.yml down --volumes --remove-orphans
dev:
docker-compose -f docker/dev.docker-compose.yml build --progress tty
docker-compose -f docker/dev.docker-compose.yml up --remove-orphans
docker-compose -f docker/dev/docker-compose.yml build --progress tty
docker-compose -f docker/dev/docker-compose.yml up --remove-orphans
fmt:
black app
@@ -14,3 +15,7 @@ lint:
test:
pytest
run:
docker-compose -f docker/prod/docker-compose.yml build --progress tty
docker-compose -f docker/prod/docker-compose.yml up --remove-orphans

View File

@@ -1,16 +1,15 @@
"""add_job_tables
Revision ID: 426b6bdc3360
Revision ID: dc8582aea0bc
Revises:
Create Date: 2023-01-27 17:55:21.758828
Create Date: 2023-02-08 12:12:00.808816
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "426b6bdc3360"
revision = "dc8582aea0bc"
down_revision = None
branch_labels = None
depends_on = None
@@ -30,23 +29,28 @@ def upgrade() -> None:
sa.Column("meta", sa.JSON(none_as_null=True), nullable=True),
sa.Column(
"type",
sa.Enum("transcript", "translation", "language_detection", name="jobtype"),
sa.Enum(
"transcript",
"translation",
"language_detection",
name="jobtype",
),
nullable=False,
),
sa.Column(
"created_at",
sa.DateTime(),
server_default=sa.text("now()"),
server_default=sa.text("(CURRENT_TIMESTAMP)"),
nullable=False,
),
sa.Column("updated_at", sa.DateTime(), nullable=True),
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("id", sa.VARCHAR(length=36), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_jobs_id"), "jobs", ["id"], unique=False)
op.create_table(
"artifacts",
sa.Column("job_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("job_id", sa.VARCHAR(length=36), nullable=False),
sa.Column("data", sa.JSON(none_as_null=True), nullable=True),
sa.Column(
"type",
@@ -56,11 +60,11 @@ def upgrade() -> None:
sa.Column(
"created_at",
sa.DateTime(),
server_default=sa.text("now()"),
server_default=sa.text("(CURRENT_TIMESTAMP)"),
nullable=False,
),
sa.Column("updated_at", sa.DateTime(), nullable=True),
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("id", sa.VARCHAR(length=36), nullable=False),
sa.ForeignKeyConstraint(["job_id"], ["jobs.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)

View File

@@ -5,7 +5,7 @@ from sqlalchemy.orm import Session, sessionmaker
from app.shared.config import settings
engine = create_engine(settings.DATABASE_URI)
engine = create_engine(settings.DATABASE_URI, connect_args={"check_same_thread": False})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

View File

@@ -1,7 +1,7 @@
import uuid
from typing import Optional
from sqlalchemy import JSON, Column, DateTime, Enum, ForeignKey, String, func
from sqlalchemy import JSON, VARCHAR, Column, DateTime, Enum, ForeignKey, String, func
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.orm import Mapped, declarative_mixin # type: ignore
@@ -26,7 +26,7 @@ class WithStandardFields:
@declared_attr
def id(cls) -> Mapped[UUID]:
return Column(
UUID(as_uuid=True), primary_key=True, index=True, default=uuid.uuid4
VARCHAR(36), primary_key=True, index=True, default=lambda: str(uuid.uuid4())
)
@@ -44,7 +44,9 @@ class Artifact(Base, WithStandardFields):
__tablename__ = "artifacts"
job_id = Column(
UUID(as_uuid=True), ForeignKey("jobs.id", ondelete="CASCADE"), nullable=False
VARCHAR(36),
ForeignKey("jobs.id", ondelete="CASCADE"),
nullable=False,
)
data = Column(JSON(none_as_null=True))

View File

@@ -7,11 +7,10 @@ from sqlalchemy.orm import Session
import app.shared.db.dtos as dtos
import app.shared.db.models as models
from app.shared.db.dtos import JobStatus, JobType
from app.web.main import app
from app.web.main import app, celery
client = TestClient(app)
@pytest.fixture(name="mock_job", scope="function", autouse=False)
def mock_job(db_session: Session) -> models.Job:
job = models.Job(

View File

@@ -81,7 +81,7 @@ def get_transcripts(
def get_transcript(
id: UUID = Path(), session: Session = Depends(get_session)
) -> Optional[models.Job]:
job = session.query(models.Job).filter(models.Job.id == id).one_or_none()
job = session.query(models.Job).filter(models.Job.id == str(id)).one_or_none()
if not job:
raise HTTPException(status_code=404)
return job
@@ -92,7 +92,7 @@ def get_artifacts_for_job(
id: UUID = Path(), session: Session = Depends(get_session)
) -> List[models.Artifact]:
artifacts = (
session.query(models.Artifact).filter(models.Artifact.job_id == id)
session.query(models.Artifact).filter(models.Artifact.job_id == str(id))
).all()
if not len(artifacts):
@@ -105,14 +105,14 @@ def get_artifacts_for_job(
def delete_transcript(
id: UUID = Path(), session: Session = Depends(get_session)
) -> None:
session.query(models.Job).filter(models.Job.id == id).delete()
session.query(models.Job).filter(models.Job.id == str(id)).delete()
return None
app.include_router(api_router)
# TODO:
# we could use `acks_late` to handle this scenario within celery itself.
# TODO: we could use `acks_late` to handle this scenario within celery itself.
# the reason this does not work well in our case is that `visibility_timeout`
# needs to be very high since whisper workers can be long running.
# doing this application-side bears the risk of poison pilling the worker though,
@@ -123,10 +123,15 @@ def on_startup() -> None:
jobs = (
session.query(models.Job)
.filter(or_(models.Job.status == dtos.JobStatus.processing, models.Job.status == dtos.JobStatus.create))
.order_by(models.Job.created_at)
.filter(
or_(
models.Job.status == dtos.JobStatus.processing,
models.Job.status == dtos.JobStatus.create,
)
)
.order_by(models.Job.created_at)
).all()
logger.info(f"Re-queueing {len(jobs)} jobs.")
logger.info(f"Requeueing {len(jobs)} jobs.")
for job in jobs:
queue_task(job)

View File

@@ -1,9 +0,0 @@
#! /usr/bin/env bash
set -e
# run migrations
alembic upgrade head
# start app
uvicorn app.web.main:app --reload --host ${HOST:-0.0.0.0} --port ${PORT:-80} --log-level info

View File

@@ -13,17 +13,16 @@ from app.worker.strategies.local import LocalStrategy
celery = get_celery_binding()
@celery.task(
bind=True,
soft_time_limit=2 * 60 * 60 # TODO: make configurable
)
@celery.task(bind=True, soft_time_limit=2 * 60 * 60) # TODO: make configurable
def transcribe(self: Task, job_id: UUID) -> None:
try:
db: Session = SessionLocal()
job = db.query(models.Job).filter(models.Job.id == job_id).one()
if job.status == dtos.JobStatus.error or job.status == dtos.JobStatus.success:
logger.warn("[{job.id}]: Received job that has already been processed, abort.")
logger.warn(
"[{job.id}]: Received job that has already been processed, abort."
)
return
job.meta = {"task_id": self.request.id}
@@ -51,7 +50,7 @@ def transcribe(self: Task, job_id: UUID) -> None:
result = strategy.detect_language()
artifact = models.Artifact(
job_id=job.id, data=result, type=dtos.ArtifactType.raw_transcript
job_id=str(job.id), data=result, type=dtos.ArtifactType.raw_transcript
)
db.add(artifact)
@@ -64,9 +63,9 @@ def transcribe(self: Task, job_id: UUID) -> None:
logger.info(f"[{job.id}]: set task to status success.")
except Exception as e:
if job and db:
job.meta = { **job.meta, "error": str(e) }
job.meta = {**job.meta.__dict__, "error": str(e)}
job.status = dtos.JobStatus.error
db.commit()
raise(e)
raise (e)
finally:
db.close()

View File

@@ -1,91 +0,0 @@
version: "3.8"
x-app-variables: &app-variables
API_SECRET: a_very_secret_token
DATABASE_URI: postgresql://postgres:postgres@postgres/whisperbox
ENVIRONMENT: development
BROKER_URL: redis://redis:6379/0
services:
postgres:
container_name: whisperbox_postgres
image: postgres:15-alpine
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: whisperbox
ports:
- "5432:5432"
networks:
- app
volumes:
- postgres-data:/var/lib/postgresql/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres"]
interval: 5s
timeout: 5s
retries: 5
redis:
container_name: whisperbox_redis
image: redis:7-alpine
ports:
- 6379:6379
networks:
- app
app:
container_name: whisperbox_app
build:
context: ../
dockerfile: docker/app.dev.Dockerfile
environment: *app-variables
ports:
- "8000:80"
networks:
- app
volumes:
- ../:/code
depends_on:
postgres:
condition: service_healthy
postgres:
condition: service_healthy
redis:
condition: service_started
worker:
build:
context: ../
dockerfile: docker/worker.dev.Dockerfile
container_name: whisperbox_worker
volumes:
- ../:/code
environment: *app-variables
depends_on:
- redis
networks:
- app
healthcheck:
test: ["CMD-SHELL", "celery inspect ping -A app.worker.main.transcribe -d celery@$$HOSTNAME"]
interval: 5s
timeout: 5s
retries: 5
flower:
container_name: whisperbox_flower
image: mher/flower
command: celery --broker redis://redis:6379/0 flower --port=5555
ports:
- 5555:5555
depends_on:
- redis
networks:
- app
volumes:
postgres-data:
networks:
app:
driver: bridge

View File

@@ -0,0 +1,71 @@
version: "3.8"
x-app-variables: &app-variables
API_SECRET: a_very_secret_token
DATABASE_URI: sqlite:////whisperbox.sqlite
ENVIRONMENT: development
BROKER_URL: redis://redis:6379/0
WHISPER_MODEL: small
services:
redis:
container_name: whisperbox_redis_dev
image: redis:7-alpine
ports:
- 6379:6379
networks:
- app
resources:
limits:
memory: 128M
web:
container_name: whisperbox_web_dev
build:
context: ../../
dockerfile: docker/dev/web.Dockerfile
environment: *app-variables
ports:
- "8000:80"
networks:
- app
volumes:
- ../../:/code
depends_on:
redis:
condition: service_started
worker:
build:
context: ../../
dockerfile: docker/dev/worker.Dockerfile
args:
WHISPER_MODEL: small
container_name: whisperbox_worker_dev
volumes:
- ../../:/code
environment: *app-variables
depends_on:
- redis
networks:
- app
healthcheck:
test: ["CMD-SHELL", "celery inspect ping -A app.worker.main.transcribe -d celery@$$HOSTNAME"]
interval: 5s
timeout: 5s
retries: 5
flower:
container_name: whisperbox_flower_dev
image: mher/flower
command: celery --broker redis://redis:6379/0 flower --port=5555
ports:
- 5555:5555
depends_on:
- redis
networks:
- app
networks:
app:
driver: bridge

View File

@@ -5,8 +5,9 @@ WORKDIR /code
COPY pyproject.toml .
RUN pip install --no-cache-dir --user .[web]
ENV PYTHONIOENCODING=utf-8
ENV PYTHONDONTWRITEBYTECODE 1
ENV PYTHONUNBUFFERED 1
ENV PATH=/root/.local/bin:$PATH
ENTRYPOINT ["bash", "./app/web/start.sh"]
CMD alembic upgrade head && uvicorn app.web.main:app --reload --host ${HOST:-0.0.0.0} --port ${PORT:-80} --log-level info

View File

@@ -1,14 +1,17 @@
FROM python:3.10 AS compile-image
ARG WHISPER_MODEL
WORKDIR /code
RUN apt-get update && apt-get install -y --no-install-recommends ffmpeg
COPY --from=mwader/static-ffmpeg:5.1.2 /ffmpeg /usr/local/bin/
COPY --from=mwader/static-ffmpeg:5.1.2 /ffprobe /usr/local/bin/
COPY pyproject.toml .
RUN pip install --no-cache-dir --user .[worker,worker_dev]
COPY scripts/download_model.py .
RUN chmod +x download_model.py && python download_model.py small small.en
RUN chmod +x download_model.py && python download_model.py ${WHISPER_MODEL}
ENV PYTHONIOENCODING=utf-8
ENV PYTHONDONTWRITEBYTECODE 1

View File

@@ -5,7 +5,6 @@ version = "0.0.1"
dependencies=[
"celery[redis] ==5.2.7",
"psycopg2 ==2.9.5",
"sqlalchemy[mypy] == 1.4.45",
"pydantic ==1.10.4"
]
@@ -34,9 +33,9 @@ lint = [
test = [
"httpx",
"sqlalchemy-stubs",
"sqlalchemy-utils",
"sqlalchemy-utils ==0.39.0",
"python-dotenv",
"pytest"
"pytest",
]
worker_dev = [

View File

@@ -2,6 +2,7 @@ import sys
from whisper import _download, _MODELS # type: ignore
if __name__ == "__main__":
args = sys.argv[1:]
for name in args:
_download(_MODELS[name], "/models/", False)
model_name = sys.argv[1]
_download(_MODELS[model_name], "/models/", False)
if model_name != "large":
_download(_MODELS[f"{model_name}.en"], "/models/", False)