diff --git a/.env.test b/.env.test index 04f290d..b166cd1 100644 --- a/.env.test +++ b/.env.test @@ -1,4 +1,4 @@ -DATABASE_URI="postgresql://postgres:postgres@localhost:5432/whisper_api_test" +DATABASE_URI="postgresql://postgres:postgres@localhost:5432/whisperbox_test" ENVIRONMENT="development" API_SECRET="foo" -REDIS_URI="redis://localhost:6379/0" +BROKER_URI="redis://localhost:6379/0" diff --git a/.flake8 b/.flake8 index c7c7565..5579ce6 100644 --- a/.flake8 +++ b/.flake8 @@ -1,4 +1,4 @@ [flake8] -max-line-length = 88 +max-line-length = 90 extend-ignore = E203 exclude = .git,__pycache__,__init__.py,.mypy_cache,.pytest_cache,app/alembic/versions diff --git a/.gitignore b/.gitignore index d1c340c..84f123f 100644 --- a/.gitignore +++ b/.gitignore @@ -161,3 +161,5 @@ cython_debug/ # VS Code .vscode + +.DS_Store diff --git a/Makefile b/Makefile index 53122a5..040c8f4 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,5 @@ +clean: + docker-compose -f docker/dev.docker-compose.yml down --volumes --remove-orphans dev: docker-compose -f docker/dev.docker-compose.yml build --progress tty docker-compose -f docker/dev.docker-compose.yml up --remove-orphans @@ -7,8 +9,8 @@ fmt: isort app lint: - mypy app flake8 app + mypy app test: pytest diff --git a/app/shared/config.py b/app/shared/config.py index 0c78151..214d6ac 100644 --- a/app/shared/config.py +++ b/app/shared/config.py @@ -7,7 +7,9 @@ class Settings(BaseSettings): API_SECRET: str DATABASE_URI: str ENVIRONMENT: str - REDIS_URI: str + + # derived settings + BROKER_URI: str if "pytest" in sys.modules: diff --git a/app/shared/db/alembic/versions/684a5e546314_add_job_meta_field.py b/app/shared/db/alembic/versions/684a5e546314_add_job_meta_field.py new file mode 100644 index 0000000..4337f57 --- /dev/null +++ b/app/shared/db/alembic/versions/684a5e546314_add_job_meta_field.py @@ -0,0 +1,27 @@ +"""add_job_meta_field + +Revision ID: 684a5e546314 +Revises: bb249ed79907 +Create Date: 2023-01-18 13:38:07.692830 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "684a5e546314" +down_revision = "bb249ed79907" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("jobs", sa.Column("meta", sa.JSON(none_as_null=True), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("jobs", "meta") + # ### end Alembic commands ### diff --git a/app/shared/db/alembic/versions/bb249ed79907_add_job_tables.py b/app/shared/db/alembic/versions/bb249ed79907_add_job_tables.py index 4c498ba..7fb73e3 100644 --- a/app/shared/db/alembic/versions/bb249ed79907_add_job_tables.py +++ b/app/shared/db/alembic/versions/bb249ed79907_add_job_tables.py @@ -1,16 +1,16 @@ """add_job_tables Revision ID: bb249ed79907 -Revises: +Revises: Create Date: 2023-01-17 14:30:30.920466 """ -from alembic import op import sqlalchemy as sa +from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. -revision = 'bb249ed79907' +revision = "bb249ed79907" down_revision = None branch_labels = None depends_on = None @@ -18,34 +18,46 @@ depends_on = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.create_table('jobs', - sa.Column('url', sa.String(length=2048), nullable=True), - sa.Column('status', sa.Enum('create', 'error', 'success', name='jobstatus'), nullable=False), - sa.Column('type', sa.Enum('transcript', name='jobtype'), nullable=False), - sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False), - sa.Column('updated_at', sa.DateTime(), nullable=True), - sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False), - sa.PrimaryKeyConstraint('id') + op.create_table( + "jobs", + sa.Column("url", sa.String(length=2048), nullable=True), + sa.Column( + "status", + sa.Enum("create", "error", "processing", "success", name="jobstatus"), + nullable=False, + ), + sa.Column("type", sa.Enum("transcript", name="jobtype"), nullable=False), + sa.Column( + "created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=False + ), + sa.Column("updated_at", sa.DateTime(), nullable=True), + sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False), + sa.PrimaryKeyConstraint("id"), ) - op.create_index(op.f('ix_jobs_id'), 'jobs', ['id'], unique=False) - op.create_table('artifacts', - sa.Column('job_id', postgresql.UUID(as_uuid=True), nullable=False), - sa.Column('data', sa.JSON(none_as_null=True), nullable=True), - sa.Column('type', sa.Enum('raw_transcript', name='artifacttype'), nullable=False), - sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False), - sa.Column('updated_at', sa.DateTime(), nullable=True), - sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False), - sa.ForeignKeyConstraint(['job_id'], ['jobs.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id') + op.create_index(op.f("ix_jobs_id"), "jobs", ["id"], unique=False) + op.create_table( + "artifacts", + sa.Column("job_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("data", sa.JSON(none_as_null=True), nullable=True), + sa.Column( + "type", sa.Enum("raw_transcript", name="artifacttype"), nullable=False + ), + sa.Column( + "created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=False + ), + sa.Column("updated_at", sa.DateTime(), nullable=True), + sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False), + sa.ForeignKeyConstraint(["job_id"], ["jobs.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), ) - op.create_index(op.f('ix_artifacts_id'), 'artifacts', ['id'], unique=False) + op.create_index(op.f("ix_artifacts_id"), "artifacts", ["id"], unique=False) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_index(op.f('ix_artifacts_id'), table_name='artifacts') - op.drop_table('artifacts') - op.drop_index(op.f('ix_jobs_id'), table_name='jobs') - op.drop_table('jobs') + op.drop_index(op.f("ix_artifacts_id"), table_name="artifacts") + op.drop_table("artifacts") + op.drop_index(op.f("ix_jobs_id"), table_name="jobs") + op.drop_table("jobs") # ### end Alembic commands ### diff --git a/app/shared/db/dtos.py b/app/shared/db/dtos.py index 0946d95..0213d92 100644 --- a/app/shared/db/dtos.py +++ b/app/shared/db/dtos.py @@ -3,7 +3,16 @@ from datetime import datetime from typing import Any, Optional from uuid import UUID -from pydantic import AnyHttpUrl, BaseModel, Json +from pydantic import AnyHttpUrl, BaseModel + + +class WithDbFields(BaseModel): + id: UUID + created_at: datetime + updated_at: Optional[datetime] + + class Config: + orm_mode = True class ArtifactType(str, enum.Enum): @@ -16,23 +25,21 @@ class JobType(str, enum.Enum): class JobStatus(str, enum.Enum): create = "create" + processing = "processing" error = "error" success = "success" -class WithDbFields(BaseModel): - id: UUID - created_at: datetime - updated_at: Optional[datetime] - - class Config: - orm_mode = True +class JobMeta(BaseModel): + language: Optional[str] + task_id: Optional[UUID] class Job(WithDbFields): status: JobStatus type: JobType url: AnyHttpUrl + meta: Optional[JobMeta] class Artifact(WithDbFields): diff --git a/app/shared/db/models.py b/app/shared/db/models.py index 82f3214..1ac52c6 100644 --- a/app/shared/db/models.py +++ b/app/shared/db/models.py @@ -3,8 +3,8 @@ from typing import Optional from sqlalchemy import JSON, Column, DateTime, Enum, ForeignKey, String, func from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import Mapped, declarative_mixin, declared_attr +from sqlalchemy.ext.declarative import declarative_base, declared_attr +from sqlalchemy.orm import Mapped, declarative_mixin # type: ignore from .dtos import ArtifactType, JobStatus, JobType @@ -33,9 +33,9 @@ class WithStandardFields: class Job(Base, WithStandardFields): __tablename__ = "jobs" - # TODO: job config url = Column(String(length=2048)) status = Column(Enum(JobStatus), nullable=False) + meta = Column(JSON(none_as_null=True)) type = Column(Enum(JobType), nullable=False) diff --git a/app/tests/conftest.py b/app/tests/conftest.py index 6b817e7..7af6dde 100644 --- a/app/tests/conftest.py +++ b/app/tests/conftest.py @@ -3,10 +3,10 @@ from typing import Dict, Generator import pytest from sqlalchemy.orm import Session from sqlalchemy_utils import create_database, database_exists, drop_database -from app.shared.config import settings -from app.shared.db.base import SessionLocal, engine, get_session import app.shared.db.models as models +from app.shared.config import settings +from app.shared.db.base import SessionLocal, engine, get_session from app.web.main import app diff --git a/app/tests/test_api.py b/app/tests/test_api.py index 3eada34..8d2d93e 100644 --- a/app/tests/test_api.py +++ b/app/tests/test_api.py @@ -1,12 +1,13 @@ from typing import Dict -from fastapi.testclient import TestClient + import pytest +from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from app.shared.db.dtos import JobType, JobStatus -from app.web.main import app -import app.shared.db.models as models import app.shared.db.dtos as dtos +import app.shared.db.models as models +from app.shared.db.dtos import JobStatus, JobType +from app.web.main import app client = TestClient(app) @@ -73,7 +74,7 @@ def test_get_job_pass(auth_headers: Dict[str, str], mock_job: models.Job) -> Non def test_get_job_not_found(auth_headers: Dict[str, str], mock_job: models.Job) -> None: res = client.get( - f"/api/v1/jobs/c8ecf5ea-77cf-48a2-9ecd-199ef35e0ccb", + "/api/v1/jobs/c8ecf5ea-77cf-48a2-9ecd-199ef35e0ccb", headers=auth_headers, ) assert res.status_code == 404 diff --git a/app/tests/test_auth.py b/app/tests/test_auth.py index b961b14..d1501cc 100644 --- a/app/tests/test_auth.py +++ b/app/tests/test_auth.py @@ -2,7 +2,6 @@ from typing import Dict from fastapi.testclient import TestClient -from app.shared.config import settings from app.web.main import app client = TestClient(app) @@ -19,7 +18,7 @@ def test_authorization_header_malformed() -> None: def test_incorrect_api_key() -> None: - res = client.get("/api/v1", headers={"Authorization": "Bearer incorrect" }) + res = client.get("/api/v1", headers={"Authorization": "Bearer incorrect"}) assert res.status_code == 401 diff --git a/app/web/main.py b/app/web/main.py index 4832739..6403fc8 100644 --- a/app/web/main.py +++ b/app/web/main.py @@ -9,6 +9,7 @@ import app.shared.db.dtos as dtos import app.shared.db.models as models from app.shared.db.base import get_session from app.web.security import authenticate_api_key +from app.worker.main import transcribe app = FastAPI() @@ -30,9 +31,11 @@ def create_job( payload: TranscriptPayload, session: Session = Depends(get_session) ) -> models.Job: job = models.Job(url=payload.url, status=dtos.JobStatus.create, type=payload.type) - session.add(job) session.flush() + + transcribe.delay(job.id) + return job @@ -51,7 +54,7 @@ def get_transcripts( @api_router.get("/jobs/{id}", response_model=dtos.Job) def get_transcript( id: UUID = Path(), session: Session = Depends(get_session) -) -> Optional[dtos.Job]: +) -> Optional[models.Job]: job = session.query(models.Job).filter(models.Job.id == id).one_or_none() if not job: raise HTTPException(status_code=404) @@ -61,7 +64,7 @@ def get_transcript( @api_router.get("/jobs/{id}/artifacts", response_model=List[dtos.Artifact]) def get_artifacts_for_job( id: UUID = Path(), session: Session = Depends(get_session) -) -> List[dtos.Artifact]: +) -> List[models.Artifact]: artifacts = ( session.query(models.Artifact).filter(models.Artifact.job_id == id) ).all() diff --git a/app/worker/main.py b/app/worker/main.py index e8b1a93..7ce4791 100644 --- a/app/worker/main.py +++ b/app/worker/main.py @@ -1,7 +1,37 @@ -from celery import Celery +from time import sleep +from uuid import UUID +from celery import Celery +from sqlalchemy.orm import Session + +import app.shared.db.dtos as dtos +import app.shared.db.models as models from app.shared.config import settings +from app.shared.db.base import SessionLocal celery = Celery(__name__) -celery.conf.broker_url = settings.REDIS_URI +celery.conf.broker_url = settings.BROKER_URI + + +def update_job_status(db: Session, job_id: UUID, status: dtos.JobStatus) -> None: + db.begin() + job = db.query(models.Job).filter(models.Job.id == job_id).one() + job.status = status + db.commit() + + +@celery.task() +def transcribe(job_id: UUID) -> int: + try: + db: Session = SessionLocal() + update_job_status(db, job_id, dtos.JobStatus.processing) + sleep(60) + update_job_status(db, job_id, dtos.JobStatus.success) + db.commit() + except Exception: + update_job_status(db, job_id, dtos.JobStatus.error) + finally: + db.close() + + return 0 diff --git a/docker/dev.docker-compose.yml b/docker/dev.docker-compose.yml index 8a6affe..c65f7af 100644 --- a/docker/dev.docker-compose.yml +++ b/docker/dev.docker-compose.yml @@ -2,18 +2,18 @@ version: "3.8" x-app-variables: &app-variables API_SECRET: a_very_secret_token - DATABASE_URI: postgresql://postgres:postgres@postgres/whisper_api + DATABASE_URI: postgresql://postgres:postgres@postgres/whisperbox ENVIRONMENT: development - REDIS_URI: redis://redis:6379/0 + BROKER_URI: redis://redis:6379/0 services: postgres: - container_name: whisper_api_postgres + container_name: whisperbox_postgres image: postgres:15-alpine environment: POSTGRES_USER: postgres POSTGRES_PASSWORD: postgres - POSTGRES_DB: whisper_api + POSTGRES_DB: whisperbox ports: - "5432:5432" networks: @@ -27,8 +27,9 @@ services: retries: 5 redis: - container_name: whisper_api_redis + container_name: whisperbox_redis image: redis:7-alpine + command: ["redis-server", "--save", "60 1"] ports: - 6379:6379 networks: @@ -37,7 +38,7 @@ services: - redis-data:/data app: - container_name: whisper_api_app + container_name: whisperbox_app build: context: ../ dockerfile: docker/app.dev.Dockerfile @@ -58,7 +59,7 @@ services: build: context: ../ dockerfile: docker/worker.dev.Dockerfile - container_name: whisper_api_worker + container_name: whisperbox_worker volumes: - ../:/code environment: *app-variables @@ -69,7 +70,7 @@ services: - app flower: - container_name: whisper_api_flower + container_name: whisperbox_flower image: mher/flower command: celery --broker redis://redis:6379/0 flower --port=5555 ports: diff --git a/mypy.ini b/mypy.ini index 4249051..9813db0 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,4 +1,4 @@ [mypy] -plugins = pydantic.mypy, sqlmypy, sqlalchemy.ext.mypy.plugin +plugins = pydantic.mypy, sqlmypy ignore_missing_imports = True disallow_untyped_defs = True diff --git a/pyproject.toml b/pyproject.toml index b5b728f..362a854 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [project] -name = "whisper-api" +name = "whisperbox" description = "" version = "0.0.1" @@ -7,7 +7,7 @@ dependencies=[ "celery[redis] ==5.2.7", "psycopg2 ==2.9.5", "sqlalchemy[mypy] == 1.4.45", - "python-dotenv ==0.21.0", + "pydantic ==1.10.4" ] [project.optional-dependencies] @@ -21,7 +21,7 @@ worker=[ "whisper-openai ==1.0.0" ] -dev = [ +lint = [ # code formatting "black", "isort", @@ -34,9 +34,14 @@ test = [ "httpx", "sqlalchemy-stubs", "sqlalchemy-utils", + "python-dotenv", "pytest" ] +worker_dev = [ + "watchdog[watchmedo]" +] + [tool.isort] profile = "black"