introduces Sheet models and auth flow

This commit is contained in:
msramalho
2024-11-05 10:33:44 +00:00
parent f1525ef85a
commit 59c1be597c
13 changed files with 493 additions and 74 deletions

View File

@@ -126,7 +126,8 @@ def is_user_in_group(db: Session, group_name: str, email: str) -> models.Group:
return len(group_name) and len(email) and group_name in get_user_groups(db, email) return len(group_name) and len(email) and group_name in get_user_groups(db, email)
def get_user_groups(db: Session, email: str): #TODO: maybe this can be cached? what about the db session?
def get_user_groups(db: Session, email: str) -> list[str]:
""" """
given an email retrieves the user groups from the DB and then the email-domain groups from a global variable, the email does not need to belong to an existing user. User does not need to be active. given an email retrieves the user groups from the DB and then the email-domain groups from a global variable, the email does not need to belong to an existing user. User does not need to be active.
""" """
@@ -135,15 +136,53 @@ def get_user_groups(db: Session, email: str):
# get user groups # get user groups
user_groups = db.query(models.association_table_user_groups).filter_by(user_id=email).with_entities(Column("group_id")).all() user_groups = db.query(models.association_table_user_groups).filter_by(user_id=email).with_entities(Column("group_id")).all()
user_level_groups = [g[0] for g in user_groups] user_level_groups_names = [g[0] for g in user_groups]
# get domain groups # get domain groups
domain = email.split('@')[1] domain = email.split('@')[1]
domain_level_groups = db.query(models.Group.id).filter(models.Group.domains.contains(domain)).with_entities(Column("id")).all() domain_level_groups = db.query(models.Group.id).filter(models.Group.domains.contains(domain)).with_entities(Column("id")).all()
domain_level_groups = [g[0] for g in domain_level_groups] domain_level_groups_names = [g[0] for g in domain_level_groups]
# combine and return return list(set(user_level_groups_names + domain_level_groups_names))
return list(set(user_level_groups + domain_level_groups))
# --------------- SHEET
def has_quota_sheet(db: Session, email: str, user_groups_names: list[str]) -> bool:
"""
checks if a user has reached their sheet quota
"""
user_sheets = db.query(models.Sheet).filter(models.Sheet.author_id == email).count()
user_groups = db.query(models.Group).filter(models.Group.id.in_(user_groups_names)).all()
quota = 0
for group in user_groups:
active_sheets = group.permissions.get("active_sheets", 0)
if active_sheets == -1: return True
quota = max(quota, active_sheets)
return user_sheets < quota
def create_sheet(db: Session, sheet_id: str, sheet_name: str, email: str, group_id: str, frequency: str):
db_sheet = models.Sheet(id=sheet_id, name=sheet_name, author_id=email, group_id=group_id, frequency=frequency)
db.add(db_sheet)
db.commit()
db.refresh(db_sheet)
return db_sheet
def get_user_sheets(db: Session, email: str) -> list[models.Sheet]:
return db.query(models.Sheet).filter(models.Sheet.author_id == email).order_by(models.Sheet.last_archived_at.desc()).all()
def get_user_sheet(db: Session, email: str, sheet_id: str) -> models.Sheet:
return db.query(models.Sheet).filter(models.Sheet.author_id == email, models.Sheet.id == sheet_id).first()
def delete_sheet(db: Session, sheet_id: str, email: str) -> bool:
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id, models.Sheet.author_id == email).first()
if db_sheet:
db.delete(db_sheet)
db.commit()
return db_sheet is not None
# --------------- INIT User-Groups # --------------- INIT User-Groups
@@ -255,5 +294,5 @@ def upsert_user_groups(db: Session):
db.commit() db.commit()
count_user_groups = db.query(models.association_table_user_groups).count() count_user_groups = db.query(models.association_table_user_groups).count()
count_groups = db.query(func.count(models.Group.id)).scalar() count_groups = db.query(func.count(models.Group.id)).scalar()
logger.success(f"[CONFIG] DONE: [users={count_users(db)}, groups={count_groups}, explicit user groups={count_user_groups}].") logger.success(f"[CONFIG] DONE: [users={count_users(db)}, groups={count_groups}, explicit user groups={count_user_groups}].")

View File

@@ -6,9 +6,11 @@ import uuid
Base = declarative_base() Base = declarative_base()
def generate_uuid(): def generate_uuid():
return str(uuid.uuid4()) return str(uuid.uuid4())
# many to many association tables # many to many association tables
association_table_archive_tags = Table( association_table_archive_tags = Table(
"mtm_archives_tags", "mtm_archives_tags",
@@ -24,24 +26,29 @@ association_table_user_groups = Table(
) )
# data model tables # data model tables
class Archive(Base): class Archive(Base):
__tablename__ = "archives" __tablename__ = "archives"
id = Column(String, primary_key=True, index=True) id = Column(String, primary_key=True, index=True)
url = Column(String, index=True) url = Column(String, index=True)
result = Column(JSON, default=None) result = Column(JSON, default=None)
public = Column(Boolean, default=True) # if public=false, access to group and author public = Column(Boolean, default=True) # if public=false, access to group and author
deleted = Column(Boolean, default=False) deleted = Column(Boolean, default=False)
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now())
group_id = Column(String, ForeignKey("groups.id"), default=None) group_id = Column(String, ForeignKey("groups.id"), default=None)
author_id = Column(String, ForeignKey("users.email")) author_id = Column(String, ForeignKey("users.email"))
sheet_id = Column(String, ForeignKey("sheets.id"), default=None)
tags = relationship("Tag", back_populates="archives", secondary=association_table_archive_tags) tags = relationship("Tag", back_populates="archives", secondary=association_table_archive_tags)
group = relationship("Group", back_populates="archives") group = relationship("Group", back_populates="archives")
author = relationship("User", back_populates="archives") author = relationship("User", back_populates="archives")
urls = relationship("ArchiveUrl", back_populates="archive") urls = relationship("ArchiveUrl", back_populates="archive")
sheet = relationship("Sheet", back_populates="archives")
class ArchiveUrl(Base): class ArchiveUrl(Base):
__tablename__ = "archive_urls" __tablename__ = "archive_urls"
@@ -61,6 +68,7 @@ class Tag(Base):
archives = relationship("Archive", back_populates="tags", secondary=association_table_archive_tags) archives = relationship("Archive", back_populates="tags", secondary=association_table_archive_tags)
class User(Base): class User(Base):
__tablename__ = "users" __tablename__ = "users"
@@ -68,8 +76,10 @@ class User(Base):
is_active = Column(Boolean, default=False) is_active = Column(Boolean, default=False)
archives = relationship("Archive", back_populates="author") archives = relationship("Archive", back_populates="author")
sheets = relationship("Sheet", back_populates="author")
groups = relationship("Group", back_populates="users", secondary=association_table_user_groups) groups = relationship("Group", back_populates="users", secondary=association_table_user_groups)
class Group(Base): class Group(Base):
__tablename__ = "groups" __tablename__ = "groups"
@@ -81,4 +91,23 @@ class Group(Base):
domains = Column(JSON, default=[]) domains = Column(JSON, default=[])
archives = relationship("Archive", back_populates="group") archives = relationship("Archive", back_populates="group")
users = relationship("User", back_populates="groups", secondary=association_table_user_groups) sheets = relationship("Sheet", back_populates="group")
users = relationship("User", back_populates="groups", secondary=association_table_user_groups)
class Sheet(Base):
__tablename__ = "sheets"
id = Column(String, primary_key=True, index=True, doc="Google Sheet ID")
name = Column(String, default=None)
author_id = Column(String, ForeignKey("users.email"))
group_id = Column(String, ForeignKey("groups.id"), doc="Group ID, user must be in a group to create a sheet.")
frequency = Column(String, default="daily", doc="Frequency of archiving: hourly, daily, weekly.")
stats = Column(JSON, default={}, doc="Sheet statistics like total links, total rows, ...")
last_archived_at = Column(DateTime(timezone=True), server_default=func.now(), doc="Last time a new link was archived.")
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
group = relationship("Group", back_populates="sheets")
author = relationship("User", back_populates="sheets")
archives = relationship("Archive", back_populates="sheet")

View File

@@ -1,4 +1,4 @@
from pydantic import BaseModel from pydantic import BaseModel, field_validator
from datetime import datetime from datetime import datetime
@@ -6,9 +6,10 @@ class Tag(BaseModel):
id: str id: str
created_at: datetime created_at: datetime
model_config = { "from_attributes": True } model_config = {"from_attributes": True}
__hash__ = object.__hash__ __hash__ = object.__hash__
class ArchiveCreate(BaseModel): class ArchiveCreate(BaseModel):
id: str | None = None id: str | None = None
url: str url: str
@@ -26,7 +27,8 @@ class Archive(ArchiveCreate):
updated_at: datetime | None updated_at: datetime | None
deleted: bool deleted: bool
model_config = { "from_attributes": True } model_config = {"from_attributes": True}
class SubmitSheet(BaseModel): class SubmitSheet(BaseModel):
sheet_name: str | None = None sheet_name: str | None = None
@@ -36,31 +38,70 @@ class SubmitSheet(BaseModel):
author_id: str | None = None author_id: str | None = None
group_id: str | None = None group_id: str | None = None
tags: set[str] | None = set() tags: set[str] | None = set()
columns: dict | None = {} # TODO: implement columns: dict | None = {} # TODO: implement
class SubmitManual(BaseModel): class SubmitManual(BaseModel):
result: str # should be a Metadata.to_json() result: str # should be a Metadata.to_json()
public: bool = False public: bool = False
author_id: str | None = None author_id: str | None = None
group_id: str | None = None group_id: str | None = None
tags: set[str] | None = set() tags: set[str] | None = set()
# API REQUESTS BELOW
# TODO: replace existing schemas with these
class ArchiveUrl(BaseModel):
url: str
public: bool = False
author_id: str | None
group_id: str | None
tags: set[str] | None = set()
# API RESPONSES BELOW # API RESPONSES BELOW
class ArchiveResult(BaseModel): class ArchiveResult(BaseModel):
id: str id: str
url: str url: str
result: dict result: dict
created_at: datetime created_at: datetime
class Task(BaseModel): class Task(BaseModel):
id: str id: str
class TaskResult(Task): class TaskResult(Task):
status: str status: str
result: str result: str
class TaskDelete(Task): class TaskDelete(Task):
deleted: bool deleted: bool
class ActiveUser(BaseModel): class ActiveUser(BaseModel):
active: bool active: bool
class SheetAdd(BaseModel):
id: str
name: str
group_id: str
frequency: str
@field_validator('frequency')
def validate_frequency(cls, v):
valid_frequencies = {"hourly", "daily"}
if v not in {"hourly", "daily"}:
raise ValueError(f"Invalid frequency: {v}. Must be one of {valid_frequencies}.")
return v
class SheetResponse(SheetAdd):
author_id: str
stats: dict | None
last_archived_at: datetime | None
created_at: datetime

View File

@@ -2,23 +2,80 @@
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from loguru import logger from sqlalchemy import exc
from sqlalchemy.orm import Session
from core.config import ALLOW_ANY_EMAIL from web.security import token_api_key_auth, get_active_user_auth
from web.security import get_token_or_user_auth from db import schemas, crud
from db import schemas from db.database import get_db_dependency
from worker.main import create_sheet_task from worker.main import create_sheet_task
sheet_router = APIRouter(prefix="/sheet", tags=["Google Spreadsheet operations"]) sheet_router = APIRouter(prefix="/sheet", tags=["Google Spreadsheet operations"])
@sheet_router.post("/archive", status_code=201, summary="Submit a Google Sheet archive request, starts a sheet archiving task.", response_description="task_id for the archiving task.") @sheet_router.post("/create", status_code=201, summary="Store a new Google Sheet for regular archiving.")
def archive_sheet(sheet:schemas.SubmitSheet, email = Depends(get_token_or_user_auth)) -> schemas.Task: def create_sheet(
logger.info(f"SHEET TASK for {sheet=}") sheet: schemas.SheetAdd,
if email == ALLOW_ANY_EMAIL: email=Depends(get_active_user_auth),
email = sheet.author_id or "api-endpoint" db: Session = Depends(get_db_dependency),
sheet.author_id = email ) -> schemas.SheetResponse:
if not sheet.sheet_name and not sheet.sheet_id: user_groups_names = crud.get_user_groups(db, email)
raise HTTPException(status_code=422, detail=f"sheet name or id is required")
if sheet.group_id not in user_groups_names:
raise HTTPException(status_code=403, detail="User does not have access to this group.")
if not crud.has_quota_sheet(db, email, user_groups_names):
raise HTTPException(status_code=429, detail="User has reached their sheet quota.")
try:
return crud.create_sheet(db, sheet.id, sheet.name, email, sheet.group_id, sheet.frequency)
except exc.IntegrityError as e:
raise HTTPException(status_code=400, detail="Sheet with this ID already exists.") from e
@sheet_router.get("/mine", status_code=200, summary="Get the authenticated user's Google Sheets.")
def get_user_sheets(
email=Depends(get_active_user_auth),
db: Session = Depends(get_db_dependency)
) -> list[schemas.SheetResponse]:
return crud.get_user_sheets(db, email)
@sheet_router.delete("/{id}", summary="Delete a Google Sheet by ID.")
def delete_sheet(
id: str,
email=Depends(get_active_user_auth),
db: Session = Depends(get_db_dependency),
) -> schemas.TaskDelete:
return JSONResponse({
"id": id,
"deleted": crud.delete_sheet(db, id, email)
})
@sheet_router.post("/{id}/archive", status_code=201, summary="Trigger an archiving task for a GSheet you own.", response_description="task_id for the archiving task.")
def archive_user_sheet(
id: str,
email=Depends(get_active_user_auth),
db: Session = Depends(get_db_dependency),
) -> schemas.Task:
sheet = crud.get_user_sheet(db, email, sheet_id=id)
if not sheet:
raise HTTPException(status_code=403, detail="No access to this sheet.")
task = create_sheet_task.delay(schemas.SubmitSheet(sheet_id=id, author_id=email, group=sheet.group_id).model_dump_json())
return JSONResponse({"id": task.id}, status_code=201)
@sheet_router.post("/archive", status_code=201, summary="Trigger an archiving task for any GSheet with an API token.", response_description="task_id for the archiving task.")
def archive_sheet(
sheet: schemas.SubmitSheet, #TODO: replace with simpler model
auth=Depends(token_api_key_auth)
) -> schemas.Task:
sheet.author_id = sheet.author_id or "api-endpoint"
if not sheet.sheet_id:
raise HTTPException(status_code=422, detail=f"sheet id is required")
task = create_sheet_task.delay(sheet.model_dump_json()) task = create_sheet_task.delay(sheet.model_dump_json())
return JSONResponse({"id": task.id}, status_code=201) return JSONResponse({"id": task.id}, status_code=201)

View File

@@ -0,0 +1,42 @@
"""add sheet_id to archive table
Revision ID: 89121d2c96d8
Revises: fa012ec405b8
Create Date: 2024-11-04 11:12:30.237299
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.engine.reflection import Inspector
# revision identifiers, used by Alembic.
revision = '89121d2c96d8'
down_revision = 'fa012ec405b8'
branch_labels = None
depends_on = None
def upgrade() -> None:
conn = op.get_bind()
inspector = Inspector.from_engine(conn)
columns = [col['name'] for col in inspector.get_columns('archives')]
if 'sheet_id' not in columns:
with op.batch_alter_table('archives') as batch_op:
batch_op.add_column(sa.Column('sheet_id', sa.String(), nullable=True, default=None))
batch_op.create_foreign_key('fk_sheet_id', 'sheets', ['sheet_id'], ['id'])
def downgrade() -> None:
conn = op.get_bind()
inspector = Inspector.from_engine(conn)
foreign_keys = [fk['name'] for fk in inspector.get_foreign_keys('archives')]
columns = [col['name'] for col in inspector.get_columns('archives')]
with op.batch_alter_table('archives') as batch_op:
if 'fk_sheet_id' in foreign_keys:
batch_op.drop_constraint('fk_sheet_id', type_='foreignkey')
if 'sheet_id' in columns:
batch_op.drop_column('sheet_id')

View File

@@ -35,8 +35,11 @@ def upgrade() -> None:
def downgrade() -> None: def downgrade() -> None:
op.drop_column('groups', 'description') conn = op.get_bind()
op.drop_column('groups', 'orchestrator') inspector = Inspector.from_engine(conn)
op.drop_column('groups', 'orchestrator_sheet') columns = [col['name'] for col in inspector.get_columns('groups')]
op.drop_column('groups', 'permissions')
op.drop_column('groups', 'domains') column_names = ['description', 'orchestrator', 'orchestrator_sheet', 'permissions', 'domains']
for column_name in column_names:
if column_name in columns:
op.drop_column('groups', column_name)

View File

@@ -16,6 +16,7 @@ def mock_logger_add():
def get_settings(): def get_settings():
return Settings(_env_file=".env.test") return Settings(_env_file=".env.test")
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_settings(): def mock_settings():
with patch('shared.settings.Settings', return_value=Settings(_env_file=".env.test")) as mock_settings: with patch('shared.settings.Settings', return_value=Settings(_env_file=".env.test")) as mock_settings:
@@ -26,7 +27,7 @@ def mock_settings():
def test_db(get_settings: Settings): def test_db(get_settings: Settings):
from db.database import make_engine from db.database import make_engine
from db import models from db import models
make_engine.cache_clear() make_engine.cache_clear()
engine = make_engine(get_settings.DATABASE_PATH) engine = make_engine(get_settings.DATABASE_PATH)
@@ -72,10 +73,10 @@ def client(app):
@pytest.fixture() @pytest.fixture()
def app_with_auth(app): def app_with_auth(app):
from web.security import get_token_or_user_auth, get_user_auth, token_api_key_auth from web.security import get_token_or_user_auth, get_user_auth, get_active_user_auth
app.dependency_overrides[get_token_or_user_auth] = lambda: "rick@example.com" app.dependency_overrides[get_token_or_user_auth] = lambda: "rick@example.com"
app.dependency_overrides[get_user_auth] = lambda: "morty@example.com" app.dependency_overrides[get_user_auth] = lambda: "morty@example.com"
app.dependency_overrides[token_api_key_auth] = lambda: "jerry@example.com" app.dependency_overrides[get_active_user_auth] = lambda: "morty@example.com"
return app return app
@@ -85,6 +86,19 @@ def client_with_auth(app_with_auth):
return client return client
@pytest.fixture()
def app_with_token(app):
from web.security import token_api_key_auth
app.dependency_overrides[token_api_key_auth] = lambda: "jerry@example.com"
return app
@pytest.fixture()
def client_with_token(app_with_token):
client = TestClient(app_with_token)
return client
@pytest.fixture() @pytest.fixture()
def test_no_auth(): def test_no_auth():
# reusable code to ensure a method/endpoint combination is unauthorized # reusable code to ensure a method/endpoint combination is unauthorized
@@ -92,4 +106,4 @@ def test_no_auth():
response = http_method(endpoint) response = http_method(endpoint)
assert response.status_code == 403 assert response.status_code == 403
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
return no_auth return no_auth

View File

@@ -101,10 +101,16 @@ def test_favicon(client_with_auth):
assert r.headers["content-type"] == "image/vnd.microsoft.icon" assert r.headers["content-type"] == "image/vnd.microsoft.icon"
def test_endpoint_test_prometheus_no_auth(client, test_no_auth):
test_no_auth(client.get, "/metrics")
def test_endpoint_test_prometheus_no_user_auth(client_with_auth, test_no_auth):
test_no_auth(client_with_auth.get, "/metrics")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_prometheus_metrics(test_data, client_with_auth, get_settings): async def test_prometheus_metrics(test_data, client_with_token, get_settings):
# before metrics calculation # before metrics calculation
r = client_with_auth.get("/metrics") r = client_with_token.get("/metrics")
assert r.status_code == 200 assert r.status_code == 200
assert r.headers["content-type"] == "text/plain; version=0.0.4; charset=utf-8" assert r.headers["content-type"] == "text/plain; version=0.0.4; charset=utf-8"
assert "disk_utilization" in r.text assert "disk_utilization" in r.text
@@ -116,7 +122,7 @@ async def test_prometheus_metrics(test_data, client_with_auth, get_settings):
# after metrics calculation # after metrics calculation
from utils.metrics import measure_regular_metrics from utils.metrics import measure_regular_metrics
await measure_regular_metrics(get_settings.DATABASE_PATH, 60 * 60 * 24 * 31 * 12 * 100) await measure_regular_metrics(get_settings.DATABASE_PATH, 60 * 60 * 24 * 31 * 12 * 100)
r2 = client_with_auth.get("/metrics") r2 = client_with_token.get("/metrics")
assert 'disk_utilization{type="used"}' in r2.text assert 'disk_utilization{type="used"}' in r2.text
assert 'disk_utilization{type="free"}' in r2.text assert 'disk_utilization{type="free"}' in r2.text
assert 'disk_utilization{type="database"}' in r2.text assert 'disk_utilization{type="database"}' in r2.text
@@ -130,7 +136,7 @@ async def test_prometheus_metrics(test_data, client_with_auth, get_settings):
# 30s window, should not change the gauges nor the total in the counters # 30s window, should not change the gauges nor the total in the counters
from utils.metrics import measure_regular_metrics from utils.metrics import measure_regular_metrics
await measure_regular_metrics(get_settings.DATABASE_PATH, 30) await measure_regular_metrics(get_settings.DATABASE_PATH, 30)
r3 = client_with_auth.get("/metrics") r3 = client_with_token.get("/metrics")
assert 'database_metrics{query="count_archives"} 100.0' in r3.text assert 'database_metrics{query="count_archives"} 100.0' in r3.text
assert 'database_metrics{query="count_archive_urls"} 1000.0' in r3.text assert 'database_metrics{query="count_archive_urls"} 1000.0' in r3.text
assert 'database_metrics{query="count_users"} 4.0' in r3.text assert 'database_metrics{query="count_users"} 4.0' in r3.text

View File

@@ -5,15 +5,19 @@ def test_submit_manual_archive_unauthenticated(client, test_no_auth):
test_no_auth(client.post, "/interop/submit-archive") test_no_auth(client.post, "/interop/submit-archive")
def test_submit_manual_archive(client_with_auth): def test_submit_manual_archive_not_user_auth(client_with_auth, test_no_auth):
test_no_auth(client_with_auth.post, "/interop/submit-archive")
def test_submit_manual_archive(client_with_token):
aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": []}) aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": []})
r = client_with_auth.post("/interop/submit-archive", json={"result": aa_metadata, "public": False, "author_id": "jerry@gmail.com", "group_id": None, "tags": ["test"]}) r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": False, "author_id": "jerry@gmail.com", "group_id": None, "tags": ["test"]})
assert r.status_code == 201 assert r.status_code == 201
assert "id" in r.json() assert "id" in r.json()
# cannot have the same URL twice # cannot have the same URL twice
aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.com", "http://example.com"]}]}) aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.com", "http://example.com"]}]})
r = client_with_auth.post("/interop/submit-archive", json={"result": aa_metadata, "public": False, "author_id": "jerry@gmail.com", "group_id": None, "tags": ["test"]}) r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": False, "author_id": "jerry@gmail.com", "group_id": None, "tags": ["test"]})
assert r.status_code == 422 assert r.status_code == 422
assert r.json() == {"detail": "Cannot insert into DB due to integrity error"} assert r.json() == {"detail": "Cannot insert into DB due to integrity error"}

View File

@@ -1,46 +1,210 @@
from datetime import datetime
import json import json
from unittest.mock import patch from unittest.mock import patch
from fastapi.testclient import TestClient
from db.schemas import TaskResult from db.schemas import TaskResult
def test_sheet_no_auth(client, test_no_auth): def test_endpoints_no_auth(client, test_no_auth):
test_no_auth(client.post, "/sheet/create")
test_no_auth(client.get, "/sheet/mine")
test_no_auth(client.delete, "/sheet/123-sheet-id")
test_no_auth(client.post, "/sheet/123-sheet-id/archive")
test_no_auth(client.post, "/sheet/archive") test_no_auth(client.post, "/sheet/archive")
@patch("worker.main.create_sheet_task.delay", return_value=TaskResult(id="123-456-789", status="PENDING", result="")) def test_create_sheet_endpoint(app_with_auth):
def test_sheet_rick(m1, client_with_auth): client_with_auth = TestClient(app_with_auth)
good_data = {
"id": "123-sheet-id",
"name": "Test Sheet",
"group_id": "spaceship",
"frequency": "daily"
}
response = client_with_auth.post("/sheet/archive", json={"sheet_id": "123-sheet-id"}) # with good data
response = client_with_auth.post("/sheet/create", json=good_data)
assert response.status_code == 201 assert response.status_code == 201
assert response.json() == {'id': '123-456-789'} j = response.json()
assert datetime.fromisoformat(j.pop("created_at"))
assert datetime.fromisoformat(j.pop("last_archived_at"))
assert j.pop("stats") == {}
assert j.pop("author_id") == 'morty@example.com'
assert j == good_data
m1.assert_called_once() # already exists
called_val = m1.call_args.args[0] response = client_with_auth.post("/sheet/create", json=good_data)
assert json.loads(called_val) == {"sheet_id": "123-sheet-id", "sheet_name": None, "public": False, "author_id": "rick@example.com", "group_id": None, "tags": [], "columns": {}, "header": 1} assert response.status_code == 400
assert response.json() == {"detail": "Sheet with this ID already exists."}
# bad frequency
bad_data = good_data.copy()
bad_data["frequency"] = "every hour"
response = client_with_auth.post("/sheet/create", json=bad_data)
assert response.status_code == 422
assert "Value error, Invalid frequency: every hour. Must be one of" in response.json()["detail"][0]["msg"]
def test_sheet_missing_sheet_data(client_with_auth): # bad group
r = client_with_auth.post("/sheet/archive", json={}) bad_data = good_data.copy()
assert r.status_code == 422 bad_data["group_id"] = "not a group"
assert r.json() == {"detail": "sheet name or id is required"} response = client_with_auth.post("/sheet/create", json=bad_data)
assert response.status_code == 403
assert response.json() == {"detail": "User does not have access to this group."}
# bad quota
@patch("worker.main.create_sheet_task.delay", return_value=TaskResult(id="123-API-789", status="PENDING", result="")) jerry_data = good_data.copy()
def test_sheet_api(m1, client): jerry_data["group_id"] = "animated-characters"
jerry_data["id"] = "jerry-sheet-id"
response = client.post("/sheet/archive", json={"sheet_name": "456-sheet_name-id"}, headers={"Authorization": "Bearer this_is_the_test_api_token"}) from web.security import get_active_user_auth
app_with_auth.dependency_overrides[get_active_user_auth] = lambda: "jerry@example.com"
client_jerry = TestClient(app_with_auth)
response = client_jerry.post("/sheet/create", json=jerry_data)
assert response.status_code == 201 assert response.status_code == 201
assert response.json() == {'id': '123-API-789'}
m1.assert_called_once() response = client_jerry.post("/sheet/create", json=jerry_data)
called_val = m1.call_args.args[0] assert response.status_code == 429
assert json.loads(called_val) == {"sheet_name": "456-sheet_name-id", "sheet_id": None, "public": False, "author_id": "api-endpoint", "group_id": None, "tags": [], "columns": {}, "header": 1} assert response.json() == {"detail": "User has reached their sheet quota."}
response = client.post("/sheet/archive", json={"sheet_id": "456-sheet-id", "author_id": "custom-author"}, headers={"Authorization": "Bearer this_is_the_test_api_token"})
assert response.status_code == 201
assert response.json() == {'id': '123-API-789'}
assert m1.call_count == 2 def test_get_user_sheets_endpoint(client_with_auth, db_session):
called_val = m1.call_args.args[0] # no data
assert json.loads(called_val) == {"sheet_id": "456-sheet-id", "sheet_name": None, "public": False, "author_id": "custom-author", "group_id": None, "tags": [], "columns": {}, "header": 1} response = client_with_auth.get("/sheet/mine")
assert response.status_code == 200
assert response.json() == []
# with data
from db import models
db_session.add(
models.Sheet(id="123", name="Test Sheet 1", author_id="morty@example.com", group_id="spaceship", frequency="hourly")
)
db_session.commit()
db_session.add_all([
models.Sheet(id="456", name="Test Sheet 2", author_id="morty@example.com", group_id="interdimensional", frequency="daily"),
models.Sheet(id="789", name="Test Sheet 3", author_id="rick@example.com", group_id="interdimensional", frequency="hourly"),
])
db_session.commit()
response = client_with_auth.get("/sheet/mine")
assert response.status_code == 200
r = response.json()
assert isinstance(r, list)
assert len(r) == 2
assert datetime.fromisoformat(r[0].pop("created_at"))
assert datetime.fromisoformat(r[0].pop("last_archived_at"))
assert datetime.fromisoformat(r[1].pop("created_at"))
assert datetime.fromisoformat(r[1].pop("last_archived_at"))
assert r[0] == {
'id': '123',
'author_id': 'morty@example.com',
'frequency': 'hourly',
'group_id': 'spaceship',
'name': 'Test Sheet 1',
'stats': {},
}
assert r[1] == {
'id': '456',
'author_id': 'morty@example.com',
'frequency': 'daily',
'group_id': 'interdimensional',
'name': 'Test Sheet 2',
'stats': {},
}
def test_delete_sheet_endpoint(client_with_auth, db_session):
# missing sheet
response = client_with_auth.delete("/sheet/123-sheet-id")
assert response.status_code == 200
assert response.json() == {
"id": "123-sheet-id",
"deleted": False
}
# add sheets for deletion
from db import models
db_session.add_all([
models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="interdimensional", frequency="daily"),
models.Sheet(id="456-sheet-id", name="Test Sheet 2", author_id="rick@example.com", group_id="spaceship", frequency="hourly"),
])
db_session.commit()
# morty can delete his
response = client_with_auth.delete("/sheet/123-sheet-id")
assert response.status_code == 200
assert response.json() == {"id": "123-sheet-id", "deleted": True}
# but only once
response = client_with_auth.delete("/sheet/123-sheet-id")
assert response.status_code == 200
assert response.json() == {"id": "123-sheet-id", "deleted": False}
# and not rick's
response = client_with_auth.delete("/sheet/456-sheet-id")
assert response.status_code == 200
assert response.json() == {"id": "456-sheet-id", "deleted": False}
# def test_archive_user_sheet_endpoint(client_with_auth):
# response = client_with_auth.post("/sheet/123-sheet-id/archive")
# assert response.status_code == 201
# assert "id" in response.json()
class TestArchiveUserSheetEndpoint:
def test_token_auth(self, client_with_token, test_no_auth):
test_no_auth(client_with_token.post, "/sheet/123-sheet-id/archive")
def test_missing_data(self, client_with_auth):
r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == 403
assert r.json() == {"detail": "No access to this sheet."}
def test_no_access(self, client_with_auth, db_session):
from db import models
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="rick@example.com", group_id="spaceship", frequency="hourly"))
db_session.commit()
r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == 403
assert r.json() == {"detail": "No access to this sheet."}
@patch("worker.main.create_sheet_task.delay", return_value=TaskResult(id="123-taskid", status="PENDING", result=""))
def test_normal_flow(self, m1, client_with_auth, db_session):
from db import models
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="spaceship", frequency="hourly"))
db_session.commit()
r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == 201
assert r.json() == {"id": "123-taskid"}
m1.assert_called_once()
class TestTokenArchiveEndpoint:
def test_user_auth(self, client_with_auth, test_no_auth):
test_no_auth(client_with_auth.post, "/sheet/archive")
def test_missing_data(self, client_with_token):
r = client_with_token.post("/sheet/archive", json={})
assert r.status_code == 422
assert r.json() == {"detail": "sheet id is required"}
@patch("worker.main.create_sheet_task.delay", return_value=TaskResult(id="123-456-789", status="PENDING", result=""))
def test_normal_flow(self, m1, client_with_token):
# minimum data
response = client_with_token.post("/sheet/archive", json={"sheet_id": "123-sheet-id"})
assert response.status_code == 201
assert response.json() == {'id': '123-456-789'}
m1.assert_called_once()
called_val = m1.call_args.args[0]
assert json.loads(called_val) == {"sheet_id": "123-sheet-id", "sheet_name": None, "public": False, "author_id": "api-endpoint", "group_id": None, "tags": [], "columns": {}, "header": 1}
# maximum data
response = client_with_token.post("/sheet/archive", json={"sheet_id": "123-sheet-id", "sheet_name": "768-sheet-name", "author_id": "birdman@example.com", "header": 2, "public": True, "group_id": "456-group-id", "tags": ["tag1"], "columns": {"col1": "type1"}})
assert response.status_code == 201
assert response.json() == {'id': '123-456-789'}
m1.call_count == 2
called_val = m1.call_args.args[0]
assert json.loads(called_val) == {"sheet_id": "123-sheet-id", "sheet_name": "768-sheet-name", "public": True, "author_id": "birdman@example.com", "group_id": "456-group-id", "tags": ["tag1"], "columns": {"col1": "type1"}, "header": 2}

View File

@@ -33,6 +33,7 @@ groups:
active_sheets: -1 active_sheets: -1
monthly_urls: all monthly_urls: all
monthly_mbs: all monthly_mbs: all
alowed_frequency: "hourly"
interdimensional: interdimensional:
description: "Interdimensional travelers" description: "Interdimensional travelers"
orchestrator: tests/orchestration.test.yaml orchestrator: tests/orchestration.test.yaml
@@ -42,12 +43,14 @@ groups:
active_sheets: 5 active_sheets: 5
monthly_urls: 1000 monthly_urls: 1000
monthly_mbs: 1000 monthly_mbs: 1000
alowed_frequency: "hourly"
animated-characters: animated-characters:
description: "Animated characters" description: "Animated characters"
orchestrator: tests/orchestration.test.yaml orchestrator: tests/orchestration.test.yaml
orchestrator_sheet: tests/orchestration.test.yaml orchestrator_sheet: tests/orchestration.test.yaml
permissions: permissions:
read: ["animated-characters"] read: ["animated-characters"]
active_sheets: -1 active_sheets: 1
monthly_urls: all monthly_urls: 2
monthly_mbs: all monthly_mbs: 10
alowed_frequency: "daily"

View File

@@ -4,6 +4,8 @@ from fastapi import HTTPException, status, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from core.config import ALLOW_ANY_EMAIL from core.config import ALLOW_ANY_EMAIL
from shared.settings import get_settings from shared.settings import get_settings
from db.database import get_db
from db import crud
settings = get_settings() settings = get_settings()
bearer_security = HTTPBearer() bearer_security = HTTPBearer()
@@ -54,6 +56,18 @@ async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bear
) )
async def get_active_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
# validates Bearer token and Active User status
try:
email = await get_user_auth(credentials)
with get_db() as db:
if crud.is_active_user(db, email):
return email
raise HTTPException(status_code=403, detail="User is not active")
except HTTPException as e:
raise e
def authenticate_user(access_token): def authenticate_user(access_token):
# https://cloud.google.com/docs/authentication/token-types#access # https://cloud.google.com/docs/authentication/token-types#access
if type(access_token) != str or len(access_token) < 10: return False, "invalid access_token" if type(access_token) != str or len(access_token) < 10: return False, "invalid access_token"
@@ -69,7 +83,7 @@ def authenticate_user(access_token):
return False, f"email '{j.get('email')}' not verified" return False, f"email '{j.get('email')}' not verified"
if int(j.get("expires_in", -1)) <= 0: if int(j.get("expires_in", -1)) <= 0:
return False, "Token expired" return False, "Token expired"
return True, j.get('email') return True, j.get('email').lower()
except Exception as e: except Exception as e:
logger.warning(f"AUTH EXCEPTION occurred: {e}") logger.warning(f"AUTH EXCEPTION occurred: {e}")
return False, "exception occurred" return False, "exception occurred"

View File

@@ -64,11 +64,13 @@ def create_sheet_task(self, sheet_json: str):
sheet.tags.add("gsheet") sheet.tags.add("gsheet")
logger.info(f"SHEET START {sheet=}") logger.info(f"SHEET START {sheet=}")
#TODO: should this check live here?
if (em := is_group_invalid_for_user(sheet.public, sheet.group_id, sheet.author_id)): if (em := is_group_invalid_for_user(sheet.public, sheet.group_id, sheet.author_id)):
return {"error": em} return {"error": em}
config = Config() config = Config()
# TODO: use choose_orchestrator and overwrite the feeder # TODO: use choose_orchestrator and overwrite the feeder
# TODO: drop sheet_name and use only sheet_id (new endpoints/models)
config.parse(use_cli=False, yaml_config_filename=get_settings().SHEET_ORCHESTRATION_YAML, overwrite_configs={"configurations": {"gsheet_feeder": {"sheet": sheet.sheet_name, "sheet_id": sheet.sheet_id, "header": sheet.header}}}) config.parse(use_cli=False, yaml_config_filename=get_settings().SHEET_ORCHESTRATION_YAML, overwrite_configs={"configurations": {"gsheet_feeder": {"sheet": sheet.sheet_name, "sheet_id": sheet.sheet_id, "header": sheet.header}}})
orchestrator = ArchivingOrchestrator(config) orchestrator = ArchivingOrchestrator(config)
@@ -78,6 +80,7 @@ def create_sheet_task(self, sheet_json: str):
logger.error("Got empty result from feeder, an internal error must have occurred.") logger.error("Got empty result from feeder, an internal error must have occurred.")
continue continue
try: try:
#TODO: remove public from sheet in new refactor
insert_result_into_db(result, sheet.tags, sheet.public, sheet.group_id, sheet.author_id, models.generate_uuid()) insert_result_into_db(result, sheet.tags, sheet.public, sheet.group_id, sheet.author_id, models.generate_uuid())
stats["archived"] += 1 stats["archived"] += 1
except exc.IntegrityError as e: except exc.IntegrityError as e: