test: add integration test for sharing

This commit is contained in:
Felix Spöttel
2023-06-29 14:59:43 +02:00
parent f01ea48f57
commit 05ebc17215
9 changed files with 236 additions and 219 deletions

3
.gitignore vendored
View File

@@ -166,3 +166,6 @@ cython_debug/
whisperbox-transcribe.sqlite*
*shm
*wal
# ruff
.ruff_cache

View File

@@ -1,13 +1,11 @@
from typing import Generator
import pytest
from sqlalchemy.orm import Session
from fastapi.testclient import TestClient
from sqlalchemy_utils import create_database, database_exists, drop_database
import app.shared.db.models as models
from app.shared.db.base import SessionLocal, engine, get_session
from app.shared.db.base import SessionLocal, engine
from app.shared.settings import settings
from app.web.main import app
from app.web.main import app_factory
def pytest_configure() -> None:
@@ -26,15 +24,48 @@ def auth_headers() -> dict[str, str]:
@pytest.fixture(scope="function", autouse=True)
def db_session() -> Generator[Session, None, None]:
def db_session():
models.Base.metadata.create_all(engine)
connection = engine.connect()
with SessionLocal(bind=connection) as session:
app.dependency_overrides[get_session] = lambda: session
yield session
app.dependency_overrides.clear()
connection.close()
models.Base.metadata.drop_all(bind=engine)
@pytest.fixture(scope="function")
def client(db_session):
app = app_factory(lambda: db_session)
client = TestClient(app)
return client
@pytest.fixture(scope="function", autouse=False)
def mock_job(db_session):
job = models.Job(
url="https://example.com",
type=models.JobType.transcript,
status=models.JobStatus.create,
)
db_session.add(job)
db_session.flush()
return job
@pytest.fixture(scope="function", autouse=False)
def mock_artifact(db_session, mock_job):
artifact = models.Artifact(
data=None, job_id=str(mock_job.id), type=models.ArtifactType.raw_transcript
)
db_session.add(artifact)
db_session.flush()
return artifact
@pytest.fixture(scope="function")
def sharing_enabled():
settings.ENABLE_SHARING = True
yield
settings.ENABLE_SHARING = False

View File

@@ -1,30 +1,12 @@
import pytest
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session
import app.shared.db.models as models
from app.web.main import app
client = TestClient(app)
@pytest.fixture(name="mock_job", scope="function", autouse=False)
def mock_job(db_session: Session) -> models.Job:
job = models.Job(
url="https://example.com",
type=models.JobType.transcript,
status=models.JobStatus.create,
)
db_session.add(job)
db_session.flush()
return job
from app.web.main import app_factory
# POST /api/v1/jobs
# ---
def test_create_job_pass(auth_headers: dict[str, str]) -> None:
def test_create_job_pass(client, auth_headers: dict[str, str]):
res = client.post(
"/api/v1/jobs",
headers=auth_headers,
@@ -34,12 +16,12 @@ def test_create_job_pass(auth_headers: dict[str, str]) -> None:
assert isinstance(res.json()["id"], str)
def test_create_job_missing_body(auth_headers: dict[str, str]) -> None:
def test_create_job_missing_body(client, auth_headers: dict[str, str]):
res = client.post("/api/v1/jobs", headers=auth_headers, json={})
assert res.status_code == 422
def test_create_job_malformed_url(auth_headers: dict[str, str]) -> None:
def test_create_job_malformed_url(client, auth_headers: dict[str, str]):
res = client.post(
"/api/v1/jobs",
headers=auth_headers,
@@ -50,7 +32,7 @@ def test_create_job_malformed_url(auth_headers: dict[str, str]) -> None:
# GET /api/v1/jobs
# ---
def test_get_jobs_pass(auth_headers: dict[str, str], mock_job: models.Job) -> None:
def test_get_jobs_pass(client, auth_headers: dict[str, str], mock_job: models.Job):
res = client.get(
"/api/v1/jobs?type=transcribe",
headers=auth_headers,
@@ -61,7 +43,7 @@ def test_get_jobs_pass(auth_headers: dict[str, str], mock_job: models.Job) -> No
# GET /api/v1/jobs/:id
# ---
def test_get_job_pass(auth_headers: dict[str, str], mock_job: models.Job) -> None:
def test_get_job_pass(client, auth_headers: dict[str, str], mock_job: models.Job):
res = client.get(
f"/api/v1/jobs/{mock_job.id}",
headers=auth_headers,
@@ -70,7 +52,7 @@ def test_get_job_pass(auth_headers: dict[str, str], mock_job: models.Job) -> Non
assert res.json()["id"] == str(mock_job.id)
def test_get_job_not_found(auth_headers: dict[str, str], mock_job: models.Job) -> None:
def test_get_job_not_found(client, auth_headers: dict[str, str], mock_job):
res = client.get(
"/api/v1/jobs/c8ecf5ea-77cf-48a2-9ecd-199ef35e0ccb",
headers=auth_headers,
@@ -79,18 +61,29 @@ def test_get_job_not_found(auth_headers: dict[str, str], mock_job: models.Job) -
assert res.status_code == 404
# GET /api/v1/jobs/:id/artifacts
# ---
def test_get_artifacts_pass(
auth_headers: dict[str, str], db_session: Session, mock_job: models.Job
) -> None:
artifact = models.Artifact(
data=None, job_id=str(mock_job.id), type=models.ArtifactType.raw_transcript
def test_get_job_sharing_disabled(client, mock_job):
res = client.get(
f"/api/v1/jobs/{mock_job.id}",
headers={},
)
assert res.status_code == 401
def test_get_job_sharing_enabled(db_session, mock_job, sharing_enabled):
# HACK: delay construction until settings are patched.
client = TestClient(app_factory(lambda: db_session))
res = client.get(
f"/api/v1/jobs/{mock_job.id}",
headers={},
)
db_session.add(artifact)
db_session.flush()
assert res.status_code == 200
# GET /api/v1/jobs/:id/artifacts
# ---
def test_get_artifacts_pass(client, auth_headers, db_session, mock_job, mock_artifact):
res = client.get(
f"/api/v1/jobs/{mock_job.id}/artifacts",
headers=auth_headers,
@@ -98,12 +91,10 @@ def test_get_artifacts_pass(
assert res.status_code == 200
assert res.json()[0]["job_id"] == str(mock_job.id)
assert res.json()[0]["id"] == str(artifact.id)
assert res.json()[0]["id"] == str(mock_artifact.id)
def test_get_artifacts_not_found(
auth_headers: dict[str, str], mock_job: models.Job
) -> None:
def test_get_artifacts_not_found(client, auth_headers, mock_job):
res = client.get(
f"/api/v1/jobs/{mock_job.id}/artifacts",
headers=auth_headers,
@@ -115,12 +106,11 @@ def test_get_artifacts_not_found(
# DELETE /api/v1/jobs
# ---
def test_delete_job_pass(
auth_headers: dict[str, str], mock_job: models.Job, db_session: Session
) -> None:
def test_delete_job_pass(client, auth_headers, mock_job, db_session):
res = client.delete(
f"/api/v1/jobs/{mock_job.id}",
headers=auth_headers,
)
assert db_session.query(models.Job).count() == 0
assert res.status_code == 204

View File

@@ -1,25 +1,18 @@
from fastapi.testclient import TestClient
from app.web.main import app
client = TestClient(app)
def test_authorization_header_missing() -> None:
def test_authorization_header_missing(client):
res = client.get("/api/v1/jobs")
assert res.status_code == 401
def test_authorization_header_malformed() -> None:
def test_authorization_header_malformed(client):
res = client.get("/api/v1/jobs", headers={"Authorization": "Bearer"})
assert res.status_code == 401
def test_incorrect_api_key() -> None:
def test_incorrect_api_key(client):
res = client.get("/api/v1/jobs", headers={"Authorization": "Bearer incorrect"})
assert res.status_code == 401
def test_existing_api_key(auth_headers: dict[str, str]) -> None:
def test_existing_api_key(client, auth_headers):
res = client.get("/api/v1/jobs", headers=auth_headers)
assert res.status_code == 200

View File

@@ -0,0 +1,4 @@
from app.shared.db.base import get_session
from app.web.main import app_factory
app = app_factory(get_session)

View File

@@ -1,5 +1,5 @@
from contextlib import asynccontextmanager
from typing import Annotated
from typing import Annotated, Callable, Generator
from uuid import UUID
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path
@@ -8,173 +8,172 @@ from sqlalchemy.orm import Session
import app.shared.db.models as models
import app.web.dtos as dtos
from app.shared.db.base import SessionLocal, get_session
from app.shared.db.base import SessionLocal
from app.shared.settings import settings
from app.web.security import authenticate_api_key
from app.web.task_queue import task_queue
DatabaseSession = Annotated[Session, Depends(get_session)]
from app.web.task_queue import TaskQueue
@asynccontextmanager
async def lifespan(_: FastAPI):
with SessionLocal() as session:
task_queue.rehydrate(session)
yield
def app_factory(
session_getter: Callable[[], Generator[Session, None, None]]
) -> FastAPI:
DatabaseSession = Annotated[Session, Depends(session_getter)]
task_queue = TaskQueue()
app = FastAPI(
description="whisperbox-transcribe is an async HTTP wrapper for openai/whisper.",
lifespan=lifespan,
title="whisperbox-transcribe",
)
@asynccontextmanager
async def lifespan(_: FastAPI):
with SessionLocal() as session:
task_queue.rehydrate(session)
yield
api_router = APIRouter(prefix="/api/v1")
@api_router.get("/", response_model=None, status_code=204)
def api_root() -> None:
return None
@api_router.get(
"/jobs",
dependencies=[Depends(authenticate_api_key)],
response_model=list[dtos.Job],
summary="Get metadata for all jobs",
)
def get_transcripts(
session: DatabaseSession,
type: dtos.JobType | None = None,
) -> list[models.Job]:
"""Get metadata for all jobs."""
query = session.query(models.Job).order_by(models.Job.created_at.desc())
if type:
query = query.filter(models.Job.type == type)
return query.all()
@api_router.get(
"/jobs/{id}",
dependencies=[] if settings.ENABLE_SHARING else [Depends(authenticate_api_key)],
response_model=dtos.Job,
summary="Get metadata for one job",
)
def get_transcript(
session: DatabaseSession,
id: UUID = Path(),
) -> models.Job | None:
"""
Use this route to check transcription status of any given job.
"""
job = session.query(models.Job).filter(models.Job.id == str(id)).one_or_none()
if not job:
raise HTTPException(status_code=404)
return job
@api_router.get(
"/jobs/{id}/artifacts",
dependencies=[] if settings.ENABLE_SHARING else [Depends(authenticate_api_key)],
response_model=list[dtos.Artifact],
summary="Get all artifacts for one job",
)
def get_artifacts_for_job(
session: DatabaseSession,
id: UUID = Path(),
) -> list[models.Artifact]:
"""
Returns all artifacts for one job.
See the type of `data` for possible data types.
Returns an empty array for unfinished or non-existant jobs.
"""
artifacts = (
session.query(models.Artifact).filter(models.Artifact.job_id == str(id))
).all()
return artifacts
@api_router.delete(
"/jobs/{id}",
dependencies=[Depends(authenticate_api_key)],
status_code=204,
summary="Delete a job with all artifacts",
)
def delete_transcript(
session: DatabaseSession,
id: UUID = Path(),
) -> None:
"""Remove metadata and artifacts for a single job."""
session.query(models.Job).filter(models.Job.id == str(id)).delete()
return None
class PostJobPayload(BaseModel):
url: AnyHttpUrl = Field(
app = FastAPI(
description=(
"URL where the media file is available. This needs to be a direct link."
"whisperbox-transcribe is an async HTTP wrapper for openai/whisper."
),
lifespan=lifespan,
title="whisperbox-transcribe",
)
api_router = APIRouter(prefix="/api/v1")
@api_router.get("/", response_model=None, status_code=204)
def api_root() -> None:
return None
@api_router.get(
"/jobs",
dependencies=[Depends(authenticate_api_key)],
response_model=list[dtos.Job],
summary="Get metadata for all jobs",
)
def get_transcripts(
session: DatabaseSession,
type: dtos.JobType | None = None,
) -> list[models.Job]:
"""Get metadata for all jobs."""
query = session.query(models.Job).order_by(models.Job.created_at.desc())
if type:
query = query.filter(models.Job.type == type)
return query.all()
@api_router.get(
"/jobs/{id}",
dependencies=[] if settings.ENABLE_SHARING else [Depends(authenticate_api_key)],
response_model=dtos.Job,
summary="Get metadata for one job",
)
def get_transcript(
session: DatabaseSession,
id: UUID = Path(),
) -> models.Job | None:
"""
Use this route to check transcription status of any given job.
"""
job = session.query(models.Job).filter(models.Job.id == str(id)).one_or_none()
if not job:
raise HTTPException(status_code=404)
return job
@api_router.get(
"/jobs/{id}/artifacts",
dependencies=[] if settings.ENABLE_SHARING else [Depends(authenticate_api_key)],
response_model=list[dtos.Artifact],
summary="Get all artifacts for one job",
)
def get_artifacts_for_job(
session: DatabaseSession,
id: UUID = Path(),
) -> list[models.Artifact]:
"""
Returns all artifacts for one job.
See the type of `data` for possible data types.
Returns an empty array for unfinished or non-existant jobs.
"""
artifacts = (
session.query(models.Artifact).filter(models.Artifact.job_id == str(id))
).all()
return artifacts
@api_router.delete(
"/jobs/{id}",
dependencies=[Depends(authenticate_api_key)],
status_code=204,
summary="Delete a job with all artifacts",
)
def delete_transcript(
session: DatabaseSession,
id: UUID = Path(),
) -> None:
"""Remove metadata and artifacts for a single job."""
session.query(models.Job).filter(models.Job.id == str(id)).delete()
return None
class PostJobPayload(BaseModel):
url: AnyHttpUrl = Field(
description=(
"URL where the media file is available. This needs to be a direct link."
)
)
)
type: models.JobType = Field(
description="""Type of this job.
`transcript` uses the original language of the audio.
`translation` creates an automatic translation to english.
`language_detection` detects language from the first 30 seconds of audio."""
)
language: str | None = Field(
description=(
"Spoken language in the media file. "
"While optional, this can improve output when set."
type: models.JobType = Field(
description="""Type of this job.
`transcript` uses the original language of the audio.
`translation` creates an automatic translation to english.
`language_detection` detects language from the first 30 seconds of audio."""
)
language: str | None = Field(
description=(
"Spoken language in the media file. "
"While optional, this can improve output when set."
)
)
@api_router.post(
"/jobs",
dependencies=[Depends(authenticate_api_key)],
response_model=dtos.Job,
status_code=201,
summary="Enqueue a new job",
)
def create_job(
payload: PostJobPayload,
session: DatabaseSession,
) -> models.Job:
"""
Enqueue a new whisper job for processing.
Notes:
* Jobs are processed one-by-one in order of creation.
* `payload.url` needs to point directly to a media file.
* The media file is downloaded to a tmp file for the duration of processing.
enough free space needs to be available on disk.
* Media files ideally are audio files with a sampling rate of 16kHz.
other files will be transcoded automatically via ffmpeg which might
consume considerable resources while active.
* Once a job is created, you can query its status by its id.
"""
# create a job with status "create" and save it to the database.
job = models.Job(
url=payload.url,
status=dtos.JobStatus.create,
type=payload.type,
config={"language": payload.language} if payload.language else None,
)
@api_router.post(
"/jobs",
dependencies=[Depends(authenticate_api_key)],
response_model=dtos.Job,
status_code=201,
summary="Enqueue a new job",
)
def create_job(
payload: PostJobPayload,
session: DatabaseSession,
) -> models.Job:
"""
Enqueue a new whisper job for processing.
Notes:
* Jobs are processed one-by-one in order of creation.
* `payload.url` needs to point directly to a media file.
* The media file is downloaded to a tmp file for the duration of processing.
enough free space needs to be available on disk.
* Media files ideally are audio files with a sampling rate of 16kHz.
other files will be transcoded automatically via ffmpeg which might
consume considerable resources while active.
* Once a job is created, you can query its status by its id.
"""
session.add(job)
session.commit()
# create a job with status "create" and save it to the database.
job = models.Job(
url=payload.url,
status=dtos.JobStatus.create,
type=payload.type,
config={"language": payload.language} if payload.language else None,
)
task_queue.queue_task(job)
session.add(job)
session.commit()
return job
task_queue.queue_task(job)
app.include_router(api_router)
return job
app.include_router(api_router)
return app

View File

@@ -43,6 +43,3 @@ class TaskQueue:
for job in jobs:
self.queue_task(job)
task_queue = TaskQueue()

View File

@@ -12,7 +12,7 @@ services:
- "--entrypoints.web.address=:80"
web:
command: bash -c "alembic upgrade head && uvicorn app.web.main:app --reload --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info"
command: bash -c "alembic upgrade head && uvicorn app.web:app --reload --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info"
volumes:
- ./:/etc/whisperbox-transcribe/
labels:

View File

@@ -20,4 +20,4 @@ COPY alembic.ini .
ENV VIRTUAL_ENV /opt/venv
ENV PATH /opt/venv/bin:$PATH
CMD alembic upgrade head && uvicorn app.web.main:app --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info --workers 4 --proxy-headers
CMD alembic upgrade head && uvicorn app.web:app --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info --workers 4 --proxy-headers