docs: improve openapi documentation

closes #6
This commit is contained in:
Felix Spöttel
2023-02-08 17:23:59 +01:00
parent d9ce63ee39
commit 4f020853b6
9 changed files with 95 additions and 34 deletions

View File

@@ -3,8 +3,8 @@ from logging.config import fileConfig
from alembic import context
from sqlalchemy import engine_from_config, pool
from app.shared.settings import settings
from app.shared.db.models import Base
from app.shared.settings import settings
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.

View File

@@ -1,9 +1,9 @@
import enum
from datetime import datetime
from typing import Any, List, Optional
from typing import List, Optional
from uuid import UUID
from pydantic import AnyHttpUrl, BaseModel
from pydantic import AnyHttpUrl, BaseModel, Field
class WithDbFields(BaseModel):
@@ -26,6 +26,8 @@ class JobType(str, enum.Enum):
class JobStatus(str, enum.Enum):
"""Processing status of a job."""
create = "create"
processing = "processing"
error = "error"
@@ -33,15 +35,32 @@ class JobStatus(str, enum.Enum):
class JobConfig(BaseModel):
language: Optional[str]
"""Configuration for a job."""
# TODO: limit to locales selected by whisper.
language: Optional[str] = Field(
description=(
"Spoken language in the media file."
"While optional, this can improve output "
"by selecting a language-specific model. (applies to 'en')"
)
)
class JobMeta(BaseModel):
error: Optional[str]
task_id: Optional[UUID]
"""Metadata relating to a job's execution."""
error: Optional[str] = Field(
description="Will contain a descriptive error message if processing failed."
)
task_id: Optional[UUID] = Field(
description="Internal celery id of this job submission."
)
class Job(WithDbFields):
"""A transcription job for one media file."""
status: JobStatus
type: JobType
url: AnyHttpUrl
@@ -50,6 +69,8 @@ class Job(WithDbFields):
class RawTranscript(BaseModel):
"""A single transcript passage returned by whisper."""
id: int
seek: int
start: float
@@ -63,6 +84,8 @@ class RawTranscript(BaseModel):
class Artifact(WithDbFields):
"""whisper output for one job."""
data: Optional[List[RawTranscript]]
job_id: UUID
type: ArtifactType

View File

@@ -5,8 +5,8 @@ from sqlalchemy.orm import Session
from sqlalchemy_utils import create_database, database_exists, drop_database
import app.shared.db.models as models
from app.shared.settings import settings
from app.shared.db.base import SessionLocal, engine, get_session
from app.shared.settings import settings
from app.web.main import app

View File

@@ -4,8 +4,8 @@ import pytest
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session
import app.shared.db.schemas as schemas
import app.shared.db.models as models
import app.shared.db.schemas as schemas
from app.shared.db.schemas import JobStatus, JobType
from app.web.main import app
@@ -77,12 +77,13 @@ def test_get_job_not_found(auth_headers: Dict[str, str], mock_job: models.Job) -
"/api/v1/jobs/c8ecf5ea-77cf-48a2-9ecd-199ef35e0ccb",
headers=auth_headers,
)
assert res.status_code == 404
# GET /api/v1/jobs/:id/artifacts
# ---
def test_get_artifact_pass(
def test_get_artifacts_pass(
auth_headers: Dict[str, str], db_session: Session, mock_job: models.Job
) -> None:
artifact = models.Artifact(
@@ -102,14 +103,16 @@ def test_get_artifact_pass(
assert res.json()[0]["id"] == str(artifact.id)
def test_get_artifact_not_found(
def test_get_artifacts_not_found(
auth_headers: Dict[str, str], mock_job: models.Job
) -> None:
res = client.get(
f"/api/v1/jobs/{mock_job.id}/artifacts",
headers=auth_headers,
)
assert res.status_code == 404
assert len(res.json()) == 0
assert res.status_code == 200
# DELETE /api/v1/jobs

View File

@@ -14,7 +14,7 @@ def test_authorization_header_missing() -> None:
def test_authorization_header_malformed() -> None:
res = client.get("/api/v1", headers={"Authorization": "Bearer"})
assert res.status_code == 422
assert res.status_code == 401
def test_incorrect_api_key() -> None:
@@ -24,4 +24,4 @@ def test_incorrect_api_key() -> None:
def test_existing_api_key(auth_headers: Dict[str, str]) -> None:
res = client.get("/api/v1", headers=auth_headers)
assert res.status_code == 200
assert res.status_code == 204

33
app/web/dtos.py Normal file
View File

@@ -0,0 +1,33 @@
from typing import Any, Dict, Optional
from pydantic import AnyHttpUrl, BaseModel, Field
import app.shared.db.schemas as schemas
class DetailResponse(BaseModel):
detail: str
DEFAULT_RESPONSES: Dict[int | str, Dict[str, Any]] = {
401: {"model": DetailResponse, "description": "Not authenticated"}
}
class PostJobPayload(BaseModel):
url: AnyHttpUrl = Field(
description=(
"URL where the media file is available. This needs to be a direct link."
)
)
type: schemas.JobType = Field(description="Type of this job.")
# TODO: limit to locales selected by whisper.
language: Optional[str] = Field(
description=(
"Spoken language in the media file."
"While optional, this can improve output "
"by selecting a language-specific model. (applies to 'en')"
)
)

View File

@@ -3,20 +3,24 @@ from typing import List, Optional
from uuid import UUID
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path
from pydantic import AnyHttpUrl, BaseModel
from sqlalchemy import or_
from sqlalchemy.orm import Session
import app.shared.db.schemas as schemas
import app.shared.db.models as models
import app.shared.db.schemas as schemas
from app.shared.celery import get_celery_binding
from app.shared.db.base import get_session
from app.web.dtos import DEFAULT_RESPONSES, DetailResponse, PostJobPayload
from app.web.security import authenticate_api_key
app = FastAPI()
celery = get_celery_binding()
api_router = APIRouter(prefix="/api/v1", dependencies=[Depends(authenticate_api_key)])
api_router = APIRouter(
prefix="/api/v1",
dependencies=[Depends(authenticate_api_key)],
responses={**DEFAULT_RESPONSES},
)
def queue_task(job: models.Job) -> None:
@@ -28,17 +32,11 @@ def queue_task(job: models.Job) -> None:
transcribe.delay(job.id)
@api_router.get("/")
@api_router.get("/", response_model=None, status_code=204)
def api_root() -> None:
return None
class PostJobPayload(BaseModel):
url: AnyHttpUrl
type: schemas.JobType
language: Optional[str]
@api_router.post("/jobs", response_model=schemas.Job, status_code=201)
def create_job(
payload: PostJobPayload,
@@ -77,7 +75,11 @@ def get_transcripts(
return query.all()
@api_router.get("/jobs/{id}", response_model=schemas.Job)
@api_router.get(
"/jobs/{id}",
response_model=schemas.Job,
responses={404: {"model": DetailResponse, "description": "Not authenticated"}},
)
def get_transcript(
id: UUID = Path(), session: Session = Depends(get_session)
) -> Optional[models.Job]:
@@ -95,9 +97,6 @@ def get_artifacts_for_job(
session.query(models.Artifact).filter(models.Artifact.job_id == str(id))
).all()
if not len(artifacts):
raise HTTPException(status_code=404)
return artifacts

View File

@@ -1,16 +1,16 @@
from hmac import compare_digest
from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from app.shared.settings import settings
def authenticate_api_key(
token: str = Depends(OAuth2PasswordBearer(tokenUrl="token")),
credentials: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
) -> None:
if not token:
raise HTTPException(status_code=422)
# use compare_digest to counter timing attacks.
if not compare_digest(settings.API_SECRET, token):
if not credentials or not compare_digest(
settings.API_SECRET, credentials.credentials
):
raise HTTPException(status_code=401)

View File

@@ -4,8 +4,8 @@ from uuid import UUID
from celery import Task
from sqlalchemy.orm import Session
import app.shared.db.schemas as schemas
import app.shared.db.models as models
import app.shared.db.schemas as schemas
from app.shared.celery import get_celery_binding
from app.shared.db.base import SessionLocal
from app.worker.strategies.local import LocalStrategy
@@ -19,7 +19,10 @@ def transcribe(self: Task, job_id: UUID) -> None:
db: Session = SessionLocal()
job = db.query(models.Job).filter(models.Job.id == job_id).one()
if job.status == schemas.JobStatus.error or job.status == schemas.JobStatus.success:
if (
job.status == schemas.JobStatus.error
or job.status == schemas.JobStatus.success
):
logger.warn(
"[{job.id}]: Received job that has already been processed, abort."
)