mirror of
https://github.com/bellingcat/whisperbox-transcribe.git
synced 2026-06-12 21:48:35 +03:00
feat: add language detection task
This commit is contained in:
@@ -1,15 +1,15 @@
|
||||
"""add_job_tables
|
||||
"""add_tables
|
||||
|
||||
Revision ID: dc8582aea0bc
|
||||
Revision ID: 0eee2b7913b7
|
||||
Revises:
|
||||
Create Date: 2023-02-08 12:12:00.808816
|
||||
Create Date: 2023-06-29 08:33:26.123728
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "dc8582aea0bc"
|
||||
revision = "0eee2b7913b7"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
@@ -54,7 +54,7 @@ def upgrade() -> None:
|
||||
sa.Column("data", sa.JSON(none_as_null=True), nullable=True),
|
||||
sa.Column(
|
||||
"type",
|
||||
sa.Enum("raw_transcript", name="artifacttype"),
|
||||
sa.Enum("raw_transcript", "language_detection", name="artifacttype"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
@@ -19,11 +19,8 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
db: Session = SessionLocal()
|
||||
session: Session = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
db.commit()
|
||||
except Exception:
|
||||
db.rollback()
|
||||
yield session
|
||||
finally:
|
||||
db.close()
|
||||
session.close()
|
||||
|
||||
@@ -1,13 +1,39 @@
|
||||
import enum
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import JSON, VARCHAR, Column, DateTime, Enum, ForeignKey, String, func
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, declarative_base, declarative_mixin, declared_attr
|
||||
|
||||
from .schemas import ArtifactType, JobStatus, JobType
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
# Enums
|
||||
|
||||
|
||||
class JobType(str, enum.Enum):
|
||||
"""Requested type of a job."""
|
||||
|
||||
transcript = "transcribe"
|
||||
translation = "translate"
|
||||
language_detection = "detect_language"
|
||||
|
||||
|
||||
class JobStatus(str, enum.Enum):
|
||||
"""Processing status of a job."""
|
||||
|
||||
create = "create"
|
||||
processing = "processing"
|
||||
error = "error"
|
||||
success = "success"
|
||||
|
||||
|
||||
class ArtifactType(str, enum.Enum):
|
||||
raw_transcript = "transcript_raw"
|
||||
language_detection = "language_detection"
|
||||
|
||||
|
||||
# SQLAlchemy models
|
||||
|
||||
|
||||
@declarative_mixin
|
||||
class WithStandardFields:
|
||||
|
||||
@@ -1,42 +1,16 @@
|
||||
import enum
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import AnyHttpUrl, BaseModel, Field
|
||||
|
||||
from app.shared.db.models import ArtifactType, JobStatus, JobType
|
||||
|
||||
class WithDbFields(BaseModel):
|
||||
id: UUID
|
||||
created_at: datetime
|
||||
updated_at: datetime | None
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
|
||||
class ArtifactType(str, enum.Enum):
|
||||
raw_transcript = "raw_transcript"
|
||||
|
||||
|
||||
class JobType(str, enum.Enum):
|
||||
transcript = "transcript"
|
||||
translation = "translation"
|
||||
language_detection = "language_detection"
|
||||
|
||||
|
||||
class JobStatus(str, enum.Enum):
|
||||
"""Processing status of a job."""
|
||||
|
||||
create = "create"
|
||||
processing = "processing"
|
||||
error = "error"
|
||||
success = "success"
|
||||
# JSON field types
|
||||
|
||||
|
||||
class JobConfig(BaseModel):
|
||||
"""Configuration for a job."""
|
||||
|
||||
# TODO: limit to locales selected by whisper.
|
||||
language: str | None = Field(
|
||||
description=(
|
||||
"Spoken language in the media file. "
|
||||
@@ -51,21 +25,12 @@ class JobMeta(BaseModel):
|
||||
error: str | None = Field(
|
||||
description="Will contain a descriptive error message if processing failed."
|
||||
)
|
||||
|
||||
task_id: UUID | None = Field(
|
||||
description="Internal celery id of this job submission."
|
||||
)
|
||||
|
||||
|
||||
class Job(WithDbFields):
|
||||
"""A transcription job for one media file."""
|
||||
|
||||
status: JobStatus
|
||||
type: JobType
|
||||
url: AnyHttpUrl
|
||||
meta: JobMeta | None
|
||||
config: JobConfig | None
|
||||
|
||||
|
||||
class RawTranscript(BaseModel):
|
||||
"""A single transcript passage returned by whisper."""
|
||||
|
||||
@@ -81,9 +46,35 @@ class RawTranscript(BaseModel):
|
||||
no_speech_prob: float
|
||||
|
||||
|
||||
class Artifact(WithDbFields):
|
||||
"""whisper output for one job."""
|
||||
class LanguageDetection(BaseModel):
|
||||
"""A language detection"""
|
||||
|
||||
data: list[RawTranscript] | None
|
||||
code: str
|
||||
|
||||
|
||||
# DB objects
|
||||
|
||||
|
||||
class WithDbFields(BaseModel):
|
||||
id: UUID
|
||||
created_at: datetime
|
||||
updated_at: datetime | None
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
|
||||
class Job(WithDbFields):
|
||||
"""A transcription job for one media file."""
|
||||
|
||||
status: JobStatus
|
||||
type: JobType
|
||||
url: AnyHttpUrl
|
||||
meta: JobMeta | None
|
||||
config: JobConfig | None
|
||||
|
||||
|
||||
class Artifact(WithDbFields):
|
||||
job_id: UUID
|
||||
data: LanguageDetection | RawTranscript | None
|
||||
type: ArtifactType
|
||||
|
||||
Reference in New Issue
Block a user