feat: add artifact routes, test web api

This commit is contained in:
Felix Spöttel
2023-01-18 12:33:22 +01:00
parent 4fa1d5c0da
commit 7b6b453b45
10 changed files with 239 additions and 120 deletions

View File

@@ -6,9 +6,9 @@ fmt:
black app
isort app
test:
ENVIRONMENT=test pytest
lint:
mypy app
flake8 app
test:
pytest

View File

@@ -1 +1,3 @@
# whisper-api
> Tiny HTTP wrapper around [openai/whisper](https://github.com/openai/whisper).

View File

@@ -1,4 +1,4 @@
import os
import sys
from pydantic import BaseSettings
@@ -10,7 +10,7 @@ class Settings(BaseSettings):
REDIS_URI: str
if "ENVIRONMENT" in os.environ and os.environ["ENVIRONMENT"] == "test":
settings = Settings(_env_file=".env.test") # type: ignore
if "pytest" in sys.modules:
settings = Settings(_env_file=".env.test", _env_file_encoding="utf-8") # type: ignore
else:
settings = Settings()

View File

@@ -0,0 +1,51 @@
"""add_job_tables
Revision ID: bb249ed79907
Revises:
Create Date: 2023-01-17 14:30:30.920466
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'bb249ed79907'
down_revision = None
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('jobs',
sa.Column('url', sa.String(length=2048), nullable=True),
sa.Column('status', sa.Enum('create', 'error', 'success', name='jobstatus'), nullable=False),
sa.Column('type', sa.Enum('transcript', name='jobtype'), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_jobs_id'), 'jobs', ['id'], unique=False)
op.create_table('artifacts',
sa.Column('job_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('data', sa.JSON(none_as_null=True), nullable=True),
sa.Column('type', sa.Enum('raw_transcript', name='artifacttype'), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
sa.ForeignKeyConstraint(['job_id'], ['jobs.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_artifacts_id'), 'artifacts', ['id'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_artifacts_id'), table_name='artifacts')
op.drop_table('artifacts')
op.drop_index(op.f('ix_jobs_id'), table_name='jobs')
op.drop_table('jobs')
# ### end Alembic commands ###

View File

@@ -1,71 +0,0 @@
"""add_job_and_artifact_tables
Revision ID: c43a1ddae8b7
Revises:
Create Date: 2023-01-05 12:00:58.824773
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "c43a1ddae8b7"
down_revision = None
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"jobs",
sa.Column("url", sa.String(length=2048), nullable=True),
sa.Column(
"status",
sa.Enum("Create", "Error", "Success", name="jobstatus"),
nullable=False,
),
sa.Column("type", sa.Enum("Transcript", name="jobtype"), nullable=False),
sa.Column(
"created_at",
sa.DateTime(),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("updated_at", sa.DateTime(), nullable=True),
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_jobs_id"), "jobs", ["id"], unique=False)
op.create_table(
"artifacts",
sa.Column("job_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("data", sa.JSON(none_as_null=True), nullable=True),
sa.Column(
"type",
sa.Enum("RawTranscript", name="artifacttype"),
nullable=False,
),
sa.Column(
"created_at",
sa.DateTime(),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("updated_at", sa.DateTime(), nullable=True),
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.ForeignKeyConstraint(["job_id"], ["jobs.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_artifacts_id"), "artifacts", ["id"], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_artifacts_id"), table_name="artifacts")
op.drop_table("artifacts")
op.drop_index(op.f("ix_jobs_id"), table_name="jobs")
op.drop_table("jobs")
# ### end Alembic commands ###

View File

@@ -1,23 +1,23 @@
import enum
from datetime import datetime
from typing import Optional
from typing import Any, Optional
from uuid import UUID
from pydantic import AnyHttpUrl, BaseModel, Json
class ArtifactType(enum.Enum):
RawTranscript = "RawTranscript"
class ArtifactType(str, enum.Enum):
raw_transcript = "raw_transcript"
class JobType(enum.Enum):
Transcript = "Transcript"
class JobType(str, enum.Enum):
transcript = "transcript"
class JobStatus(enum.Enum):
Create = "Create"
Error = "Error"
Success = "Success"
class JobStatus(str, enum.Enum):
create = "create"
error = "error"
success = "success"
class WithDbFields(BaseModel):
@@ -36,6 +36,7 @@ class Job(WithDbFields):
class Artifact(WithDbFields):
data: Optional[Json]
# TODO: narrow type
data: Optional[Any]
job_id: UUID
type: ArtifactType

View File

@@ -1,18 +1,19 @@
from typing import Generator
from typing import Dict, Generator
import pytest
from sqlalchemy.orm import Session
from sqlalchemy_utils import create_database, database_exists, drop_database
from app.shared.config import settings
from app.shared.db.base import SessionLocal, engine, get_session
from app.shared.db.models import Base
import app.shared.db.models as models
from app.web.main import app
def pytest_configure() -> None:
if not database_exists(engine.url):
create_database(engine.url)
Base.metadata.create_all(engine)
models.Base.metadata.create_all(engine)
def pytest_unconfigure() -> None:
@@ -20,6 +21,11 @@ def pytest_unconfigure() -> None:
drop_database(engine.url)
@pytest.fixture(name="auth_headers", scope="function")
def auth_header() -> Dict[str, str]:
return {"Authorization": f"Bearer {settings.API_SECRET}"}
@pytest.fixture(name="db_session", scope="function", autouse=True)
def db_session() -> Generator[Session, None, None]:
connection = engine.connect()

124
app/tests/test_api.py Normal file
View File

@@ -0,0 +1,124 @@
from typing import Dict
from fastapi.testclient import TestClient
import pytest
from sqlalchemy.orm import Session
from app.shared.db.dtos import JobType, JobStatus
from app.web.main import app
import app.shared.db.models as models
import app.shared.db.dtos as dtos
client = TestClient(app)
@pytest.fixture(name="mock_job", scope="function", autouse=False)
def mock_job(db_session: Session) -> models.Job:
job = models.Job(
url="https://example.com", type=JobType.transcript, status=JobStatus.create
)
db_session.add(job)
db_session.flush()
return job
# POST /api/v1/jobs
# ---
def test_create_job_pass(auth_headers: Dict[str, str]) -> None:
res = client.post(
"/api/v1/jobs",
headers=auth_headers,
json={"url": "https://example.com", "type": JobType.transcript},
)
assert res.status_code == 201
assert isinstance(res.json()["id"], str)
def test_create_job_missing_body(auth_headers: Dict[str, str]) -> None:
res = client.post("/api/v1/jobs", headers=auth_headers, json={})
assert res.status_code == 422
def test_create_job_malformed_url(auth_headers: Dict[str, str]) -> None:
res = client.post(
"/api/v1/jobs",
headers=auth_headers,
json={"url": "example.com", "type": JobType.transcript},
)
assert res.status_code == 422
# GET /api/v1/jobs
# ---
def test_get_jobs_pass(auth_headers: Dict[str, str], mock_job: models.Job) -> None:
res = client.get(
"/api/v1/jobs?type=transcript",
headers=auth_headers,
)
assert len(res.json()) == 1
assert res.status_code == 200
# GET /api/v1/jobs/:id
# ---
def test_get_job_pass(auth_headers: Dict[str, str], mock_job: models.Job) -> None:
res = client.get(
f"/api/v1/jobs/{mock_job.id}",
headers=auth_headers,
)
assert res.status_code == 200
assert res.json()["id"] == str(mock_job.id)
def test_get_job_not_found(auth_headers: Dict[str, str], mock_job: models.Job) -> None:
res = client.get(
f"/api/v1/jobs/c8ecf5ea-77cf-48a2-9ecd-199ef35e0ccb",
headers=auth_headers,
)
assert res.status_code == 404
# GET /api/v1/jobs/:id/artifacts
# ---
def test_get_artifact_pass(
auth_headers: Dict[str, str], db_session: Session, mock_job: models.Job
) -> None:
artifact = models.Artifact(
data={}, job_id=mock_job.id, type=dtos.ArtifactType.raw_transcript
)
db_session.add(artifact)
db_session.flush()
res = client.get(
f"/api/v1/jobs/{mock_job.id}/artifacts",
headers=auth_headers,
)
assert res.status_code == 200
assert res.json()[0]["job_id"] == str(mock_job.id)
assert res.json()[0]["id"] == str(artifact.id)
def test_get_artifact_not_found(
auth_headers: Dict[str, str], mock_job: models.Job
) -> None:
res = client.get(
f"/api/v1/jobs/{mock_job.id}/artifacts",
headers=auth_headers,
)
assert res.status_code == 404
# DELETE /api/v1/jobs
# ---
def test_delete_job_pass(
auth_headers: Dict[str, str], mock_job: models.Job, db_session: Session
) -> None:
res = client.delete(
f"/api/v1/jobs/{mock_job.id}",
headers=auth_headers,
)
assert db_session.query(models.Job).count() == 0
assert res.status_code == 204

View File

@@ -8,10 +8,6 @@ from app.web.main import app
client = TestClient(app)
def auth_header(s: str) -> Dict[str, str]:
return {"Authorization": f"Bearer {s}"}
def test_authorization_header_missing() -> None:
res = client.get("/api/v1")
assert res.status_code == 401
@@ -23,10 +19,10 @@ def test_authorization_header_malformed() -> None:
def test_incorrect_api_key() -> None:
res = client.get("/api/v1", headers=auth_header("not_valid"))
res = client.get("/api/v1", headers={"Authorization": "Bearer incorrect" })
assert res.status_code == 401
def test_existing_api_key() -> None:
res = client.get("/api/v1", headers=auth_header(settings.API_SECRET))
def test_existing_api_key(auth_headers: Dict[str, str]) -> None:
res = client.get("/api/v1", headers=auth_headers)
assert res.status_code == 200

View File

@@ -22,51 +22,61 @@ def api_root() -> Dict:
class TranscriptPayload(BaseModel):
url: AnyHttpUrl
type: dtos.JobType
@api_router.post("/transcripts", response_model=dtos.Job)
def create_transcript(
@api_router.post("/jobs", response_model=dtos.Job, status_code=201)
def create_job(
payload: TranscriptPayload, session: Session = Depends(get_session)
) -> models.Job:
job = models.Job(
url=payload.url, status=dtos.JobStatus.Create, type=dtos.JobType.Transcript
)
job = models.Job(url=payload.url, status=dtos.JobStatus.create, type=payload.type)
session.add(job)
session.flush()
return job
@api_router.get("/transcripts", response_model=List[dtos.Job])
def get_transcripts(session: Session = Depends(get_session)) -> List[models.Job]:
return (
session.query(models.Job)
.filter(models.Job.type == dtos.JobType.Transcript)
.all()
)
@api_router.get("/jobs", response_model=List[dtos.Job])
def get_transcripts(
type: Optional[dtos.JobType] = None, session: Session = Depends(get_session)
) -> List[models.Job]:
query = session.query(models.Job)
if type:
query = query.filter(models.Job.type == type)
return query.all()
@api_router.get("/transcripts/{id}", response_model=dtos.Job)
@api_router.get("/jobs/{id}", response_model=dtos.Job)
def get_transcript(
id: UUID = Path(), session: Session = Depends(get_session)
) -> Optional[dtos.Job]:
job = (
session.query(models.Job)
.filter(models.Job.id == id)
.filter(models.Job.type == dtos.JobType.Transcript)
.one_or_none()
)
job = session.query(models.Job).filter(models.Job.id == id).one_or_none()
if not job:
raise HTTPException(status_code=404)
return job
@api_router.delete("/transcripts/{id}")
@api_router.get("/jobs/{id}/artifacts", response_model=List[dtos.Artifact])
def get_artifacts_for_job(
id: UUID = Path(), session: Session = Depends(get_session)
) -> List[dtos.Artifact]:
artifacts = (
session.query(models.Artifact).filter(models.Artifact.job_id == id)
).all()
if not len(artifacts):
raise HTTPException(status_code=404)
return artifacts
@api_router.delete("/jobs/{id}", status_code=204)
def delete_transcript(
id: UUID = Path(), session: Session = Depends(get_session)
) -> None:
session.query(models.Job).filter(models.Job.id == id).filter(
models.Job.type == dtos.JobType.Transcript
).delete()
session.query(models.Job).filter(models.Job.id == id).delete()
return None