mirror of
https://github.com/bellingcat/whisperbox-transcribe.git
synced 2026-06-11 13:08:35 +03:00
chore: rename dtos.py => schemas.py
This commit is contained in:
@@ -6,7 +6,7 @@ from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.ext.declarative import declarative_base, declared_attr
|
||||
from sqlalchemy.orm import Mapped, declarative_mixin # type: ignore
|
||||
|
||||
from .dtos import ArtifactType, JobStatus, JobType
|
||||
from .schemas import ArtifactType, JobStatus, JobType
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
@@ -4,9 +4,9 @@ import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import app.shared.db.dtos as dtos
|
||||
import app.shared.db.schemas as schemas
|
||||
import app.shared.db.models as models
|
||||
from app.shared.db.dtos import JobStatus, JobType
|
||||
from app.shared.db.schemas import JobStatus, JobType
|
||||
from app.web.main import app
|
||||
|
||||
client = TestClient(app)
|
||||
@@ -86,7 +86,7 @@ def test_get_artifact_pass(
|
||||
auth_headers: Dict[str, str], db_session: Session, mock_job: models.Job
|
||||
) -> None:
|
||||
artifact = models.Artifact(
|
||||
data=[], job_id=mock_job.id, type=dtos.ArtifactType.raw_transcript
|
||||
data=[], job_id=mock_job.id, type=schemas.ArtifactType.raw_transcript
|
||||
)
|
||||
|
||||
db_session.add(artifact)
|
||||
|
||||
@@ -7,7 +7,7 @@ from pydantic import AnyHttpUrl, BaseModel
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import app.shared.db.dtos as dtos
|
||||
import app.shared.db.schemas as schemas
|
||||
import app.shared.db.models as models
|
||||
from app.shared.celery import get_celery_binding
|
||||
from app.shared.db.base import get_session
|
||||
@@ -35,11 +35,11 @@ def api_root() -> None:
|
||||
|
||||
class PostJobPayload(BaseModel):
|
||||
url: AnyHttpUrl
|
||||
type: dtos.JobType
|
||||
type: schemas.JobType
|
||||
language: Optional[str]
|
||||
|
||||
|
||||
@api_router.post("/jobs", response_model=dtos.Job, status_code=201)
|
||||
@api_router.post("/jobs", response_model=schemas.Job, status_code=201)
|
||||
def create_job(
|
||||
payload: PostJobPayload,
|
||||
session: Session = Depends(get_session),
|
||||
@@ -47,7 +47,7 @@ def create_job(
|
||||
# create a job with status "create" and save it to the database.
|
||||
job = models.Job(
|
||||
url=payload.url,
|
||||
status=dtos.JobStatus.create,
|
||||
status=schemas.JobStatus.create,
|
||||
type=payload.type,
|
||||
config={"language": payload.language} if payload.language else None,
|
||||
)
|
||||
@@ -65,9 +65,9 @@ def create_job(
|
||||
return job
|
||||
|
||||
|
||||
@api_router.get("/jobs", response_model=List[dtos.Job])
|
||||
@api_router.get("/jobs", response_model=List[schemas.Job])
|
||||
def get_transcripts(
|
||||
type: Optional[dtos.JobType] = None, session: Session = Depends(get_session)
|
||||
type: Optional[schemas.JobType] = None, session: Session = Depends(get_session)
|
||||
) -> List[models.Job]:
|
||||
query = session.query(models.Job)
|
||||
|
||||
@@ -77,7 +77,7 @@ def get_transcripts(
|
||||
return query.all()
|
||||
|
||||
|
||||
@api_router.get("/jobs/{id}", response_model=dtos.Job)
|
||||
@api_router.get("/jobs/{id}", response_model=schemas.Job)
|
||||
def get_transcript(
|
||||
id: UUID = Path(), session: Session = Depends(get_session)
|
||||
) -> Optional[models.Job]:
|
||||
@@ -87,7 +87,7 @@ def get_transcript(
|
||||
return job
|
||||
|
||||
|
||||
@api_router.get("/jobs/{id}/artifacts", response_model=List[dtos.Artifact])
|
||||
@api_router.get("/jobs/{id}/artifacts", response_model=List[schemas.Artifact])
|
||||
def get_artifacts_for_job(
|
||||
id: UUID = Path(), session: Session = Depends(get_session)
|
||||
) -> List[models.Artifact]:
|
||||
@@ -125,8 +125,8 @@ def on_startup() -> None:
|
||||
session.query(models.Job)
|
||||
.filter(
|
||||
or_(
|
||||
models.Job.status == dtos.JobStatus.processing,
|
||||
models.Job.status == dtos.JobStatus.create,
|
||||
models.Job.status == schemas.JobStatus.processing,
|
||||
models.Job.status == schemas.JobStatus.create,
|
||||
)
|
||||
)
|
||||
.order_by(models.Job.created_at)
|
||||
|
||||
@@ -4,7 +4,7 @@ from uuid import UUID
|
||||
from celery import Task
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import app.shared.db.dtos as dtos
|
||||
import app.shared.db.schemas as schemas
|
||||
import app.shared.db.models as models
|
||||
from app.shared.celery import get_celery_binding
|
||||
from app.shared.db.base import SessionLocal
|
||||
@@ -19,52 +19,52 @@ def transcribe(self: Task, job_id: UUID) -> None:
|
||||
db: Session = SessionLocal()
|
||||
job = db.query(models.Job).filter(models.Job.id == job_id).one()
|
||||
|
||||
if job.status == dtos.JobStatus.error or job.status == dtos.JobStatus.success:
|
||||
if job.status == schemas.JobStatus.error or job.status == schemas.JobStatus.success:
|
||||
logger.warn(
|
||||
"[{job.id}]: Received job that has already been processed, abort."
|
||||
)
|
||||
return
|
||||
|
||||
job.meta = {"task_id": self.request.id}
|
||||
job.status = dtos.JobStatus.processing
|
||||
job.status = schemas.JobStatus.processing
|
||||
db.commit()
|
||||
|
||||
logger.info(f"[{job.id}]: set task to status processing.")
|
||||
|
||||
# pick a transcription strategy.
|
||||
# currently only `local` is supported.
|
||||
job_record = dtos.Job.from_orm(job)
|
||||
job_record = schemas.Job.from_orm(job)
|
||||
strategy = LocalStrategy(
|
||||
db=db, job_id=job.id, url=job_record.url, config=job_record.config
|
||||
)
|
||||
|
||||
# process selected task.
|
||||
# currently only `transcribe` is supported.
|
||||
if job.type == dtos.JobType.transcript:
|
||||
if job.type == schemas.JobType.transcript:
|
||||
result = strategy.transcribe()
|
||||
logger.info(f"[{job.id}]: successfully transcribed audio.")
|
||||
elif job.type == dtos.JobType.translation:
|
||||
elif job.type == schemas.JobType.translation:
|
||||
result = strategy.translate()
|
||||
logger.info(f"[{job.id}]: successfully translated audio.")
|
||||
else:
|
||||
result = strategy.detect_language()
|
||||
|
||||
artifact = models.Artifact(
|
||||
job_id=str(job.id), data=result, type=dtos.ArtifactType.raw_transcript
|
||||
job_id=str(job.id), data=result, type=schemas.ArtifactType.raw_transcript
|
||||
)
|
||||
|
||||
db.add(artifact)
|
||||
db.commit()
|
||||
logger.info(f"[{job.id}]: successfully stored artifact.")
|
||||
|
||||
job.status = dtos.JobStatus.success
|
||||
job.status = schemas.JobStatus.success
|
||||
db.commit()
|
||||
|
||||
logger.info(f"[{job.id}]: set task to status success.")
|
||||
except Exception as e:
|
||||
if job and db:
|
||||
job.meta = {**job.meta.__dict__, "error": str(e)}
|
||||
job.status = dtos.JobStatus.error
|
||||
job.status = schemas.JobStatus.error
|
||||
db.commit()
|
||||
raise (e)
|
||||
finally:
|
||||
|
||||
@@ -10,7 +10,7 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from whisper import load_model
|
||||
|
||||
import app.shared.db.dtos as dtos
|
||||
import app.shared.db.schemas as schemas
|
||||
|
||||
|
||||
class DecodeOptions(BaseModel):
|
||||
@@ -20,7 +20,7 @@ class DecodeOptions(BaseModel):
|
||||
|
||||
class LocalStrategy:
|
||||
def __init__(
|
||||
self, db: Session, job_id: UUID, url: str, config: Optional[dtos.JobConfig]
|
||||
self, db: Session, job_id: UUID, url: str, config: Optional[schemas.JobConfig]
|
||||
):
|
||||
self.db = db
|
||||
self.job_id = job_id
|
||||
|
||||
Reference in New Issue
Block a user