refactor: remove shared schemas

This commit is contained in:
Felix Spöttel
2023-06-29 12:38:56 +02:00
parent 908bd48170
commit b3f8d5c82a
8 changed files with 204 additions and 201 deletions

View File

@@ -1,25 +1,40 @@
from pydantic import AnyHttpUrl, BaseModel, Field
from datetime import datetime
from uuid import UUID
import app.shared.db.models as models
from pydantic import AnyHttpUrl, BaseModel
from app.shared.db.models import (
ArtifactData,
ArtifactType,
JobConfig,
JobMeta,
JobStatus,
JobType,
)
# DB objects
class PostJobPayload(BaseModel):
url: AnyHttpUrl = Field(
description=(
"URL where the media file is available. This needs to be a direct link."
)
)
class WithDbFields(BaseModel):
id: UUID
created_at: datetime
updated_at: datetime | None
type: models.JobType = Field(
description="""Type of this job.
`transcript` uses the original language of the audio.
`translation` creates an automatic translation to english.
`language_detection` detects language from the first 30 seconds of audio."""
)
class Config:
orm_mode = True
language: str | None = Field(
description=(
"Spoken language in the media file. "
"While optional, this can improve output when set."
)
)
class Job(WithDbFields):
"""A transcription job for one media file."""
status: JobStatus
type: JobType
url: AnyHttpUrl
meta: JobMeta | None
config: JobConfig | None
class Artifact(WithDbFields):
job_id: UUID
data: ArtifactData
type: ArtifactType

View File

@@ -3,12 +3,12 @@ from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path
from pydantic import AnyHttpUrl, BaseModel, Field
from sqlalchemy.orm import Session
import app.shared.db.models as models
import app.shared.db.schemas as schemas
import app.web.dtos as dtos
from app.shared.db.base import SessionLocal, get_session
from app.web.dtos import PostJobPayload
from app.web.security import authenticate_api_key
from app.web.task_queue import task_queue
@@ -40,54 +40,15 @@ def api_root() -> None:
return None
@api_router.post(
"/jobs",
response_model=schemas.Job,
status_code=201,
summary="Enqueue a new job",
)
def create_job(
payload: PostJobPayload,
session: DatabaseSession,
) -> models.Job:
"""
Enqueue a new whisper job for processing.
Notes:
* Jobs are processed one-by-one in order of creation.
* `payload.url` needs to point directly to a media file.
* The media file is downloaded to a tmp file for the duration of processing.
enough free space needs to be available on disk.
* Media files ideally are audio files with a sampling rate of 16kHz.
other files will be transcoded automatically via ffmpeg which might
consume considerable resources while active.
* Once a job is created, you can query its status by its id.
"""
# create a job with status "create" and save it to the database.
job = models.Job(
url=payload.url,
status=schemas.JobStatus.create,
type=payload.type,
config={"language": payload.language} if payload.language else None,
)
session.add(job)
session.commit()
task_queue.queue_task(job)
return job
@api_router.get(
"/jobs", response_model=list[schemas.Job], summary="Get metadata for all jobs"
"/jobs", response_model=list[dtos.Job], summary="Get metadata for all jobs"
)
def get_transcripts(
session: DatabaseSession,
type: schemas.JobType | None = None,
type: dtos.JobType | None = None,
) -> list[models.Job]:
"""Get metadata for all jobs."""
query = session.query(models.Job)
query = session.query(models.Job).order_by(models.Job.created_at.desc())
if type:
query = query.filter(models.Job.type == type)
@@ -97,7 +58,7 @@ def get_transcripts(
@api_router.get(
"/jobs/{id}",
response_model=schemas.Job,
response_model=dtos.Job,
summary="Get metadata for one job",
)
def get_transcript(
@@ -117,7 +78,7 @@ def get_transcript(
@api_router.get(
"/jobs/{id}/artifacts",
response_model=list[schemas.Artifact],
response_model=list[dtos.Artifact],
summary="Get all artifacts for one job",
)
def get_artifacts_for_job(
@@ -148,4 +109,65 @@ def delete_transcript(
return None
class PostJobPayload(BaseModel):
url: AnyHttpUrl = Field(
description=(
"URL where the media file is available. This needs to be a direct link."
)
)
type: models.JobType = Field(
description="""Type of this job.
`transcript` uses the original language of the audio.
`translation` creates an automatic translation to english.
`language_detection` detects language from the first 30 seconds of audio."""
)
language: str | None = Field(
description=(
"Spoken language in the media file. "
"While optional, this can improve output when set."
)
)
@api_router.post(
"/jobs",
response_model=dtos.Job,
status_code=201,
summary="Enqueue a new job",
)
def create_job(
payload: PostJobPayload,
session: DatabaseSession,
) -> models.Job:
"""
Enqueue a new whisper job for processing.
Notes:
* Jobs are processed one-by-one in order of creation.
* `payload.url` needs to point directly to a media file.
* The media file is downloaded to a tmp file for the duration of processing.
enough free space needs to be available on disk.
* Media files ideally are audio files with a sampling rate of 16kHz.
other files will be transcoded automatically via ffmpeg which might
consume considerable resources while active.
* Once a job is created, you can query its status by its id.
"""
# create a job with status "create" and save it to the database.
job = models.Job(
url=payload.url,
status=dtos.JobStatus.create,
type=payload.type,
config={"language": payload.language} if payload.language else None,
)
session.add(job)
session.commit()
task_queue.queue_task(job)
return job
app.include_router(api_router)

View File

@@ -15,9 +15,10 @@ class TaskQueue:
self.celery = get_celery_binding()
def queue_task(self, job: models.Job):
# queue an async transcription task.
# we use a signature here to allow full separation of
# worker processes and dependencies.
"""
Queues an async transcription job. We use a celery signature here to
allow for full separation of worker processes and dependencies.
"""
transcribe = self.celery.signature("app.worker.main.transcribe")
transcribe.delay(job.id)