feat: postgres => sqlite

This commit is contained in:
Felix Spöttel
2023-02-08 14:09:16 +01:00
parent e995b1f2ff
commit 18921d34c6
16 changed files with 148 additions and 159 deletions

View File

@@ -1,16 +1,15 @@
"""add_job_tables
Revision ID: 426b6bdc3360
Revision ID: dc8582aea0bc
Revises:
Create Date: 2023-01-27 17:55:21.758828
Create Date: 2023-02-08 12:12:00.808816
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "426b6bdc3360"
revision = "dc8582aea0bc"
down_revision = None
branch_labels = None
depends_on = None
@@ -30,23 +29,28 @@ def upgrade() -> None:
sa.Column("meta", sa.JSON(none_as_null=True), nullable=True),
sa.Column(
"type",
sa.Enum("transcript", "translation", "language_detection", name="jobtype"),
sa.Enum(
"transcript",
"translation",
"language_detection",
name="jobtype",
),
nullable=False,
),
sa.Column(
"created_at",
sa.DateTime(),
server_default=sa.text("now()"),
server_default=sa.text("(CURRENT_TIMESTAMP)"),
nullable=False,
),
sa.Column("updated_at", sa.DateTime(), nullable=True),
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("id", sa.VARCHAR(length=36), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_jobs_id"), "jobs", ["id"], unique=False)
op.create_table(
"artifacts",
sa.Column("job_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("job_id", sa.VARCHAR(length=36), nullable=False),
sa.Column("data", sa.JSON(none_as_null=True), nullable=True),
sa.Column(
"type",
@@ -56,11 +60,11 @@ def upgrade() -> None:
sa.Column(
"created_at",
sa.DateTime(),
server_default=sa.text("now()"),
server_default=sa.text("(CURRENT_TIMESTAMP)"),
nullable=False,
),
sa.Column("updated_at", sa.DateTime(), nullable=True),
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("id", sa.VARCHAR(length=36), nullable=False),
sa.ForeignKeyConstraint(["job_id"], ["jobs.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)

View File

@@ -5,7 +5,7 @@ from sqlalchemy.orm import Session, sessionmaker
from app.shared.config import settings
engine = create_engine(settings.DATABASE_URI)
engine = create_engine(settings.DATABASE_URI, connect_args={"check_same_thread": False})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

View File

@@ -1,7 +1,7 @@
import uuid
from typing import Optional
from sqlalchemy import JSON, Column, DateTime, Enum, ForeignKey, String, func
from sqlalchemy import JSON, VARCHAR, Column, DateTime, Enum, ForeignKey, String, func
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.orm import Mapped, declarative_mixin # type: ignore
@@ -26,7 +26,7 @@ class WithStandardFields:
@declared_attr
def id(cls) -> Mapped[UUID]:
return Column(
UUID(as_uuid=True), primary_key=True, index=True, default=uuid.uuid4
VARCHAR(36), primary_key=True, index=True, default=lambda: str(uuid.uuid4())
)
@@ -44,7 +44,9 @@ class Artifact(Base, WithStandardFields):
__tablename__ = "artifacts"
job_id = Column(
UUID(as_uuid=True), ForeignKey("jobs.id", ondelete="CASCADE"), nullable=False
VARCHAR(36),
ForeignKey("jobs.id", ondelete="CASCADE"),
nullable=False,
)
data = Column(JSON(none_as_null=True))

View File

@@ -7,11 +7,10 @@ from sqlalchemy.orm import Session
import app.shared.db.dtos as dtos
import app.shared.db.models as models
from app.shared.db.dtos import JobStatus, JobType
from app.web.main import app
from app.web.main import app, celery
client = TestClient(app)
@pytest.fixture(name="mock_job", scope="function", autouse=False)
def mock_job(db_session: Session) -> models.Job:
job = models.Job(

View File

@@ -81,7 +81,7 @@ def get_transcripts(
def get_transcript(
id: UUID = Path(), session: Session = Depends(get_session)
) -> Optional[models.Job]:
job = session.query(models.Job).filter(models.Job.id == id).one_or_none()
job = session.query(models.Job).filter(models.Job.id == str(id)).one_or_none()
if not job:
raise HTTPException(status_code=404)
return job
@@ -92,7 +92,7 @@ def get_artifacts_for_job(
id: UUID = Path(), session: Session = Depends(get_session)
) -> List[models.Artifact]:
artifacts = (
session.query(models.Artifact).filter(models.Artifact.job_id == id)
session.query(models.Artifact).filter(models.Artifact.job_id == str(id))
).all()
if not len(artifacts):
@@ -105,14 +105,14 @@ def get_artifacts_for_job(
def delete_transcript(
id: UUID = Path(), session: Session = Depends(get_session)
) -> None:
session.query(models.Job).filter(models.Job.id == id).delete()
session.query(models.Job).filter(models.Job.id == str(id)).delete()
return None
app.include_router(api_router)
# TODO:
# we could use `acks_late` to handle this scenario within celery itself.
# TODO: we could use `acks_late` to handle this scenario within celery itself.
# the reason this does not work well in our case is that `visibility_timeout`
# needs to be very high since whisper workers can be long running.
# doing this application-side bears the risk of poison pilling the worker though,
@@ -123,10 +123,15 @@ def on_startup() -> None:
jobs = (
session.query(models.Job)
.filter(or_(models.Job.status == dtos.JobStatus.processing, models.Job.status == dtos.JobStatus.create))
.order_by(models.Job.created_at)
.filter(
or_(
models.Job.status == dtos.JobStatus.processing,
models.Job.status == dtos.JobStatus.create,
)
)
.order_by(models.Job.created_at)
).all()
logger.info(f"Re-queueing {len(jobs)} jobs.")
logger.info(f"Requeueing {len(jobs)} jobs.")
for job in jobs:
queue_task(job)

View File

@@ -1,9 +0,0 @@
#! /usr/bin/env bash
set -e
# run migrations
alembic upgrade head
# start app
uvicorn app.web.main:app --reload --host ${HOST:-0.0.0.0} --port ${PORT:-80} --log-level info

View File

@@ -13,17 +13,16 @@ from app.worker.strategies.local import LocalStrategy
celery = get_celery_binding()
@celery.task(
bind=True,
soft_time_limit=2 * 60 * 60 # TODO: make configurable
)
@celery.task(bind=True, soft_time_limit=2 * 60 * 60) # TODO: make configurable
def transcribe(self: Task, job_id: UUID) -> None:
try:
db: Session = SessionLocal()
job = db.query(models.Job).filter(models.Job.id == job_id).one()
if job.status == dtos.JobStatus.error or job.status == dtos.JobStatus.success:
logger.warn("[{job.id}]: Received job that has already been processed, abort.")
logger.warn(
"[{job.id}]: Received job that has already been processed, abort."
)
return
job.meta = {"task_id": self.request.id}
@@ -51,7 +50,7 @@ def transcribe(self: Task, job_id: UUID) -> None:
result = strategy.detect_language()
artifact = models.Artifact(
job_id=job.id, data=result, type=dtos.ArtifactType.raw_transcript
job_id=str(job.id), data=result, type=dtos.ArtifactType.raw_transcript
)
db.add(artifact)
@@ -64,9 +63,9 @@ def transcribe(self: Task, job_id: UUID) -> None:
logger.info(f"[{job.id}]: set task to status success.")
except Exception as e:
if job and db:
job.meta = { **job.meta, "error": str(e) }
job.meta = {**job.meta.__dict__, "error": str(e)}
job.status = dtos.JobStatus.error
db.commit()
raise(e)
raise (e)
finally:
db.close()