mirror of
https://github.com/bellingcat/whisperbox-transcribe.git
synced 2026-06-12 13:38:34 +03:00
feat: add job & artifact tables
* remove `accounts` table in favor of a simple API key auth
This commit is contained in:
@@ -1,2 +1,3 @@
|
||||
DATABASE_URI="postgresql://felix@localhost:5432/whisper_api_test"
|
||||
ENVIRONMENT="development"
|
||||
API_SECRET="foo"
|
||||
|
||||
2
.flake8
2
.flake8
@@ -1,4 +1,4 @@
|
||||
[flake8]
|
||||
max-line-length = 88
|
||||
extend-ignore = E203
|
||||
exclude = .git,__pycache__,__init__.py,.mypy_cache,.pytest_cache
|
||||
exclude = .git,__pycache__,__init__.py,.mypy_cache,.pytest_cache,app/alembic/versions
|
||||
|
||||
18
.github/workflows/ci.yml
vendored
18
.github/workflows/ci.yml
vendored
@@ -2,20 +2,6 @@ name: CI
|
||||
on: push
|
||||
|
||||
jobs:
|
||||
fmt:
|
||||
runs-on: ubuntu-latest
|
||||
name: Fmt
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
- pip install -e .[dev]
|
||||
- black --check app
|
||||
- isort --check app
|
||||
- mypy app
|
||||
- flake8 app
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
name: Lint
|
||||
@@ -26,8 +12,10 @@ jobs:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
- pip install -e .[dev]
|
||||
- mypy app
|
||||
- black --check app
|
||||
- isort --check app
|
||||
- flake8 app
|
||||
- mypy app
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
name: Test
|
||||
|
||||
5
Makefile
5
Makefile
@@ -3,7 +3,7 @@ dev:
|
||||
uvicorn app.main:app --reload
|
||||
|
||||
fmt:
|
||||
black app --check
|
||||
black app
|
||||
isort app
|
||||
|
||||
test:
|
||||
@@ -12,6 +12,3 @@ test:
|
||||
lint:
|
||||
mypy app
|
||||
flake8 app
|
||||
|
||||
create_account:
|
||||
python -m scripts.create_account ${name}
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
"""add_account_table
|
||||
|
||||
Revision ID: 54824f17a11d
|
||||
Revises:
|
||||
Create Date: 2022-12-18 17:51:09.172531
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "54824f17a11d"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"accounts",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("updated_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("api_key", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("name", sa.String(length=256), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("name"),
|
||||
)
|
||||
op.create_index(op.f("ix_accounts_api_key"), "accounts", ["api_key"], unique=False)
|
||||
op.create_index(op.f("ix_accounts_id"), "accounts", ["id"], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f("ix_accounts_id"), table_name="accounts")
|
||||
op.drop_index(op.f("ix_accounts_api_key"), table_name="accounts")
|
||||
op.drop_table("accounts")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,73 @@
|
||||
"""add_job_and_artifact_tables
|
||||
|
||||
Revision ID: c43a1ddae8b7
|
||||
Revises:
|
||||
Create Date: 2023-01-05 12:00:58.824773
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c43a1ddae8b7"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"jobs",
|
||||
sa.Column("url", sa.String(length=2048), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.Enum("Create", "Error", "Success", name="jobstatus"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"type", sa.Enum("Transcript", name="jobtype"), nullable=False
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("updated_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("ix_jobs_id"), "jobs", ["id"], unique=False)
|
||||
op.create_table(
|
||||
"artifacts",
|
||||
sa.Column("job_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("data", sa.JSON(none_as_null=True), nullable=True),
|
||||
sa.Column(
|
||||
"type",
|
||||
sa.Enum("RawTranscript", name="artifacttype"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("updated_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(["job_id"], ["jobs.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("ix_artifacts_id"), "artifacts", ["id"], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f("ix_artifacts_id"), table_name="artifacts")
|
||||
op.drop_table("artifacts")
|
||||
op.drop_index(op.f("ix_jobs_id"), table_name="jobs")
|
||||
op.drop_table("jobs")
|
||||
# ### end Alembic commands ###
|
||||
@@ -6,6 +6,7 @@ from pydantic import BaseSettings
|
||||
class Settings(BaseSettings):
|
||||
DATABASE_URI: str
|
||||
ENVIRONMENT: str
|
||||
API_SECRET: str
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
@@ -10,9 +10,12 @@ engine = create_engine(settings.DATABASE_URI)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
db = SessionLocal()
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
db: Session = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
db.commit()
|
||||
except Exception:
|
||||
db.rollback()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -1,11 +1,26 @@
|
||||
import enum
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import AnyHttpUrl, BaseModel, Json
|
||||
|
||||
|
||||
class WithStandardFields(BaseModel):
|
||||
class ArtifactType(enum.Enum):
|
||||
RawTranscript = "RawTranscript"
|
||||
|
||||
|
||||
class JobType(enum.Enum):
|
||||
Transcript = "Transcript"
|
||||
|
||||
|
||||
class JobStatus(enum.Enum):
|
||||
Create = "Create"
|
||||
Error = "Error"
|
||||
Success = "Success"
|
||||
|
||||
|
||||
class WithDbFields(BaseModel):
|
||||
id: UUID
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime]
|
||||
@@ -14,10 +29,13 @@ class WithStandardFields(BaseModel):
|
||||
orm_mode = True
|
||||
|
||||
|
||||
class AccountBase(BaseModel):
|
||||
api_key: UUID
|
||||
name: str
|
||||
class Job(WithDbFields):
|
||||
status: JobStatus
|
||||
type: JobType
|
||||
url: AnyHttpUrl
|
||||
|
||||
|
||||
class Account(AccountBase, WithStandardFields):
|
||||
pass
|
||||
class Artifact(WithDbFields):
|
||||
data: Optional[Json]
|
||||
job_id: UUID
|
||||
type: ArtifactType
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
from typing import Optional
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Column, DateTime, String, func
|
||||
from sqlalchemy import JSON, Column, DateTime, Enum, ForeignKey, String, func
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import declarative_mixin, declared_attr, Mapped
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Mapped, declarative_mixin, declared_attr
|
||||
|
||||
from app.db.dtos import ArtifactType, JobStatus, JobType
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
@declarative_mixin
|
||||
class WithStandardFields:
|
||||
"""Mixin that adds standard fields (id, created_at, updated_at)."""
|
||||
|
||||
@declared_attr
|
||||
def created_at(cls) -> Mapped[DateTime]:
|
||||
return Column(DateTime, server_default=func.now(), nullable=False)
|
||||
@@ -21,11 +25,26 @@ class WithStandardFields:
|
||||
|
||||
@declared_attr
|
||||
def id(cls) -> Mapped[UUID]:
|
||||
return Column(UUID(as_uuid=True), primary_key=True, index=True, default=uuid.uuid4)
|
||||
return Column(
|
||||
UUID(as_uuid=True), primary_key=True, index=True, default=uuid.uuid4
|
||||
)
|
||||
|
||||
|
||||
class Account(Base, WithStandardFields):
|
||||
__tablename__ = "accounts"
|
||||
class Job(Base, WithStandardFields):
|
||||
__tablename__ = "jobs"
|
||||
|
||||
api_key = Column(UUID(as_uuid=True), index=True, default=uuid.uuid4)
|
||||
name = Column(String(length=256), unique=True)
|
||||
# TODO: job config
|
||||
url = Column(String(length=2048))
|
||||
status = Column(Enum(JobStatus), nullable=False)
|
||||
type = Column(Enum(JobType), nullable=False)
|
||||
|
||||
|
||||
class Artifact(Base, WithStandardFields):
|
||||
__tablename__ = "artifacts"
|
||||
|
||||
job_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("jobs.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
data = Column(JSON(none_as_null=True))
|
||||
type = Column(Enum(ArtifactType), nullable=False)
|
||||
|
||||
@@ -1,26 +1,16 @@
|
||||
from uuid import UUID
|
||||
from hmac import compare_digest
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm.exc import NoResultFound
|
||||
|
||||
from .db.base import get_db
|
||||
from .db.models import Account
|
||||
from app.config import settings
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
|
||||
def authenticate_api_key(
|
||||
db: Session = Depends(get_db),
|
||||
api_key: str = Depends(oauth2_scheme),
|
||||
) -> Account:
|
||||
try:
|
||||
account = db.query(Account).filter(Account.api_key == UUID(api_key)).one()
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=401)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
def authenticate_api_key(token: str = Depends(oauth2_scheme)) -> None:
|
||||
if not token:
|
||||
raise HTTPException(status_code=422)
|
||||
|
||||
return account
|
||||
# use compare_digest to counter timing attacks.
|
||||
if not compare_digest(settings.API_SECRET, token):
|
||||
raise HTTPException(status_code=401)
|
||||
|
||||
@@ -4,8 +4,8 @@ import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy_utils import create_database, database_exists, drop_database
|
||||
|
||||
from app.db.base import SessionLocal, engine, get_db
|
||||
from app.db.models import Account, Base
|
||||
from app.db.base import SessionLocal, engine, get_session
|
||||
from app.db.models import Base
|
||||
from app.main import app
|
||||
|
||||
|
||||
@@ -26,17 +26,8 @@ def db_session() -> Generator[Session, None, None]:
|
||||
transaction = connection.begin()
|
||||
|
||||
with SessionLocal(bind=connection) as session:
|
||||
app.dependency_overrides[get_db] = lambda: session
|
||||
app.dependency_overrides[get_session] = lambda: session
|
||||
yield session
|
||||
app.dependency_overrides.clear()
|
||||
transaction.rollback()
|
||||
connection.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def test_account(db_session: Session) -> Account:
|
||||
account = Account(name="test_account")
|
||||
db_session.add(account)
|
||||
db_session.commit()
|
||||
db_session.refresh(account)
|
||||
return account
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Dict
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.db.models import Account
|
||||
from app.config import settings
|
||||
from app.main import app
|
||||
|
||||
client = TestClient(app)
|
||||
@@ -18,15 +18,15 @@ def test_authorization_header_missing() -> None:
|
||||
|
||||
|
||||
def test_authorization_header_malformed() -> None:
|
||||
res = client.get("/api/v1", headers=auth_header("not_a_uuid"))
|
||||
res = client.get("/api/v1", headers={"Authorization": "Bearer"})
|
||||
assert res.status_code == 422
|
||||
|
||||
|
||||
def test_inexistent_api_key(test_account: Account) -> None:
|
||||
res = client.get("/api/v1", headers=auth_header(str(test_account.id)))
|
||||
def test_incorrect_api_key() -> None:
|
||||
res = client.get("/api/v1", headers=auth_header("not_valid"))
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_existing_api_key(test_account: Account) -> None:
|
||||
res = client.get("/api/v1", headers=auth_header(str(test_account.api_key)))
|
||||
def test_existing_api_key() -> None:
|
||||
res = client.get("/api/v1", headers=auth_header(settings.API_SECRET))
|
||||
assert res.status_code == 200
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
import argparse
|
||||
from dotenv import load_dotenv
|
||||
from app.db.base import get_db
|
||||
from app.db.models import Account
|
||||
|
||||
load_dotenv()
|
||||
|
||||
def create_account(name: str) -> Account:
|
||||
db = get_db().__next__()
|
||||
account = Account(name=name)
|
||||
db.add(account)
|
||||
db.commit()
|
||||
db.refresh(account)
|
||||
return account
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("name", type=str, nargs=1)
|
||||
args = parser.parse_args()
|
||||
create_account(args.name[0])
|
||||
Reference in New Issue
Block a user