feat: add job & artifact tables

* remove `accounts` table in favor of a simple API key auth
This commit is contained in:
Felix Spöttel
2023-01-05 10:14:50 +01:00
parent 4f7cd063f1
commit b3a38846ba
15 changed files with 153 additions and 139 deletions

View File

@@ -1,2 +1,3 @@
DATABASE_URI="postgresql://felix@localhost:5432/whisper_api_test"
ENVIRONMENT="development"
API_SECRET="foo"

View File

@@ -1,4 +1,4 @@
[flake8]
max-line-length = 88
extend-ignore = E203
exclude = .git,__pycache__,__init__.py,.mypy_cache,.pytest_cache
exclude = .git,__pycache__,__init__.py,.mypy_cache,.pytest_cache,app/alembic/versions

View File

@@ -2,20 +2,6 @@ name: CI
on: push
jobs:
fmt:
runs-on: ubuntu-latest
name: Fmt
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.11'
cache: 'pip'
- pip install -e .[dev]
- black --check app
- isort --check app
- mypy app
- flake8 app
lint:
runs-on: ubuntu-latest
name: Lint
@@ -26,8 +12,10 @@ jobs:
python-version: '3.11'
cache: 'pip'
- pip install -e .[dev]
- mypy app
- black --check app
- isort --check app
- flake8 app
- mypy app
test:
runs-on: ubuntu-latest
name: Test

View File

@@ -3,7 +3,7 @@ dev:
uvicorn app.main:app --reload
fmt:
black app --check
black app
isort app
test:
@@ -12,6 +12,3 @@ test:
lint:
mypy app
flake8 app
create_account:
python -m scripts.create_account ${name}

View File

@@ -1,46 +0,0 @@
"""add_account_table
Revision ID: 54824f17a11d
Revises:
Create Date: 2022-12-18 17:51:09.172531
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "54824f17a11d"
down_revision = None
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"accounts",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column(
"created_at",
sa.DateTime(),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("updated_at", sa.DateTime(), nullable=True),
sa.Column("api_key", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column("name", sa.String(length=256), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"),
)
op.create_index(op.f("ix_accounts_api_key"), "accounts", ["api_key"], unique=False)
op.create_index(op.f("ix_accounts_id"), "accounts", ["id"], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_accounts_id"), table_name="accounts")
op.drop_index(op.f("ix_accounts_api_key"), table_name="accounts")
op.drop_table("accounts")
# ### end Alembic commands ###

View File

@@ -0,0 +1,73 @@
"""add_job_and_artifact_tables
Revision ID: c43a1ddae8b7
Revises:
Create Date: 2023-01-05 12:00:58.824773
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "c43a1ddae8b7"
down_revision = None
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"jobs",
sa.Column("url", sa.String(length=2048), nullable=True),
sa.Column(
"status",
sa.Enum("Create", "Error", "Success", name="jobstatus"),
nullable=False,
),
sa.Column(
"type", sa.Enum("Transcript", name="jobtype"), nullable=False
),
sa.Column(
"created_at",
sa.DateTime(),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("updated_at", sa.DateTime(), nullable=True),
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_jobs_id"), "jobs", ["id"], unique=False)
op.create_table(
"artifacts",
sa.Column("job_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("data", sa.JSON(none_as_null=True), nullable=True),
sa.Column(
"type",
sa.Enum("RawTranscript", name="artifacttype"),
nullable=False,
),
sa.Column(
"created_at",
sa.DateTime(),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("updated_at", sa.DateTime(), nullable=True),
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.ForeignKeyConstraint(["job_id"], ["jobs.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_artifacts_id"), "artifacts", ["id"], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_artifacts_id"), table_name="artifacts")
op.drop_table("artifacts")
op.drop_index(op.f("ix_jobs_id"), table_name="jobs")
op.drop_table("jobs")
# ### end Alembic commands ###

View File

@@ -6,6 +6,7 @@ from pydantic import BaseSettings
class Settings(BaseSettings):
DATABASE_URI: str
ENVIRONMENT: str
API_SECRET: str
class Config:
env_file = ".env"

View File

@@ -10,9 +10,12 @@ engine = create_engine(settings.DATABASE_URI)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def get_db() -> Generator[Session, None, None]:
db = SessionLocal()
def get_session() -> Generator[Session, None, None]:
db: Session = SessionLocal()
try:
yield db
db.commit()
except Exception:
db.rollback()
finally:
db.close()

View File

@@ -1,11 +1,26 @@
import enum
from datetime import datetime
from typing import Optional
from uuid import UUID
from pydantic import BaseModel
from pydantic import AnyHttpUrl, BaseModel, Json
class WithStandardFields(BaseModel):
class ArtifactType(enum.Enum):
RawTranscript = "RawTranscript"
class JobType(enum.Enum):
Transcript = "Transcript"
class JobStatus(enum.Enum):
Create = "Create"
Error = "Error"
Success = "Success"
class WithDbFields(BaseModel):
id: UUID
created_at: datetime
updated_at: Optional[datetime]
@@ -14,10 +29,13 @@ class WithStandardFields(BaseModel):
orm_mode = True
class AccountBase(BaseModel):
api_key: UUID
name: str
class Job(WithDbFields):
status: JobStatus
type: JobType
url: AnyHttpUrl
class Account(AccountBase, WithStandardFields):
pass
class Artifact(WithDbFields):
data: Optional[Json]
job_id: UUID
type: ArtifactType

View File

@@ -1,16 +1,20 @@
from typing import Optional
import uuid
from typing import Optional
from sqlalchemy import Column, DateTime, String, func
from sqlalchemy import JSON, Column, DateTime, Enum, ForeignKey, String, func
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import declarative_mixin, declared_attr, Mapped
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Mapped, declarative_mixin, declared_attr
from app.db.dtos import ArtifactType, JobStatus, JobType
Base = declarative_base()
@declarative_mixin
class WithStandardFields:
"""Mixin that adds standard fields (id, created_at, updated_at)."""
@declared_attr
def created_at(cls) -> Mapped[DateTime]:
return Column(DateTime, server_default=func.now(), nullable=False)
@@ -21,11 +25,26 @@ class WithStandardFields:
@declared_attr
def id(cls) -> Mapped[UUID]:
return Column(UUID(as_uuid=True), primary_key=True, index=True, default=uuid.uuid4)
return Column(
UUID(as_uuid=True), primary_key=True, index=True, default=uuid.uuid4
)
class Account(Base, WithStandardFields):
__tablename__ = "accounts"
class Job(Base, WithStandardFields):
__tablename__ = "jobs"
api_key = Column(UUID(as_uuid=True), index=True, default=uuid.uuid4)
name = Column(String(length=256), unique=True)
# TODO: job config
url = Column(String(length=2048))
status = Column(Enum(JobStatus), nullable=False)
type = Column(Enum(JobType), nullable=False)
class Artifact(Base, WithStandardFields):
__tablename__ = "artifacts"
job_id = Column(
UUID(as_uuid=True), ForeignKey("jobs.id", ondelete="CASCADE"), nullable=False
)
data = Column(JSON(none_as_null=True))
type = Column(Enum(ArtifactType), nullable=False)

View File

@@ -1,26 +1,16 @@
from uuid import UUID
from hmac import compare_digest
from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.orm import Session
from sqlalchemy.orm.exc import NoResultFound
from .db.base import get_db
from .db.models import Account
from app.config import settings
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def authenticate_api_key(
db: Session = Depends(get_db),
api_key: str = Depends(oauth2_scheme),
) -> Account:
try:
account = db.query(Account).filter(Account.api_key == UUID(api_key)).one()
except NoResultFound:
raise HTTPException(status_code=401)
except Exception as e:
print(e)
def authenticate_api_key(token: str = Depends(oauth2_scheme)) -> None:
if not token:
raise HTTPException(status_code=422)
return account
# use compare_digest to counter timing attacks.
if not compare_digest(settings.API_SECRET, token):
raise HTTPException(status_code=401)

View File

@@ -4,8 +4,8 @@ import pytest
from sqlalchemy.orm import Session
from sqlalchemy_utils import create_database, database_exists, drop_database
from app.db.base import SessionLocal, engine, get_db
from app.db.models import Account, Base
from app.db.base import SessionLocal, engine, get_session
from app.db.models import Base
from app.main import app
@@ -26,17 +26,8 @@ def db_session() -> Generator[Session, None, None]:
transaction = connection.begin()
with SessionLocal(bind=connection) as session:
app.dependency_overrides[get_db] = lambda: session
app.dependency_overrides[get_session] = lambda: session
yield session
app.dependency_overrides.clear()
transaction.rollback()
connection.close()
@pytest.fixture(scope="function")
def test_account(db_session: Session) -> Account:
account = Account(name="test_account")
db_session.add(account)
db_session.commit()
db_session.refresh(account)
return account

View File

@@ -2,7 +2,7 @@ from typing import Dict
from fastapi.testclient import TestClient
from app.db.models import Account
from app.config import settings
from app.main import app
client = TestClient(app)
@@ -18,15 +18,15 @@ def test_authorization_header_missing() -> None:
def test_authorization_header_malformed() -> None:
res = client.get("/api/v1", headers=auth_header("not_a_uuid"))
res = client.get("/api/v1", headers={"Authorization": "Bearer"})
assert res.status_code == 422
def test_inexistent_api_key(test_account: Account) -> None:
res = client.get("/api/v1", headers=auth_header(str(test_account.id)))
def test_incorrect_api_key() -> None:
res = client.get("/api/v1", headers=auth_header("not_valid"))
assert res.status_code == 401
def test_existing_api_key(test_account: Account) -> None:
res = client.get("/api/v1", headers=auth_header(str(test_account.api_key)))
def test_existing_api_key() -> None:
res = client.get("/api/v1", headers=auth_header(settings.API_SECRET))
assert res.status_code == 200

View File

View File

@@ -1,21 +0,0 @@
import argparse
from dotenv import load_dotenv
from app.db.base import get_db
from app.db.models import Account
load_dotenv()
def create_account(name: str) -> Account:
db = get_db().__next__()
account = Account(name=name)
db.add(account)
db.commit()
db.refresh(account)
return account
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("name", type=str, nargs=1)
args = parser.parse_args()
create_account(args.name[0])