From b3a38846ba5d873de73852017c56e92f223575fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20Sp=C3=B6ttel?= <1682504+fspoettel@users.noreply.github.com> Date: Thu, 5 Jan 2023 10:14:50 +0100 Subject: [PATCH] feat: add job & artifact tables * remove `accounts` table in favor of a simple API key auth --- .env.test | 1 + .flake8 | 2 +- .github/workflows/ci.yml | 18 +---- Makefile | 5 +- .../54824f17a11d_add_account_table.py | 46 ------------ ...43a1ddae8b7_add_job_and_artifact_tables.py | 73 +++++++++++++++++++ app/config.py | 1 + app/db/base.py | 7 +- app/db/dtos.py | 32 ++++++-- app/db/models.py | 35 +++++++-- app/security.py | 24 ++---- app/tests/conftest.py | 15 +--- app/tests/test_auth.py | 12 +-- scripts/__init__.py | 0 scripts/create_account.py | 21 ------ 15 files changed, 153 insertions(+), 139 deletions(-) delete mode 100644 app/alembic/versions/54824f17a11d_add_account_table.py create mode 100644 app/alembic/versions/c43a1ddae8b7_add_job_and_artifact_tables.py delete mode 100644 scripts/__init__.py delete mode 100644 scripts/create_account.py diff --git a/.env.test b/.env.test index 57a4c47..8cdfcb1 100644 --- a/.env.test +++ b/.env.test @@ -1,2 +1,3 @@ DATABASE_URI="postgresql://felix@localhost:5432/whisper_api_test" ENVIRONMENT="development" +API_SECRET="foo" diff --git a/.flake8 b/.flake8 index 9218e20..c7c7565 100644 --- a/.flake8 +++ b/.flake8 @@ -1,4 +1,4 @@ [flake8] max-line-length = 88 extend-ignore = E203 -exclude = .git,__pycache__,__init__.py,.mypy_cache,.pytest_cache +exclude = .git,__pycache__,__init__.py,.mypy_cache,.pytest_cache,app/alembic/versions diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cb61cf3..a6ca2fe 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,20 +2,6 @@ name: CI on: push jobs: - fmt: - runs-on: ubuntu-latest - name: Fmt - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: '3.11' - cache: 'pip' - - pip install -e .[dev] - - black --check app - - isort --check app - - mypy app - - flake8 app lint: runs-on: ubuntu-latest name: Lint @@ -26,8 +12,10 @@ jobs: python-version: '3.11' cache: 'pip' - pip install -e .[dev] - - mypy app + - black --check app + - isort --check app - flake8 app + - mypy app test: runs-on: ubuntu-latest name: Test diff --git a/Makefile b/Makefile index 3c492e8..9b2e58f 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ dev: uvicorn app.main:app --reload fmt: - black app --check + black app isort app test: @@ -12,6 +12,3 @@ test: lint: mypy app flake8 app - -create_account: - python -m scripts.create_account ${name} diff --git a/app/alembic/versions/54824f17a11d_add_account_table.py b/app/alembic/versions/54824f17a11d_add_account_table.py deleted file mode 100644 index 40a5c6c..0000000 --- a/app/alembic/versions/54824f17a11d_add_account_table.py +++ /dev/null @@ -1,46 +0,0 @@ -"""add_account_table - -Revision ID: 54824f17a11d -Revises: -Create Date: 2022-12-18 17:51:09.172531 - -""" -import sqlalchemy as sa -from alembic import op -from sqlalchemy.dialects import postgresql - -# revision identifiers, used by Alembic. -revision = "54824f17a11d" -down_revision = None -branch_labels = None -depends_on = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "accounts", - sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False), - sa.Column( - "created_at", - sa.DateTime(), - server_default=sa.text("now()"), - nullable=False, - ), - sa.Column("updated_at", sa.DateTime(), nullable=True), - sa.Column("api_key", postgresql.UUID(as_uuid=True), nullable=True), - sa.Column("name", sa.String(length=256), nullable=True), - sa.PrimaryKeyConstraint("id"), - sa.UniqueConstraint("name"), - ) - op.create_index(op.f("ix_accounts_api_key"), "accounts", ["api_key"], unique=False) - op.create_index(op.f("ix_accounts_id"), "accounts", ["id"], unique=False) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_index(op.f("ix_accounts_id"), table_name="accounts") - op.drop_index(op.f("ix_accounts_api_key"), table_name="accounts") - op.drop_table("accounts") - # ### end Alembic commands ### diff --git a/app/alembic/versions/c43a1ddae8b7_add_job_and_artifact_tables.py b/app/alembic/versions/c43a1ddae8b7_add_job_and_artifact_tables.py new file mode 100644 index 0000000..02fac72 --- /dev/null +++ b/app/alembic/versions/c43a1ddae8b7_add_job_and_artifact_tables.py @@ -0,0 +1,73 @@ +"""add_job_and_artifact_tables + +Revision ID: c43a1ddae8b7 +Revises: +Create Date: 2023-01-05 12:00:58.824773 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "c43a1ddae8b7" +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "jobs", + sa.Column("url", sa.String(length=2048), nullable=True), + sa.Column( + "status", + sa.Enum("Create", "Error", "Success", name="jobstatus"), + nullable=False, + ), + sa.Column( + "type", sa.Enum("Transcript", name="jobtype"), nullable=False + ), + sa.Column( + "created_at", + sa.DateTime(), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column("updated_at", sa.DateTime(), nullable=True), + sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_jobs_id"), "jobs", ["id"], unique=False) + op.create_table( + "artifacts", + sa.Column("job_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("data", sa.JSON(none_as_null=True), nullable=True), + sa.Column( + "type", + sa.Enum("RawTranscript", name="artifacttype"), + nullable=False, + ), + sa.Column( + "created_at", + sa.DateTime(), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column("updated_at", sa.DateTime(), nullable=True), + sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False), + sa.ForeignKeyConstraint(["job_id"], ["jobs.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_artifacts_id"), "artifacts", ["id"], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_artifacts_id"), table_name="artifacts") + op.drop_table("artifacts") + op.drop_index(op.f("ix_jobs_id"), table_name="jobs") + op.drop_table("jobs") + # ### end Alembic commands ### diff --git a/app/config.py b/app/config.py index 82db499..114b610 100644 --- a/app/config.py +++ b/app/config.py @@ -6,6 +6,7 @@ from pydantic import BaseSettings class Settings(BaseSettings): DATABASE_URI: str ENVIRONMENT: str + API_SECRET: str class Config: env_file = ".env" diff --git a/app/db/base.py b/app/db/base.py index 5c6f3e9..7c62f24 100644 --- a/app/db/base.py +++ b/app/db/base.py @@ -10,9 +10,12 @@ engine = create_engine(settings.DATABASE_URI) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) -def get_db() -> Generator[Session, None, None]: - db = SessionLocal() +def get_session() -> Generator[Session, None, None]: + db: Session = SessionLocal() try: yield db + db.commit() + except Exception: + db.rollback() finally: db.close() diff --git a/app/db/dtos.py b/app/db/dtos.py index cb2fce3..dae093b 100644 --- a/app/db/dtos.py +++ b/app/db/dtos.py @@ -1,11 +1,26 @@ +import enum from datetime import datetime from typing import Optional from uuid import UUID -from pydantic import BaseModel +from pydantic import AnyHttpUrl, BaseModel, Json -class WithStandardFields(BaseModel): +class ArtifactType(enum.Enum): + RawTranscript = "RawTranscript" + + +class JobType(enum.Enum): + Transcript = "Transcript" + + +class JobStatus(enum.Enum): + Create = "Create" + Error = "Error" + Success = "Success" + + +class WithDbFields(BaseModel): id: UUID created_at: datetime updated_at: Optional[datetime] @@ -14,10 +29,13 @@ class WithStandardFields(BaseModel): orm_mode = True -class AccountBase(BaseModel): - api_key: UUID - name: str +class Job(WithDbFields): + status: JobStatus + type: JobType + url: AnyHttpUrl -class Account(AccountBase, WithStandardFields): - pass +class Artifact(WithDbFields): + data: Optional[Json] + job_id: UUID + type: ArtifactType diff --git a/app/db/models.py b/app/db/models.py index 818e3dc..da1472e 100644 --- a/app/db/models.py +++ b/app/db/models.py @@ -1,16 +1,20 @@ -from typing import Optional import uuid +from typing import Optional -from sqlalchemy import Column, DateTime, String, func +from sqlalchemy import JSON, Column, DateTime, Enum, ForeignKey, String, func from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import declarative_mixin, declared_attr, Mapped from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import Mapped, declarative_mixin, declared_attr + +from app.db.dtos import ArtifactType, JobStatus, JobType Base = declarative_base() @declarative_mixin class WithStandardFields: + """Mixin that adds standard fields (id, created_at, updated_at).""" + @declared_attr def created_at(cls) -> Mapped[DateTime]: return Column(DateTime, server_default=func.now(), nullable=False) @@ -21,11 +25,26 @@ class WithStandardFields: @declared_attr def id(cls) -> Mapped[UUID]: - return Column(UUID(as_uuid=True), primary_key=True, index=True, default=uuid.uuid4) + return Column( + UUID(as_uuid=True), primary_key=True, index=True, default=uuid.uuid4 + ) -class Account(Base, WithStandardFields): - __tablename__ = "accounts" +class Job(Base, WithStandardFields): + __tablename__ = "jobs" - api_key = Column(UUID(as_uuid=True), index=True, default=uuid.uuid4) - name = Column(String(length=256), unique=True) + # TODO: job config + url = Column(String(length=2048)) + status = Column(Enum(JobStatus), nullable=False) + type = Column(Enum(JobType), nullable=False) + + +class Artifact(Base, WithStandardFields): + __tablename__ = "artifacts" + + job_id = Column( + UUID(as_uuid=True), ForeignKey("jobs.id", ondelete="CASCADE"), nullable=False + ) + + data = Column(JSON(none_as_null=True)) + type = Column(Enum(ArtifactType), nullable=False) diff --git a/app/security.py b/app/security.py index bfe0ae4..6194b5d 100644 --- a/app/security.py +++ b/app/security.py @@ -1,26 +1,16 @@ -from uuid import UUID +from hmac import compare_digest from fastapi import Depends, HTTPException from fastapi.security import OAuth2PasswordBearer -from sqlalchemy.orm import Session -from sqlalchemy.orm.exc import NoResultFound -from .db.base import get_db -from .db.models import Account +from app.config import settings oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") -def authenticate_api_key( - db: Session = Depends(get_db), - api_key: str = Depends(oauth2_scheme), -) -> Account: - try: - account = db.query(Account).filter(Account.api_key == UUID(api_key)).one() - except NoResultFound: - raise HTTPException(status_code=401) - except Exception as e: - print(e) +def authenticate_api_key(token: str = Depends(oauth2_scheme)) -> None: + if not token: raise HTTPException(status_code=422) - - return account + # use compare_digest to counter timing attacks. + if not compare_digest(settings.API_SECRET, token): + raise HTTPException(status_code=401) diff --git a/app/tests/conftest.py b/app/tests/conftest.py index 295d405..d2500a9 100644 --- a/app/tests/conftest.py +++ b/app/tests/conftest.py @@ -4,8 +4,8 @@ import pytest from sqlalchemy.orm import Session from sqlalchemy_utils import create_database, database_exists, drop_database -from app.db.base import SessionLocal, engine, get_db -from app.db.models import Account, Base +from app.db.base import SessionLocal, engine, get_session +from app.db.models import Base from app.main import app @@ -26,17 +26,8 @@ def db_session() -> Generator[Session, None, None]: transaction = connection.begin() with SessionLocal(bind=connection) as session: - app.dependency_overrides[get_db] = lambda: session + app.dependency_overrides[get_session] = lambda: session yield session app.dependency_overrides.clear() transaction.rollback() connection.close() - - -@pytest.fixture(scope="function") -def test_account(db_session: Session) -> Account: - account = Account(name="test_account") - db_session.add(account) - db_session.commit() - db_session.refresh(account) - return account diff --git a/app/tests/test_auth.py b/app/tests/test_auth.py index 72cdff0..190a9b1 100644 --- a/app/tests/test_auth.py +++ b/app/tests/test_auth.py @@ -2,7 +2,7 @@ from typing import Dict from fastapi.testclient import TestClient -from app.db.models import Account +from app.config import settings from app.main import app client = TestClient(app) @@ -18,15 +18,15 @@ def test_authorization_header_missing() -> None: def test_authorization_header_malformed() -> None: - res = client.get("/api/v1", headers=auth_header("not_a_uuid")) + res = client.get("/api/v1", headers={"Authorization": "Bearer"}) assert res.status_code == 422 -def test_inexistent_api_key(test_account: Account) -> None: - res = client.get("/api/v1", headers=auth_header(str(test_account.id))) +def test_incorrect_api_key() -> None: + res = client.get("/api/v1", headers=auth_header("not_valid")) assert res.status_code == 401 -def test_existing_api_key(test_account: Account) -> None: - res = client.get("/api/v1", headers=auth_header(str(test_account.api_key))) +def test_existing_api_key() -> None: + res = client.get("/api/v1", headers=auth_header(settings.API_SECRET)) assert res.status_code == 200 diff --git a/scripts/__init__.py b/scripts/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/scripts/create_account.py b/scripts/create_account.py deleted file mode 100644 index ec7d661..0000000 --- a/scripts/create_account.py +++ /dev/null @@ -1,21 +0,0 @@ -import argparse -from dotenv import load_dotenv -from app.db.base import get_db -from app.db.models import Account - -load_dotenv() - -def create_account(name: str) -> Account: - db = get_db().__next__() - account = Account(name=name) - db.add(account) - db.commit() - db.refresh(account) - return account - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("name", type=str, nargs=1) - args = parser.parse_args() - create_account(args.name[0])