From d2223206bef4064aa095455b3817502da0adced0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20Sp=C3=B6ttel?= <1682504+fspoettel@users.noreply.github.com> Date: Wed, 28 Jun 2023 13:17:55 +0200 Subject: [PATCH] refactor: simplify type annotations --- app/shared/db/models.py | 3 +-- app/shared/db/schemas.py | 17 ++++++++--------- app/tests/conftest.py | 4 ++-- app/tests/test_api.py | 20 +++++++++----------- app/tests/test_auth.py | 4 +--- app/web/dtos.py | 7 +++---- app/web/main.py | 13 ++++++------- app/worker/main.py | 6 +++--- app/worker/strategies/local.py | 20 +++++++++----------- 9 files changed, 42 insertions(+), 52 deletions(-) diff --git a/app/shared/db/models.py b/app/shared/db/models.py index 793b3f0..30c008a 100644 --- a/app/shared/db/models.py +++ b/app/shared/db/models.py @@ -1,5 +1,4 @@ import uuid -from typing import Optional from sqlalchemy import JSON, VARCHAR, Column, DateTime, Enum, ForeignKey, String, func from sqlalchemy.dialects.postgresql import UUID @@ -19,7 +18,7 @@ class WithStandardFields: return Column(DateTime, server_default=func.now(), nullable=False) @declared_attr - def updated_at(cls) -> Mapped[Optional[DateTime]]: + def updated_at(cls) -> Mapped[DateTime | None]: return Column(DateTime, onupdate=func.now()) @declared_attr diff --git a/app/shared/db/schemas.py b/app/shared/db/schemas.py index 8a61116..fd7266c 100644 --- a/app/shared/db/schemas.py +++ b/app/shared/db/schemas.py @@ -1,6 +1,5 @@ import enum from datetime import datetime -from typing import List, Optional from uuid import UUID from pydantic import AnyHttpUrl, BaseModel, Field @@ -9,7 +8,7 @@ from pydantic import AnyHttpUrl, BaseModel, Field class WithDbFields(BaseModel): id: UUID created_at: datetime - updated_at: Optional[datetime] + updated_at: datetime | None class Config: orm_mode = True @@ -38,7 +37,7 @@ class JobConfig(BaseModel): """Configuration for a job.""" # TODO: limit to locales selected by whisper. - language: Optional[str] = Field( + language: str | None = Field( description=( "Spoken language in the media file. " "While optional, this can improve output." @@ -49,10 +48,10 @@ class JobConfig(BaseModel): class JobMeta(BaseModel): """Metadata relating to a job's execution.""" - error: Optional[str] = Field( + error: str | None = Field( description="Will contain a descriptive error message if processing failed." ) - task_id: Optional[UUID] = Field( + task_id: UUID | None = Field( description="Internal celery id of this job submission." ) @@ -63,8 +62,8 @@ class Job(WithDbFields): status: JobStatus type: JobType url: AnyHttpUrl - meta: Optional[JobMeta] - config: Optional[JobConfig] + meta: JobMeta | None + config: JobConfig | None class RawTranscript(BaseModel): @@ -75,7 +74,7 @@ class RawTranscript(BaseModel): start: float end: float text: str - tokens: List[int] + tokens: list[int] temperature: float avg_logprob: float compression_ratio: float @@ -85,6 +84,6 @@ class RawTranscript(BaseModel): class Artifact(WithDbFields): """whisper output for one job.""" - data: Optional[List[RawTranscript]] + data: list[RawTranscript] | None job_id: UUID type: ArtifactType diff --git a/app/tests/conftest.py b/app/tests/conftest.py index 7edafb3..d64f9e3 100644 --- a/app/tests/conftest.py +++ b/app/tests/conftest.py @@ -1,4 +1,4 @@ -from typing import Dict, Generator +from typing import Generator import pytest from sqlalchemy.orm import Session @@ -22,7 +22,7 @@ def pytest_unconfigure() -> None: @pytest.fixture(name="auth_headers", scope="function") -def auth_header() -> Dict[str, str]: +def auth_header() -> dict[str, str]: return {"Authorization": f"Bearer {settings.API_SECRET}"} diff --git a/app/tests/test_api.py b/app/tests/test_api.py index d0f3e3f..63fca22 100644 --- a/app/tests/test_api.py +++ b/app/tests/test_api.py @@ -1,5 +1,3 @@ -from typing import Dict - import pytest from fastapi.testclient import TestClient from sqlalchemy.orm import Session @@ -26,7 +24,7 @@ def mock_job(db_session: Session) -> models.Job: # POST /api/v1/jobs # --- -def test_create_job_pass(auth_headers: Dict[str, str]) -> None: +def test_create_job_pass(auth_headers: dict[str, str]) -> None: res = client.post( "/api/v1/jobs", headers=auth_headers, @@ -36,12 +34,12 @@ def test_create_job_pass(auth_headers: Dict[str, str]) -> None: assert isinstance(res.json()["id"], str) -def test_create_job_missing_body(auth_headers: Dict[str, str]) -> None: +def test_create_job_missing_body(auth_headers: dict[str, str]) -> None: res = client.post("/api/v1/jobs", headers=auth_headers, json={}) assert res.status_code == 422 -def test_create_job_malformed_url(auth_headers: Dict[str, str]) -> None: +def test_create_job_malformed_url(auth_headers: dict[str, str]) -> None: res = client.post( "/api/v1/jobs", headers=auth_headers, @@ -52,7 +50,7 @@ def test_create_job_malformed_url(auth_headers: Dict[str, str]) -> None: # GET /api/v1/jobs # --- -def test_get_jobs_pass(auth_headers: Dict[str, str], mock_job: models.Job) -> None: +def test_get_jobs_pass(auth_headers: dict[str, str], mock_job: models.Job) -> None: res = client.get( "/api/v1/jobs?type=transcript", headers=auth_headers, @@ -63,7 +61,7 @@ def test_get_jobs_pass(auth_headers: Dict[str, str], mock_job: models.Job) -> No # GET /api/v1/jobs/:id # --- -def test_get_job_pass(auth_headers: Dict[str, str], mock_job: models.Job) -> None: +def test_get_job_pass(auth_headers: dict[str, str], mock_job: models.Job) -> None: res = client.get( f"/api/v1/jobs/{mock_job.id}", headers=auth_headers, @@ -72,7 +70,7 @@ def test_get_job_pass(auth_headers: Dict[str, str], mock_job: models.Job) -> Non assert res.json()["id"] == str(mock_job.id) -def test_get_job_not_found(auth_headers: Dict[str, str], mock_job: models.Job) -> None: +def test_get_job_not_found(auth_headers: dict[str, str], mock_job: models.Job) -> None: res = client.get( "/api/v1/jobs/c8ecf5ea-77cf-48a2-9ecd-199ef35e0ccb", headers=auth_headers, @@ -84,7 +82,7 @@ def test_get_job_not_found(auth_headers: Dict[str, str], mock_job: models.Job) - # GET /api/v1/jobs/:id/artifacts # --- def test_get_artifacts_pass( - auth_headers: Dict[str, str], db_session: Session, mock_job: models.Job + auth_headers: dict[str, str], db_session: Session, mock_job: models.Job ) -> None: artifact = models.Artifact( data=[], job_id=str(mock_job.id), type=schemas.ArtifactType.raw_transcript @@ -104,7 +102,7 @@ def test_get_artifacts_pass( def test_get_artifacts_not_found( - auth_headers: Dict[str, str], mock_job: models.Job + auth_headers: dict[str, str], mock_job: models.Job ) -> None: res = client.get( f"/api/v1/jobs/{mock_job.id}/artifacts", @@ -118,7 +116,7 @@ def test_get_artifacts_not_found( # DELETE /api/v1/jobs # --- def test_delete_job_pass( - auth_headers: Dict[str, str], mock_job: models.Job, db_session: Session + auth_headers: dict[str, str], mock_job: models.Job, db_session: Session ) -> None: res = client.delete( f"/api/v1/jobs/{mock_job.id}", diff --git a/app/tests/test_auth.py b/app/tests/test_auth.py index c2b69a5..4161b6e 100644 --- a/app/tests/test_auth.py +++ b/app/tests/test_auth.py @@ -1,5 +1,3 @@ -from typing import Dict - from fastapi.testclient import TestClient from app.web.main import app @@ -22,6 +20,6 @@ def test_incorrect_api_key() -> None: assert res.status_code == 401 -def test_existing_api_key(auth_headers: Dict[str, str]) -> None: +def test_existing_api_key(auth_headers: dict[str, str]) -> None: res = client.get("/api/v1", headers=auth_headers) assert res.status_code == 204 diff --git a/app/web/dtos.py b/app/web/dtos.py index b3873d9..9ef63cd 100644 --- a/app/web/dtos.py +++ b/app/web/dtos.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Union +from typing import Any from pydantic import AnyHttpUrl, BaseModel, Field @@ -9,7 +9,7 @@ class DetailResponse(BaseModel): detail: str -DEFAULT_RESPONSES: Dict[Union[int, str], Dict[str, Any]] = { +DEFAULT_RESPONSES: dict[int | str, dict[str, Any]] = { 401: {"model": DetailResponse, "description": "Not authenticated"} } @@ -28,8 +28,7 @@ class PostJobPayload(BaseModel): `language_detection` detects language from the first 30 seconds of audio.""" ) - # TODO: limit to locales selected by whisper. - language: Optional[str] = Field( + language: str | None = Field( description=( "Spoken language in the media file. " "While optional, this can improve output when set." diff --git a/app/web/main.py b/app/web/main.py index 41bb707..d440d52 100644 --- a/app/web/main.py +++ b/app/web/main.py @@ -1,5 +1,4 @@ from asyncio.log import logger -from typing import List, Optional from uuid import UUID from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path @@ -85,11 +84,11 @@ def create_job( @api_router.get( - "/jobs", response_model=List[schemas.Job], summary="Get metadata for all jobs" + "/jobs", response_model=list[schemas.Job], summary="Get metadata for all jobs" ) def get_transcripts( - type: Optional[schemas.JobType] = None, session: Session = Depends(get_session) -) -> List[models.Job]: + type: schemas.JobType | None = None, session: Session = Depends(get_session) +) -> list[models.Job]: """Get metadata for all jobs.""" query = session.query(models.Job) @@ -107,7 +106,7 @@ def get_transcripts( ) def get_transcript( id: UUID = Path(), session: Session = Depends(get_session) -) -> Optional[models.Job]: +) -> models.Job | None: """ Use this route to check transcription status of any given job. """ @@ -119,12 +118,12 @@ def get_transcript( @api_router.get( "/jobs/{id}/artifacts", - response_model=List[schemas.Artifact], + response_model=list[schemas.Artifact], summary="Get all artifacts for one job", ) def get_artifacts_for_job( id: UUID = Path(), session: Session = Depends(get_session) -) -> List[models.Artifact]: +) -> list[models.Artifact]: """ Right now, there is only one type of artifact (`raw_transcript`). Returns an empty array for unfinished or non-existant jobs. diff --git a/app/worker/main.py b/app/worker/main.py index 391ad66..9a46c76 100644 --- a/app/worker/main.py +++ b/app/worker/main.py @@ -1,5 +1,5 @@ from asyncio.log import logger -from typing import Any, Optional +from typing import Any from uuid import UUID from celery import Task @@ -21,7 +21,7 @@ class TranscribeTask(Task): def __init__(self) -> None: super().__init__() # currently only `LocalStrategy` is implemented. - self.strategy: Optional[LocalStrategy] = None + self.strategy: LocalStrategy | None = None def __call__(self, *args: Any, **kwargs: Any) -> Any: # load model into memory once when the first task is processed. @@ -94,7 +94,7 @@ def transcribe(self: Task, job_id: UUID) -> None: job.meta = {**job.meta, "error": str(e)} # type: ignore job.status = schemas.JobStatus.error db.commit() - raise (e) + raise finally: self.strategy.cleanup(job_id=job_id) db.close() diff --git a/app/worker/strategies/local.py b/app/worker/strategies/local.py index bedf9b3..db3007d 100644 --- a/app/worker/strategies/local.py +++ b/app/worker/strategies/local.py @@ -2,7 +2,7 @@ import os import tempfile from asyncio.log import logger from os import path -from typing import Any, List, Literal, Optional +from typing import Any, Literal from uuid import UUID import requests @@ -14,7 +14,7 @@ import app.shared.db.schemas as schemas class DecodeOptions(BaseModel): - language: Optional[str] + language: str | None task: Literal["translate", "transcribe"] @@ -34,15 +34,15 @@ class LocalStrategy: logger.info("initialized local strategy.") def transcribe( - self, url: str, job_id: UUID, config: Optional[schemas.JobConfig] - ) -> List[Any]: + self, url: str, job_id: UUID, config: schemas.JobConfig | None + ) -> list[Any]: return self.run_whisper( self._download(url, job_id), "transcribe", config, job_id ) def translate( - self, url: str, job_id: UUID, config: Optional[schemas.JobConfig] - ) -> List[Any]: + self, url: str, job_id: UUID, config: schemas.JobConfig | None + ) -> list[Any]: return self.run_whisper( self._download(url, job_id), "translate", @@ -50,9 +50,7 @@ class LocalStrategy: job_id, ) - def detect_language( - self, url: str, config: Optional[schemas.JobConfig] - ) -> List[Any]: + def detect_language(self, url: str, config: schemas.JobConfig | None) -> list[Any]: raise NotImplementedError("detect_language has not been implemented yet.") def _download(self, url: str, job_id: UUID) -> str: @@ -73,9 +71,9 @@ class LocalStrategy: self, filepath: str, task: Literal["translate", "transcribe"], - config: Optional[schemas.JobConfig], + config: schemas.JobConfig | None, job_id: UUID, - ) -> List[Any]: + ) -> list[Any]: try: language = config.language if config else None