diff --git a/Makefile b/Makefile index 04ae1b0..53122a5 100644 --- a/Makefile +++ b/Makefile @@ -6,9 +6,9 @@ fmt: black app isort app -test: - ENVIRONMENT=test pytest - lint: mypy app flake8 app + +test: + pytest diff --git a/README.md b/README.md index b57765a..04cae39 100644 --- a/README.md +++ b/README.md @@ -1 +1,3 @@ # whisper-api + +> Tiny HTTP wrapper around [openai/whisper](https://github.com/openai/whisper). diff --git a/app/shared/config.py b/app/shared/config.py index d5e618b..0c78151 100644 --- a/app/shared/config.py +++ b/app/shared/config.py @@ -1,4 +1,4 @@ -import os +import sys from pydantic import BaseSettings @@ -10,7 +10,7 @@ class Settings(BaseSettings): REDIS_URI: str -if "ENVIRONMENT" in os.environ and os.environ["ENVIRONMENT"] == "test": - settings = Settings(_env_file=".env.test") # type: ignore +if "pytest" in sys.modules: + settings = Settings(_env_file=".env.test", _env_file_encoding="utf-8") # type: ignore else: settings = Settings() diff --git a/app/shared/db/alembic/versions/bb249ed79907_add_job_tables.py b/app/shared/db/alembic/versions/bb249ed79907_add_job_tables.py new file mode 100644 index 0000000..4c498ba --- /dev/null +++ b/app/shared/db/alembic/versions/bb249ed79907_add_job_tables.py @@ -0,0 +1,51 @@ +"""add_job_tables + +Revision ID: bb249ed79907 +Revises: +Create Date: 2023-01-17 14:30:30.920466 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'bb249ed79907' +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('jobs', + sa.Column('url', sa.String(length=2048), nullable=True), + sa.Column('status', sa.Enum('create', 'error', 'success', name='jobstatus'), nullable=False), + sa.Column('type', sa.Enum('transcript', name='jobtype'), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_jobs_id'), 'jobs', ['id'], unique=False) + op.create_table('artifacts', + sa.Column('job_id', postgresql.UUID(as_uuid=True), nullable=False), + sa.Column('data', sa.JSON(none_as_null=True), nullable=True), + sa.Column('type', sa.Enum('raw_transcript', name='artifacttype'), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False), + sa.ForeignKeyConstraint(['job_id'], ['jobs.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_artifacts_id'), 'artifacts', ['id'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_artifacts_id'), table_name='artifacts') + op.drop_table('artifacts') + op.drop_index(op.f('ix_jobs_id'), table_name='jobs') + op.drop_table('jobs') + # ### end Alembic commands ### diff --git a/app/shared/db/alembic/versions/c43a1ddae8b7_add_job_and_artifact_tables.py b/app/shared/db/alembic/versions/c43a1ddae8b7_add_job_and_artifact_tables.py deleted file mode 100644 index b8d0a71..0000000 --- a/app/shared/db/alembic/versions/c43a1ddae8b7_add_job_and_artifact_tables.py +++ /dev/null @@ -1,71 +0,0 @@ -"""add_job_and_artifact_tables - -Revision ID: c43a1ddae8b7 -Revises: -Create Date: 2023-01-05 12:00:58.824773 - -""" -import sqlalchemy as sa -from alembic import op -from sqlalchemy.dialects import postgresql - -# revision identifiers, used by Alembic. -revision = "c43a1ddae8b7" -down_revision = None -branch_labels = None -depends_on = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "jobs", - sa.Column("url", sa.String(length=2048), nullable=True), - sa.Column( - "status", - sa.Enum("Create", "Error", "Success", name="jobstatus"), - nullable=False, - ), - sa.Column("type", sa.Enum("Transcript", name="jobtype"), nullable=False), - sa.Column( - "created_at", - sa.DateTime(), - server_default=sa.text("now()"), - nullable=False, - ), - sa.Column("updated_at", sa.DateTime(), nullable=True), - sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False), - sa.PrimaryKeyConstraint("id"), - ) - op.create_index(op.f("ix_jobs_id"), "jobs", ["id"], unique=False) - op.create_table( - "artifacts", - sa.Column("job_id", postgresql.UUID(as_uuid=True), nullable=False), - sa.Column("data", sa.JSON(none_as_null=True), nullable=True), - sa.Column( - "type", - sa.Enum("RawTranscript", name="artifacttype"), - nullable=False, - ), - sa.Column( - "created_at", - sa.DateTime(), - server_default=sa.text("now()"), - nullable=False, - ), - sa.Column("updated_at", sa.DateTime(), nullable=True), - sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False), - sa.ForeignKeyConstraint(["job_id"], ["jobs.id"], ondelete="CASCADE"), - sa.PrimaryKeyConstraint("id"), - ) - op.create_index(op.f("ix_artifacts_id"), "artifacts", ["id"], unique=False) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_index(op.f("ix_artifacts_id"), table_name="artifacts") - op.drop_table("artifacts") - op.drop_index(op.f("ix_jobs_id"), table_name="jobs") - op.drop_table("jobs") - # ### end Alembic commands ### diff --git a/app/shared/db/dtos.py b/app/shared/db/dtos.py index dae093b..0946d95 100644 --- a/app/shared/db/dtos.py +++ b/app/shared/db/dtos.py @@ -1,23 +1,23 @@ import enum from datetime import datetime -from typing import Optional +from typing import Any, Optional from uuid import UUID from pydantic import AnyHttpUrl, BaseModel, Json -class ArtifactType(enum.Enum): - RawTranscript = "RawTranscript" +class ArtifactType(str, enum.Enum): + raw_transcript = "raw_transcript" -class JobType(enum.Enum): - Transcript = "Transcript" +class JobType(str, enum.Enum): + transcript = "transcript" -class JobStatus(enum.Enum): - Create = "Create" - Error = "Error" - Success = "Success" +class JobStatus(str, enum.Enum): + create = "create" + error = "error" + success = "success" class WithDbFields(BaseModel): @@ -36,6 +36,7 @@ class Job(WithDbFields): class Artifact(WithDbFields): - data: Optional[Json] + # TODO: narrow type + data: Optional[Any] job_id: UUID type: ArtifactType diff --git a/app/tests/conftest.py b/app/tests/conftest.py index feb6762..6b817e7 100644 --- a/app/tests/conftest.py +++ b/app/tests/conftest.py @@ -1,18 +1,19 @@ -from typing import Generator +from typing import Dict, Generator import pytest from sqlalchemy.orm import Session from sqlalchemy_utils import create_database, database_exists, drop_database +from app.shared.config import settings from app.shared.db.base import SessionLocal, engine, get_session -from app.shared.db.models import Base +import app.shared.db.models as models from app.web.main import app def pytest_configure() -> None: if not database_exists(engine.url): create_database(engine.url) - Base.metadata.create_all(engine) + models.Base.metadata.create_all(engine) def pytest_unconfigure() -> None: @@ -20,6 +21,11 @@ def pytest_unconfigure() -> None: drop_database(engine.url) +@pytest.fixture(name="auth_headers", scope="function") +def auth_header() -> Dict[str, str]: + return {"Authorization": f"Bearer {settings.API_SECRET}"} + + @pytest.fixture(name="db_session", scope="function", autouse=True) def db_session() -> Generator[Session, None, None]: connection = engine.connect() diff --git a/app/tests/test_api.py b/app/tests/test_api.py new file mode 100644 index 0000000..3eada34 --- /dev/null +++ b/app/tests/test_api.py @@ -0,0 +1,124 @@ +from typing import Dict +from fastapi.testclient import TestClient +import pytest +from sqlalchemy.orm import Session + +from app.shared.db.dtos import JobType, JobStatus +from app.web.main import app +import app.shared.db.models as models +import app.shared.db.dtos as dtos + +client = TestClient(app) + + +@pytest.fixture(name="mock_job", scope="function", autouse=False) +def mock_job(db_session: Session) -> models.Job: + job = models.Job( + url="https://example.com", type=JobType.transcript, status=JobStatus.create + ) + + db_session.add(job) + db_session.flush() + + return job + + +# POST /api/v1/jobs +# --- +def test_create_job_pass(auth_headers: Dict[str, str]) -> None: + res = client.post( + "/api/v1/jobs", + headers=auth_headers, + json={"url": "https://example.com", "type": JobType.transcript}, + ) + assert res.status_code == 201 + assert isinstance(res.json()["id"], str) + + +def test_create_job_missing_body(auth_headers: Dict[str, str]) -> None: + res = client.post("/api/v1/jobs", headers=auth_headers, json={}) + assert res.status_code == 422 + + +def test_create_job_malformed_url(auth_headers: Dict[str, str]) -> None: + res = client.post( + "/api/v1/jobs", + headers=auth_headers, + json={"url": "example.com", "type": JobType.transcript}, + ) + assert res.status_code == 422 + + +# GET /api/v1/jobs +# --- +def test_get_jobs_pass(auth_headers: Dict[str, str], mock_job: models.Job) -> None: + res = client.get( + "/api/v1/jobs?type=transcript", + headers=auth_headers, + ) + assert len(res.json()) == 1 + assert res.status_code == 200 + + +# GET /api/v1/jobs/:id +# --- +def test_get_job_pass(auth_headers: Dict[str, str], mock_job: models.Job) -> None: + res = client.get( + f"/api/v1/jobs/{mock_job.id}", + headers=auth_headers, + ) + assert res.status_code == 200 + assert res.json()["id"] == str(mock_job.id) + + +def test_get_job_not_found(auth_headers: Dict[str, str], mock_job: models.Job) -> None: + res = client.get( + f"/api/v1/jobs/c8ecf5ea-77cf-48a2-9ecd-199ef35e0ccb", + headers=auth_headers, + ) + assert res.status_code == 404 + + +# GET /api/v1/jobs/:id/artifacts +# --- +def test_get_artifact_pass( + auth_headers: Dict[str, str], db_session: Session, mock_job: models.Job +) -> None: + artifact = models.Artifact( + data={}, job_id=mock_job.id, type=dtos.ArtifactType.raw_transcript + ) + + db_session.add(artifact) + db_session.flush() + + res = client.get( + f"/api/v1/jobs/{mock_job.id}/artifacts", + headers=auth_headers, + ) + + assert res.status_code == 200 + assert res.json()[0]["job_id"] == str(mock_job.id) + assert res.json()[0]["id"] == str(artifact.id) + + +def test_get_artifact_not_found( + auth_headers: Dict[str, str], mock_job: models.Job +) -> None: + res = client.get( + f"/api/v1/jobs/{mock_job.id}/artifacts", + headers=auth_headers, + ) + assert res.status_code == 404 + + +# DELETE /api/v1/jobs +# --- +def test_delete_job_pass( + auth_headers: Dict[str, str], mock_job: models.Job, db_session: Session +) -> None: + res = client.delete( + f"/api/v1/jobs/{mock_job.id}", + headers=auth_headers, + ) + assert db_session.query(models.Job).count() == 0 + assert res.status_code == 204 diff --git a/app/tests/test_auth.py b/app/tests/test_auth.py index 26caed7..b961b14 100644 --- a/app/tests/test_auth.py +++ b/app/tests/test_auth.py @@ -8,10 +8,6 @@ from app.web.main import app client = TestClient(app) -def auth_header(s: str) -> Dict[str, str]: - return {"Authorization": f"Bearer {s}"} - - def test_authorization_header_missing() -> None: res = client.get("/api/v1") assert res.status_code == 401 @@ -23,10 +19,10 @@ def test_authorization_header_malformed() -> None: def test_incorrect_api_key() -> None: - res = client.get("/api/v1", headers=auth_header("not_valid")) + res = client.get("/api/v1", headers={"Authorization": "Bearer incorrect" }) assert res.status_code == 401 -def test_existing_api_key() -> None: - res = client.get("/api/v1", headers=auth_header(settings.API_SECRET)) +def test_existing_api_key(auth_headers: Dict[str, str]) -> None: + res = client.get("/api/v1", headers=auth_headers) assert res.status_code == 200 diff --git a/app/web/main.py b/app/web/main.py index 8fe6b7c..4832739 100644 --- a/app/web/main.py +++ b/app/web/main.py @@ -22,51 +22,61 @@ def api_root() -> Dict: class TranscriptPayload(BaseModel): url: AnyHttpUrl + type: dtos.JobType -@api_router.post("/transcripts", response_model=dtos.Job) -def create_transcript( +@api_router.post("/jobs", response_model=dtos.Job, status_code=201) +def create_job( payload: TranscriptPayload, session: Session = Depends(get_session) ) -> models.Job: - job = models.Job( - url=payload.url, status=dtos.JobStatus.Create, type=dtos.JobType.Transcript - ) + job = models.Job(url=payload.url, status=dtos.JobStatus.create, type=payload.type) + session.add(job) session.flush() return job -@api_router.get("/transcripts", response_model=List[dtos.Job]) -def get_transcripts(session: Session = Depends(get_session)) -> List[models.Job]: - return ( - session.query(models.Job) - .filter(models.Job.type == dtos.JobType.Transcript) - .all() - ) +@api_router.get("/jobs", response_model=List[dtos.Job]) +def get_transcripts( + type: Optional[dtos.JobType] = None, session: Session = Depends(get_session) +) -> List[models.Job]: + query = session.query(models.Job) + + if type: + query = query.filter(models.Job.type == type) + + return query.all() -@api_router.get("/transcripts/{id}", response_model=dtos.Job) +@api_router.get("/jobs/{id}", response_model=dtos.Job) def get_transcript( id: UUID = Path(), session: Session = Depends(get_session) ) -> Optional[dtos.Job]: - job = ( - session.query(models.Job) - .filter(models.Job.id == id) - .filter(models.Job.type == dtos.JobType.Transcript) - .one_or_none() - ) + job = session.query(models.Job).filter(models.Job.id == id).one_or_none() if not job: raise HTTPException(status_code=404) return job -@api_router.delete("/transcripts/{id}") +@api_router.get("/jobs/{id}/artifacts", response_model=List[dtos.Artifact]) +def get_artifacts_for_job( + id: UUID = Path(), session: Session = Depends(get_session) +) -> List[dtos.Artifact]: + artifacts = ( + session.query(models.Artifact).filter(models.Artifact.job_id == id) + ).all() + + if not len(artifacts): + raise HTTPException(status_code=404) + + return artifacts + + +@api_router.delete("/jobs/{id}", status_code=204) def delete_transcript( id: UUID = Path(), session: Session = Depends(get_session) ) -> None: - session.query(models.Job).filter(models.Job.id == id).filter( - models.Job.type == dtos.JobType.Transcript - ).delete() + session.query(models.Job).filter(models.Job.id == id).delete() return None