31 Commits

Author SHA1 Message Date
renovate[bot]
8decdf4c02 chore(deps): update dependency tooling/python-dotenv to v1.0.1 2024-02-01 01:58:51 +00:00
renovate[bot]
f076ef9a1b chore(deps): update dependency sqlalchemy to v2.0.24 (#63) 2024-01-01 14:30:41 +01:00
renovate[bot]
bcfc3616d6 chore(deps): update dependency pydantic-settings to v2.1.0 (#76) 2024-01-01 14:30:21 +01:00
renovate[bot]
f4f760ee19 chore(deps): update dependency celery to v5.3.6 (#61) 2024-01-01 14:29:31 +01:00
renovate[bot]
e5166dab2e chore(deps): update dependency pydantic to v2.5.3 (#74) 2024-01-01 14:28:38 +01:00
Felix Spöttel
3559aa5936 refactor: use Depends for settings and session_local (#59) 2024-01-01 14:28:19 +01:00
renovate[bot]
557de5a442 chore(deps): update dependency tooling/types-requests to v2.31.0.20231231 (#72) 2024-01-01 11:36:40 +01:00
renovate[bot]
4402dc23bb chore(deps): update dependency tooling/httpx to v0.26.0 (#73) 2024-01-01 11:35:50 +01:00
renovate[bot]
50d5a63232 chore(deps): update dependency tooling/pytest to v7.4.4 (#71) 2024-01-01 11:35:15 +01:00
renovate[bot]
65fca1f597 chore(deps): update dependency tooling/black to v23.12.1 (#77) 2024-01-01 11:34:49 +01:00
renovate[bot]
21006d33dd chore(deps): update dependency pydantic to v2.4.2 (#67) 2023-10-10 13:20:51 +02:00
renovate[bot]
3ee1e9f685 chore(deps): update dependency tooling/black to v23.9.1 (#68) 2023-10-10 13:20:30 +02:00
renovate[bot]
05eed3f6ea chore(deps): update dependency tooling/httpx to v0.25.0 (#69) 2023-10-10 13:20:08 +02:00
renovate[bot]
9fe10389b8 chore(deps): update dependency tooling/types-requests to v2.31.0.8 (#66) 2023-10-10 13:16:10 +02:00
renovate[bot]
4ae14366a7 chore(deps): update dependency tooling/pytest to v7.4.2 (#64) 2023-10-10 13:15:48 +02:00
renovate[bot]
bbc00affa1 chore(deps): update dependency tooling/ruff to v0.0.292 (#65) 2023-10-10 13:13:54 +02:00
renovate[bot]
f469903d47 chore(deps): update dependency tooling/ruff to v0.0.286 (#62) 2023-09-01 09:49:28 +02:00
renovate[bot]
101903a7a2 chore(deps): update dependency tooling/ruff to v0.0.285 (#60) 2023-08-18 00:30:57 +02:00
Felix Spöttel
504975a07a feat: configure celery to use rabbitmq broker (#58) 2023-08-17 22:45:51 +02:00
Felix Spöttel
423018e92a fix: missing commit when deleting jobs (#56) 2023-08-17 13:46:59 +02:00
Felix Spöttel
cf07aa6d52 chore: update renovate manager config (#54) 2023-08-17 13:31:48 +02:00
renovate[bot]
aeccad6226 chore(deps): update dependency mypy to v1.5.1 (#55) 2023-08-17 00:23:14 +02:00
renovate[bot]
21790fffeb chore(deps): update dependency fastapi to v0.101.1 (#49) 2023-08-16 23:12:02 +02:00
renovate[bot]
28754ee0e9 chore(deps): update dependency pydantic-settings to v2.0.3 (#53) 2023-08-16 22:41:35 +02:00
renovate[bot]
ec203127fa chore(deps): update dependency sqlalchemy to v2.0.20 (#52) 2023-08-16 22:36:28 +02:00
renovate[bot]
8e35968b04 chore(deps): update dependency ruff to v0.0.284 (#51) 2023-08-16 22:33:29 +02:00
renovate[bot]
7baa24ff78 chore(deps): update dependency alembic to v1.11.3 (#50) 2023-08-16 22:30:26 +02:00
renovate[bot]
3a905148a0 chore(deps): update dependency alembic to v1.11.2 (#47) 2023-08-07 17:04:55 +02:00
renovate[bot]
8579667777 chore(deps): update dependency ruff to v0.0.282 (#48) 2023-08-07 17:02:22 +02:00
Miguel Sozinho Ramalho
7428cceb0f Merge pull request #46 from bellingcat/build/release 2023-08-07 15:23:59 +01:00
Felix Spöttel
7cb6a3eff6 build(release): v1.0.0 2023-08-07 14:34:38 +02:00
28 changed files with 324 additions and 215 deletions

View File

@@ -3,4 +3,6 @@ TRAEFIK_DOMAIN="whisperbox-transcribe.localhost"
WHISPER_MODEL="tiny" WHISPER_MODEL="tiny"
ENVIRONMENT="development" ENVIRONMENT="development"
DATABASE_URI="sqlite:///./whisperbox-transcribe.sqlite" DATABASE_URI="sqlite:///./whisperbox-transcribe.sqlite"
BROKER_URL="redis://redis:6379/0"
RABBITMQ_DEFAULT_USER="rabbitmq"
RABBITMQ_DEFAULT_PASS="rabbitmq_password"

View File

@@ -16,6 +16,8 @@ TRAEFIK_SSLEMAIL=""
# --- # ---
# below settings match the default docker-compose configuration. # below settings match the default docker-compose configuration.
BROKER_URL="redis://redis:6379/0" RABBITMQ_DEFAULT_USER="rabbitmq"
RABBITMQ_DEFAULT_PASS="rabbitmq_password"
DATABASE_URI="sqlite:////etc/whisperbox-transcribe/data/whisperbox-transcribe.sqlite" DATABASE_URI="sqlite:////etc/whisperbox-transcribe/data/whisperbox-transcribe.sqlite"
ENVIRONMENT="production" ENVIRONMENT="production"

15
.github/renovate.json vendored
View File

@@ -2,18 +2,5 @@
"$schema": "https://docs.renovatebot.com/renovate-schema.json", "$schema": "https://docs.renovatebot.com/renovate-schema.json",
"extends": ["config:base", "schedule:monthly"], "extends": ["config:base", "schedule:monthly"],
"timezone": "Europe/Berlin", "timezone": "Europe/Berlin",
"pip_setup": { "enabledManagers": ["dockerfile", "docker-compose", "pep621"]
"fileMatch": [
"(^|/)pyproject\\.toml$"
]
},
"dockerfile": {
"enabled": false
},
"docker-compose": {
"enabled": false
},
"pyenv": {
"enabled": false
}
} }

View File

@@ -57,6 +57,7 @@ Builds and starts the docker containers.
``` ```
# Bindings # Bindings
http://localhost:5555 => Celery dashboard http://localhost:5555 => Celery dashboard
http://localhost:15672 => RabbitMQ dashboard
http://whisperbox-transcribe.localhost => API http://whisperbox-transcribe.localhost => API
http://whisperbox-transcribe.localhost/docs => API docs http://whisperbox-transcribe.localhost/docs => API docs
./whisperbox-transcribe.sqlite => Database ./whisperbox-transcribe.sqlite => Database

View File

@@ -1,12 +1,9 @@
from celery import Celery from celery import Celery
from app.shared.settings import settings
def get_celery_binding(broker_url: str) -> Celery:
def get_celery_binding() -> Celery: return Celery(
celery = Celery( broker_url=broker_url,
broker_url=settings.BROKER_URL,
broker_connection_retry=False, broker_connection_retry=False,
broker_connection_retry_on_startup=False, broker_connection_retry_on_startup=False,
) )
return celery

View File

@@ -4,7 +4,9 @@ from alembic import context
from sqlalchemy import engine_from_config, pool from sqlalchemy import engine_from_config, pool
from app.shared.db.models import Base from app.shared.db.models import Base
from app.shared.settings import settings from app.shared.settings import Settings
settings = Settings() # type: ignore
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides
# access to the values within the .ini file in use. # access to the values within the .ini file in use.

View File

@@ -1,26 +1,21 @@
from typing import Any, Generator from typing import Any
from sqlalchemy import create_engine, event from sqlalchemy import Engine, create_engine, event
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import sessionmaker
from app.shared.settings import settings
engine = create_engine(settings.DATABASE_URI, connect_args={"check_same_thread": False})
@event.listens_for(engine, "connect") def make_engine(database_url: str):
def set_sqlite_pragma(conn: Any, _: Any) -> None: engine = create_engine(database_url, connect_args={"check_same_thread": False})
cursor = conn.cursor()
cursor.execute("PRAGMA journal_mode=WAL") @event.listens_for(engine, "connect")
cursor.close() def set_sqlite_pragma(conn: Any, _: Any) -> None:
cursor = conn.cursor()
cursor.execute("PRAGMA journal_mode=WAL")
cursor.close()
return engine
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) def make_session_local(engine: Engine):
session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
return session_local
def get_session() -> Generator[Session, None, None]:
session: Session = SessionLocal()
try:
yield session
finally:
session.close()

View File

@@ -52,6 +52,11 @@ class JobConfig(BaseModel):
class JobMeta(BaseModel): class JobMeta(BaseModel):
"""(JSON) Metadata relating to a job's execution.""" """(JSON) Metadata relating to a job's execution."""
attempts: int | None = Field(
default=None,
description="Number of processing attempts a job has taken.",
)
error: str | None = Field( error: str | None = Field(
default=None, default=None,
description="Will contain a descriptive error message if processing failed.", description="Will contain a descriptive error message if processing failed.",

7
app/shared/logger.py Normal file
View File

@@ -0,0 +1,7 @@
import logging
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

View File

@@ -1,5 +1,3 @@
import sys
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
@@ -13,9 +11,3 @@ class Settings(BaseSettings):
TASK_HARD_TIME_LIMIT: int = 4 * 60 * 60 TASK_HARD_TIME_LIMIT: int = 4 * 60 * 60
ENABLE_SHARING: bool = False ENABLE_SHARING: bool = False
if "pytest" in sys.modules:
settings = Settings(_env_file=".env.test") # type: ignore
else:
settings = Settings() # type: ignore

View File

@@ -3,46 +3,62 @@ from fastapi.testclient import TestClient
from sqlalchemy_utils import create_database, database_exists, drop_database from sqlalchemy_utils import create_database, database_exists, drop_database
import app.shared.db.models as models import app.shared.db.models as models
from app.shared.db.base import SessionLocal, engine from app.shared.db.base import make_engine, make_session_local
from app.shared.settings import settings from app.shared.settings import Settings
from app.web.injections.db import get_session
from app.web.injections.settings import get_settings
from app.web.main import app_factory from app.web.main import app_factory
def pytest_configure() -> None: @pytest.fixture()
if not database_exists(engine.url): def settings():
create_database(engine.url) return Settings(_env_file=".env.test") # type: ignore
def pytest_unconfigure() -> None: @pytest.fixture()
if database_exists(engine.url): def auth_headers(settings) -> dict[str, str]:
drop_database(engine.url)
@pytest.fixture(scope="function")
def auth_headers() -> dict[str, str]:
return {"Authorization": f"Bearer {settings.API_SECRET}"} return {"Authorization": f"Bearer {settings.API_SECRET}"}
@pytest.fixture(scope="function", autouse=True) @pytest.fixture()
def db_session(): def test_db(settings):
models.Base.metadata.create_all(engine) engine = make_engine(settings.DATABASE_URI)
connection = engine.connect()
with SessionLocal(bind=connection) as session: if not database_exists(engine.url):
yield session create_database(engine.url)
connection.close()
models.Base.metadata.create_all(engine)
connection = engine.connect()
yield connection
connection.close()
models.Base.metadata.drop_all(bind=engine) models.Base.metadata.drop_all(bind=engine)
drop_database(engine.url)
@pytest.fixture(scope="function") @pytest.fixture()
def client(db_session): def db_session(test_db):
app = app_factory(lambda: db_session) session_local = make_session_local(test_db)
with session_local() as session:
yield session
@pytest.fixture()
def app(db_session, settings):
app = app_factory()
app.dependency_overrides[get_settings] = lambda: settings
app.dependency_overrides[get_session] = lambda: db_session
return app
@pytest.fixture()
def client(app):
client = TestClient(app) client = TestClient(app)
return client return client
@pytest.fixture(scope="function", autouse=False) @pytest.fixture()
def mock_job(db_session): def mock_job(db_session):
job = models.Job( job = models.Job(
url="https://example.com", url="https://example.com",
@@ -51,22 +67,15 @@ def mock_job(db_session):
meta={"task_id": "5c790c76-2cc1-4e91-a305-443df55a4a4c"}, meta={"task_id": "5c790c76-2cc1-4e91-a305-443df55a4a4c"},
) )
db_session.add(job) db_session.add(job)
db_session.flush() db_session.commit()
return job return job
@pytest.fixture(scope="function", autouse=False) @pytest.fixture()
def mock_artifact(db_session, mock_job): def mock_artifact(db_session, mock_job):
artifact = models.Artifact( artifact = models.Artifact(
data=None, job_id=str(mock_job.id), type=models.ArtifactType.raw_transcript data=None, job_id=str(mock_job.id), type=models.ArtifactType.raw_transcript
) )
db_session.add(artifact) db_session.add(artifact)
db_session.flush() db_session.commit()
return artifact return artifact
@pytest.fixture(scope="function")
def sharing_enabled():
settings.ENABLE_SHARING = True
yield
settings.ENABLE_SHARING = False

View File

@@ -1,7 +1,6 @@
from fastapi.testclient import TestClient
import app.shared.db.models as models import app.shared.db.models as models
from app.web.main import app_factory from app.shared.settings import Settings
from app.web.injections.settings import get_settings
# POST /api/v1/jobs # POST /api/v1/jobs
@@ -69,9 +68,10 @@ def test_get_job_sharing_disabled(client, mock_job):
assert res.status_code == 401 assert res.status_code == 401
def test_get_job_sharing_enabled(db_session, mock_job, sharing_enabled): def test_get_job_sharing_enabled(client, app, mock_job):
# HACK: delay construction until settings are patched. app.dependency_overrides[get_settings] = lambda: Settings(
client = TestClient(app_factory(lambda: db_session)) _env_file=".env.test", ENABLE_SHARING=True # type: ignore
)
res = client.get( res = client.get(
f"/api/v1/jobs/{mock_job.id}", f"/api/v1/jobs/{mock_job.id}",
@@ -107,10 +107,25 @@ def test_get_artifacts_not_found(client, auth_headers, mock_job):
# DELETE /api/v1/jobs # DELETE /api/v1/jobs
# --- # ---
def test_delete_job_pass(client, auth_headers, mock_job, db_session): def test_delete_job_pass(client, auth_headers, mock_job, db_session):
res = client.delete( res_job = client.get(
f"/api/v1/jobs/{mock_job.id}", f"/api/v1/jobs/{mock_job.id}",
headers=auth_headers, headers=auth_headers,
) )
assert db_session.query(models.Job).count() == 0 assert res_job.status_code == 200
assert res.status_code == 204
client.delete(
f"/api/v1/jobs/{mock_job.id}",
headers=auth_headers,
)
# HACK: this catches a missed .commit().
# TODO: clean up pytest database handling.
db_session.rollback()
res_job_missing = client.get(
f"/api/v1/jobs/{mock_job.id}",
headers=auth_headers,
)
assert res_job_missing.status_code == 404

View File

@@ -1,4 +1,3 @@
from app.shared.db.base import get_session
from app.web.main import app_factory from app.web.main import app_factory
app = app_factory(get_session) app = app_factory

View File

26
app/web/injections/db.py Normal file
View File

@@ -0,0 +1,26 @@
from functools import lru_cache
from typing import Generator
from fastapi import Depends
from sqlalchemy.orm import Session
from app.shared.db.base import make_engine, make_session_local
from app.shared.settings import Settings
from app.web.injections.settings import get_settings
@lru_cache
def session_local(database_url: str):
engine = make_engine(database_url)
return make_session_local(engine)
def get_session_local(settings: Settings = Depends(get_settings)):
return session_local(settings.DATABASE_URI)
def get_session(
session_local=Depends(get_session_local),
) -> Generator[Session, None, None]:
with session_local() as session:
yield session

View File

@@ -0,0 +1,39 @@
from hmac import compare_digest
from typing import Annotated
from fastapi import Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from app.shared.settings import Settings
from app.web.injections.settings import get_settings
def api_key_auth(
credentials: Annotated[
HTTPAuthorizationCredentials, Depends(HTTPBearer(auto_error=False))
],
settings: Annotated[Settings, Depends(get_settings)],
):
validate_credentials(credentials, settings.API_SECRET)
def sharing_auth(
credentials: Annotated[
HTTPAuthorizationCredentials, Depends(HTTPBearer(auto_error=False))
],
settings: Annotated[Settings, Depends(get_settings)],
):
if settings.ENABLE_SHARING:
pass
else:
validate_credentials(credentials, settings.API_SECRET)
def validate_credentials(credentials: HTTPAuthorizationCredentials, secret: str):
# use compare_digest to counter timing attacks.
if (
not credentials
or not secret
or not compare_digest(secret, credentials.credentials)
):
raise HTTPException(status_code=401)

View File

@@ -0,0 +1,8 @@
from functools import lru_cache
from app.shared.settings import Settings
@lru_cache
def get_settings():
return Settings() # type: ignore

View File

@@ -0,0 +1,16 @@
from functools import lru_cache
from fastapi import Depends
from app.shared.settings import Settings
from app.web.injections.settings import get_settings
from app.web.task_queue import TaskQueue
@lru_cache
def task_queue(broker_url: str):
return TaskQueue(broker_url)
def get_task_queue(settings: Settings = Depends(get_settings)):
return task_queue(settings.BROKER_URL)

View File

@@ -1,5 +1,4 @@
from contextlib import asynccontextmanager from typing import Annotated
from typing import Annotated, Callable, Generator
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path
@@ -8,42 +7,31 @@ from sqlalchemy.orm import Session
import app.shared.db.models as models import app.shared.db.models as models
import app.web.dtos as dtos import app.web.dtos as dtos
from app.shared.db.base import SessionLocal from app.web.injections.db import get_session
from app.shared.settings import settings from app.web.injections.security import api_key_auth, sharing_auth
from app.web.security import authenticate_api_key from app.web.injections.task_queue import get_task_queue
from app.web.task_queue import TaskQueue from app.web.task_queue import TaskQueue
DatabaseSession = Annotated[Session, Depends(get_session)]
def app_factory(
session_getter: Callable[[], Generator[Session, None, None]]
) -> FastAPI:
DatabaseSession = Annotated[Session, Depends(session_getter)]
task_queue = TaskQueue()
@asynccontextmanager
async def lifespan(_: FastAPI):
with SessionLocal() as session:
task_queue.rehydrate(session)
yield
def app_factory():
app = FastAPI( app = FastAPI(
description=( description=(
"whisperbox-transcribe is an async HTTP wrapper for openai/whisper." "whisperbox-transcribe is an async HTTP wrapper for openai/whisper."
), ),
lifespan=lifespan,
title="whisperbox-transcribe", title="whisperbox-transcribe",
) )
api_router = APIRouter(prefix="/api/v1") api_router = APIRouter(prefix="/api/v1")
@api_router.get("/", response_model=None, status_code=204) @api_router.get("/", status_code=204)
def api_root() -> None: def api_root():
return None return None
@api_router.get( @api_router.get(
"/jobs", "/jobs",
dependencies=[Depends(authenticate_api_key)], dependencies=[Depends(api_key_auth)],
response_model=list[dtos.Job], response_model=list[dtos.Job],
summary="Get metadata for all jobs", summary="Get metadata for all jobs",
) )
@@ -61,7 +49,7 @@ def app_factory(
@api_router.get( @api_router.get(
"/jobs/{id}", "/jobs/{id}",
dependencies=[] if settings.ENABLE_SHARING else [Depends(authenticate_api_key)], dependencies=[Depends(sharing_auth)],
response_model=dtos.Job, response_model=dtos.Job,
summary="Get metadata for one job", summary="Get metadata for one job",
) )
@@ -81,7 +69,7 @@ def app_factory(
@api_router.get( @api_router.get(
"/jobs/{id}/artifacts", "/jobs/{id}/artifacts",
dependencies=[] if settings.ENABLE_SHARING else [Depends(authenticate_api_key)], dependencies=[Depends(api_key_auth)],
response_model=list[dtos.Artifact], response_model=list[dtos.Artifact],
summary="Get all artifacts for one job", summary="Get all artifacts for one job",
) )
@@ -102,7 +90,7 @@ def app_factory(
@api_router.delete( @api_router.delete(
"/jobs/{id}", "/jobs/{id}",
dependencies=[Depends(authenticate_api_key)], dependencies=[Depends(sharing_auth)],
status_code=204, status_code=204,
summary="Delete a job with all artifacts", summary="Delete a job with all artifacts",
) )
@@ -112,6 +100,7 @@ def app_factory(
) -> None: ) -> None:
"""Remove metadata and artifacts for a single job.""" """Remove metadata and artifacts for a single job."""
session.query(models.Job).filter(models.Job.id == str(id)).delete() session.query(models.Job).filter(models.Job.id == str(id)).delete()
session.commit()
return None return None
class PostJobPayload(BaseModel): class PostJobPayload(BaseModel):
@@ -138,7 +127,7 @@ def app_factory(
@api_router.post( @api_router.post(
"/jobs", "/jobs",
dependencies=[Depends(authenticate_api_key)], dependencies=[Depends(api_key_auth)],
response_model=dtos.Job, response_model=dtos.Job,
status_code=201, status_code=201,
summary="Enqueue a new job", summary="Enqueue a new job",
@@ -146,6 +135,7 @@ def app_factory(
def create_job( def create_job(
payload: PostJobPayload, payload: PostJobPayload,
session: DatabaseSession, session: DatabaseSession,
task_queue: Annotated[TaskQueue, Depends(get_task_queue)],
) -> models.Job: ) -> models.Job:
""" """
Enqueue a new whisper job for processing. Enqueue a new whisper job for processing.

View File

@@ -1,16 +0,0 @@
from hmac import compare_digest
from fastapi import Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from app.shared.settings import settings
def authenticate_api_key(
credentials: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
) -> None:
# use compare_digest to counter timing attacks.
if not credentials or not compare_digest(
settings.API_SECRET, credentials.credentials
):
raise HTTPException(status_code=401)

View File

@@ -1,8 +1,4 @@
from asyncio.log import logger
from celery import Celery from celery import Celery
from sqlalchemy import or_
from sqlalchemy.orm import Session
import app.shared.db.models as models import app.shared.db.models as models
from app.shared.celery import get_celery_binding from app.shared.celery import get_celery_binding
@@ -11,8 +7,8 @@ from app.shared.celery import get_celery_binding
class TaskQueue: class TaskQueue:
celery: Celery celery: Celery
def __init__(self) -> None: def __init__(self, broker_url: str) -> None:
self.celery = get_celery_binding() self.celery = get_celery_binding(broker_url=broker_url)
def queue_task(self, job: models.Job): def queue_task(self, job: models.Job):
""" """
@@ -22,25 +18,3 @@ class TaskQueue:
transcribe = self.celery.signature("app.worker.main.transcribe") transcribe = self.celery.signature("app.worker.main.transcribe")
# TODO: catch delivery errors? # TODO: catch delivery errors?
transcribe.delay(job.id) transcribe.delay(job.id)
def rehydrate(self, session: Session):
# TODO: we could use `acks_late` to handle this scenario within celery itself.
# the reason this does not work well in our case is that `visibility_timeout`
# needs to be very high since whisper workers can be long running.
# doing this app-side bears the risk of poison pilling the worker though,
# implement a workaround with an acceptable trade-off. (=> retry only once?)
jobs = (
session.query(models.Job)
.filter(
or_(
models.Job.status == models.JobStatus.processing,
models.Job.status == models.JobStatus.create,
)
)
.order_by(models.Job.created_at)
).all()
logger.info(f"Requeueing {len(jobs)} jobs.")
for job in jobs:
self.queue_task(job)

View File

@@ -1,4 +1,3 @@
from asyncio.log import logger
from typing import Any from typing import Any
from uuid import UUID from uuid import UUID
@@ -7,11 +6,16 @@ from sqlalchemy.orm import Session
import app.shared.db.models as models import app.shared.db.models as models
from app.shared.celery import get_celery_binding from app.shared.celery import get_celery_binding
from app.shared.db.base import SessionLocal from app.shared.db.base import make_engine, make_session_local
from app.shared.settings import settings from app.shared.logger import logger
from app.shared.settings import Settings
from app.worker.strategies.local import LocalStrategy from app.worker.strategies.local import LocalStrategy
celery = get_celery_binding() # TODO: refactor to be part of a Task instance.
settings = Settings() # type: ignore
celery = get_celery_binding(settings.BROKER_URL)
engine = make_engine(settings.DATABASE_URI)
SessionLocal = make_session_local(engine)
class TranscribeTask(Task): class TranscribeTask(Task):
@@ -39,15 +43,28 @@ class TranscribeTask(Task):
bind=True, bind=True,
soft_time_limit=settings.TASK_SOFT_TIME_LIMIT, soft_time_limit=settings.TASK_SOFT_TIME_LIMIT,
time_limit=settings.TASK_HARD_TIME_LIMIT, time_limit=settings.TASK_HARD_TIME_LIMIT,
task_acks_late=True,
task_acks_on_failure_or_timeout=True,
task_reject_on_worker_lost=True,
) )
def transcribe(self: Task, job_id: UUID) -> None: def transcribe(self: TranscribeTask, job_id: UUID) -> None:
session: Session | None = None
job: models.Job | None = None
try: try:
if not self.strategy:
raise Exception("expected a transcription strategy to be defined.")
# runs in a separate thread => requires sqlite's WAL mode to be enabled. # runs in a separate thread => requires sqlite's WAL mode to be enabled.
db: Session = SessionLocal() session = SessionLocal()
# work around mypy not inferring the sum type correctly.
if not session:
raise Exception("failed to acquire a session.")
# check if passed job should be processed. # check if passed job should be processed.
job = db.query(models.Job).filter(models.Job.id == job_id).one_or_none() job = session.query(models.Job).filter(models.Job.id == job_id).one_or_none()
if job is None: if job is None:
logger.warn("[{job.id}]: Received unknown job, abort.") logger.warn("[{job.id}]: Received unknown job, abort.")
@@ -59,11 +76,22 @@ def transcribe(self: Task, job_id: UUID) -> None:
logger.debug(f"[{job.id}]: start processing {job.type} job.") logger.debug(f"[{job.id}]: start processing {job.type} job.")
if job.meta is not None:
attempts = 1 + (job.meta.get("attempts") or 0)
else:
attempts = 1
# SAFEGUARD: celery's retry policies do not handle lost workers, retry once.
# @see https://github.com/celery/celery/pull/6103
if attempts > 2:
raise Exception("Maximum number of retries exceeded for killed worker.")
# unit of work: set task status to processing. # unit of work: set task status to processing.
job.meta = {"task_id": self.request.id} job.meta = {"task_id": self.request.id, "attempts": attempts}
job.status = models.JobStatus.processing job.status = models.JobStatus.processing
db.commit() session.commit()
logger.debug(f"[{job.id}]: finished setting task to {job.status}.") logger.debug(f"[{job.id}]: finished setting task to {job.status}.")
@@ -72,21 +100,27 @@ def transcribe(self: Task, job_id: UUID) -> None:
logger.debug(f"[{job.id}]: successfully processed audio.") logger.debug(f"[{job.id}]: successfully processed audio.")
artifact = models.Artifact(job_id=str(job.id), data=result, type=result_type) artifact = models.Artifact(job_id=str(job.id), data=result, type=result_type)
db.add(artifact) session.add(artifact)
job.status = models.JobStatus.success job.status = models.JobStatus.success
db.commit() session.commit()
logger.debug(f"[{job.id}]: successfully stored artifact.") logger.debug(f"[{job.id}]: successfully stored artifact.")
except Exception as e: except Exception as e:
if job and db: if job and session:
if db.in_transaction(): if session.in_transaction():
db.rollback() session.rollback()
job.meta = {**job.meta, "error": str(e)} # type: ignore if job.meta is not None:
job.meta = {**job.meta, "error": str(e)}
else:
job.meta = {"error": str(e)}
job.status = models.JobStatus.error job.status = models.JobStatus.error
db.commit() session.commit()
raise raise
finally: finally:
self.strategy.cleanup(job_id) if self.strategy:
db.close() self.strategy.cleanup(job_id)
if session:
session.close()

1
conf/rabbitmq.conf Normal file
View File

@@ -0,0 +1 @@
vm_memory_high_watermark.absolute = 192MB

View File

@@ -1,3 +1,6 @@
x-broker-environment: &broker-environment
BROKER_URL: "amqp://${RABBITMQ_DEFAULT_USER}:${RABBITMQ_DEFAULT_PASS}@rabbitmq:5672"
version: "3.8" version: "3.8"
name: whisperbox-transcribe name: whisperbox-transcribe
@@ -12,46 +15,59 @@ services:
networks: networks:
- traefik - traefik
redis: rabbitmq:
image: redis:7-alpine env_file: .env
image: rabbitmq:3-alpine
networks: networks:
- app - app
deploy: deploy:
resources: resources:
limits: limits:
memory: 128M memory: 256M
healthcheck:
test: rabbitmq-diagnostics check_port_connectivity
interval: 3s
timeout: 3s
retries: 10
volumes:
- ./conf/rabbitmq.conf:/etc/rabbitmq/rabbitmq.conf
- rabbitmq-data:/var/lib/rabbitmq/mnesia/
worker: worker:
env_file: .env env_file: .env
environment:
<<: *broker-environment
build: build:
context: . context: .
dockerfile: worker.Dockerfile dockerfile: worker.Dockerfile
args: args:
WHISPER_MODEL: ${WHISPER_MODEL} WHISPER_MODEL: ${WHISPER_MODEL}
depends_on:
rabbitmq:
condition: service_healthy
networks: networks:
- app - app
depends_on:
- redis
healthcheck:
test: ["CMD-SHELL", "celery -b ${BROKER_URL} inspect ping -d celery@$$HOSTNAME"]
interval: 5s
timeout: 5s
retries: 5
web: web:
env_file: .env env_file: .env
environment:
<<: *broker-environment
build: build:
context: . context: .
dockerfile: web.Dockerfile dockerfile: web.Dockerfile
depends_on:
rabbitmq:
condition: service_healthy
networks: networks:
- app - app
- traefik - traefik
depends_on:
worker:
condition: service_healthy
networks: networks:
app: app:
driver: bridge driver: bridge
traefik: traefik:
driver: bridge driver: bridge
volumes:
rabbitmq-data:

View File

@@ -12,7 +12,9 @@ services:
- "--entrypoints.web.address=:80" - "--entrypoints.web.address=:80"
web: web:
command: bash -c "alembic upgrade head && uvicorn app.web:app --reload --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info" command: bash -c "alembic upgrade head && uvicorn app.web:app --reload --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info --factory"
# NOTE: the docker on mac mount adapter (virtioFS) does not support flock.
# this can cause the sqlite database to corrupt when written from worker <> api simultaneously.
volumes: volumes:
- ./:/etc/whisperbox-transcribe/ - ./:/etc/whisperbox-transcribe/
labels: labels:
@@ -26,13 +28,18 @@ services:
volumes: volumes:
- ./:/etc/whisperbox-transcribe/ - ./:/etc/whisperbox-transcribe/
rabbitmq:
image: rabbitmq:3-management-alpine
ports:
- 15672:15672
flower: flower:
image: mher/flower image: mher/flower
command: celery --broker redis://redis:6379/0 flower --port=5555 command: celery --broker amqp://${RABBITMQ_DEFAULT_USER}:${RABBITMQ_DEFAULT_PASS}@rabbitmq:5672 flower --port=5555
ports: ports:
- 5555:5555 - 5555:5555
depends_on: depends_on:
worker: - worker
condition: service_healthy - rabbitmq
networks: networks:
- app - app

View File

@@ -2,3 +2,4 @@
plugins = sqlalchemy.ext.mypy.plugin plugins = sqlalchemy.ext.mypy.plugin
ignore_missing_imports = True ignore_missing_imports = True
disallow_untyped_defs = False disallow_untyped_defs = False
check_untyped_defs = True

View File

@@ -1,19 +1,19 @@
[project] [project]
name = "whisperbox-transcribe" name = "whisperbox-transcribe"
description = "" description = ""
version = "0.1.0" version = "1.0.0"
dependencies=[ dependencies=[
"celery[redis] ==5.3.1", "celery ==5.3.6",
"sqlalchemy[mypy] ==2.0.19", "sqlalchemy[mypy] ==2.0.24",
"pydantic ==2.1.1", "pydantic ==2.5.3",
"pydantic-settings ==2.0.2" "pydantic-settings ==2.1.0"
] ]
[project.optional-dependencies] [project.optional-dependencies]
web=[ web=[
"alembic ==1.11.1", "alembic ==1.11.3",
"fastapi ==0.100.1", "fastapi ==0.101.1",
"uvicorn[standard] ==0.23.2", "uvicorn[standard] ==0.23.2",
"gunicorn ==21.2.0" "gunicorn ==21.2.0"
] ]
@@ -26,17 +26,17 @@ worker=[
tooling = [ tooling = [
# code formatting # code formatting
"black ==23.7.0", "black ==23.12.1",
# linting # linting
"ruff ==0.0.280", "ruff ==0.0.292",
# tests # tests
"httpx ==0.24.1", "httpx ==0.26.0",
"sqlalchemy-utils ==0.41.1", "sqlalchemy-utils ==0.41.1",
"python-dotenv ==1.0.0", "python-dotenv ==1.0.1",
"pytest ==7.4.0", "pytest ==7.4.4",
# types # types
"mypy ==1.4.1", "mypy ==1.5.1",
"types-requests ==2.31.0.2" "types-requests ==2.31.0.20231231"
] ]
[tool.ruff] [tool.ruff]

View File

@@ -20,4 +20,4 @@ COPY alembic.ini .
ENV VIRTUAL_ENV /opt/venv ENV VIRTUAL_ENV /opt/venv
ENV PATH /opt/venv/bin:$PATH ENV PATH /opt/venv/bin:$PATH
CMD alembic upgrade head && uvicorn app.web:app --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info --workers 4 --proxy-headers CMD alembic upgrade head && uvicorn app.web:app --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info --workers 4 --proxy-headers --factory