mirror of
https://github.com/bellingcat/whisperbox-transcribe.git
synced 2026-06-12 21:48:35 +03:00
feat: add celery job queue
This commit is contained in:
@@ -5,8 +5,8 @@ Revises:
|
||||
Create Date: 2023-01-05 12:00:58.824773
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
@@ -26,9 +26,7 @@ def upgrade() -> None:
|
||||
sa.Enum("Create", "Error", "Success", name="jobstatus"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"type", sa.Enum("Transcript", name="jobtype"), nullable=False
|
||||
),
|
||||
sa.Column("type", sa.Enum("Transcript", name="jobtype"), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(),
|
||||
|
||||
@@ -4,13 +4,10 @@ from pydantic import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
API_SECRET: str
|
||||
DATABASE_URI: str
|
||||
ENVIRONMENT: str
|
||||
API_SECRET: str
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
REDIS_URI: str
|
||||
|
||||
|
||||
if "ENVIRONMENT" in os.environ and os.environ["ENVIRONMENT"] == "test":
|
||||
|
||||
29
app/main.py
29
app/main.py
@@ -4,12 +4,11 @@ from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path
|
||||
from pydantic import AnyHttpUrl, BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db.base import get_session
|
||||
|
||||
from app.db.dtos import Job, JobStatus, JobType
|
||||
import app.db.dtos as dtos
|
||||
import app.db.models as models
|
||||
|
||||
from .security import authenticate_api_key
|
||||
from app.db.base import get_session
|
||||
from app.utils.security import authenticate_api_key
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@@ -25,29 +24,35 @@ class TranscriptPayload(BaseModel):
|
||||
url: AnyHttpUrl
|
||||
|
||||
|
||||
@api_router.post("/transcripts", response_model=Job)
|
||||
@api_router.post("/transcripts", response_model=dtos.Job)
|
||||
def create_transcript(
|
||||
payload: TranscriptPayload, session: Session = Depends(get_session)
|
||||
) -> models.Job:
|
||||
job = models.Job(url=payload.url, status=JobStatus.Create, type=JobType.Transcript)
|
||||
job = models.Job(
|
||||
url=payload.url, status=dtos.JobStatus.Create, type=dtos.JobType.Transcript
|
||||
)
|
||||
session.add(job)
|
||||
session.flush()
|
||||
return job
|
||||
|
||||
|
||||
@api_router.get("/transcripts", response_model=List[Job])
|
||||
@api_router.get("/transcripts", response_model=List[dtos.Job])
|
||||
def get_transcripts(session: Session = Depends(get_session)) -> List[models.Job]:
|
||||
return session.query(models.Job).filter(models.Job.type == JobType.Transcript).all()
|
||||
return (
|
||||
session.query(models.Job)
|
||||
.filter(models.Job.type == dtos.JobType.Transcript)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
@api_router.get("/transcripts/{id}", response_model=Job)
|
||||
@api_router.get("/transcripts/{id}", response_model=dtos.Job)
|
||||
def get_transcript(
|
||||
id: UUID = Path(), session: Session = Depends(get_session)
|
||||
) -> Optional[Job]:
|
||||
) -> Optional[dtos.Job]:
|
||||
job = (
|
||||
session.query(models.Job)
|
||||
.filter(models.Job.id == id)
|
||||
.filter(models.Job.type == JobType.Transcript)
|
||||
.filter(models.Job.type == dtos.JobType.Transcript)
|
||||
.one_or_none()
|
||||
)
|
||||
if not job:
|
||||
@@ -60,7 +65,7 @@ def delete_transcript(
|
||||
id: UUID = Path(), session: Session = Depends(get_session)
|
||||
) -> None:
|
||||
session.query(models.Job).filter(models.Job.id == id).filter(
|
||||
models.Job.type == JobType.Transcript
|
||||
models.Job.type == dtos.JobType.Transcript
|
||||
).delete()
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.main import app
|
||||
import app.db.models as models
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_create_task(db_session: Session) -> None:
|
||||
jobs = db_session.query(models.Job).all()
|
||||
assert len(jobs) == 0
|
||||
0
app/utils/__init__.py
Normal file
0
app/utils/__init__.py
Normal file
7
app/worker.py
Normal file
7
app/worker.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from celery import Celery
|
||||
|
||||
from .config import settings
|
||||
|
||||
celery = Celery(__name__)
|
||||
|
||||
celery.conf.broker_url = settings.REDIS_URI
|
||||
Reference in New Issue
Block a user