mirror of
https://github.com/bellingcat/whisperbox-transcribe.git
synced 2026-06-11 21:18:36 +03:00
Compare commits
31 Commits
v1.0.0
...
renovate/t
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8decdf4c02 | ||
|
|
f076ef9a1b | ||
|
|
bcfc3616d6 | ||
|
|
f4f760ee19 | ||
|
|
e5166dab2e | ||
|
|
3559aa5936 | ||
|
|
557de5a442 | ||
|
|
4402dc23bb | ||
|
|
50d5a63232 | ||
|
|
65fca1f597 | ||
|
|
21006d33dd | ||
|
|
3ee1e9f685 | ||
|
|
05eed3f6ea | ||
|
|
9fe10389b8 | ||
|
|
4ae14366a7 | ||
|
|
bbc00affa1 | ||
|
|
f469903d47 | ||
|
|
101903a7a2 | ||
|
|
504975a07a | ||
|
|
423018e92a | ||
|
|
cf07aa6d52 | ||
|
|
aeccad6226 | ||
|
|
21790fffeb | ||
|
|
28754ee0e9 | ||
|
|
ec203127fa | ||
|
|
8e35968b04 | ||
|
|
7baa24ff78 | ||
|
|
3a905148a0 | ||
|
|
8579667777 | ||
|
|
7428cceb0f | ||
|
|
7cb6a3eff6 |
4
.env.dev
4
.env.dev
@@ -3,4 +3,6 @@ TRAEFIK_DOMAIN="whisperbox-transcribe.localhost"
|
|||||||
WHISPER_MODEL="tiny"
|
WHISPER_MODEL="tiny"
|
||||||
ENVIRONMENT="development"
|
ENVIRONMENT="development"
|
||||||
DATABASE_URI="sqlite:///./whisperbox-transcribe.sqlite"
|
DATABASE_URI="sqlite:///./whisperbox-transcribe.sqlite"
|
||||||
BROKER_URL="redis://redis:6379/0"
|
|
||||||
|
RABBITMQ_DEFAULT_USER="rabbitmq"
|
||||||
|
RABBITMQ_DEFAULT_PASS="rabbitmq_password"
|
||||||
|
|||||||
@@ -16,6 +16,8 @@ TRAEFIK_SSLEMAIL=""
|
|||||||
# ---
|
# ---
|
||||||
# below settings match the default docker-compose configuration.
|
# below settings match the default docker-compose configuration.
|
||||||
|
|
||||||
BROKER_URL="redis://redis:6379/0"
|
RABBITMQ_DEFAULT_USER="rabbitmq"
|
||||||
|
RABBITMQ_DEFAULT_PASS="rabbitmq_password"
|
||||||
|
|
||||||
DATABASE_URI="sqlite:////etc/whisperbox-transcribe/data/whisperbox-transcribe.sqlite"
|
DATABASE_URI="sqlite:////etc/whisperbox-transcribe/data/whisperbox-transcribe.sqlite"
|
||||||
ENVIRONMENT="production"
|
ENVIRONMENT="production"
|
||||||
|
|||||||
15
.github/renovate.json
vendored
15
.github/renovate.json
vendored
@@ -2,18 +2,5 @@
|
|||||||
"$schema": "https://docs.renovatebot.com/renovate-schema.json",
|
"$schema": "https://docs.renovatebot.com/renovate-schema.json",
|
||||||
"extends": ["config:base", "schedule:monthly"],
|
"extends": ["config:base", "schedule:monthly"],
|
||||||
"timezone": "Europe/Berlin",
|
"timezone": "Europe/Berlin",
|
||||||
"pip_setup": {
|
"enabledManagers": ["dockerfile", "docker-compose", "pep621"]
|
||||||
"fileMatch": [
|
|
||||||
"(^|/)pyproject\\.toml$"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"dockerfile": {
|
|
||||||
"enabled": false
|
|
||||||
},
|
|
||||||
"docker-compose": {
|
|
||||||
"enabled": false
|
|
||||||
},
|
|
||||||
"pyenv": {
|
|
||||||
"enabled": false
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ Builds and starts the docker containers.
|
|||||||
```
|
```
|
||||||
# Bindings
|
# Bindings
|
||||||
http://localhost:5555 => Celery dashboard
|
http://localhost:5555 => Celery dashboard
|
||||||
|
http://localhost:15672 => RabbitMQ dashboard
|
||||||
http://whisperbox-transcribe.localhost => API
|
http://whisperbox-transcribe.localhost => API
|
||||||
http://whisperbox-transcribe.localhost/docs => API docs
|
http://whisperbox-transcribe.localhost/docs => API docs
|
||||||
./whisperbox-transcribe.sqlite => Database
|
./whisperbox-transcribe.sqlite => Database
|
||||||
|
|||||||
@@ -1,12 +1,9 @@
|
|||||||
from celery import Celery
|
from celery import Celery
|
||||||
|
|
||||||
from app.shared.settings import settings
|
|
||||||
|
|
||||||
|
def get_celery_binding(broker_url: str) -> Celery:
|
||||||
def get_celery_binding() -> Celery:
|
return Celery(
|
||||||
celery = Celery(
|
broker_url=broker_url,
|
||||||
broker_url=settings.BROKER_URL,
|
|
||||||
broker_connection_retry=False,
|
broker_connection_retry=False,
|
||||||
broker_connection_retry_on_startup=False,
|
broker_connection_retry_on_startup=False,
|
||||||
)
|
)
|
||||||
return celery
|
|
||||||
|
|||||||
@@ -4,7 +4,9 @@ from alembic import context
|
|||||||
from sqlalchemy import engine_from_config, pool
|
from sqlalchemy import engine_from_config, pool
|
||||||
|
|
||||||
from app.shared.db.models import Base
|
from app.shared.db.models import Base
|
||||||
from app.shared.settings import settings
|
from app.shared.settings import Settings
|
||||||
|
|
||||||
|
settings = Settings() # type: ignore
|
||||||
|
|
||||||
# this is the Alembic Config object, which provides
|
# this is the Alembic Config object, which provides
|
||||||
# access to the values within the .ini file in use.
|
# access to the values within the .ini file in use.
|
||||||
|
|||||||
@@ -1,26 +1,21 @@
|
|||||||
from typing import Any, Generator
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import create_engine, event
|
from sqlalchemy import Engine, create_engine, event
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from app.shared.settings import settings
|
|
||||||
|
|
||||||
engine = create_engine(settings.DATABASE_URI, connect_args={"check_same_thread": False})
|
|
||||||
|
|
||||||
|
|
||||||
@event.listens_for(engine, "connect")
|
def make_engine(database_url: str):
|
||||||
def set_sqlite_pragma(conn: Any, _: Any) -> None:
|
engine = create_engine(database_url, connect_args={"check_same_thread": False})
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.execute("PRAGMA journal_mode=WAL")
|
@event.listens_for(engine, "connect")
|
||||||
cursor.close()
|
def set_sqlite_pragma(conn: Any, _: Any) -> None:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("PRAGMA journal_mode=WAL")
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
return engine
|
||||||
|
|
||||||
|
|
||||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
def make_session_local(engine: Engine):
|
||||||
|
session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
return session_local
|
||||||
def get_session() -> Generator[Session, None, None]:
|
|
||||||
session: Session = SessionLocal()
|
|
||||||
try:
|
|
||||||
yield session
|
|
||||||
finally:
|
|
||||||
session.close()
|
|
||||||
|
|||||||
@@ -52,6 +52,11 @@ class JobConfig(BaseModel):
|
|||||||
class JobMeta(BaseModel):
|
class JobMeta(BaseModel):
|
||||||
"""(JSON) Metadata relating to a job's execution."""
|
"""(JSON) Metadata relating to a job's execution."""
|
||||||
|
|
||||||
|
attempts: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Number of processing attempts a job has taken.",
|
||||||
|
)
|
||||||
|
|
||||||
error: str | None = Field(
|
error: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Will contain a descriptive error message if processing failed.",
|
description="Will contain a descriptive error message if processing failed.",
|
||||||
|
|||||||
7
app/shared/logger.py
Normal file
7
app/shared/logger.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
logging.basicConfig()
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
@@ -1,5 +1,3 @@
|
|||||||
import sys
|
|
||||||
|
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
@@ -13,9 +11,3 @@ class Settings(BaseSettings):
|
|||||||
TASK_HARD_TIME_LIMIT: int = 4 * 60 * 60
|
TASK_HARD_TIME_LIMIT: int = 4 * 60 * 60
|
||||||
|
|
||||||
ENABLE_SHARING: bool = False
|
ENABLE_SHARING: bool = False
|
||||||
|
|
||||||
|
|
||||||
if "pytest" in sys.modules:
|
|
||||||
settings = Settings(_env_file=".env.test") # type: ignore
|
|
||||||
else:
|
|
||||||
settings = Settings() # type: ignore
|
|
||||||
|
|||||||
@@ -3,46 +3,62 @@ from fastapi.testclient import TestClient
|
|||||||
from sqlalchemy_utils import create_database, database_exists, drop_database
|
from sqlalchemy_utils import create_database, database_exists, drop_database
|
||||||
|
|
||||||
import app.shared.db.models as models
|
import app.shared.db.models as models
|
||||||
from app.shared.db.base import SessionLocal, engine
|
from app.shared.db.base import make_engine, make_session_local
|
||||||
from app.shared.settings import settings
|
from app.shared.settings import Settings
|
||||||
|
from app.web.injections.db import get_session
|
||||||
|
from app.web.injections.settings import get_settings
|
||||||
from app.web.main import app_factory
|
from app.web.main import app_factory
|
||||||
|
|
||||||
|
|
||||||
def pytest_configure() -> None:
|
@pytest.fixture()
|
||||||
if not database_exists(engine.url):
|
def settings():
|
||||||
create_database(engine.url)
|
return Settings(_env_file=".env.test") # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def pytest_unconfigure() -> None:
|
@pytest.fixture()
|
||||||
if database_exists(engine.url):
|
def auth_headers(settings) -> dict[str, str]:
|
||||||
drop_database(engine.url)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def auth_headers() -> dict[str, str]:
|
|
||||||
return {"Authorization": f"Bearer {settings.API_SECRET}"}
|
return {"Authorization": f"Bearer {settings.API_SECRET}"}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
@pytest.fixture()
|
||||||
def db_session():
|
def test_db(settings):
|
||||||
models.Base.metadata.create_all(engine)
|
engine = make_engine(settings.DATABASE_URI)
|
||||||
connection = engine.connect()
|
|
||||||
|
|
||||||
with SessionLocal(bind=connection) as session:
|
if not database_exists(engine.url):
|
||||||
yield session
|
create_database(engine.url)
|
||||||
connection.close()
|
|
||||||
|
models.Base.metadata.create_all(engine)
|
||||||
|
|
||||||
|
connection = engine.connect()
|
||||||
|
yield connection
|
||||||
|
connection.close()
|
||||||
|
|
||||||
models.Base.metadata.drop_all(bind=engine)
|
models.Base.metadata.drop_all(bind=engine)
|
||||||
|
drop_database(engine.url)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture()
|
||||||
def client(db_session):
|
def db_session(test_db):
|
||||||
app = app_factory(lambda: db_session)
|
session_local = make_session_local(test_db)
|
||||||
|
with session_local() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def app(db_session, settings):
|
||||||
|
app = app_factory()
|
||||||
|
app.dependency_overrides[get_settings] = lambda: settings
|
||||||
|
app.dependency_overrides[get_session] = lambda: db_session
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def client(app):
|
||||||
client = TestClient(app)
|
client = TestClient(app)
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=False)
|
@pytest.fixture()
|
||||||
def mock_job(db_session):
|
def mock_job(db_session):
|
||||||
job = models.Job(
|
job = models.Job(
|
||||||
url="https://example.com",
|
url="https://example.com",
|
||||||
@@ -51,22 +67,15 @@ def mock_job(db_session):
|
|||||||
meta={"task_id": "5c790c76-2cc1-4e91-a305-443df55a4a4c"},
|
meta={"task_id": "5c790c76-2cc1-4e91-a305-443df55a4a4c"},
|
||||||
)
|
)
|
||||||
db_session.add(job)
|
db_session.add(job)
|
||||||
db_session.flush()
|
db_session.commit()
|
||||||
return job
|
return job
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=False)
|
@pytest.fixture()
|
||||||
def mock_artifact(db_session, mock_job):
|
def mock_artifact(db_session, mock_job):
|
||||||
artifact = models.Artifact(
|
artifact = models.Artifact(
|
||||||
data=None, job_id=str(mock_job.id), type=models.ArtifactType.raw_transcript
|
data=None, job_id=str(mock_job.id), type=models.ArtifactType.raw_transcript
|
||||||
)
|
)
|
||||||
db_session.add(artifact)
|
db_session.add(artifact)
|
||||||
db_session.flush()
|
db_session.commit()
|
||||||
return artifact
|
return artifact
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def sharing_enabled():
|
|
||||||
settings.ENABLE_SHARING = True
|
|
||||||
yield
|
|
||||||
settings.ENABLE_SHARING = False
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
from fastapi.testclient import TestClient
|
|
||||||
|
|
||||||
import app.shared.db.models as models
|
import app.shared.db.models as models
|
||||||
from app.web.main import app_factory
|
from app.shared.settings import Settings
|
||||||
|
from app.web.injections.settings import get_settings
|
||||||
|
|
||||||
|
|
||||||
# POST /api/v1/jobs
|
# POST /api/v1/jobs
|
||||||
@@ -69,9 +68,10 @@ def test_get_job_sharing_disabled(client, mock_job):
|
|||||||
assert res.status_code == 401
|
assert res.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
def test_get_job_sharing_enabled(db_session, mock_job, sharing_enabled):
|
def test_get_job_sharing_enabled(client, app, mock_job):
|
||||||
# HACK: delay construction until settings are patched.
|
app.dependency_overrides[get_settings] = lambda: Settings(
|
||||||
client = TestClient(app_factory(lambda: db_session))
|
_env_file=".env.test", ENABLE_SHARING=True # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
res = client.get(
|
res = client.get(
|
||||||
f"/api/v1/jobs/{mock_job.id}",
|
f"/api/v1/jobs/{mock_job.id}",
|
||||||
@@ -107,10 +107,25 @@ def test_get_artifacts_not_found(client, auth_headers, mock_job):
|
|||||||
# DELETE /api/v1/jobs
|
# DELETE /api/v1/jobs
|
||||||
# ---
|
# ---
|
||||||
def test_delete_job_pass(client, auth_headers, mock_job, db_session):
|
def test_delete_job_pass(client, auth_headers, mock_job, db_session):
|
||||||
res = client.delete(
|
res_job = client.get(
|
||||||
f"/api/v1/jobs/{mock_job.id}",
|
f"/api/v1/jobs/{mock_job.id}",
|
||||||
headers=auth_headers,
|
headers=auth_headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert db_session.query(models.Job).count() == 0
|
assert res_job.status_code == 200
|
||||||
assert res.status_code == 204
|
|
||||||
|
client.delete(
|
||||||
|
f"/api/v1/jobs/{mock_job.id}",
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
# HACK: this catches a missed .commit().
|
||||||
|
# TODO: clean up pytest database handling.
|
||||||
|
db_session.rollback()
|
||||||
|
|
||||||
|
res_job_missing = client.get(
|
||||||
|
f"/api/v1/jobs/{mock_job.id}",
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert res_job_missing.status_code == 404
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
from app.shared.db.base import get_session
|
|
||||||
from app.web.main import app_factory
|
from app.web.main import app_factory
|
||||||
|
|
||||||
app = app_factory(get_session)
|
app = app_factory
|
||||||
|
|||||||
0
app/web/injections/__init__.py
Normal file
0
app/web/injections/__init__.py
Normal file
26
app/web/injections/db.py
Normal file
26
app/web/injections/db.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from functools import lru_cache
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.shared.db.base import make_engine, make_session_local
|
||||||
|
from app.shared.settings import Settings
|
||||||
|
from app.web.injections.settings import get_settings
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def session_local(database_url: str):
|
||||||
|
engine = make_engine(database_url)
|
||||||
|
return make_session_local(engine)
|
||||||
|
|
||||||
|
|
||||||
|
def get_session_local(settings: Settings = Depends(get_settings)):
|
||||||
|
return session_local(settings.DATABASE_URI)
|
||||||
|
|
||||||
|
|
||||||
|
def get_session(
|
||||||
|
session_local=Depends(get_session_local),
|
||||||
|
) -> Generator[Session, None, None]:
|
||||||
|
with session_local() as session:
|
||||||
|
yield session
|
||||||
39
app/web/injections/security.py
Normal file
39
app/web/injections/security.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
from hmac import compare_digest
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException
|
||||||
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||||
|
|
||||||
|
from app.shared.settings import Settings
|
||||||
|
from app.web.injections.settings import get_settings
|
||||||
|
|
||||||
|
|
||||||
|
def api_key_auth(
|
||||||
|
credentials: Annotated[
|
||||||
|
HTTPAuthorizationCredentials, Depends(HTTPBearer(auto_error=False))
|
||||||
|
],
|
||||||
|
settings: Annotated[Settings, Depends(get_settings)],
|
||||||
|
):
|
||||||
|
validate_credentials(credentials, settings.API_SECRET)
|
||||||
|
|
||||||
|
|
||||||
|
def sharing_auth(
|
||||||
|
credentials: Annotated[
|
||||||
|
HTTPAuthorizationCredentials, Depends(HTTPBearer(auto_error=False))
|
||||||
|
],
|
||||||
|
settings: Annotated[Settings, Depends(get_settings)],
|
||||||
|
):
|
||||||
|
if settings.ENABLE_SHARING:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
validate_credentials(credentials, settings.API_SECRET)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_credentials(credentials: HTTPAuthorizationCredentials, secret: str):
|
||||||
|
# use compare_digest to counter timing attacks.
|
||||||
|
if (
|
||||||
|
not credentials
|
||||||
|
or not secret
|
||||||
|
or not compare_digest(secret, credentials.credentials)
|
||||||
|
):
|
||||||
|
raise HTTPException(status_code=401)
|
||||||
8
app/web/injections/settings.py
Normal file
8
app/web/injections/settings.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
from app.shared.settings import Settings
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_settings():
|
||||||
|
return Settings() # type: ignore
|
||||||
16
app/web/injections/task_queue.py
Normal file
16
app/web/injections/task_queue.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
|
||||||
|
from app.shared.settings import Settings
|
||||||
|
from app.web.injections.settings import get_settings
|
||||||
|
from app.web.task_queue import TaskQueue
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def task_queue(broker_url: str):
|
||||||
|
return TaskQueue(broker_url)
|
||||||
|
|
||||||
|
|
||||||
|
def get_task_queue(settings: Settings = Depends(get_settings)):
|
||||||
|
return task_queue(settings.BROKER_URL)
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
from contextlib import asynccontextmanager
|
from typing import Annotated
|
||||||
from typing import Annotated, Callable, Generator
|
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path
|
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path
|
||||||
@@ -8,42 +7,31 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
import app.shared.db.models as models
|
import app.shared.db.models as models
|
||||||
import app.web.dtos as dtos
|
import app.web.dtos as dtos
|
||||||
from app.shared.db.base import SessionLocal
|
from app.web.injections.db import get_session
|
||||||
from app.shared.settings import settings
|
from app.web.injections.security import api_key_auth, sharing_auth
|
||||||
from app.web.security import authenticate_api_key
|
from app.web.injections.task_queue import get_task_queue
|
||||||
from app.web.task_queue import TaskQueue
|
from app.web.task_queue import TaskQueue
|
||||||
|
|
||||||
|
DatabaseSession = Annotated[Session, Depends(get_session)]
|
||||||
|
|
||||||
def app_factory(
|
|
||||||
session_getter: Callable[[], Generator[Session, None, None]]
|
|
||||||
) -> FastAPI:
|
|
||||||
DatabaseSession = Annotated[Session, Depends(session_getter)]
|
|
||||||
|
|
||||||
task_queue = TaskQueue()
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def lifespan(_: FastAPI):
|
|
||||||
with SessionLocal() as session:
|
|
||||||
task_queue.rehydrate(session)
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
def app_factory():
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
description=(
|
description=(
|
||||||
"whisperbox-transcribe is an async HTTP wrapper for openai/whisper."
|
"whisperbox-transcribe is an async HTTP wrapper for openai/whisper."
|
||||||
),
|
),
|
||||||
lifespan=lifespan,
|
|
||||||
title="whisperbox-transcribe",
|
title="whisperbox-transcribe",
|
||||||
)
|
)
|
||||||
|
|
||||||
api_router = APIRouter(prefix="/api/v1")
|
api_router = APIRouter(prefix="/api/v1")
|
||||||
|
|
||||||
@api_router.get("/", response_model=None, status_code=204)
|
@api_router.get("/", status_code=204)
|
||||||
def api_root() -> None:
|
def api_root():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@api_router.get(
|
@api_router.get(
|
||||||
"/jobs",
|
"/jobs",
|
||||||
dependencies=[Depends(authenticate_api_key)],
|
dependencies=[Depends(api_key_auth)],
|
||||||
response_model=list[dtos.Job],
|
response_model=list[dtos.Job],
|
||||||
summary="Get metadata for all jobs",
|
summary="Get metadata for all jobs",
|
||||||
)
|
)
|
||||||
@@ -61,7 +49,7 @@ def app_factory(
|
|||||||
|
|
||||||
@api_router.get(
|
@api_router.get(
|
||||||
"/jobs/{id}",
|
"/jobs/{id}",
|
||||||
dependencies=[] if settings.ENABLE_SHARING else [Depends(authenticate_api_key)],
|
dependencies=[Depends(sharing_auth)],
|
||||||
response_model=dtos.Job,
|
response_model=dtos.Job,
|
||||||
summary="Get metadata for one job",
|
summary="Get metadata for one job",
|
||||||
)
|
)
|
||||||
@@ -81,7 +69,7 @@ def app_factory(
|
|||||||
|
|
||||||
@api_router.get(
|
@api_router.get(
|
||||||
"/jobs/{id}/artifacts",
|
"/jobs/{id}/artifacts",
|
||||||
dependencies=[] if settings.ENABLE_SHARING else [Depends(authenticate_api_key)],
|
dependencies=[Depends(api_key_auth)],
|
||||||
response_model=list[dtos.Artifact],
|
response_model=list[dtos.Artifact],
|
||||||
summary="Get all artifacts for one job",
|
summary="Get all artifacts for one job",
|
||||||
)
|
)
|
||||||
@@ -102,7 +90,7 @@ def app_factory(
|
|||||||
|
|
||||||
@api_router.delete(
|
@api_router.delete(
|
||||||
"/jobs/{id}",
|
"/jobs/{id}",
|
||||||
dependencies=[Depends(authenticate_api_key)],
|
dependencies=[Depends(sharing_auth)],
|
||||||
status_code=204,
|
status_code=204,
|
||||||
summary="Delete a job with all artifacts",
|
summary="Delete a job with all artifacts",
|
||||||
)
|
)
|
||||||
@@ -112,6 +100,7 @@ def app_factory(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Remove metadata and artifacts for a single job."""
|
"""Remove metadata and artifacts for a single job."""
|
||||||
session.query(models.Job).filter(models.Job.id == str(id)).delete()
|
session.query(models.Job).filter(models.Job.id == str(id)).delete()
|
||||||
|
session.commit()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
class PostJobPayload(BaseModel):
|
class PostJobPayload(BaseModel):
|
||||||
@@ -138,7 +127,7 @@ def app_factory(
|
|||||||
|
|
||||||
@api_router.post(
|
@api_router.post(
|
||||||
"/jobs",
|
"/jobs",
|
||||||
dependencies=[Depends(authenticate_api_key)],
|
dependencies=[Depends(api_key_auth)],
|
||||||
response_model=dtos.Job,
|
response_model=dtos.Job,
|
||||||
status_code=201,
|
status_code=201,
|
||||||
summary="Enqueue a new job",
|
summary="Enqueue a new job",
|
||||||
@@ -146,6 +135,7 @@ def app_factory(
|
|||||||
def create_job(
|
def create_job(
|
||||||
payload: PostJobPayload,
|
payload: PostJobPayload,
|
||||||
session: DatabaseSession,
|
session: DatabaseSession,
|
||||||
|
task_queue: Annotated[TaskQueue, Depends(get_task_queue)],
|
||||||
) -> models.Job:
|
) -> models.Job:
|
||||||
"""
|
"""
|
||||||
Enqueue a new whisper job for processing.
|
Enqueue a new whisper job for processing.
|
||||||
|
|||||||
@@ -1,16 +0,0 @@
|
|||||||
from hmac import compare_digest
|
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException
|
|
||||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
||||||
|
|
||||||
from app.shared.settings import settings
|
|
||||||
|
|
||||||
|
|
||||||
def authenticate_api_key(
|
|
||||||
credentials: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
|
|
||||||
) -> None:
|
|
||||||
# use compare_digest to counter timing attacks.
|
|
||||||
if not credentials or not compare_digest(
|
|
||||||
settings.API_SECRET, credentials.credentials
|
|
||||||
):
|
|
||||||
raise HTTPException(status_code=401)
|
|
||||||
@@ -1,8 +1,4 @@
|
|||||||
from asyncio.log import logger
|
|
||||||
|
|
||||||
from celery import Celery
|
from celery import Celery
|
||||||
from sqlalchemy import or_
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
import app.shared.db.models as models
|
import app.shared.db.models as models
|
||||||
from app.shared.celery import get_celery_binding
|
from app.shared.celery import get_celery_binding
|
||||||
@@ -11,8 +7,8 @@ from app.shared.celery import get_celery_binding
|
|||||||
class TaskQueue:
|
class TaskQueue:
|
||||||
celery: Celery
|
celery: Celery
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, broker_url: str) -> None:
|
||||||
self.celery = get_celery_binding()
|
self.celery = get_celery_binding(broker_url=broker_url)
|
||||||
|
|
||||||
def queue_task(self, job: models.Job):
|
def queue_task(self, job: models.Job):
|
||||||
"""
|
"""
|
||||||
@@ -22,25 +18,3 @@ class TaskQueue:
|
|||||||
transcribe = self.celery.signature("app.worker.main.transcribe")
|
transcribe = self.celery.signature("app.worker.main.transcribe")
|
||||||
# TODO: catch delivery errors?
|
# TODO: catch delivery errors?
|
||||||
transcribe.delay(job.id)
|
transcribe.delay(job.id)
|
||||||
|
|
||||||
def rehydrate(self, session: Session):
|
|
||||||
# TODO: we could use `acks_late` to handle this scenario within celery itself.
|
|
||||||
# the reason this does not work well in our case is that `visibility_timeout`
|
|
||||||
# needs to be very high since whisper workers can be long running.
|
|
||||||
# doing this app-side bears the risk of poison pilling the worker though,
|
|
||||||
# implement a workaround with an acceptable trade-off. (=> retry only once?)
|
|
||||||
jobs = (
|
|
||||||
session.query(models.Job)
|
|
||||||
.filter(
|
|
||||||
or_(
|
|
||||||
models.Job.status == models.JobStatus.processing,
|
|
||||||
models.Job.status == models.JobStatus.create,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.order_by(models.Job.created_at)
|
|
||||||
).all()
|
|
||||||
|
|
||||||
logger.info(f"Requeueing {len(jobs)} jobs.")
|
|
||||||
|
|
||||||
for job in jobs:
|
|
||||||
self.queue_task(job)
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
from asyncio.log import logger
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
@@ -7,11 +6,16 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
import app.shared.db.models as models
|
import app.shared.db.models as models
|
||||||
from app.shared.celery import get_celery_binding
|
from app.shared.celery import get_celery_binding
|
||||||
from app.shared.db.base import SessionLocal
|
from app.shared.db.base import make_engine, make_session_local
|
||||||
from app.shared.settings import settings
|
from app.shared.logger import logger
|
||||||
|
from app.shared.settings import Settings
|
||||||
from app.worker.strategies.local import LocalStrategy
|
from app.worker.strategies.local import LocalStrategy
|
||||||
|
|
||||||
celery = get_celery_binding()
|
# TODO: refactor to be part of a Task instance.
|
||||||
|
settings = Settings() # type: ignore
|
||||||
|
celery = get_celery_binding(settings.BROKER_URL)
|
||||||
|
engine = make_engine(settings.DATABASE_URI)
|
||||||
|
SessionLocal = make_session_local(engine)
|
||||||
|
|
||||||
|
|
||||||
class TranscribeTask(Task):
|
class TranscribeTask(Task):
|
||||||
@@ -39,15 +43,28 @@ class TranscribeTask(Task):
|
|||||||
bind=True,
|
bind=True,
|
||||||
soft_time_limit=settings.TASK_SOFT_TIME_LIMIT,
|
soft_time_limit=settings.TASK_SOFT_TIME_LIMIT,
|
||||||
time_limit=settings.TASK_HARD_TIME_LIMIT,
|
time_limit=settings.TASK_HARD_TIME_LIMIT,
|
||||||
|
task_acks_late=True,
|
||||||
|
task_acks_on_failure_or_timeout=True,
|
||||||
|
task_reject_on_worker_lost=True,
|
||||||
)
|
)
|
||||||
def transcribe(self: Task, job_id: UUID) -> None:
|
def transcribe(self: TranscribeTask, job_id: UUID) -> None:
|
||||||
|
session: Session | None = None
|
||||||
|
job: models.Job | None = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if not self.strategy:
|
||||||
|
raise Exception("expected a transcription strategy to be defined.")
|
||||||
|
|
||||||
# runs in a separate thread => requires sqlite's WAL mode to be enabled.
|
# runs in a separate thread => requires sqlite's WAL mode to be enabled.
|
||||||
db: Session = SessionLocal()
|
session = SessionLocal()
|
||||||
|
|
||||||
|
# work around mypy not inferring the sum type correctly.
|
||||||
|
if not session:
|
||||||
|
raise Exception("failed to acquire a session.")
|
||||||
|
|
||||||
# check if passed job should be processed.
|
# check if passed job should be processed.
|
||||||
|
|
||||||
job = db.query(models.Job).filter(models.Job.id == job_id).one_or_none()
|
job = session.query(models.Job).filter(models.Job.id == job_id).one_or_none()
|
||||||
|
|
||||||
if job is None:
|
if job is None:
|
||||||
logger.warn("[{job.id}]: Received unknown job, abort.")
|
logger.warn("[{job.id}]: Received unknown job, abort.")
|
||||||
@@ -59,11 +76,22 @@ def transcribe(self: Task, job_id: UUID) -> None:
|
|||||||
|
|
||||||
logger.debug(f"[{job.id}]: start processing {job.type} job.")
|
logger.debug(f"[{job.id}]: start processing {job.type} job.")
|
||||||
|
|
||||||
|
if job.meta is not None:
|
||||||
|
attempts = 1 + (job.meta.get("attempts") or 0)
|
||||||
|
else:
|
||||||
|
attempts = 1
|
||||||
|
|
||||||
|
# SAFEGUARD: celery's retry policies do not handle lost workers, retry once.
|
||||||
|
# @see https://github.com/celery/celery/pull/6103
|
||||||
|
if attempts > 2:
|
||||||
|
raise Exception("Maximum number of retries exceeded for killed worker.")
|
||||||
|
|
||||||
# unit of work: set task status to processing.
|
# unit of work: set task status to processing.
|
||||||
|
|
||||||
job.meta = {"task_id": self.request.id}
|
job.meta = {"task_id": self.request.id, "attempts": attempts}
|
||||||
|
|
||||||
job.status = models.JobStatus.processing
|
job.status = models.JobStatus.processing
|
||||||
db.commit()
|
session.commit()
|
||||||
|
|
||||||
logger.debug(f"[{job.id}]: finished setting task to {job.status}.")
|
logger.debug(f"[{job.id}]: finished setting task to {job.status}.")
|
||||||
|
|
||||||
@@ -72,21 +100,27 @@ def transcribe(self: Task, job_id: UUID) -> None:
|
|||||||
logger.debug(f"[{job.id}]: successfully processed audio.")
|
logger.debug(f"[{job.id}]: successfully processed audio.")
|
||||||
|
|
||||||
artifact = models.Artifact(job_id=str(job.id), data=result, type=result_type)
|
artifact = models.Artifact(job_id=str(job.id), data=result, type=result_type)
|
||||||
db.add(artifact)
|
session.add(artifact)
|
||||||
|
|
||||||
job.status = models.JobStatus.success
|
job.status = models.JobStatus.success
|
||||||
db.commit()
|
session.commit()
|
||||||
|
|
||||||
logger.debug(f"[{job.id}]: successfully stored artifact.")
|
logger.debug(f"[{job.id}]: successfully stored artifact.")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if job and db:
|
if job and session:
|
||||||
if db.in_transaction():
|
if session.in_transaction():
|
||||||
db.rollback()
|
session.rollback()
|
||||||
job.meta = {**job.meta, "error": str(e)} # type: ignore
|
if job.meta is not None:
|
||||||
|
job.meta = {**job.meta, "error": str(e)}
|
||||||
|
else:
|
||||||
|
job.meta = {"error": str(e)}
|
||||||
|
|
||||||
job.status = models.JobStatus.error
|
job.status = models.JobStatus.error
|
||||||
db.commit()
|
session.commit()
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
self.strategy.cleanup(job_id)
|
if self.strategy:
|
||||||
db.close()
|
self.strategy.cleanup(job_id)
|
||||||
|
if session:
|
||||||
|
session.close()
|
||||||
|
|||||||
1
conf/rabbitmq.conf
Normal file
1
conf/rabbitmq.conf
Normal file
@@ -0,0 +1 @@
|
|||||||
|
vm_memory_high_watermark.absolute = 192MB
|
||||||
@@ -1,3 +1,6 @@
|
|||||||
|
x-broker-environment: &broker-environment
|
||||||
|
BROKER_URL: "amqp://${RABBITMQ_DEFAULT_USER}:${RABBITMQ_DEFAULT_PASS}@rabbitmq:5672"
|
||||||
|
|
||||||
version: "3.8"
|
version: "3.8"
|
||||||
name: whisperbox-transcribe
|
name: whisperbox-transcribe
|
||||||
|
|
||||||
@@ -12,46 +15,59 @@ services:
|
|||||||
networks:
|
networks:
|
||||||
- traefik
|
- traefik
|
||||||
|
|
||||||
redis:
|
rabbitmq:
|
||||||
image: redis:7-alpine
|
env_file: .env
|
||||||
|
image: rabbitmq:3-alpine
|
||||||
networks:
|
networks:
|
||||||
- app
|
- app
|
||||||
deploy:
|
deploy:
|
||||||
resources:
|
resources:
|
||||||
limits:
|
limits:
|
||||||
memory: 128M
|
memory: 256M
|
||||||
|
healthcheck:
|
||||||
|
test: rabbitmq-diagnostics check_port_connectivity
|
||||||
|
interval: 3s
|
||||||
|
timeout: 3s
|
||||||
|
retries: 10
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
- ./conf/rabbitmq.conf:/etc/rabbitmq/rabbitmq.conf
|
||||||
|
- rabbitmq-data:/var/lib/rabbitmq/mnesia/
|
||||||
|
|
||||||
worker:
|
worker:
|
||||||
env_file: .env
|
env_file: .env
|
||||||
|
environment:
|
||||||
|
<<: *broker-environment
|
||||||
build:
|
build:
|
||||||
context: .
|
context: .
|
||||||
dockerfile: worker.Dockerfile
|
dockerfile: worker.Dockerfile
|
||||||
args:
|
args:
|
||||||
WHISPER_MODEL: ${WHISPER_MODEL}
|
WHISPER_MODEL: ${WHISPER_MODEL}
|
||||||
|
depends_on:
|
||||||
|
rabbitmq:
|
||||||
|
condition: service_healthy
|
||||||
networks:
|
networks:
|
||||||
- app
|
- app
|
||||||
depends_on:
|
|
||||||
- redis
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD-SHELL", "celery -b ${BROKER_URL} inspect ping -d celery@$$HOSTNAME"]
|
|
||||||
interval: 5s
|
|
||||||
timeout: 5s
|
|
||||||
retries: 5
|
|
||||||
|
|
||||||
web:
|
web:
|
||||||
env_file: .env
|
env_file: .env
|
||||||
|
environment:
|
||||||
|
<<: *broker-environment
|
||||||
build:
|
build:
|
||||||
context: .
|
context: .
|
||||||
dockerfile: web.Dockerfile
|
dockerfile: web.Dockerfile
|
||||||
|
depends_on:
|
||||||
|
rabbitmq:
|
||||||
|
condition: service_healthy
|
||||||
networks:
|
networks:
|
||||||
- app
|
- app
|
||||||
- traefik
|
- traefik
|
||||||
depends_on:
|
|
||||||
worker:
|
|
||||||
condition: service_healthy
|
|
||||||
|
|
||||||
networks:
|
networks:
|
||||||
app:
|
app:
|
||||||
driver: bridge
|
driver: bridge
|
||||||
traefik:
|
traefik:
|
||||||
driver: bridge
|
driver: bridge
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
rabbitmq-data:
|
||||||
|
|||||||
@@ -12,7 +12,9 @@ services:
|
|||||||
- "--entrypoints.web.address=:80"
|
- "--entrypoints.web.address=:80"
|
||||||
|
|
||||||
web:
|
web:
|
||||||
command: bash -c "alembic upgrade head && uvicorn app.web:app --reload --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info"
|
command: bash -c "alembic upgrade head && uvicorn app.web:app --reload --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info --factory"
|
||||||
|
# NOTE: the docker on mac mount adapter (virtioFS) does not support flock.
|
||||||
|
# this can cause the sqlite database to corrupt when written from worker <> api simultaneously.
|
||||||
volumes:
|
volumes:
|
||||||
- ./:/etc/whisperbox-transcribe/
|
- ./:/etc/whisperbox-transcribe/
|
||||||
labels:
|
labels:
|
||||||
@@ -26,13 +28,18 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- ./:/etc/whisperbox-transcribe/
|
- ./:/etc/whisperbox-transcribe/
|
||||||
|
|
||||||
|
rabbitmq:
|
||||||
|
image: rabbitmq:3-management-alpine
|
||||||
|
ports:
|
||||||
|
- 15672:15672
|
||||||
|
|
||||||
flower:
|
flower:
|
||||||
image: mher/flower
|
image: mher/flower
|
||||||
command: celery --broker redis://redis:6379/0 flower --port=5555
|
command: celery --broker amqp://${RABBITMQ_DEFAULT_USER}:${RABBITMQ_DEFAULT_PASS}@rabbitmq:5672 flower --port=5555
|
||||||
ports:
|
ports:
|
||||||
- 5555:5555
|
- 5555:5555
|
||||||
depends_on:
|
depends_on:
|
||||||
worker:
|
- worker
|
||||||
condition: service_healthy
|
- rabbitmq
|
||||||
networks:
|
networks:
|
||||||
- app
|
- app
|
||||||
|
|||||||
1
mypy.ini
1
mypy.ini
@@ -2,3 +2,4 @@
|
|||||||
plugins = sqlalchemy.ext.mypy.plugin
|
plugins = sqlalchemy.ext.mypy.plugin
|
||||||
ignore_missing_imports = True
|
ignore_missing_imports = True
|
||||||
disallow_untyped_defs = False
|
disallow_untyped_defs = False
|
||||||
|
check_untyped_defs = True
|
||||||
|
|||||||
@@ -1,19 +1,19 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "whisperbox-transcribe"
|
name = "whisperbox-transcribe"
|
||||||
description = ""
|
description = ""
|
||||||
version = "0.1.0"
|
version = "1.0.0"
|
||||||
|
|
||||||
dependencies=[
|
dependencies=[
|
||||||
"celery[redis] ==5.3.1",
|
"celery ==5.3.6",
|
||||||
"sqlalchemy[mypy] ==2.0.19",
|
"sqlalchemy[mypy] ==2.0.24",
|
||||||
"pydantic ==2.1.1",
|
"pydantic ==2.5.3",
|
||||||
"pydantic-settings ==2.0.2"
|
"pydantic-settings ==2.1.0"
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
web=[
|
web=[
|
||||||
"alembic ==1.11.1",
|
"alembic ==1.11.3",
|
||||||
"fastapi ==0.100.1",
|
"fastapi ==0.101.1",
|
||||||
"uvicorn[standard] ==0.23.2",
|
"uvicorn[standard] ==0.23.2",
|
||||||
"gunicorn ==21.2.0"
|
"gunicorn ==21.2.0"
|
||||||
]
|
]
|
||||||
@@ -26,17 +26,17 @@ worker=[
|
|||||||
|
|
||||||
tooling = [
|
tooling = [
|
||||||
# code formatting
|
# code formatting
|
||||||
"black ==23.7.0",
|
"black ==23.12.1",
|
||||||
# linting
|
# linting
|
||||||
"ruff ==0.0.280",
|
"ruff ==0.0.292",
|
||||||
# tests
|
# tests
|
||||||
"httpx ==0.24.1",
|
"httpx ==0.26.0",
|
||||||
"sqlalchemy-utils ==0.41.1",
|
"sqlalchemy-utils ==0.41.1",
|
||||||
"python-dotenv ==1.0.0",
|
"python-dotenv ==1.0.1",
|
||||||
"pytest ==7.4.0",
|
"pytest ==7.4.4",
|
||||||
# types
|
# types
|
||||||
"mypy ==1.4.1",
|
"mypy ==1.5.1",
|
||||||
"types-requests ==2.31.0.2"
|
"types-requests ==2.31.0.20231231"
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
|
|||||||
@@ -20,4 +20,4 @@ COPY alembic.ini .
|
|||||||
ENV VIRTUAL_ENV /opt/venv
|
ENV VIRTUAL_ENV /opt/venv
|
||||||
ENV PATH /opt/venv/bin:$PATH
|
ENV PATH /opt/venv/bin:$PATH
|
||||||
|
|
||||||
CMD alembic upgrade head && uvicorn app.web:app --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info --workers 4 --proxy-headers
|
CMD alembic upgrade head && uvicorn app.web:app --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info --workers 4 --proxy-headers --factory
|
||||||
|
|||||||
Reference in New Issue
Block a user