From 4f020853b6b23ddb748cba2e2d4143ff6f005eec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20Sp=C3=B6ttel?= <1682504+fspoettel@users.noreply.github.com> Date: Wed, 8 Feb 2023 17:23:59 +0100 Subject: [PATCH] docs: improve openapi documentation closes #6 --- app/shared/db/alembic/env.py | 2 +- app/shared/db/schemas.py | 33 ++++++++++++++++++++++++++++----- app/tests/conftest.py | 2 +- app/tests/test_api.py | 11 +++++++---- app/tests/test_auth.py | 4 ++-- app/web/dtos.py | 33 +++++++++++++++++++++++++++++++++ app/web/main.py | 27 +++++++++++++-------------- app/web/security.py | 10 +++++----- app/worker/main.py | 7 +++++-- 9 files changed, 95 insertions(+), 34 deletions(-) create mode 100644 app/web/dtos.py diff --git a/app/shared/db/alembic/env.py b/app/shared/db/alembic/env.py index d150dc6..6d9ed9f 100644 --- a/app/shared/db/alembic/env.py +++ b/app/shared/db/alembic/env.py @@ -3,8 +3,8 @@ from logging.config import fileConfig from alembic import context from sqlalchemy import engine_from_config, pool -from app.shared.settings import settings from app.shared.db.models import Base +from app.shared.settings import settings # this is the Alembic Config object, which provides # access to the values within the .ini file in use. diff --git a/app/shared/db/schemas.py b/app/shared/db/schemas.py index dc2e2d7..87a7281 100644 --- a/app/shared/db/schemas.py +++ b/app/shared/db/schemas.py @@ -1,9 +1,9 @@ import enum from datetime import datetime -from typing import Any, List, Optional +from typing import List, Optional from uuid import UUID -from pydantic import AnyHttpUrl, BaseModel +from pydantic import AnyHttpUrl, BaseModel, Field class WithDbFields(BaseModel): @@ -26,6 +26,8 @@ class JobType(str, enum.Enum): class JobStatus(str, enum.Enum): + """Processing status of a job.""" + create = "create" processing = "processing" error = "error" @@ -33,15 +35,32 @@ class JobStatus(str, enum.Enum): class JobConfig(BaseModel): - language: Optional[str] + """Configuration for a job.""" + + # TODO: limit to locales selected by whisper. + language: Optional[str] = Field( + description=( + "Spoken language in the media file." + "While optional, this can improve output " + "by selecting a language-specific model. (applies to 'en')" + ) + ) class JobMeta(BaseModel): - error: Optional[str] - task_id: Optional[UUID] + """Metadata relating to a job's execution.""" + + error: Optional[str] = Field( + description="Will contain a descriptive error message if processing failed." + ) + task_id: Optional[UUID] = Field( + description="Internal celery id of this job submission." + ) class Job(WithDbFields): + """A transcription job for one media file.""" + status: JobStatus type: JobType url: AnyHttpUrl @@ -50,6 +69,8 @@ class Job(WithDbFields): class RawTranscript(BaseModel): + """A single transcript passage returned by whisper.""" + id: int seek: int start: float @@ -63,6 +84,8 @@ class RawTranscript(BaseModel): class Artifact(WithDbFields): + """whisper output for one job.""" + data: Optional[List[RawTranscript]] job_id: UUID type: ArtifactType diff --git a/app/tests/conftest.py b/app/tests/conftest.py index 5039eac..7edafb3 100644 --- a/app/tests/conftest.py +++ b/app/tests/conftest.py @@ -5,8 +5,8 @@ from sqlalchemy.orm import Session from sqlalchemy_utils import create_database, database_exists, drop_database import app.shared.db.models as models -from app.shared.settings import settings from app.shared.db.base import SessionLocal, engine, get_session +from app.shared.settings import settings from app.web.main import app diff --git a/app/tests/test_api.py b/app/tests/test_api.py index 9dd073d..2d2754d 100644 --- a/app/tests/test_api.py +++ b/app/tests/test_api.py @@ -4,8 +4,8 @@ import pytest from fastapi.testclient import TestClient from sqlalchemy.orm import Session -import app.shared.db.schemas as schemas import app.shared.db.models as models +import app.shared.db.schemas as schemas from app.shared.db.schemas import JobStatus, JobType from app.web.main import app @@ -77,12 +77,13 @@ def test_get_job_not_found(auth_headers: Dict[str, str], mock_job: models.Job) - "/api/v1/jobs/c8ecf5ea-77cf-48a2-9ecd-199ef35e0ccb", headers=auth_headers, ) + assert res.status_code == 404 # GET /api/v1/jobs/:id/artifacts # --- -def test_get_artifact_pass( +def test_get_artifacts_pass( auth_headers: Dict[str, str], db_session: Session, mock_job: models.Job ) -> None: artifact = models.Artifact( @@ -102,14 +103,16 @@ def test_get_artifact_pass( assert res.json()[0]["id"] == str(artifact.id) -def test_get_artifact_not_found( +def test_get_artifacts_not_found( auth_headers: Dict[str, str], mock_job: models.Job ) -> None: res = client.get( f"/api/v1/jobs/{mock_job.id}/artifacts", headers=auth_headers, ) - assert res.status_code == 404 + + assert len(res.json()) == 0 + assert res.status_code == 200 # DELETE /api/v1/jobs diff --git a/app/tests/test_auth.py b/app/tests/test_auth.py index d1501cc..c2b69a5 100644 --- a/app/tests/test_auth.py +++ b/app/tests/test_auth.py @@ -14,7 +14,7 @@ def test_authorization_header_missing() -> None: def test_authorization_header_malformed() -> None: res = client.get("/api/v1", headers={"Authorization": "Bearer"}) - assert res.status_code == 422 + assert res.status_code == 401 def test_incorrect_api_key() -> None: @@ -24,4 +24,4 @@ def test_incorrect_api_key() -> None: def test_existing_api_key(auth_headers: Dict[str, str]) -> None: res = client.get("/api/v1", headers=auth_headers) - assert res.status_code == 200 + assert res.status_code == 204 diff --git a/app/web/dtos.py b/app/web/dtos.py new file mode 100644 index 0000000..a5df059 --- /dev/null +++ b/app/web/dtos.py @@ -0,0 +1,33 @@ +from typing import Any, Dict, Optional + +from pydantic import AnyHttpUrl, BaseModel, Field + +import app.shared.db.schemas as schemas + + +class DetailResponse(BaseModel): + detail: str + + +DEFAULT_RESPONSES: Dict[int | str, Dict[str, Any]] = { + 401: {"model": DetailResponse, "description": "Not authenticated"} +} + + +class PostJobPayload(BaseModel): + url: AnyHttpUrl = Field( + description=( + "URL where the media file is available. This needs to be a direct link." + ) + ) + + type: schemas.JobType = Field(description="Type of this job.") + + # TODO: limit to locales selected by whisper. + language: Optional[str] = Field( + description=( + "Spoken language in the media file." + "While optional, this can improve output " + "by selecting a language-specific model. (applies to 'en')" + ) + ) diff --git a/app/web/main.py b/app/web/main.py index fd3c903..62a9738 100644 --- a/app/web/main.py +++ b/app/web/main.py @@ -3,20 +3,24 @@ from typing import List, Optional from uuid import UUID from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path -from pydantic import AnyHttpUrl, BaseModel from sqlalchemy import or_ from sqlalchemy.orm import Session -import app.shared.db.schemas as schemas import app.shared.db.models as models +import app.shared.db.schemas as schemas from app.shared.celery import get_celery_binding from app.shared.db.base import get_session +from app.web.dtos import DEFAULT_RESPONSES, DetailResponse, PostJobPayload from app.web.security import authenticate_api_key app = FastAPI() celery = get_celery_binding() -api_router = APIRouter(prefix="/api/v1", dependencies=[Depends(authenticate_api_key)]) +api_router = APIRouter( + prefix="/api/v1", + dependencies=[Depends(authenticate_api_key)], + responses={**DEFAULT_RESPONSES}, +) def queue_task(job: models.Job) -> None: @@ -28,17 +32,11 @@ def queue_task(job: models.Job) -> None: transcribe.delay(job.id) -@api_router.get("/") +@api_router.get("/", response_model=None, status_code=204) def api_root() -> None: return None -class PostJobPayload(BaseModel): - url: AnyHttpUrl - type: schemas.JobType - language: Optional[str] - - @api_router.post("/jobs", response_model=schemas.Job, status_code=201) def create_job( payload: PostJobPayload, @@ -77,7 +75,11 @@ def get_transcripts( return query.all() -@api_router.get("/jobs/{id}", response_model=schemas.Job) +@api_router.get( + "/jobs/{id}", + response_model=schemas.Job, + responses={404: {"model": DetailResponse, "description": "Not authenticated"}}, +) def get_transcript( id: UUID = Path(), session: Session = Depends(get_session) ) -> Optional[models.Job]: @@ -95,9 +97,6 @@ def get_artifacts_for_job( session.query(models.Artifact).filter(models.Artifact.job_id == str(id)) ).all() - if not len(artifacts): - raise HTTPException(status_code=404) - return artifacts diff --git a/app/web/security.py b/app/web/security.py index 635c46d..6d66ef1 100644 --- a/app/web/security.py +++ b/app/web/security.py @@ -1,16 +1,16 @@ from hmac import compare_digest from fastapi import Depends, HTTPException -from fastapi.security import OAuth2PasswordBearer +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from app.shared.settings import settings def authenticate_api_key( - token: str = Depends(OAuth2PasswordBearer(tokenUrl="token")), + credentials: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), ) -> None: - if not token: - raise HTTPException(status_code=422) # use compare_digest to counter timing attacks. - if not compare_digest(settings.API_SECRET, token): + if not credentials or not compare_digest( + settings.API_SECRET, credentials.credentials + ): raise HTTPException(status_code=401) diff --git a/app/worker/main.py b/app/worker/main.py index b410c31..c6acc09 100644 --- a/app/worker/main.py +++ b/app/worker/main.py @@ -4,8 +4,8 @@ from uuid import UUID from celery import Task from sqlalchemy.orm import Session -import app.shared.db.schemas as schemas import app.shared.db.models as models +import app.shared.db.schemas as schemas from app.shared.celery import get_celery_binding from app.shared.db.base import SessionLocal from app.worker.strategies.local import LocalStrategy @@ -19,7 +19,10 @@ def transcribe(self: Task, job_id: UUID) -> None: db: Session = SessionLocal() job = db.query(models.Job).filter(models.Job.id == job_id).one() - if job.status == schemas.JobStatus.error or job.status == schemas.JobStatus.success: + if ( + job.status == schemas.JobStatus.error + or job.status == schemas.JobStatus.success + ): logger.warn( "[{job.id}]: Received job that has already been processed, abort." )