diff --git a/.gitignore b/.gitignore index 5fa62ee..3ec2352 100644 --- a/.gitignore +++ b/.gitignore @@ -166,3 +166,6 @@ cython_debug/ whisperbox-transcribe.sqlite* *shm *wal + +# ruff +.ruff_cache diff --git a/app/tests/conftest.py b/app/tests/conftest.py index 52bab7b..6460c18 100644 --- a/app/tests/conftest.py +++ b/app/tests/conftest.py @@ -1,13 +1,11 @@ -from typing import Generator - import pytest -from sqlalchemy.orm import Session +from fastapi.testclient import TestClient from sqlalchemy_utils import create_database, database_exists, drop_database import app.shared.db.models as models -from app.shared.db.base import SessionLocal, engine, get_session +from app.shared.db.base import SessionLocal, engine from app.shared.settings import settings -from app.web.main import app +from app.web.main import app_factory def pytest_configure() -> None: @@ -26,15 +24,48 @@ def auth_headers() -> dict[str, str]: @pytest.fixture(scope="function", autouse=True) -def db_session() -> Generator[Session, None, None]: +def db_session(): models.Base.metadata.create_all(engine) - connection = engine.connect() with SessionLocal(bind=connection) as session: - app.dependency_overrides[get_session] = lambda: session yield session - app.dependency_overrides.clear() connection.close() models.Base.metadata.drop_all(bind=engine) + + +@pytest.fixture(scope="function") +def client(db_session): + app = app_factory(lambda: db_session) + client = TestClient(app) + return client + + +@pytest.fixture(scope="function", autouse=False) +def mock_job(db_session): + job = models.Job( + url="https://example.com", + type=models.JobType.transcript, + status=models.JobStatus.create, + ) + db_session.add(job) + db_session.flush() + return job + + +@pytest.fixture(scope="function", autouse=False) +def mock_artifact(db_session, mock_job): + artifact = models.Artifact( + data=None, job_id=str(mock_job.id), type=models.ArtifactType.raw_transcript + ) + db_session.add(artifact) + db_session.flush() + return artifact + + +@pytest.fixture(scope="function") +def sharing_enabled(): + settings.ENABLE_SHARING = True + yield + settings.ENABLE_SHARING = False diff --git a/app/tests/test_api.py b/app/tests/test_api.py index 3d7c0ad..8c950bb 100644 --- a/app/tests/test_api.py +++ b/app/tests/test_api.py @@ -1,30 +1,12 @@ -import pytest from fastapi.testclient import TestClient -from sqlalchemy.orm import Session import app.shared.db.models as models -from app.web.main import app - -client = TestClient(app) - - -@pytest.fixture(name="mock_job", scope="function", autouse=False) -def mock_job(db_session: Session) -> models.Job: - job = models.Job( - url="https://example.com", - type=models.JobType.transcript, - status=models.JobStatus.create, - ) - - db_session.add(job) - db_session.flush() - - return job +from app.web.main import app_factory # POST /api/v1/jobs # --- -def test_create_job_pass(auth_headers: dict[str, str]) -> None: +def test_create_job_pass(client, auth_headers: dict[str, str]): res = client.post( "/api/v1/jobs", headers=auth_headers, @@ -34,12 +16,12 @@ def test_create_job_pass(auth_headers: dict[str, str]) -> None: assert isinstance(res.json()["id"], str) -def test_create_job_missing_body(auth_headers: dict[str, str]) -> None: +def test_create_job_missing_body(client, auth_headers: dict[str, str]): res = client.post("/api/v1/jobs", headers=auth_headers, json={}) assert res.status_code == 422 -def test_create_job_malformed_url(auth_headers: dict[str, str]) -> None: +def test_create_job_malformed_url(client, auth_headers: dict[str, str]): res = client.post( "/api/v1/jobs", headers=auth_headers, @@ -50,7 +32,7 @@ def test_create_job_malformed_url(auth_headers: dict[str, str]) -> None: # GET /api/v1/jobs # --- -def test_get_jobs_pass(auth_headers: dict[str, str], mock_job: models.Job) -> None: +def test_get_jobs_pass(client, auth_headers: dict[str, str], mock_job: models.Job): res = client.get( "/api/v1/jobs?type=transcribe", headers=auth_headers, @@ -61,7 +43,7 @@ def test_get_jobs_pass(auth_headers: dict[str, str], mock_job: models.Job) -> No # GET /api/v1/jobs/:id # --- -def test_get_job_pass(auth_headers: dict[str, str], mock_job: models.Job) -> None: +def test_get_job_pass(client, auth_headers: dict[str, str], mock_job: models.Job): res = client.get( f"/api/v1/jobs/{mock_job.id}", headers=auth_headers, @@ -70,7 +52,7 @@ def test_get_job_pass(auth_headers: dict[str, str], mock_job: models.Job) -> Non assert res.json()["id"] == str(mock_job.id) -def test_get_job_not_found(auth_headers: dict[str, str], mock_job: models.Job) -> None: +def test_get_job_not_found(client, auth_headers: dict[str, str], mock_job): res = client.get( "/api/v1/jobs/c8ecf5ea-77cf-48a2-9ecd-199ef35e0ccb", headers=auth_headers, @@ -79,18 +61,29 @@ def test_get_job_not_found(auth_headers: dict[str, str], mock_job: models.Job) - assert res.status_code == 404 -# GET /api/v1/jobs/:id/artifacts -# --- -def test_get_artifacts_pass( - auth_headers: dict[str, str], db_session: Session, mock_job: models.Job -) -> None: - artifact = models.Artifact( - data=None, job_id=str(mock_job.id), type=models.ArtifactType.raw_transcript +def test_get_job_sharing_disabled(client, mock_job): + res = client.get( + f"/api/v1/jobs/{mock_job.id}", + headers={}, + ) + assert res.status_code == 401 + + +def test_get_job_sharing_enabled(db_session, mock_job, sharing_enabled): + # HACK: delay construction until settings are patched. + client = TestClient(app_factory(lambda: db_session)) + + res = client.get( + f"/api/v1/jobs/{mock_job.id}", + headers={}, ) - db_session.add(artifact) - db_session.flush() + assert res.status_code == 200 + +# GET /api/v1/jobs/:id/artifacts +# --- +def test_get_artifacts_pass(client, auth_headers, db_session, mock_job, mock_artifact): res = client.get( f"/api/v1/jobs/{mock_job.id}/artifacts", headers=auth_headers, @@ -98,12 +91,10 @@ def test_get_artifacts_pass( assert res.status_code == 200 assert res.json()[0]["job_id"] == str(mock_job.id) - assert res.json()[0]["id"] == str(artifact.id) + assert res.json()[0]["id"] == str(mock_artifact.id) -def test_get_artifacts_not_found( - auth_headers: dict[str, str], mock_job: models.Job -) -> None: +def test_get_artifacts_not_found(client, auth_headers, mock_job): res = client.get( f"/api/v1/jobs/{mock_job.id}/artifacts", headers=auth_headers, @@ -115,12 +106,11 @@ def test_get_artifacts_not_found( # DELETE /api/v1/jobs # --- -def test_delete_job_pass( - auth_headers: dict[str, str], mock_job: models.Job, db_session: Session -) -> None: +def test_delete_job_pass(client, auth_headers, mock_job, db_session): res = client.delete( f"/api/v1/jobs/{mock_job.id}", headers=auth_headers, ) + assert db_session.query(models.Job).count() == 0 assert res.status_code == 204 diff --git a/app/tests/test_auth.py b/app/tests/test_auth.py index fadc2e0..c1822b1 100644 --- a/app/tests/test_auth.py +++ b/app/tests/test_auth.py @@ -1,25 +1,18 @@ -from fastapi.testclient import TestClient - -from app.web.main import app - -client = TestClient(app) - - -def test_authorization_header_missing() -> None: +def test_authorization_header_missing(client): res = client.get("/api/v1/jobs") assert res.status_code == 401 -def test_authorization_header_malformed() -> None: +def test_authorization_header_malformed(client): res = client.get("/api/v1/jobs", headers={"Authorization": "Bearer"}) assert res.status_code == 401 -def test_incorrect_api_key() -> None: +def test_incorrect_api_key(client): res = client.get("/api/v1/jobs", headers={"Authorization": "Bearer incorrect"}) assert res.status_code == 401 -def test_existing_api_key(auth_headers: dict[str, str]) -> None: +def test_existing_api_key(client, auth_headers): res = client.get("/api/v1/jobs", headers=auth_headers) assert res.status_code == 200 diff --git a/app/web/__init__.py b/app/web/__init__.py index e69de29..61dd17a 100644 --- a/app/web/__init__.py +++ b/app/web/__init__.py @@ -0,0 +1,4 @@ +from app.shared.db.base import get_session +from app.web.main import app_factory + +app = app_factory(get_session) diff --git a/app/web/main.py b/app/web/main.py index 775c315..e167d28 100644 --- a/app/web/main.py +++ b/app/web/main.py @@ -1,5 +1,5 @@ from contextlib import asynccontextmanager -from typing import Annotated +from typing import Annotated, Callable, Generator from uuid import UUID from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path @@ -8,173 +8,172 @@ from sqlalchemy.orm import Session import app.shared.db.models as models import app.web.dtos as dtos -from app.shared.db.base import SessionLocal, get_session +from app.shared.db.base import SessionLocal from app.shared.settings import settings from app.web.security import authenticate_api_key -from app.web.task_queue import task_queue - -DatabaseSession = Annotated[Session, Depends(get_session)] +from app.web.task_queue import TaskQueue -@asynccontextmanager -async def lifespan(_: FastAPI): - with SessionLocal() as session: - task_queue.rehydrate(session) - yield +def app_factory( + session_getter: Callable[[], Generator[Session, None, None]] +) -> FastAPI: + DatabaseSession = Annotated[Session, Depends(session_getter)] + task_queue = TaskQueue() -app = FastAPI( - description="whisperbox-transcribe is an async HTTP wrapper for openai/whisper.", - lifespan=lifespan, - title="whisperbox-transcribe", -) + @asynccontextmanager + async def lifespan(_: FastAPI): + with SessionLocal() as session: + task_queue.rehydrate(session) + yield - -api_router = APIRouter(prefix="/api/v1") - - -@api_router.get("/", response_model=None, status_code=204) -def api_root() -> None: - return None - - -@api_router.get( - "/jobs", - dependencies=[Depends(authenticate_api_key)], - response_model=list[dtos.Job], - summary="Get metadata for all jobs", -) -def get_transcripts( - session: DatabaseSession, - type: dtos.JobType | None = None, -) -> list[models.Job]: - """Get metadata for all jobs.""" - query = session.query(models.Job).order_by(models.Job.created_at.desc()) - - if type: - query = query.filter(models.Job.type == type) - - return query.all() - - -@api_router.get( - "/jobs/{id}", - dependencies=[] if settings.ENABLE_SHARING else [Depends(authenticate_api_key)], - response_model=dtos.Job, - summary="Get metadata for one job", -) -def get_transcript( - session: DatabaseSession, - id: UUID = Path(), -) -> models.Job | None: - """ - Use this route to check transcription status of any given job. - """ - job = session.query(models.Job).filter(models.Job.id == str(id)).one_or_none() - - if not job: - raise HTTPException(status_code=404) - - return job - - -@api_router.get( - "/jobs/{id}/artifacts", - dependencies=[] if settings.ENABLE_SHARING else [Depends(authenticate_api_key)], - response_model=list[dtos.Artifact], - summary="Get all artifacts for one job", -) -def get_artifacts_for_job( - session: DatabaseSession, - id: UUID = Path(), -) -> list[models.Artifact]: - """ - Returns all artifacts for one job. - See the type of `data` for possible data types. - Returns an empty array for unfinished or non-existant jobs. - """ - artifacts = ( - session.query(models.Artifact).filter(models.Artifact.job_id == str(id)) - ).all() - - return artifacts - - -@api_router.delete( - "/jobs/{id}", - dependencies=[Depends(authenticate_api_key)], - status_code=204, - summary="Delete a job with all artifacts", -) -def delete_transcript( - session: DatabaseSession, - id: UUID = Path(), -) -> None: - """Remove metadata and artifacts for a single job.""" - session.query(models.Job).filter(models.Job.id == str(id)).delete() - return None - - -class PostJobPayload(BaseModel): - url: AnyHttpUrl = Field( + app = FastAPI( description=( - "URL where the media file is available. This needs to be a direct link." + "whisperbox-transcribe is an async HTTP wrapper for openai/whisper." + ), + lifespan=lifespan, + title="whisperbox-transcribe", + ) + + api_router = APIRouter(prefix="/api/v1") + + @api_router.get("/", response_model=None, status_code=204) + def api_root() -> None: + return None + + @api_router.get( + "/jobs", + dependencies=[Depends(authenticate_api_key)], + response_model=list[dtos.Job], + summary="Get metadata for all jobs", + ) + def get_transcripts( + session: DatabaseSession, + type: dtos.JobType | None = None, + ) -> list[models.Job]: + """Get metadata for all jobs.""" + query = session.query(models.Job).order_by(models.Job.created_at.desc()) + + if type: + query = query.filter(models.Job.type == type) + + return query.all() + + @api_router.get( + "/jobs/{id}", + dependencies=[] if settings.ENABLE_SHARING else [Depends(authenticate_api_key)], + response_model=dtos.Job, + summary="Get metadata for one job", + ) + def get_transcript( + session: DatabaseSession, + id: UUID = Path(), + ) -> models.Job | None: + """ + Use this route to check transcription status of any given job. + """ + job = session.query(models.Job).filter(models.Job.id == str(id)).one_or_none() + + if not job: + raise HTTPException(status_code=404) + + return job + + @api_router.get( + "/jobs/{id}/artifacts", + dependencies=[] if settings.ENABLE_SHARING else [Depends(authenticate_api_key)], + response_model=list[dtos.Artifact], + summary="Get all artifacts for one job", + ) + def get_artifacts_for_job( + session: DatabaseSession, + id: UUID = Path(), + ) -> list[models.Artifact]: + """ + Returns all artifacts for one job. + See the type of `data` for possible data types. + Returns an empty array for unfinished or non-existant jobs. + """ + artifacts = ( + session.query(models.Artifact).filter(models.Artifact.job_id == str(id)) + ).all() + + return artifacts + + @api_router.delete( + "/jobs/{id}", + dependencies=[Depends(authenticate_api_key)], + status_code=204, + summary="Delete a job with all artifacts", + ) + def delete_transcript( + session: DatabaseSession, + id: UUID = Path(), + ) -> None: + """Remove metadata and artifacts for a single job.""" + session.query(models.Job).filter(models.Job.id == str(id)).delete() + return None + + class PostJobPayload(BaseModel): + url: AnyHttpUrl = Field( + description=( + "URL where the media file is available. This needs to be a direct link." + ) ) - ) - type: models.JobType = Field( - description="""Type of this job. - `transcript` uses the original language of the audio. - `translation` creates an automatic translation to english. - `language_detection` detects language from the first 30 seconds of audio.""" - ) - - language: str | None = Field( - description=( - "Spoken language in the media file. " - "While optional, this can improve output when set." + type: models.JobType = Field( + description="""Type of this job. + `transcript` uses the original language of the audio. + `translation` creates an automatic translation to english. + `language_detection` detects language from the first 30 seconds of audio.""" ) + + language: str | None = Field( + description=( + "Spoken language in the media file. " + "While optional, this can improve output when set." + ) + ) + + @api_router.post( + "/jobs", + dependencies=[Depends(authenticate_api_key)], + response_model=dtos.Job, + status_code=201, + summary="Enqueue a new job", ) + def create_job( + payload: PostJobPayload, + session: DatabaseSession, + ) -> models.Job: + """ + Enqueue a new whisper job for processing. + Notes: + * Jobs are processed one-by-one in order of creation. + * `payload.url` needs to point directly to a media file. + * The media file is downloaded to a tmp file for the duration of processing. + enough free space needs to be available on disk. + * Media files ideally are audio files with a sampling rate of 16kHz. + other files will be transcoded automatically via ffmpeg which might + consume considerable resources while active. + * Once a job is created, you can query its status by its id. + """ + # create a job with status "create" and save it to the database. + job = models.Job( + url=payload.url, + status=dtos.JobStatus.create, + type=payload.type, + config={"language": payload.language} if payload.language else None, + ) -@api_router.post( - "/jobs", - dependencies=[Depends(authenticate_api_key)], - response_model=dtos.Job, - status_code=201, - summary="Enqueue a new job", -) -def create_job( - payload: PostJobPayload, - session: DatabaseSession, -) -> models.Job: - """ - Enqueue a new whisper job for processing. - Notes: - * Jobs are processed one-by-one in order of creation. - * `payload.url` needs to point directly to a media file. - * The media file is downloaded to a tmp file for the duration of processing. - enough free space needs to be available on disk. - * Media files ideally are audio files with a sampling rate of 16kHz. - other files will be transcoded automatically via ffmpeg which might - consume considerable resources while active. - * Once a job is created, you can query its status by its id. - """ + session.add(job) + session.commit() - # create a job with status "create" and save it to the database. - job = models.Job( - url=payload.url, - status=dtos.JobStatus.create, - type=payload.type, - config={"language": payload.language} if payload.language else None, - ) + task_queue.queue_task(job) - session.add(job) - session.commit() + return job - task_queue.queue_task(job) + app.include_router(api_router) - return job - - -app.include_router(api_router) + return app diff --git a/app/web/task_queue.py b/app/web/task_queue.py index 582f47f..03dd01d 100644 --- a/app/web/task_queue.py +++ b/app/web/task_queue.py @@ -43,6 +43,3 @@ class TaskQueue: for job in jobs: self.queue_task(job) - - -task_queue = TaskQueue() diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index 7a27abf..e6bc559 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -12,7 +12,7 @@ services: - "--entrypoints.web.address=:80" web: - command: bash -c "alembic upgrade head && uvicorn app.web.main:app --reload --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info" + command: bash -c "alembic upgrade head && uvicorn app.web:app --reload --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info" volumes: - ./:/etc/whisperbox-transcribe/ labels: diff --git a/web.Dockerfile b/web.Dockerfile index 3f55f37..e96f386 100644 --- a/web.Dockerfile +++ b/web.Dockerfile @@ -20,4 +20,4 @@ COPY alembic.ini . ENV VIRTUAL_ENV /opt/venv ENV PATH /opt/venv/bin:$PATH -CMD alembic upgrade head && uvicorn app.web.main:app --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info --workers 4 --proxy-headers +CMD alembic upgrade head && uvicorn app.web:app --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info --workers 4 --proxy-headers