feat: configure celery to use rabbitmq broker (#58)

This commit is contained in:
Felix Spöttel
2023-08-17 22:45:51 +02:00
committed by GitHub
parent 423018e92a
commit 504975a07a
12 changed files with 74 additions and 56 deletions

View File

@@ -3,4 +3,6 @@ TRAEFIK_DOMAIN="whisperbox-transcribe.localhost"
WHISPER_MODEL="tiny"
ENVIRONMENT="development"
DATABASE_URI="sqlite:///./whisperbox-transcribe.sqlite"
BROKER_URL="redis://redis:6379/0"
RABBITMQ_DEFAULT_USER="rabbitmq"
RABBITMQ_DEFAULT_PASS="rabbitmq_password"

View File

@@ -16,6 +16,8 @@ TRAEFIK_SSLEMAIL=""
# ---
# below settings match the default docker-compose configuration.
BROKER_URL="redis://redis:6379/0"
RABBITMQ_DEFAULT_USER="rabbitmq"
RABBITMQ_DEFAULT_PASS="rabbitmq_password"
DATABASE_URI="sqlite:////etc/whisperbox-transcribe/data/whisperbox-transcribe.sqlite"
ENVIRONMENT="production"

View File

@@ -57,6 +57,7 @@ Builds and starts the docker containers.
```
# Bindings
http://localhost:5555 => Celery dashboard
http://localhost:15672 => RabbitMQ dashboard
http://whisperbox-transcribe.localhost => API
http://whisperbox-transcribe.localhost/docs => API docs
./whisperbox-transcribe.sqlite => Database

View File

@@ -9,4 +9,5 @@ def get_celery_binding() -> Celery:
broker_connection_retry=False,
broker_connection_retry_on_startup=False,
)
return celery

View File

@@ -52,6 +52,11 @@ class JobConfig(BaseModel):
class JobMeta(BaseModel):
"""(JSON) Metadata relating to a job's execution."""
attempts: int | None = Field(
default=None,
description="Number of processing attempts a job has taken.",
)
error: str | None = Field(
default=None,
description="Will contain a descriptive error message if processing failed.",

View File

@@ -1,4 +1,3 @@
from contextlib import asynccontextmanager
from typing import Annotated, Callable, Generator
from uuid import UUID
@@ -8,7 +7,6 @@ from sqlalchemy.orm import Session
import app.shared.db.models as models
import app.web.dtos as dtos
from app.shared.db.base import SessionLocal
from app.shared.settings import settings
from app.web.security import authenticate_api_key
from app.web.task_queue import TaskQueue
@@ -21,17 +19,10 @@ def app_factory(
task_queue = TaskQueue()
@asynccontextmanager
async def lifespan(_: FastAPI):
with SessionLocal() as session:
task_queue.rehydrate(session)
yield
app = FastAPI(
description=(
"whisperbox-transcribe is an async HTTP wrapper for openai/whisper."
),
lifespan=lifespan,
title="whisperbox-transcribe",
)

View File

@@ -1,8 +1,4 @@
from asyncio.log import logger
from celery import Celery
from sqlalchemy import or_
from sqlalchemy.orm import Session
import app.shared.db.models as models
from app.shared.celery import get_celery_binding
@@ -22,25 +18,3 @@ class TaskQueue:
transcribe = self.celery.signature("app.worker.main.transcribe")
# TODO: catch delivery errors?
transcribe.delay(job.id)
def rehydrate(self, session: Session):
# TODO: we could use `acks_late` to handle this scenario within celery itself.
# the reason this does not work well in our case is that `visibility_timeout`
# needs to be very high since whisper workers can be long running.
# doing this app-side bears the risk of poison pilling the worker though,
# implement a workaround with an acceptable trade-off. (=> retry only once?)
jobs = (
session.query(models.Job)
.filter(
or_(
models.Job.status == models.JobStatus.processing,
models.Job.status == models.JobStatus.create,
)
)
.order_by(models.Job.created_at)
).all()
logger.info(f"Requeueing {len(jobs)} jobs.")
for job in jobs:
self.queue_task(job)

View File

@@ -39,6 +39,9 @@ class TranscribeTask(Task):
bind=True,
soft_time_limit=settings.TASK_SOFT_TIME_LIMIT,
time_limit=settings.TASK_HARD_TIME_LIMIT,
task_acks_late=True,
task_acks_on_failure_or_timeout=True,
task_reject_on_worker_lost=True,
)
def transcribe(self: Task, job_id: UUID) -> None:
try:
@@ -59,9 +62,20 @@ def transcribe(self: Task, job_id: UUID) -> None:
logger.debug(f"[{job.id}]: start processing {job.type} job.")
if job.meta:
attempts = 1 + (job.meta.get("attempts") or 0)
else:
attempts = 1
# SAFEGUARD: celery's retry policies do not handle lost workers, retry once.
# @see https://github.com/celery/celery/pull/6103
if attempts > 2:
raise Exception("Maximum number of retries exceeded for killed worker.")
# unit of work: set task status to processing.
job.meta = {"task_id": self.request.id}
job.meta = {"task_id": self.request.id, "attempts": attempts}
job.status = models.JobStatus.processing
db.commit()
@@ -83,7 +97,11 @@ def transcribe(self: Task, job_id: UUID) -> None:
if job and db:
if db.in_transaction():
db.rollback()
job.meta = {**job.meta, "error": str(e)} # type: ignore
if job.meta:
job.meta = {**job.meta, "error": str(e)} # type: ignore
else:
job.meta = {"error": str(e)}
job.status = models.JobStatus.error
db.commit()
raise

1
conf/rabbitmq.conf Normal file
View File

@@ -0,0 +1 @@
vm_memory_high_watermark.absolute = 192MB

View File

@@ -1,3 +1,6 @@
x-broker-environment: &broker-environment
BROKER_URL: "amqp://${RABBITMQ_DEFAULT_USER}:${RABBITMQ_DEFAULT_PASS}@rabbitmq:5672"
version: "3.8"
name: whisperbox-transcribe
@@ -12,46 +15,59 @@ services:
networks:
- traefik
redis:
image: redis:7-alpine
rabbitmq:
env_file: .env
image: rabbitmq:3-alpine
networks:
- app
deploy:
resources:
limits:
memory: 128M
memory: 256M
healthcheck:
test: rabbitmq-diagnostics check_port_connectivity
interval: 3s
timeout: 3s
retries: 10
volumes:
- ./conf/rabbitmq.conf:/etc/rabbitmq/rabbitmq.conf
- rabbitmq-data:/var/lib/rabbitmq/mnesia/
worker:
env_file: .env
environment:
<<: *broker-environment
build:
context: .
dockerfile: worker.Dockerfile
args:
WHISPER_MODEL: ${WHISPER_MODEL}
depends_on:
rabbitmq:
condition: service_healthy
networks:
- app
depends_on:
- redis
healthcheck:
test: ["CMD-SHELL", "celery -b ${BROKER_URL} inspect ping -d celery@$$HOSTNAME"]
interval: 5s
timeout: 5s
retries: 5
web:
env_file: .env
environment:
<<: *broker-environment
build:
context: .
dockerfile: web.Dockerfile
depends_on:
rabbitmq:
condition: service_healthy
networks:
- app
- traefik
depends_on:
worker:
condition: service_healthy
networks:
app:
driver: bridge
traefik:
driver: bridge
volumes:
rabbitmq-data:

View File

@@ -13,6 +13,8 @@ services:
web:
command: bash -c "alembic upgrade head && uvicorn app.web:app --reload --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info"
# NOTE: the docker on mac mount adapter (virtioFS) does not support flock.
# this can cause the sqlite database to corrupt when written from worker <> api simultaneously.
volumes:
- ./:/etc/whisperbox-transcribe/
labels:
@@ -26,13 +28,18 @@ services:
volumes:
- ./:/etc/whisperbox-transcribe/
rabbitmq:
image: rabbitmq:3-management-alpine
ports:
- 15672:15672
flower:
image: mher/flower
command: celery --broker redis://redis:6379/0 flower --port=5555
command: celery --broker amqp://${RABBITMQ_DEFAULT_USER}:${RABBITMQ_DEFAULT_PASS}@rabbitmq:5672 flower --port=5555
ports:
- 5555:5555
depends_on:
worker:
condition: service_healthy
- worker
- rabbitmq
networks:
- app

View File

@@ -4,7 +4,7 @@ description = ""
version = "1.0.0"
dependencies=[
"celery[redis] ==5.3.1",
"celery ==5.3.1",
"sqlalchemy[mypy] ==2.0.20",
"pydantic ==2.1.1",
"pydantic-settings ==2.0.3"