mirror of
https://github.com/bellingcat/whisperbox-transcribe.git
synced 2026-06-07 19:18:35 +03:00
refactor: simplify type annotations
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import JSON, VARCHAR, Column, DateTime, Enum, ForeignKey, String, func
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
@@ -19,7 +18,7 @@ class WithStandardFields:
|
||||
return Column(DateTime, server_default=func.now(), nullable=False)
|
||||
|
||||
@declared_attr
|
||||
def updated_at(cls) -> Mapped[Optional[DateTime]]:
|
||||
def updated_at(cls) -> Mapped[DateTime | None]:
|
||||
return Column(DateTime, onupdate=func.now())
|
||||
|
||||
@declared_attr
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import enum
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import AnyHttpUrl, BaseModel, Field
|
||||
@@ -9,7 +8,7 @@ from pydantic import AnyHttpUrl, BaseModel, Field
|
||||
class WithDbFields(BaseModel):
|
||||
id: UUID
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime]
|
||||
updated_at: datetime | None
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
@@ -38,7 +37,7 @@ class JobConfig(BaseModel):
|
||||
"""Configuration for a job."""
|
||||
|
||||
# TODO: limit to locales selected by whisper.
|
||||
language: Optional[str] = Field(
|
||||
language: str | None = Field(
|
||||
description=(
|
||||
"Spoken language in the media file. "
|
||||
"While optional, this can improve output."
|
||||
@@ -49,10 +48,10 @@ class JobConfig(BaseModel):
|
||||
class JobMeta(BaseModel):
|
||||
"""Metadata relating to a job's execution."""
|
||||
|
||||
error: Optional[str] = Field(
|
||||
error: str | None = Field(
|
||||
description="Will contain a descriptive error message if processing failed."
|
||||
)
|
||||
task_id: Optional[UUID] = Field(
|
||||
task_id: UUID | None = Field(
|
||||
description="Internal celery id of this job submission."
|
||||
)
|
||||
|
||||
@@ -63,8 +62,8 @@ class Job(WithDbFields):
|
||||
status: JobStatus
|
||||
type: JobType
|
||||
url: AnyHttpUrl
|
||||
meta: Optional[JobMeta]
|
||||
config: Optional[JobConfig]
|
||||
meta: JobMeta | None
|
||||
config: JobConfig | None
|
||||
|
||||
|
||||
class RawTranscript(BaseModel):
|
||||
@@ -75,7 +74,7 @@ class RawTranscript(BaseModel):
|
||||
start: float
|
||||
end: float
|
||||
text: str
|
||||
tokens: List[int]
|
||||
tokens: list[int]
|
||||
temperature: float
|
||||
avg_logprob: float
|
||||
compression_ratio: float
|
||||
@@ -85,6 +84,6 @@ class RawTranscript(BaseModel):
|
||||
class Artifact(WithDbFields):
|
||||
"""whisper output for one job."""
|
||||
|
||||
data: Optional[List[RawTranscript]]
|
||||
data: list[RawTranscript] | None
|
||||
job_id: UUID
|
||||
type: ArtifactType
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Dict, Generator
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -22,7 +22,7 @@ def pytest_unconfigure() -> None:
|
||||
|
||||
|
||||
@pytest.fixture(name="auth_headers", scope="function")
|
||||
def auth_header() -> Dict[str, str]:
|
||||
def auth_header() -> dict[str, str]:
|
||||
return {"Authorization": f"Bearer {settings.API_SECRET}"}
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -26,7 +24,7 @@ def mock_job(db_session: Session) -> models.Job:
|
||||
|
||||
# POST /api/v1/jobs
|
||||
# ---
|
||||
def test_create_job_pass(auth_headers: Dict[str, str]) -> None:
|
||||
def test_create_job_pass(auth_headers: dict[str, str]) -> None:
|
||||
res = client.post(
|
||||
"/api/v1/jobs",
|
||||
headers=auth_headers,
|
||||
@@ -36,12 +34,12 @@ def test_create_job_pass(auth_headers: Dict[str, str]) -> None:
|
||||
assert isinstance(res.json()["id"], str)
|
||||
|
||||
|
||||
def test_create_job_missing_body(auth_headers: Dict[str, str]) -> None:
|
||||
def test_create_job_missing_body(auth_headers: dict[str, str]) -> None:
|
||||
res = client.post("/api/v1/jobs", headers=auth_headers, json={})
|
||||
assert res.status_code == 422
|
||||
|
||||
|
||||
def test_create_job_malformed_url(auth_headers: Dict[str, str]) -> None:
|
||||
def test_create_job_malformed_url(auth_headers: dict[str, str]) -> None:
|
||||
res = client.post(
|
||||
"/api/v1/jobs",
|
||||
headers=auth_headers,
|
||||
@@ -52,7 +50,7 @@ def test_create_job_malformed_url(auth_headers: Dict[str, str]) -> None:
|
||||
|
||||
# GET /api/v1/jobs
|
||||
# ---
|
||||
def test_get_jobs_pass(auth_headers: Dict[str, str], mock_job: models.Job) -> None:
|
||||
def test_get_jobs_pass(auth_headers: dict[str, str], mock_job: models.Job) -> None:
|
||||
res = client.get(
|
||||
"/api/v1/jobs?type=transcript",
|
||||
headers=auth_headers,
|
||||
@@ -63,7 +61,7 @@ def test_get_jobs_pass(auth_headers: Dict[str, str], mock_job: models.Job) -> No
|
||||
|
||||
# GET /api/v1/jobs/:id
|
||||
# ---
|
||||
def test_get_job_pass(auth_headers: Dict[str, str], mock_job: models.Job) -> None:
|
||||
def test_get_job_pass(auth_headers: dict[str, str], mock_job: models.Job) -> None:
|
||||
res = client.get(
|
||||
f"/api/v1/jobs/{mock_job.id}",
|
||||
headers=auth_headers,
|
||||
@@ -72,7 +70,7 @@ def test_get_job_pass(auth_headers: Dict[str, str], mock_job: models.Job) -> Non
|
||||
assert res.json()["id"] == str(mock_job.id)
|
||||
|
||||
|
||||
def test_get_job_not_found(auth_headers: Dict[str, str], mock_job: models.Job) -> None:
|
||||
def test_get_job_not_found(auth_headers: dict[str, str], mock_job: models.Job) -> None:
|
||||
res = client.get(
|
||||
"/api/v1/jobs/c8ecf5ea-77cf-48a2-9ecd-199ef35e0ccb",
|
||||
headers=auth_headers,
|
||||
@@ -84,7 +82,7 @@ def test_get_job_not_found(auth_headers: Dict[str, str], mock_job: models.Job) -
|
||||
# GET /api/v1/jobs/:id/artifacts
|
||||
# ---
|
||||
def test_get_artifacts_pass(
|
||||
auth_headers: Dict[str, str], db_session: Session, mock_job: models.Job
|
||||
auth_headers: dict[str, str], db_session: Session, mock_job: models.Job
|
||||
) -> None:
|
||||
artifact = models.Artifact(
|
||||
data=[], job_id=str(mock_job.id), type=schemas.ArtifactType.raw_transcript
|
||||
@@ -104,7 +102,7 @@ def test_get_artifacts_pass(
|
||||
|
||||
|
||||
def test_get_artifacts_not_found(
|
||||
auth_headers: Dict[str, str], mock_job: models.Job
|
||||
auth_headers: dict[str, str], mock_job: models.Job
|
||||
) -> None:
|
||||
res = client.get(
|
||||
f"/api/v1/jobs/{mock_job.id}/artifacts",
|
||||
@@ -118,7 +116,7 @@ def test_get_artifacts_not_found(
|
||||
# DELETE /api/v1/jobs
|
||||
# ---
|
||||
def test_delete_job_pass(
|
||||
auth_headers: Dict[str, str], mock_job: models.Job, db_session: Session
|
||||
auth_headers: dict[str, str], mock_job: models.Job, db_session: Session
|
||||
) -> None:
|
||||
res = client.delete(
|
||||
f"/api/v1/jobs/{mock_job.id}",
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Dict
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.web.main import app
|
||||
@@ -22,6 +20,6 @@ def test_incorrect_api_key() -> None:
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_existing_api_key(auth_headers: Dict[str, str]) -> None:
|
||||
def test_existing_api_key(auth_headers: dict[str, str]) -> None:
|
||||
res = client.get("/api/v1", headers=auth_headers)
|
||||
assert res.status_code == 204
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from pydantic import AnyHttpUrl, BaseModel, Field
|
||||
|
||||
@@ -9,7 +9,7 @@ class DetailResponse(BaseModel):
|
||||
detail: str
|
||||
|
||||
|
||||
DEFAULT_RESPONSES: Dict[Union[int, str], Dict[str, Any]] = {
|
||||
DEFAULT_RESPONSES: dict[int | str, dict[str, Any]] = {
|
||||
401: {"model": DetailResponse, "description": "Not authenticated"}
|
||||
}
|
||||
|
||||
@@ -28,8 +28,7 @@ class PostJobPayload(BaseModel):
|
||||
`language_detection` detects language from the first 30 seconds of audio."""
|
||||
)
|
||||
|
||||
# TODO: limit to locales selected by whisper.
|
||||
language: Optional[str] = Field(
|
||||
language: str | None = Field(
|
||||
description=(
|
||||
"Spoken language in the media file. "
|
||||
"While optional, this can improve output when set."
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from asyncio.log import logger
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path
|
||||
@@ -85,11 +84,11 @@ def create_job(
|
||||
|
||||
|
||||
@api_router.get(
|
||||
"/jobs", response_model=List[schemas.Job], summary="Get metadata for all jobs"
|
||||
"/jobs", response_model=list[schemas.Job], summary="Get metadata for all jobs"
|
||||
)
|
||||
def get_transcripts(
|
||||
type: Optional[schemas.JobType] = None, session: Session = Depends(get_session)
|
||||
) -> List[models.Job]:
|
||||
type: schemas.JobType | None = None, session: Session = Depends(get_session)
|
||||
) -> list[models.Job]:
|
||||
"""Get metadata for all jobs."""
|
||||
query = session.query(models.Job)
|
||||
|
||||
@@ -107,7 +106,7 @@ def get_transcripts(
|
||||
)
|
||||
def get_transcript(
|
||||
id: UUID = Path(), session: Session = Depends(get_session)
|
||||
) -> Optional[models.Job]:
|
||||
) -> models.Job | None:
|
||||
"""
|
||||
Use this route to check transcription status of any given job.
|
||||
"""
|
||||
@@ -119,12 +118,12 @@ def get_transcript(
|
||||
|
||||
@api_router.get(
|
||||
"/jobs/{id}/artifacts",
|
||||
response_model=List[schemas.Artifact],
|
||||
response_model=list[schemas.Artifact],
|
||||
summary="Get all artifacts for one job",
|
||||
)
|
||||
def get_artifacts_for_job(
|
||||
id: UUID = Path(), session: Session = Depends(get_session)
|
||||
) -> List[models.Artifact]:
|
||||
) -> list[models.Artifact]:
|
||||
"""
|
||||
Right now, there is only one type of artifact (`raw_transcript`).
|
||||
Returns an empty array for unfinished or non-existant jobs.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from asyncio.log import logger
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from celery import Task
|
||||
@@ -21,7 +21,7 @@ class TranscribeTask(Task):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# currently only `LocalStrategy` is implemented.
|
||||
self.strategy: Optional[LocalStrategy] = None
|
||||
self.strategy: LocalStrategy | None = None
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
# load model into memory once when the first task is processed.
|
||||
@@ -94,7 +94,7 @@ def transcribe(self: Task, job_id: UUID) -> None:
|
||||
job.meta = {**job.meta, "error": str(e)} # type: ignore
|
||||
job.status = schemas.JobStatus.error
|
||||
db.commit()
|
||||
raise (e)
|
||||
raise
|
||||
finally:
|
||||
self.strategy.cleanup(job_id=job_id)
|
||||
db.close()
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
import tempfile
|
||||
from asyncio.log import logger
|
||||
from os import path
|
||||
from typing import Any, List, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
@@ -14,7 +14,7 @@ import app.shared.db.schemas as schemas
|
||||
|
||||
|
||||
class DecodeOptions(BaseModel):
|
||||
language: Optional[str]
|
||||
language: str | None
|
||||
task: Literal["translate", "transcribe"]
|
||||
|
||||
|
||||
@@ -34,15 +34,15 @@ class LocalStrategy:
|
||||
logger.info("initialized local strategy.")
|
||||
|
||||
def transcribe(
|
||||
self, url: str, job_id: UUID, config: Optional[schemas.JobConfig]
|
||||
) -> List[Any]:
|
||||
self, url: str, job_id: UUID, config: schemas.JobConfig | None
|
||||
) -> list[Any]:
|
||||
return self.run_whisper(
|
||||
self._download(url, job_id), "transcribe", config, job_id
|
||||
)
|
||||
|
||||
def translate(
|
||||
self, url: str, job_id: UUID, config: Optional[schemas.JobConfig]
|
||||
) -> List[Any]:
|
||||
self, url: str, job_id: UUID, config: schemas.JobConfig | None
|
||||
) -> list[Any]:
|
||||
return self.run_whisper(
|
||||
self._download(url, job_id),
|
||||
"translate",
|
||||
@@ -50,9 +50,7 @@ class LocalStrategy:
|
||||
job_id,
|
||||
)
|
||||
|
||||
def detect_language(
|
||||
self, url: str, config: Optional[schemas.JobConfig]
|
||||
) -> List[Any]:
|
||||
def detect_language(self, url: str, config: schemas.JobConfig | None) -> list[Any]:
|
||||
raise NotImplementedError("detect_language has not been implemented yet.")
|
||||
|
||||
def _download(self, url: str, job_id: UUID) -> str:
|
||||
@@ -73,9 +71,9 @@ class LocalStrategy:
|
||||
self,
|
||||
filepath: str,
|
||||
task: Literal["translate", "transcribe"],
|
||||
config: Optional[schemas.JobConfig],
|
||||
config: schemas.JobConfig | None,
|
||||
job_id: UUID,
|
||||
) -> List[Any]:
|
||||
) -> list[Any]:
|
||||
try:
|
||||
language = config.language if config else None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user