diff --git a/app/shared/db/models.py b/app/shared/db/models.py index f92a9f5..c25fa93 100644 --- a/app/shared/db/models.py +++ b/app/shared/db/models.py @@ -6,7 +6,7 @@ from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.ext.declarative import declarative_base, declared_attr from sqlalchemy.orm import Mapped, declarative_mixin # type: ignore -from .dtos import ArtifactType, JobStatus, JobType +from .schemas import ArtifactType, JobStatus, JobType Base = declarative_base() diff --git a/app/shared/db/dtos.py b/app/shared/db/schemas.py similarity index 100% rename from app/shared/db/dtos.py rename to app/shared/db/schemas.py diff --git a/app/tests/test_api.py b/app/tests/test_api.py index 8bdc560..9dd073d 100644 --- a/app/tests/test_api.py +++ b/app/tests/test_api.py @@ -4,9 +4,9 @@ import pytest from fastapi.testclient import TestClient from sqlalchemy.orm import Session -import app.shared.db.dtos as dtos +import app.shared.db.schemas as schemas import app.shared.db.models as models -from app.shared.db.dtos import JobStatus, JobType +from app.shared.db.schemas import JobStatus, JobType from app.web.main import app client = TestClient(app) @@ -86,7 +86,7 @@ def test_get_artifact_pass( auth_headers: Dict[str, str], db_session: Session, mock_job: models.Job ) -> None: artifact = models.Artifact( - data=[], job_id=mock_job.id, type=dtos.ArtifactType.raw_transcript + data=[], job_id=mock_job.id, type=schemas.ArtifactType.raw_transcript ) db_session.add(artifact) diff --git a/app/web/main.py b/app/web/main.py index bd45949..fd3c903 100644 --- a/app/web/main.py +++ b/app/web/main.py @@ -7,7 +7,7 @@ from pydantic import AnyHttpUrl, BaseModel from sqlalchemy import or_ from sqlalchemy.orm import Session -import app.shared.db.dtos as dtos +import app.shared.db.schemas as schemas import app.shared.db.models as models from app.shared.celery import get_celery_binding from app.shared.db.base import get_session @@ -35,11 +35,11 @@ def api_root() -> None: class PostJobPayload(BaseModel): url: AnyHttpUrl - type: dtos.JobType + type: schemas.JobType language: Optional[str] -@api_router.post("/jobs", response_model=dtos.Job, status_code=201) +@api_router.post("/jobs", response_model=schemas.Job, status_code=201) def create_job( payload: PostJobPayload, session: Session = Depends(get_session), @@ -47,7 +47,7 @@ def create_job( # create a job with status "create" and save it to the database. job = models.Job( url=payload.url, - status=dtos.JobStatus.create, + status=schemas.JobStatus.create, type=payload.type, config={"language": payload.language} if payload.language else None, ) @@ -65,9 +65,9 @@ def create_job( return job -@api_router.get("/jobs", response_model=List[dtos.Job]) +@api_router.get("/jobs", response_model=List[schemas.Job]) def get_transcripts( - type: Optional[dtos.JobType] = None, session: Session = Depends(get_session) + type: Optional[schemas.JobType] = None, session: Session = Depends(get_session) ) -> List[models.Job]: query = session.query(models.Job) @@ -77,7 +77,7 @@ def get_transcripts( return query.all() -@api_router.get("/jobs/{id}", response_model=dtos.Job) +@api_router.get("/jobs/{id}", response_model=schemas.Job) def get_transcript( id: UUID = Path(), session: Session = Depends(get_session) ) -> Optional[models.Job]: @@ -87,7 +87,7 @@ def get_transcript( return job -@api_router.get("/jobs/{id}/artifacts", response_model=List[dtos.Artifact]) +@api_router.get("/jobs/{id}/artifacts", response_model=List[schemas.Artifact]) def get_artifacts_for_job( id: UUID = Path(), session: Session = Depends(get_session) ) -> List[models.Artifact]: @@ -125,8 +125,8 @@ def on_startup() -> None: session.query(models.Job) .filter( or_( - models.Job.status == dtos.JobStatus.processing, - models.Job.status == dtos.JobStatus.create, + models.Job.status == schemas.JobStatus.processing, + models.Job.status == schemas.JobStatus.create, ) ) .order_by(models.Job.created_at) diff --git a/app/worker/main.py b/app/worker/main.py index a904ab0..b410c31 100644 --- a/app/worker/main.py +++ b/app/worker/main.py @@ -4,7 +4,7 @@ from uuid import UUID from celery import Task from sqlalchemy.orm import Session -import app.shared.db.dtos as dtos +import app.shared.db.schemas as schemas import app.shared.db.models as models from app.shared.celery import get_celery_binding from app.shared.db.base import SessionLocal @@ -19,52 +19,52 @@ def transcribe(self: Task, job_id: UUID) -> None: db: Session = SessionLocal() job = db.query(models.Job).filter(models.Job.id == job_id).one() - if job.status == dtos.JobStatus.error or job.status == dtos.JobStatus.success: + if job.status == schemas.JobStatus.error or job.status == schemas.JobStatus.success: logger.warn( "[{job.id}]: Received job that has already been processed, abort." ) return job.meta = {"task_id": self.request.id} - job.status = dtos.JobStatus.processing + job.status = schemas.JobStatus.processing db.commit() logger.info(f"[{job.id}]: set task to status processing.") # pick a transcription strategy. # currently only `local` is supported. - job_record = dtos.Job.from_orm(job) + job_record = schemas.Job.from_orm(job) strategy = LocalStrategy( db=db, job_id=job.id, url=job_record.url, config=job_record.config ) # process selected task. # currently only `transcribe` is supported. - if job.type == dtos.JobType.transcript: + if job.type == schemas.JobType.transcript: result = strategy.transcribe() logger.info(f"[{job.id}]: successfully transcribed audio.") - elif job.type == dtos.JobType.translation: + elif job.type == schemas.JobType.translation: result = strategy.translate() logger.info(f"[{job.id}]: successfully translated audio.") else: result = strategy.detect_language() artifact = models.Artifact( - job_id=str(job.id), data=result, type=dtos.ArtifactType.raw_transcript + job_id=str(job.id), data=result, type=schemas.ArtifactType.raw_transcript ) db.add(artifact) db.commit() logger.info(f"[{job.id}]: successfully stored artifact.") - job.status = dtos.JobStatus.success + job.status = schemas.JobStatus.success db.commit() logger.info(f"[{job.id}]: set task to status success.") except Exception as e: if job and db: job.meta = {**job.meta.__dict__, "error": str(e)} - job.status = dtos.JobStatus.error + job.status = schemas.JobStatus.error db.commit() raise (e) finally: diff --git a/app/worker/strategies/local.py b/app/worker/strategies/local.py index f1dd22a..d2eb0f6 100644 --- a/app/worker/strategies/local.py +++ b/app/worker/strategies/local.py @@ -10,7 +10,7 @@ from pydantic import BaseModel from sqlalchemy.orm import Session from whisper import load_model -import app.shared.db.dtos as dtos +import app.shared.db.schemas as schemas class DecodeOptions(BaseModel): @@ -20,7 +20,7 @@ class DecodeOptions(BaseModel): class LocalStrategy: def __init__( - self, db: Session, job_id: UUID, url: str, config: Optional[dtos.JobConfig] + self, db: Session, job_id: UUID, url: str, config: Optional[schemas.JobConfig] ): self.db = db self.job_id = job_id