mirror of
https://github.com/bellingcat/whisperbox-transcribe.git
synced 2026-06-08 03:28:35 +03:00
@@ -3,8 +3,8 @@ from logging.config import fileConfig
|
||||
from alembic import context
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
|
||||
from app.shared.settings import settings
|
||||
from app.shared.db.models import Base
|
||||
from app.shared.settings import settings
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import enum
|
||||
from datetime import datetime
|
||||
from typing import Any, List, Optional
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import AnyHttpUrl, BaseModel
|
||||
from pydantic import AnyHttpUrl, BaseModel, Field
|
||||
|
||||
|
||||
class WithDbFields(BaseModel):
|
||||
@@ -26,6 +26,8 @@ class JobType(str, enum.Enum):
|
||||
|
||||
|
||||
class JobStatus(str, enum.Enum):
|
||||
"""Processing status of a job."""
|
||||
|
||||
create = "create"
|
||||
processing = "processing"
|
||||
error = "error"
|
||||
@@ -33,15 +35,32 @@ class JobStatus(str, enum.Enum):
|
||||
|
||||
|
||||
class JobConfig(BaseModel):
|
||||
language: Optional[str]
|
||||
"""Configuration for a job."""
|
||||
|
||||
# TODO: limit to locales selected by whisper.
|
||||
language: Optional[str] = Field(
|
||||
description=(
|
||||
"Spoken language in the media file."
|
||||
"While optional, this can improve output "
|
||||
"by selecting a language-specific model. (applies to 'en')"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class JobMeta(BaseModel):
|
||||
error: Optional[str]
|
||||
task_id: Optional[UUID]
|
||||
"""Metadata relating to a job's execution."""
|
||||
|
||||
error: Optional[str] = Field(
|
||||
description="Will contain a descriptive error message if processing failed."
|
||||
)
|
||||
task_id: Optional[UUID] = Field(
|
||||
description="Internal celery id of this job submission."
|
||||
)
|
||||
|
||||
|
||||
class Job(WithDbFields):
|
||||
"""A transcription job for one media file."""
|
||||
|
||||
status: JobStatus
|
||||
type: JobType
|
||||
url: AnyHttpUrl
|
||||
@@ -50,6 +69,8 @@ class Job(WithDbFields):
|
||||
|
||||
|
||||
class RawTranscript(BaseModel):
|
||||
"""A single transcript passage returned by whisper."""
|
||||
|
||||
id: int
|
||||
seek: int
|
||||
start: float
|
||||
@@ -63,6 +84,8 @@ class RawTranscript(BaseModel):
|
||||
|
||||
|
||||
class Artifact(WithDbFields):
|
||||
"""whisper output for one job."""
|
||||
|
||||
data: Optional[List[RawTranscript]]
|
||||
job_id: UUID
|
||||
type: ArtifactType
|
||||
|
||||
@@ -5,8 +5,8 @@ from sqlalchemy.orm import Session
|
||||
from sqlalchemy_utils import create_database, database_exists, drop_database
|
||||
|
||||
import app.shared.db.models as models
|
||||
from app.shared.settings import settings
|
||||
from app.shared.db.base import SessionLocal, engine, get_session
|
||||
from app.shared.settings import settings
|
||||
from app.web.main import app
|
||||
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@ import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import app.shared.db.schemas as schemas
|
||||
import app.shared.db.models as models
|
||||
import app.shared.db.schemas as schemas
|
||||
from app.shared.db.schemas import JobStatus, JobType
|
||||
from app.web.main import app
|
||||
|
||||
@@ -77,12 +77,13 @@ def test_get_job_not_found(auth_headers: Dict[str, str], mock_job: models.Job) -
|
||||
"/api/v1/jobs/c8ecf5ea-77cf-48a2-9ecd-199ef35e0ccb",
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert res.status_code == 404
|
||||
|
||||
|
||||
# GET /api/v1/jobs/:id/artifacts
|
||||
# ---
|
||||
def test_get_artifact_pass(
|
||||
def test_get_artifacts_pass(
|
||||
auth_headers: Dict[str, str], db_session: Session, mock_job: models.Job
|
||||
) -> None:
|
||||
artifact = models.Artifact(
|
||||
@@ -102,14 +103,16 @@ def test_get_artifact_pass(
|
||||
assert res.json()[0]["id"] == str(artifact.id)
|
||||
|
||||
|
||||
def test_get_artifact_not_found(
|
||||
def test_get_artifacts_not_found(
|
||||
auth_headers: Dict[str, str], mock_job: models.Job
|
||||
) -> None:
|
||||
res = client.get(
|
||||
f"/api/v1/jobs/{mock_job.id}/artifacts",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert res.status_code == 404
|
||||
|
||||
assert len(res.json()) == 0
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
# DELETE /api/v1/jobs
|
||||
|
||||
@@ -14,7 +14,7 @@ def test_authorization_header_missing() -> None:
|
||||
|
||||
def test_authorization_header_malformed() -> None:
|
||||
res = client.get("/api/v1", headers={"Authorization": "Bearer"})
|
||||
assert res.status_code == 422
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_incorrect_api_key() -> None:
|
||||
@@ -24,4 +24,4 @@ def test_incorrect_api_key() -> None:
|
||||
|
||||
def test_existing_api_key(auth_headers: Dict[str, str]) -> None:
|
||||
res = client.get("/api/v1", headers=auth_headers)
|
||||
assert res.status_code == 200
|
||||
assert res.status_code == 204
|
||||
|
||||
33
app/web/dtos.py
Normal file
33
app/web/dtos.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import AnyHttpUrl, BaseModel, Field
|
||||
|
||||
import app.shared.db.schemas as schemas
|
||||
|
||||
|
||||
class DetailResponse(BaseModel):
|
||||
detail: str
|
||||
|
||||
|
||||
DEFAULT_RESPONSES: Dict[int | str, Dict[str, Any]] = {
|
||||
401: {"model": DetailResponse, "description": "Not authenticated"}
|
||||
}
|
||||
|
||||
|
||||
class PostJobPayload(BaseModel):
|
||||
url: AnyHttpUrl = Field(
|
||||
description=(
|
||||
"URL where the media file is available. This needs to be a direct link."
|
||||
)
|
||||
)
|
||||
|
||||
type: schemas.JobType = Field(description="Type of this job.")
|
||||
|
||||
# TODO: limit to locales selected by whisper.
|
||||
language: Optional[str] = Field(
|
||||
description=(
|
||||
"Spoken language in the media file."
|
||||
"While optional, this can improve output "
|
||||
"by selecting a language-specific model. (applies to 'en')"
|
||||
)
|
||||
)
|
||||
@@ -3,20 +3,24 @@ from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path
|
||||
from pydantic import AnyHttpUrl, BaseModel
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import app.shared.db.schemas as schemas
|
||||
import app.shared.db.models as models
|
||||
import app.shared.db.schemas as schemas
|
||||
from app.shared.celery import get_celery_binding
|
||||
from app.shared.db.base import get_session
|
||||
from app.web.dtos import DEFAULT_RESPONSES, DetailResponse, PostJobPayload
|
||||
from app.web.security import authenticate_api_key
|
||||
|
||||
app = FastAPI()
|
||||
celery = get_celery_binding()
|
||||
|
||||
api_router = APIRouter(prefix="/api/v1", dependencies=[Depends(authenticate_api_key)])
|
||||
api_router = APIRouter(
|
||||
prefix="/api/v1",
|
||||
dependencies=[Depends(authenticate_api_key)],
|
||||
responses={**DEFAULT_RESPONSES},
|
||||
)
|
||||
|
||||
|
||||
def queue_task(job: models.Job) -> None:
|
||||
@@ -28,17 +32,11 @@ def queue_task(job: models.Job) -> None:
|
||||
transcribe.delay(job.id)
|
||||
|
||||
|
||||
@api_router.get("/")
|
||||
@api_router.get("/", response_model=None, status_code=204)
|
||||
def api_root() -> None:
|
||||
return None
|
||||
|
||||
|
||||
class PostJobPayload(BaseModel):
|
||||
url: AnyHttpUrl
|
||||
type: schemas.JobType
|
||||
language: Optional[str]
|
||||
|
||||
|
||||
@api_router.post("/jobs", response_model=schemas.Job, status_code=201)
|
||||
def create_job(
|
||||
payload: PostJobPayload,
|
||||
@@ -77,7 +75,11 @@ def get_transcripts(
|
||||
return query.all()
|
||||
|
||||
|
||||
@api_router.get("/jobs/{id}", response_model=schemas.Job)
|
||||
@api_router.get(
|
||||
"/jobs/{id}",
|
||||
response_model=schemas.Job,
|
||||
responses={404: {"model": DetailResponse, "description": "Not authenticated"}},
|
||||
)
|
||||
def get_transcript(
|
||||
id: UUID = Path(), session: Session = Depends(get_session)
|
||||
) -> Optional[models.Job]:
|
||||
@@ -95,9 +97,6 @@ def get_artifacts_for_job(
|
||||
session.query(models.Artifact).filter(models.Artifact.job_id == str(id))
|
||||
).all()
|
||||
|
||||
if not len(artifacts):
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
return artifacts
|
||||
|
||||
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
from hmac import compare_digest
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from app.shared.settings import settings
|
||||
|
||||
|
||||
def authenticate_api_key(
|
||||
token: str = Depends(OAuth2PasswordBearer(tokenUrl="token")),
|
||||
credentials: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
|
||||
) -> None:
|
||||
if not token:
|
||||
raise HTTPException(status_code=422)
|
||||
# use compare_digest to counter timing attacks.
|
||||
if not compare_digest(settings.API_SECRET, token):
|
||||
if not credentials or not compare_digest(
|
||||
settings.API_SECRET, credentials.credentials
|
||||
):
|
||||
raise HTTPException(status_code=401)
|
||||
|
||||
@@ -4,8 +4,8 @@ from uuid import UUID
|
||||
from celery import Task
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import app.shared.db.schemas as schemas
|
||||
import app.shared.db.models as models
|
||||
import app.shared.db.schemas as schemas
|
||||
from app.shared.celery import get_celery_binding
|
||||
from app.shared.db.base import SessionLocal
|
||||
from app.worker.strategies.local import LocalStrategy
|
||||
@@ -19,7 +19,10 @@ def transcribe(self: Task, job_id: UUID) -> None:
|
||||
db: Session = SessionLocal()
|
||||
job = db.query(models.Job).filter(models.Job.id == job_id).one()
|
||||
|
||||
if job.status == schemas.JobStatus.error or job.status == schemas.JobStatus.success:
|
||||
if (
|
||||
job.status == schemas.JobStatus.error
|
||||
or job.status == schemas.JobStatus.success
|
||||
):
|
||||
logger.warn(
|
||||
"[{job.id}]: Received job that has already been processed, abort."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user