refactor: simplify type annotations

This commit is contained in:
Felix Spöttel
2023-06-28 13:17:55 +02:00
parent da751865e0
commit d2223206be
9 changed files with 42 additions and 52 deletions

View File

@@ -1,5 +1,4 @@
import uuid
from typing import Optional
from sqlalchemy import JSON, VARCHAR, Column, DateTime, Enum, ForeignKey, String, func
from sqlalchemy.dialects.postgresql import UUID
@@ -19,7 +18,7 @@ class WithStandardFields:
return Column(DateTime, server_default=func.now(), nullable=False)
@declared_attr
def updated_at(cls) -> Mapped[Optional[DateTime]]:
def updated_at(cls) -> Mapped[DateTime | None]:
return Column(DateTime, onupdate=func.now())
@declared_attr

View File

@@ -1,6 +1,5 @@
import enum
from datetime import datetime
from typing import List, Optional
from uuid import UUID
from pydantic import AnyHttpUrl, BaseModel, Field
@@ -9,7 +8,7 @@ from pydantic import AnyHttpUrl, BaseModel, Field
class WithDbFields(BaseModel):
id: UUID
created_at: datetime
updated_at: Optional[datetime]
updated_at: datetime | None
class Config:
orm_mode = True
@@ -38,7 +37,7 @@ class JobConfig(BaseModel):
"""Configuration for a job."""
# TODO: limit to locales selected by whisper.
language: Optional[str] = Field(
language: str | None = Field(
description=(
"Spoken language in the media file. "
"While optional, this can improve output."
@@ -49,10 +48,10 @@ class JobConfig(BaseModel):
class JobMeta(BaseModel):
"""Metadata relating to a job's execution."""
error: Optional[str] = Field(
error: str | None = Field(
description="Will contain a descriptive error message if processing failed."
)
task_id: Optional[UUID] = Field(
task_id: UUID | None = Field(
description="Internal celery id of this job submission."
)
@@ -63,8 +62,8 @@ class Job(WithDbFields):
status: JobStatus
type: JobType
url: AnyHttpUrl
meta: Optional[JobMeta]
config: Optional[JobConfig]
meta: JobMeta | None
config: JobConfig | None
class RawTranscript(BaseModel):
@@ -75,7 +74,7 @@ class RawTranscript(BaseModel):
start: float
end: float
text: str
tokens: List[int]
tokens: list[int]
temperature: float
avg_logprob: float
compression_ratio: float
@@ -85,6 +84,6 @@ class RawTranscript(BaseModel):
class Artifact(WithDbFields):
"""whisper output for one job."""
data: Optional[List[RawTranscript]]
data: list[RawTranscript] | None
job_id: UUID
type: ArtifactType

View File

@@ -1,4 +1,4 @@
from typing import Dict, Generator
from typing import Generator
import pytest
from sqlalchemy.orm import Session
@@ -22,7 +22,7 @@ def pytest_unconfigure() -> None:
@pytest.fixture(name="auth_headers", scope="function")
def auth_header() -> Dict[str, str]:
def auth_header() -> dict[str, str]:
return {"Authorization": f"Bearer {settings.API_SECRET}"}

View File

@@ -1,5 +1,3 @@
from typing import Dict
import pytest
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session
@@ -26,7 +24,7 @@ def mock_job(db_session: Session) -> models.Job:
# POST /api/v1/jobs
# ---
def test_create_job_pass(auth_headers: Dict[str, str]) -> None:
def test_create_job_pass(auth_headers: dict[str, str]) -> None:
res = client.post(
"/api/v1/jobs",
headers=auth_headers,
@@ -36,12 +34,12 @@ def test_create_job_pass(auth_headers: Dict[str, str]) -> None:
assert isinstance(res.json()["id"], str)
def test_create_job_missing_body(auth_headers: Dict[str, str]) -> None:
def test_create_job_missing_body(auth_headers: dict[str, str]) -> None:
res = client.post("/api/v1/jobs", headers=auth_headers, json={})
assert res.status_code == 422
def test_create_job_malformed_url(auth_headers: Dict[str, str]) -> None:
def test_create_job_malformed_url(auth_headers: dict[str, str]) -> None:
res = client.post(
"/api/v1/jobs",
headers=auth_headers,
@@ -52,7 +50,7 @@ def test_create_job_malformed_url(auth_headers: Dict[str, str]) -> None:
# GET /api/v1/jobs
# ---
def test_get_jobs_pass(auth_headers: Dict[str, str], mock_job: models.Job) -> None:
def test_get_jobs_pass(auth_headers: dict[str, str], mock_job: models.Job) -> None:
res = client.get(
"/api/v1/jobs?type=transcript",
headers=auth_headers,
@@ -63,7 +61,7 @@ def test_get_jobs_pass(auth_headers: Dict[str, str], mock_job: models.Job) -> No
# GET /api/v1/jobs/:id
# ---
def test_get_job_pass(auth_headers: Dict[str, str], mock_job: models.Job) -> None:
def test_get_job_pass(auth_headers: dict[str, str], mock_job: models.Job) -> None:
res = client.get(
f"/api/v1/jobs/{mock_job.id}",
headers=auth_headers,
@@ -72,7 +70,7 @@ def test_get_job_pass(auth_headers: Dict[str, str], mock_job: models.Job) -> Non
assert res.json()["id"] == str(mock_job.id)
def test_get_job_not_found(auth_headers: Dict[str, str], mock_job: models.Job) -> None:
def test_get_job_not_found(auth_headers: dict[str, str], mock_job: models.Job) -> None:
res = client.get(
"/api/v1/jobs/c8ecf5ea-77cf-48a2-9ecd-199ef35e0ccb",
headers=auth_headers,
@@ -84,7 +82,7 @@ def test_get_job_not_found(auth_headers: Dict[str, str], mock_job: models.Job) -
# GET /api/v1/jobs/:id/artifacts
# ---
def test_get_artifacts_pass(
auth_headers: Dict[str, str], db_session: Session, mock_job: models.Job
auth_headers: dict[str, str], db_session: Session, mock_job: models.Job
) -> None:
artifact = models.Artifact(
data=[], job_id=str(mock_job.id), type=schemas.ArtifactType.raw_transcript
@@ -104,7 +102,7 @@ def test_get_artifacts_pass(
def test_get_artifacts_not_found(
auth_headers: Dict[str, str], mock_job: models.Job
auth_headers: dict[str, str], mock_job: models.Job
) -> None:
res = client.get(
f"/api/v1/jobs/{mock_job.id}/artifacts",
@@ -118,7 +116,7 @@ def test_get_artifacts_not_found(
# DELETE /api/v1/jobs
# ---
def test_delete_job_pass(
auth_headers: Dict[str, str], mock_job: models.Job, db_session: Session
auth_headers: dict[str, str], mock_job: models.Job, db_session: Session
) -> None:
res = client.delete(
f"/api/v1/jobs/{mock_job.id}",

View File

@@ -1,5 +1,3 @@
from typing import Dict
from fastapi.testclient import TestClient
from app.web.main import app
@@ -22,6 +20,6 @@ def test_incorrect_api_key() -> None:
assert res.status_code == 401
def test_existing_api_key(auth_headers: Dict[str, str]) -> None:
def test_existing_api_key(auth_headers: dict[str, str]) -> None:
res = client.get("/api/v1", headers=auth_headers)
assert res.status_code == 204

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Union
from typing import Any
from pydantic import AnyHttpUrl, BaseModel, Field
@@ -9,7 +9,7 @@ class DetailResponse(BaseModel):
detail: str
DEFAULT_RESPONSES: Dict[Union[int, str], Dict[str, Any]] = {
DEFAULT_RESPONSES: dict[int | str, dict[str, Any]] = {
401: {"model": DetailResponse, "description": "Not authenticated"}
}
@@ -28,8 +28,7 @@ class PostJobPayload(BaseModel):
`language_detection` detects language from the first 30 seconds of audio."""
)
# TODO: limit to locales selected by whisper.
language: Optional[str] = Field(
language: str | None = Field(
description=(
"Spoken language in the media file. "
"While optional, this can improve output when set."

View File

@@ -1,5 +1,4 @@
from asyncio.log import logger
from typing import List, Optional
from uuid import UUID
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path
@@ -85,11 +84,11 @@ def create_job(
@api_router.get(
"/jobs", response_model=List[schemas.Job], summary="Get metadata for all jobs"
"/jobs", response_model=list[schemas.Job], summary="Get metadata for all jobs"
)
def get_transcripts(
type: Optional[schemas.JobType] = None, session: Session = Depends(get_session)
) -> List[models.Job]:
type: schemas.JobType | None = None, session: Session = Depends(get_session)
) -> list[models.Job]:
"""Get metadata for all jobs."""
query = session.query(models.Job)
@@ -107,7 +106,7 @@ def get_transcripts(
)
def get_transcript(
id: UUID = Path(), session: Session = Depends(get_session)
) -> Optional[models.Job]:
) -> models.Job | None:
"""
Use this route to check transcription status of any given job.
"""
@@ -119,12 +118,12 @@ def get_transcript(
@api_router.get(
"/jobs/{id}/artifacts",
response_model=List[schemas.Artifact],
response_model=list[schemas.Artifact],
summary="Get all artifacts for one job",
)
def get_artifacts_for_job(
id: UUID = Path(), session: Session = Depends(get_session)
) -> List[models.Artifact]:
) -> list[models.Artifact]:
"""
Right now, there is only one type of artifact (`raw_transcript`).
Returns an empty array for unfinished or non-existant jobs.

View File

@@ -1,5 +1,5 @@
from asyncio.log import logger
from typing import Any, Optional
from typing import Any
from uuid import UUID
from celery import Task
@@ -21,7 +21,7 @@ class TranscribeTask(Task):
def __init__(self) -> None:
super().__init__()
# currently only `LocalStrategy` is implemented.
self.strategy: Optional[LocalStrategy] = None
self.strategy: LocalStrategy | None = None
def __call__(self, *args: Any, **kwargs: Any) -> Any:
# load model into memory once when the first task is processed.
@@ -94,7 +94,7 @@ def transcribe(self: Task, job_id: UUID) -> None:
job.meta = {**job.meta, "error": str(e)} # type: ignore
job.status = schemas.JobStatus.error
db.commit()
raise (e)
raise
finally:
self.strategy.cleanup(job_id=job_id)
db.close()

View File

@@ -2,7 +2,7 @@ import os
import tempfile
from asyncio.log import logger
from os import path
from typing import Any, List, Literal, Optional
from typing import Any, Literal
from uuid import UUID
import requests
@@ -14,7 +14,7 @@ import app.shared.db.schemas as schemas
class DecodeOptions(BaseModel):
language: Optional[str]
language: str | None
task: Literal["translate", "transcribe"]
@@ -34,15 +34,15 @@ class LocalStrategy:
logger.info("initialized local strategy.")
def transcribe(
self, url: str, job_id: UUID, config: Optional[schemas.JobConfig]
) -> List[Any]:
self, url: str, job_id: UUID, config: schemas.JobConfig | None
) -> list[Any]:
return self.run_whisper(
self._download(url, job_id), "transcribe", config, job_id
)
def translate(
self, url: str, job_id: UUID, config: Optional[schemas.JobConfig]
) -> List[Any]:
self, url: str, job_id: UUID, config: schemas.JobConfig | None
) -> list[Any]:
return self.run_whisper(
self._download(url, job_id),
"translate",
@@ -50,9 +50,7 @@ class LocalStrategy:
job_id,
)
def detect_language(
self, url: str, config: Optional[schemas.JobConfig]
) -> List[Any]:
def detect_language(self, url: str, config: schemas.JobConfig | None) -> list[Any]:
raise NotImplementedError("detect_language has not been implemented yet.")
def _download(self, url: str, job_id: UUID) -> str:
@@ -73,9 +71,9 @@ class LocalStrategy:
self,
filepath: str,
task: Literal["translate", "transcribe"],
config: Optional[schemas.JobConfig],
config: schemas.JobConfig | None,
job_id: UUID,
) -> List[Any]:
) -> list[Any]:
try:
language = config.language if config else None