feat: configure celery to use rabbitmq broker (#58)

This commit is contained in:
Felix Spöttel
2023-08-17 22:45:51 +02:00
committed by GitHub
parent 423018e92a
commit 504975a07a
12 changed files with 74 additions and 56 deletions

View File

@@ -3,4 +3,6 @@ TRAEFIK_DOMAIN="whisperbox-transcribe.localhost"
WHISPER_MODEL="tiny" WHISPER_MODEL="tiny"
ENVIRONMENT="development" ENVIRONMENT="development"
DATABASE_URI="sqlite:///./whisperbox-transcribe.sqlite" DATABASE_URI="sqlite:///./whisperbox-transcribe.sqlite"
BROKER_URL="redis://redis:6379/0"
RABBITMQ_DEFAULT_USER="rabbitmq"
RABBITMQ_DEFAULT_PASS="rabbitmq_password"

View File

@@ -16,6 +16,8 @@ TRAEFIK_SSLEMAIL=""
# --- # ---
# below settings match the default docker-compose configuration. # below settings match the default docker-compose configuration.
BROKER_URL="redis://redis:6379/0" RABBITMQ_DEFAULT_USER="rabbitmq"
RABBITMQ_DEFAULT_PASS="rabbitmq_password"
DATABASE_URI="sqlite:////etc/whisperbox-transcribe/data/whisperbox-transcribe.sqlite" DATABASE_URI="sqlite:////etc/whisperbox-transcribe/data/whisperbox-transcribe.sqlite"
ENVIRONMENT="production" ENVIRONMENT="production"

View File

@@ -57,6 +57,7 @@ Builds and starts the docker containers.
``` ```
# Bindings # Bindings
http://localhost:5555 => Celery dashboard http://localhost:5555 => Celery dashboard
http://localhost:15672 => RabbitMQ dashboard
http://whisperbox-transcribe.localhost => API http://whisperbox-transcribe.localhost => API
http://whisperbox-transcribe.localhost/docs => API docs http://whisperbox-transcribe.localhost/docs => API docs
./whisperbox-transcribe.sqlite => Database ./whisperbox-transcribe.sqlite => Database

View File

@@ -9,4 +9,5 @@ def get_celery_binding() -> Celery:
broker_connection_retry=False, broker_connection_retry=False,
broker_connection_retry_on_startup=False, broker_connection_retry_on_startup=False,
) )
return celery return celery

View File

@@ -52,6 +52,11 @@ class JobConfig(BaseModel):
class JobMeta(BaseModel): class JobMeta(BaseModel):
"""(JSON) Metadata relating to a job's execution.""" """(JSON) Metadata relating to a job's execution."""
attempts: int | None = Field(
default=None,
description="Number of processing attempts a job has taken.",
)
error: str | None = Field( error: str | None = Field(
default=None, default=None,
description="Will contain a descriptive error message if processing failed.", description="Will contain a descriptive error message if processing failed.",

View File

@@ -1,4 +1,3 @@
from contextlib import asynccontextmanager
from typing import Annotated, Callable, Generator from typing import Annotated, Callable, Generator
from uuid import UUID from uuid import UUID
@@ -8,7 +7,6 @@ from sqlalchemy.orm import Session
import app.shared.db.models as models import app.shared.db.models as models
import app.web.dtos as dtos import app.web.dtos as dtos
from app.shared.db.base import SessionLocal
from app.shared.settings import settings from app.shared.settings import settings
from app.web.security import authenticate_api_key from app.web.security import authenticate_api_key
from app.web.task_queue import TaskQueue from app.web.task_queue import TaskQueue
@@ -21,17 +19,10 @@ def app_factory(
task_queue = TaskQueue() task_queue = TaskQueue()
@asynccontextmanager
async def lifespan(_: FastAPI):
with SessionLocal() as session:
task_queue.rehydrate(session)
yield
app = FastAPI( app = FastAPI(
description=( description=(
"whisperbox-transcribe is an async HTTP wrapper for openai/whisper." "whisperbox-transcribe is an async HTTP wrapper for openai/whisper."
), ),
lifespan=lifespan,
title="whisperbox-transcribe", title="whisperbox-transcribe",
) )

View File

@@ -1,8 +1,4 @@
from asyncio.log import logger
from celery import Celery from celery import Celery
from sqlalchemy import or_
from sqlalchemy.orm import Session
import app.shared.db.models as models import app.shared.db.models as models
from app.shared.celery import get_celery_binding from app.shared.celery import get_celery_binding
@@ -22,25 +18,3 @@ class TaskQueue:
transcribe = self.celery.signature("app.worker.main.transcribe") transcribe = self.celery.signature("app.worker.main.transcribe")
# TODO: catch delivery errors? # TODO: catch delivery errors?
transcribe.delay(job.id) transcribe.delay(job.id)
def rehydrate(self, session: Session):
# TODO: we could use `acks_late` to handle this scenario within celery itself.
# the reason this does not work well in our case is that `visibility_timeout`
# needs to be very high since whisper workers can be long running.
# doing this app-side bears the risk of poison pilling the worker though,
# implement a workaround with an acceptable trade-off. (=> retry only once?)
jobs = (
session.query(models.Job)
.filter(
or_(
models.Job.status == models.JobStatus.processing,
models.Job.status == models.JobStatus.create,
)
)
.order_by(models.Job.created_at)
).all()
logger.info(f"Requeueing {len(jobs)} jobs.")
for job in jobs:
self.queue_task(job)

View File

@@ -39,6 +39,9 @@ class TranscribeTask(Task):
bind=True, bind=True,
soft_time_limit=settings.TASK_SOFT_TIME_LIMIT, soft_time_limit=settings.TASK_SOFT_TIME_LIMIT,
time_limit=settings.TASK_HARD_TIME_LIMIT, time_limit=settings.TASK_HARD_TIME_LIMIT,
task_acks_late=True,
task_acks_on_failure_or_timeout=True,
task_reject_on_worker_lost=True,
) )
def transcribe(self: Task, job_id: UUID) -> None: def transcribe(self: Task, job_id: UUID) -> None:
try: try:
@@ -59,9 +62,20 @@ def transcribe(self: Task, job_id: UUID) -> None:
logger.debug(f"[{job.id}]: start processing {job.type} job.") logger.debug(f"[{job.id}]: start processing {job.type} job.")
if job.meta:
attempts = 1 + (job.meta.get("attempts") or 0)
else:
attempts = 1
# SAFEGUARD: celery's retry policies do not handle lost workers, retry once.
# @see https://github.com/celery/celery/pull/6103
if attempts > 2:
raise Exception("Maximum number of retries exceeded for killed worker.")
# unit of work: set task status to processing. # unit of work: set task status to processing.
job.meta = {"task_id": self.request.id} job.meta = {"task_id": self.request.id, "attempts": attempts}
job.status = models.JobStatus.processing job.status = models.JobStatus.processing
db.commit() db.commit()
@@ -83,7 +97,11 @@ def transcribe(self: Task, job_id: UUID) -> None:
if job and db: if job and db:
if db.in_transaction(): if db.in_transaction():
db.rollback() db.rollback()
job.meta = {**job.meta, "error": str(e)} # type: ignore if job.meta:
job.meta = {**job.meta, "error": str(e)} # type: ignore
else:
job.meta = {"error": str(e)}
job.status = models.JobStatus.error job.status = models.JobStatus.error
db.commit() db.commit()
raise raise

1
conf/rabbitmq.conf Normal file
View File

@@ -0,0 +1 @@
vm_memory_high_watermark.absolute = 192MB

View File

@@ -1,3 +1,6 @@
x-broker-environment: &broker-environment
BROKER_URL: "amqp://${RABBITMQ_DEFAULT_USER}:${RABBITMQ_DEFAULT_PASS}@rabbitmq:5672"
version: "3.8" version: "3.8"
name: whisperbox-transcribe name: whisperbox-transcribe
@@ -12,46 +15,59 @@ services:
networks: networks:
- traefik - traefik
redis: rabbitmq:
image: redis:7-alpine env_file: .env
image: rabbitmq:3-alpine
networks: networks:
- app - app
deploy: deploy:
resources: resources:
limits: limits:
memory: 128M memory: 256M
healthcheck:
test: rabbitmq-diagnostics check_port_connectivity
interval: 3s
timeout: 3s
retries: 10
volumes:
- ./conf/rabbitmq.conf:/etc/rabbitmq/rabbitmq.conf
- rabbitmq-data:/var/lib/rabbitmq/mnesia/
worker: worker:
env_file: .env env_file: .env
environment:
<<: *broker-environment
build: build:
context: . context: .
dockerfile: worker.Dockerfile dockerfile: worker.Dockerfile
args: args:
WHISPER_MODEL: ${WHISPER_MODEL} WHISPER_MODEL: ${WHISPER_MODEL}
depends_on:
rabbitmq:
condition: service_healthy
networks: networks:
- app - app
depends_on:
- redis
healthcheck:
test: ["CMD-SHELL", "celery -b ${BROKER_URL} inspect ping -d celery@$$HOSTNAME"]
interval: 5s
timeout: 5s
retries: 5
web: web:
env_file: .env env_file: .env
environment:
<<: *broker-environment
build: build:
context: . context: .
dockerfile: web.Dockerfile dockerfile: web.Dockerfile
depends_on:
rabbitmq:
condition: service_healthy
networks: networks:
- app - app
- traefik - traefik
depends_on:
worker:
condition: service_healthy
networks: networks:
app: app:
driver: bridge driver: bridge
traefik: traefik:
driver: bridge driver: bridge
volumes:
rabbitmq-data:

View File

@@ -13,6 +13,8 @@ services:
web: web:
command: bash -c "alembic upgrade head && uvicorn app.web:app --reload --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info" command: bash -c "alembic upgrade head && uvicorn app.web:app --reload --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info"
# NOTE: the docker on mac mount adapter (virtioFS) does not support flock.
# this can cause the sqlite database to corrupt when written from worker <> api simultaneously.
volumes: volumes:
- ./:/etc/whisperbox-transcribe/ - ./:/etc/whisperbox-transcribe/
labels: labels:
@@ -26,13 +28,18 @@ services:
volumes: volumes:
- ./:/etc/whisperbox-transcribe/ - ./:/etc/whisperbox-transcribe/
rabbitmq:
image: rabbitmq:3-management-alpine
ports:
- 15672:15672
flower: flower:
image: mher/flower image: mher/flower
command: celery --broker redis://redis:6379/0 flower --port=5555 command: celery --broker amqp://${RABBITMQ_DEFAULT_USER}:${RABBITMQ_DEFAULT_PASS}@rabbitmq:5672 flower --port=5555
ports: ports:
- 5555:5555 - 5555:5555
depends_on: depends_on:
worker: - worker
condition: service_healthy - rabbitmq
networks: networks:
- app - app

View File

@@ -4,7 +4,7 @@ description = ""
version = "1.0.0" version = "1.0.0"
dependencies=[ dependencies=[
"celery[redis] ==5.3.1", "celery ==5.3.1",
"sqlalchemy[mypy] ==2.0.20", "sqlalchemy[mypy] ==2.0.20",
"pydantic ==2.1.1", "pydantic ==2.1.1",
"pydantic-settings ==2.0.3" "pydantic-settings ==2.0.3"