Files
whisperbox-transcribe/app/tests/test_api.py
2023-01-18 12:33:22 +01:00

125 lines
3.2 KiB
Python

from typing import Dict
from fastapi.testclient import TestClient
import pytest
from sqlalchemy.orm import Session
from app.shared.db.dtos import JobType, JobStatus
from app.web.main import app
import app.shared.db.models as models
import app.shared.db.dtos as dtos
client = TestClient(app)
@pytest.fixture(name="mock_job", scope="function", autouse=False)
def mock_job(db_session: Session) -> models.Job:
job = models.Job(
url="https://example.com", type=JobType.transcript, status=JobStatus.create
)
db_session.add(job)
db_session.flush()
return job
# POST /api/v1/jobs
# ---
def test_create_job_pass(auth_headers: Dict[str, str]) -> None:
res = client.post(
"/api/v1/jobs",
headers=auth_headers,
json={"url": "https://example.com", "type": JobType.transcript},
)
assert res.status_code == 201
assert isinstance(res.json()["id"], str)
def test_create_job_missing_body(auth_headers: Dict[str, str]) -> None:
res = client.post("/api/v1/jobs", headers=auth_headers, json={})
assert res.status_code == 422
def test_create_job_malformed_url(auth_headers: Dict[str, str]) -> None:
res = client.post(
"/api/v1/jobs",
headers=auth_headers,
json={"url": "example.com", "type": JobType.transcript},
)
assert res.status_code == 422
# GET /api/v1/jobs
# ---
def test_get_jobs_pass(auth_headers: Dict[str, str], mock_job: models.Job) -> None:
res = client.get(
"/api/v1/jobs?type=transcript",
headers=auth_headers,
)
assert len(res.json()) == 1
assert res.status_code == 200
# GET /api/v1/jobs/:id
# ---
def test_get_job_pass(auth_headers: Dict[str, str], mock_job: models.Job) -> None:
res = client.get(
f"/api/v1/jobs/{mock_job.id}",
headers=auth_headers,
)
assert res.status_code == 200
assert res.json()["id"] == str(mock_job.id)
def test_get_job_not_found(auth_headers: Dict[str, str], mock_job: models.Job) -> None:
res = client.get(
f"/api/v1/jobs/c8ecf5ea-77cf-48a2-9ecd-199ef35e0ccb",
headers=auth_headers,
)
assert res.status_code == 404
# GET /api/v1/jobs/:id/artifacts
# ---
def test_get_artifact_pass(
auth_headers: Dict[str, str], db_session: Session, mock_job: models.Job
) -> None:
artifact = models.Artifact(
data={}, job_id=mock_job.id, type=dtos.ArtifactType.raw_transcript
)
db_session.add(artifact)
db_session.flush()
res = client.get(
f"/api/v1/jobs/{mock_job.id}/artifacts",
headers=auth_headers,
)
assert res.status_code == 200
assert res.json()[0]["job_id"] == str(mock_job.id)
assert res.json()[0]["id"] == str(artifact.id)
def test_get_artifact_not_found(
auth_headers: Dict[str, str], mock_job: models.Job
) -> None:
res = client.get(
f"/api/v1/jobs/{mock_job.id}/artifacts",
headers=auth_headers,
)
assert res.status_code == 404
# DELETE /api/v1/jobs
# ---
def test_delete_job_pass(
auth_headers: Dict[str, str], mock_job: models.Job, db_session: Session
) -> None:
res = client.delete(
f"/api/v1/jobs/{mock_job.id}",
headers=auth_headers,
)
assert db_session.query(models.Job).count() == 0
assert res.status_code == 204