mirror of
https://github.com/bellingcat/whisperbox-transcribe.git
synced 2026-06-13 05:58:35 +03:00
feat: add artifact routes, test web api
This commit is contained in:
6
Makefile
6
Makefile
@@ -6,9 +6,9 @@ fmt:
|
||||
black app
|
||||
isort app
|
||||
|
||||
test:
|
||||
ENVIRONMENT=test pytest
|
||||
|
||||
lint:
|
||||
mypy app
|
||||
flake8 app
|
||||
|
||||
test:
|
||||
pytest
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
# whisper-api
|
||||
|
||||
> Tiny HTTP wrapper around [openai/whisper](https://github.com/openai/whisper).
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
from pydantic import BaseSettings
|
||||
|
||||
@@ -10,7 +10,7 @@ class Settings(BaseSettings):
|
||||
REDIS_URI: str
|
||||
|
||||
|
||||
if "ENVIRONMENT" in os.environ and os.environ["ENVIRONMENT"] == "test":
|
||||
settings = Settings(_env_file=".env.test") # type: ignore
|
||||
if "pytest" in sys.modules:
|
||||
settings = Settings(_env_file=".env.test", _env_file_encoding="utf-8") # type: ignore
|
||||
else:
|
||||
settings = Settings()
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
"""add_job_tables
|
||||
|
||||
Revision ID: bb249ed79907
|
||||
Revises:
|
||||
Create Date: 2023-01-17 14:30:30.920466
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'bb249ed79907'
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('jobs',
|
||||
sa.Column('url', sa.String(length=2048), nullable=True),
|
||||
sa.Column('status', sa.Enum('create', 'error', 'success', name='jobstatus'), nullable=False),
|
||||
sa.Column('type', sa.Enum('transcript', name='jobtype'), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_jobs_id'), 'jobs', ['id'], unique=False)
|
||||
op.create_table('artifacts',
|
||||
sa.Column('job_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('data', sa.JSON(none_as_null=True), nullable=True),
|
||||
sa.Column('type', sa.Enum('raw_transcript', name='artifacttype'), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['job_id'], ['jobs.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_artifacts_id'), 'artifacts', ['id'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('ix_artifacts_id'), table_name='artifacts')
|
||||
op.drop_table('artifacts')
|
||||
op.drop_index(op.f('ix_jobs_id'), table_name='jobs')
|
||||
op.drop_table('jobs')
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,71 +0,0 @@
|
||||
"""add_job_and_artifact_tables
|
||||
|
||||
Revision ID: c43a1ddae8b7
|
||||
Revises:
|
||||
Create Date: 2023-01-05 12:00:58.824773
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c43a1ddae8b7"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"jobs",
|
||||
sa.Column("url", sa.String(length=2048), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.Enum("Create", "Error", "Success", name="jobstatus"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("type", sa.Enum("Transcript", name="jobtype"), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("updated_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("ix_jobs_id"), "jobs", ["id"], unique=False)
|
||||
op.create_table(
|
||||
"artifacts",
|
||||
sa.Column("job_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("data", sa.JSON(none_as_null=True), nullable=True),
|
||||
sa.Column(
|
||||
"type",
|
||||
sa.Enum("RawTranscript", name="artifacttype"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("updated_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(["job_id"], ["jobs.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("ix_artifacts_id"), "artifacts", ["id"], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f("ix_artifacts_id"), table_name="artifacts")
|
||||
op.drop_table("artifacts")
|
||||
op.drop_index(op.f("ix_jobs_id"), table_name="jobs")
|
||||
op.drop_table("jobs")
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,23 +1,23 @@
|
||||
import enum
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import AnyHttpUrl, BaseModel, Json
|
||||
|
||||
|
||||
class ArtifactType(enum.Enum):
|
||||
RawTranscript = "RawTranscript"
|
||||
class ArtifactType(str, enum.Enum):
|
||||
raw_transcript = "raw_transcript"
|
||||
|
||||
|
||||
class JobType(enum.Enum):
|
||||
Transcript = "Transcript"
|
||||
class JobType(str, enum.Enum):
|
||||
transcript = "transcript"
|
||||
|
||||
|
||||
class JobStatus(enum.Enum):
|
||||
Create = "Create"
|
||||
Error = "Error"
|
||||
Success = "Success"
|
||||
class JobStatus(str, enum.Enum):
|
||||
create = "create"
|
||||
error = "error"
|
||||
success = "success"
|
||||
|
||||
|
||||
class WithDbFields(BaseModel):
|
||||
@@ -36,6 +36,7 @@ class Job(WithDbFields):
|
||||
|
||||
|
||||
class Artifact(WithDbFields):
|
||||
data: Optional[Json]
|
||||
# TODO: narrow type
|
||||
data: Optional[Any]
|
||||
job_id: UUID
|
||||
type: ArtifactType
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
from typing import Generator
|
||||
from typing import Dict, Generator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy_utils import create_database, database_exists, drop_database
|
||||
from app.shared.config import settings
|
||||
|
||||
from app.shared.db.base import SessionLocal, engine, get_session
|
||||
from app.shared.db.models import Base
|
||||
import app.shared.db.models as models
|
||||
from app.web.main import app
|
||||
|
||||
|
||||
def pytest_configure() -> None:
|
||||
if not database_exists(engine.url):
|
||||
create_database(engine.url)
|
||||
Base.metadata.create_all(engine)
|
||||
models.Base.metadata.create_all(engine)
|
||||
|
||||
|
||||
def pytest_unconfigure() -> None:
|
||||
@@ -20,6 +21,11 @@ def pytest_unconfigure() -> None:
|
||||
drop_database(engine.url)
|
||||
|
||||
|
||||
@pytest.fixture(name="auth_headers", scope="function")
|
||||
def auth_header() -> Dict[str, str]:
|
||||
return {"Authorization": f"Bearer {settings.API_SECRET}"}
|
||||
|
||||
|
||||
@pytest.fixture(name="db_session", scope="function", autouse=True)
|
||||
def db_session() -> Generator[Session, None, None]:
|
||||
connection = engine.connect()
|
||||
|
||||
124
app/tests/test_api.py
Normal file
124
app/tests/test_api.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from typing import Dict
|
||||
from fastapi.testclient import TestClient
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.shared.db.dtos import JobType, JobStatus
|
||||
from app.web.main import app
|
||||
import app.shared.db.models as models
|
||||
import app.shared.db.dtos as dtos
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(name="mock_job", scope="function", autouse=False)
|
||||
def mock_job(db_session: Session) -> models.Job:
|
||||
job = models.Job(
|
||||
url="https://example.com", type=JobType.transcript, status=JobStatus.create
|
||||
)
|
||||
|
||||
db_session.add(job)
|
||||
db_session.flush()
|
||||
|
||||
return job
|
||||
|
||||
|
||||
# POST /api/v1/jobs
|
||||
# ---
|
||||
def test_create_job_pass(auth_headers: Dict[str, str]) -> None:
|
||||
res = client.post(
|
||||
"/api/v1/jobs",
|
||||
headers=auth_headers,
|
||||
json={"url": "https://example.com", "type": JobType.transcript},
|
||||
)
|
||||
assert res.status_code == 201
|
||||
assert isinstance(res.json()["id"], str)
|
||||
|
||||
|
||||
def test_create_job_missing_body(auth_headers: Dict[str, str]) -> None:
|
||||
res = client.post("/api/v1/jobs", headers=auth_headers, json={})
|
||||
assert res.status_code == 422
|
||||
|
||||
|
||||
def test_create_job_malformed_url(auth_headers: Dict[str, str]) -> None:
|
||||
res = client.post(
|
||||
"/api/v1/jobs",
|
||||
headers=auth_headers,
|
||||
json={"url": "example.com", "type": JobType.transcript},
|
||||
)
|
||||
assert res.status_code == 422
|
||||
|
||||
|
||||
# GET /api/v1/jobs
|
||||
# ---
|
||||
def test_get_jobs_pass(auth_headers: Dict[str, str], mock_job: models.Job) -> None:
|
||||
res = client.get(
|
||||
"/api/v1/jobs?type=transcript",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert len(res.json()) == 1
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
# GET /api/v1/jobs/:id
|
||||
# ---
|
||||
def test_get_job_pass(auth_headers: Dict[str, str], mock_job: models.Job) -> None:
|
||||
res = client.get(
|
||||
f"/api/v1/jobs/{mock_job.id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
assert res.json()["id"] == str(mock_job.id)
|
||||
|
||||
|
||||
def test_get_job_not_found(auth_headers: Dict[str, str], mock_job: models.Job) -> None:
|
||||
res = client.get(
|
||||
f"/api/v1/jobs/c8ecf5ea-77cf-48a2-9ecd-199ef35e0ccb",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert res.status_code == 404
|
||||
|
||||
|
||||
# GET /api/v1/jobs/:id/artifacts
|
||||
# ---
|
||||
def test_get_artifact_pass(
|
||||
auth_headers: Dict[str, str], db_session: Session, mock_job: models.Job
|
||||
) -> None:
|
||||
artifact = models.Artifact(
|
||||
data={}, job_id=mock_job.id, type=dtos.ArtifactType.raw_transcript
|
||||
)
|
||||
|
||||
db_session.add(artifact)
|
||||
db_session.flush()
|
||||
|
||||
res = client.get(
|
||||
f"/api/v1/jobs/{mock_job.id}/artifacts",
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.json()[0]["job_id"] == str(mock_job.id)
|
||||
assert res.json()[0]["id"] == str(artifact.id)
|
||||
|
||||
|
||||
def test_get_artifact_not_found(
|
||||
auth_headers: Dict[str, str], mock_job: models.Job
|
||||
) -> None:
|
||||
res = client.get(
|
||||
f"/api/v1/jobs/{mock_job.id}/artifacts",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert res.status_code == 404
|
||||
|
||||
|
||||
# DELETE /api/v1/jobs
|
||||
# ---
|
||||
def test_delete_job_pass(
|
||||
auth_headers: Dict[str, str], mock_job: models.Job, db_session: Session
|
||||
) -> None:
|
||||
res = client.delete(
|
||||
f"/api/v1/jobs/{mock_job.id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert db_session.query(models.Job).count() == 0
|
||||
assert res.status_code == 204
|
||||
@@ -8,10 +8,6 @@ from app.web.main import app
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def auth_header(s: str) -> Dict[str, str]:
|
||||
return {"Authorization": f"Bearer {s}"}
|
||||
|
||||
|
||||
def test_authorization_header_missing() -> None:
|
||||
res = client.get("/api/v1")
|
||||
assert res.status_code == 401
|
||||
@@ -23,10 +19,10 @@ def test_authorization_header_malformed() -> None:
|
||||
|
||||
|
||||
def test_incorrect_api_key() -> None:
|
||||
res = client.get("/api/v1", headers=auth_header("not_valid"))
|
||||
res = client.get("/api/v1", headers={"Authorization": "Bearer incorrect" })
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_existing_api_key() -> None:
|
||||
res = client.get("/api/v1", headers=auth_header(settings.API_SECRET))
|
||||
def test_existing_api_key(auth_headers: Dict[str, str]) -> None:
|
||||
res = client.get("/api/v1", headers=auth_headers)
|
||||
assert res.status_code == 200
|
||||
|
||||
@@ -22,51 +22,61 @@ def api_root() -> Dict:
|
||||
|
||||
class TranscriptPayload(BaseModel):
|
||||
url: AnyHttpUrl
|
||||
type: dtos.JobType
|
||||
|
||||
|
||||
@api_router.post("/transcripts", response_model=dtos.Job)
|
||||
def create_transcript(
|
||||
@api_router.post("/jobs", response_model=dtos.Job, status_code=201)
|
||||
def create_job(
|
||||
payload: TranscriptPayload, session: Session = Depends(get_session)
|
||||
) -> models.Job:
|
||||
job = models.Job(
|
||||
url=payload.url, status=dtos.JobStatus.Create, type=dtos.JobType.Transcript
|
||||
)
|
||||
job = models.Job(url=payload.url, status=dtos.JobStatus.create, type=payload.type)
|
||||
|
||||
session.add(job)
|
||||
session.flush()
|
||||
return job
|
||||
|
||||
|
||||
@api_router.get("/transcripts", response_model=List[dtos.Job])
|
||||
def get_transcripts(session: Session = Depends(get_session)) -> List[models.Job]:
|
||||
return (
|
||||
session.query(models.Job)
|
||||
.filter(models.Job.type == dtos.JobType.Transcript)
|
||||
.all()
|
||||
)
|
||||
@api_router.get("/jobs", response_model=List[dtos.Job])
|
||||
def get_transcripts(
|
||||
type: Optional[dtos.JobType] = None, session: Session = Depends(get_session)
|
||||
) -> List[models.Job]:
|
||||
query = session.query(models.Job)
|
||||
|
||||
if type:
|
||||
query = query.filter(models.Job.type == type)
|
||||
|
||||
return query.all()
|
||||
|
||||
|
||||
@api_router.get("/transcripts/{id}", response_model=dtos.Job)
|
||||
@api_router.get("/jobs/{id}", response_model=dtos.Job)
|
||||
def get_transcript(
|
||||
id: UUID = Path(), session: Session = Depends(get_session)
|
||||
) -> Optional[dtos.Job]:
|
||||
job = (
|
||||
session.query(models.Job)
|
||||
.filter(models.Job.id == id)
|
||||
.filter(models.Job.type == dtos.JobType.Transcript)
|
||||
.one_or_none()
|
||||
)
|
||||
job = session.query(models.Job).filter(models.Job.id == id).one_or_none()
|
||||
if not job:
|
||||
raise HTTPException(status_code=404)
|
||||
return job
|
||||
|
||||
|
||||
@api_router.delete("/transcripts/{id}")
|
||||
@api_router.get("/jobs/{id}/artifacts", response_model=List[dtos.Artifact])
|
||||
def get_artifacts_for_job(
|
||||
id: UUID = Path(), session: Session = Depends(get_session)
|
||||
) -> List[dtos.Artifact]:
|
||||
artifacts = (
|
||||
session.query(models.Artifact).filter(models.Artifact.job_id == id)
|
||||
).all()
|
||||
|
||||
if not len(artifacts):
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
return artifacts
|
||||
|
||||
|
||||
@api_router.delete("/jobs/{id}", status_code=204)
|
||||
def delete_transcript(
|
||||
id: UUID = Path(), session: Session = Depends(get_session)
|
||||
) -> None:
|
||||
session.query(models.Job).filter(models.Job.id == id).filter(
|
||||
models.Job.type == dtos.JobType.Transcript
|
||||
).delete()
|
||||
session.query(models.Job).filter(models.Job.id == id).delete()
|
||||
return None
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user