mirror of
https://github.com/bellingcat/auto-archiver-api.git
synced 2026-06-08 03:28:35 +03:00
introduces Sheet models and auth flow
This commit is contained in:
@@ -126,7 +126,8 @@ def is_user_in_group(db: Session, group_name: str, email: str) -> models.Group:
|
||||
return len(group_name) and len(email) and group_name in get_user_groups(db, email)
|
||||
|
||||
|
||||
def get_user_groups(db: Session, email: str):
|
||||
#TODO: maybe this can be cached? what about the db session?
|
||||
def get_user_groups(db: Session, email: str) -> list[str]:
|
||||
"""
|
||||
given an email retrieves the user groups from the DB and then the email-domain groups from a global variable, the email does not need to belong to an existing user. User does not need to be active.
|
||||
"""
|
||||
@@ -135,15 +136,53 @@ def get_user_groups(db: Session, email: str):
|
||||
|
||||
# get user groups
|
||||
user_groups = db.query(models.association_table_user_groups).filter_by(user_id=email).with_entities(Column("group_id")).all()
|
||||
user_level_groups = [g[0] for g in user_groups]
|
||||
user_level_groups_names = [g[0] for g in user_groups]
|
||||
|
||||
# get domain groups
|
||||
domain = email.split('@')[1]
|
||||
domain_level_groups = db.query(models.Group.id).filter(models.Group.domains.contains(domain)).with_entities(Column("id")).all()
|
||||
domain_level_groups = [g[0] for g in domain_level_groups]
|
||||
domain_level_groups_names = [g[0] for g in domain_level_groups]
|
||||
|
||||
# combine and return
|
||||
return list(set(user_level_groups + domain_level_groups))
|
||||
return list(set(user_level_groups_names + domain_level_groups_names))
|
||||
|
||||
|
||||
# --------------- SHEET
|
||||
|
||||
def has_quota_sheet(db: Session, email: str, user_groups_names: list[str]) -> bool:
|
||||
"""
|
||||
checks if a user has reached their sheet quota
|
||||
"""
|
||||
user_sheets = db.query(models.Sheet).filter(models.Sheet.author_id == email).count()
|
||||
|
||||
user_groups = db.query(models.Group).filter(models.Group.id.in_(user_groups_names)).all()
|
||||
|
||||
quota = 0
|
||||
for group in user_groups:
|
||||
active_sheets = group.permissions.get("active_sheets", 0)
|
||||
if active_sheets == -1: return True
|
||||
quota = max(quota, active_sheets)
|
||||
return user_sheets < quota
|
||||
|
||||
|
||||
def create_sheet(db: Session, sheet_id: str, sheet_name: str, email: str, group_id: str, frequency: str):
|
||||
db_sheet = models.Sheet(id=sheet_id, name=sheet_name, author_id=email, group_id=group_id, frequency=frequency)
|
||||
db.add(db_sheet)
|
||||
db.commit()
|
||||
db.refresh(db_sheet)
|
||||
return db_sheet
|
||||
|
||||
def get_user_sheets(db: Session, email: str) -> list[models.Sheet]:
|
||||
return db.query(models.Sheet).filter(models.Sheet.author_id == email).order_by(models.Sheet.last_archived_at.desc()).all()
|
||||
|
||||
def get_user_sheet(db: Session, email: str, sheet_id: str) -> models.Sheet:
|
||||
return db.query(models.Sheet).filter(models.Sheet.author_id == email, models.Sheet.id == sheet_id).first()
|
||||
|
||||
def delete_sheet(db: Session, sheet_id: str, email: str) -> bool:
|
||||
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id, models.Sheet.author_id == email).first()
|
||||
if db_sheet:
|
||||
db.delete(db_sheet)
|
||||
db.commit()
|
||||
return db_sheet is not None
|
||||
|
||||
|
||||
# --------------- INIT User-Groups
|
||||
@@ -255,5 +294,5 @@ def upsert_user_groups(db: Session):
|
||||
db.commit()
|
||||
count_user_groups = db.query(models.association_table_user_groups).count()
|
||||
count_groups = db.query(func.count(models.Group.id)).scalar()
|
||||
|
||||
|
||||
logger.success(f"[CONFIG] DONE: [users={count_users(db)}, groups={count_groups}, explicit user groups={count_user_groups}].")
|
||||
|
||||
@@ -6,9 +6,11 @@ import uuid
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def generate_uuid():
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
# many to many association tables
|
||||
association_table_archive_tags = Table(
|
||||
"mtm_archives_tags",
|
||||
@@ -24,24 +26,29 @@ association_table_user_groups = Table(
|
||||
)
|
||||
|
||||
# data model tables
|
||||
|
||||
|
||||
class Archive(Base):
|
||||
__tablename__ = "archives"
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
url = Column(String, index=True)
|
||||
result = Column(JSON, default=None)
|
||||
public = Column(Boolean, default=True) # if public=false, access to group and author
|
||||
public = Column(Boolean, default=True) # if public=false, access to group and author
|
||||
deleted = Column(Boolean, default=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
group_id = Column(String, ForeignKey("groups.id"), default=None)
|
||||
author_id = Column(String, ForeignKey("users.email"))
|
||||
sheet_id = Column(String, ForeignKey("sheets.id"), default=None)
|
||||
|
||||
tags = relationship("Tag", back_populates="archives", secondary=association_table_archive_tags)
|
||||
group = relationship("Group", back_populates="archives")
|
||||
author = relationship("User", back_populates="archives")
|
||||
urls = relationship("ArchiveUrl", back_populates="archive")
|
||||
sheet = relationship("Sheet", back_populates="archives")
|
||||
|
||||
|
||||
class ArchiveUrl(Base):
|
||||
__tablename__ = "archive_urls"
|
||||
@@ -61,6 +68,7 @@ class Tag(Base):
|
||||
|
||||
archives = relationship("Archive", back_populates="tags", secondary=association_table_archive_tags)
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
@@ -68,8 +76,10 @@ class User(Base):
|
||||
is_active = Column(Boolean, default=False)
|
||||
|
||||
archives = relationship("Archive", back_populates="author")
|
||||
sheets = relationship("Sheet", back_populates="author")
|
||||
groups = relationship("Group", back_populates="users", secondary=association_table_user_groups)
|
||||
|
||||
|
||||
class Group(Base):
|
||||
__tablename__ = "groups"
|
||||
|
||||
@@ -81,4 +91,23 @@ class Group(Base):
|
||||
domains = Column(JSON, default=[])
|
||||
|
||||
archives = relationship("Archive", back_populates="group")
|
||||
users = relationship("User", back_populates="groups", secondary=association_table_user_groups)
|
||||
sheets = relationship("Sheet", back_populates="group")
|
||||
users = relationship("User", back_populates="groups", secondary=association_table_user_groups)
|
||||
|
||||
|
||||
class Sheet(Base):
|
||||
__tablename__ = "sheets"
|
||||
|
||||
id = Column(String, primary_key=True, index=True, doc="Google Sheet ID")
|
||||
name = Column(String, default=None)
|
||||
author_id = Column(String, ForeignKey("users.email"))
|
||||
group_id = Column(String, ForeignKey("groups.id"), doc="Group ID, user must be in a group to create a sheet.")
|
||||
frequency = Column(String, default="daily", doc="Frequency of archiving: hourly, daily, weekly.")
|
||||
stats = Column(JSON, default={}, doc="Sheet statistics like total links, total rows, ...")
|
||||
last_archived_at = Column(DateTime(timezone=True), server_default=func.now(), doc="Last time a new link was archived.")
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
group = relationship("Group", back_populates="sheets")
|
||||
author = relationship("User", back_populates="sheets")
|
||||
archives = relationship("Archive", back_populates="sheet")
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, field_validator
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@@ -6,9 +6,10 @@ class Tag(BaseModel):
|
||||
id: str
|
||||
created_at: datetime
|
||||
|
||||
model_config = { "from_attributes": True }
|
||||
model_config = {"from_attributes": True}
|
||||
__hash__ = object.__hash__
|
||||
|
||||
|
||||
class ArchiveCreate(BaseModel):
|
||||
id: str | None = None
|
||||
url: str
|
||||
@@ -26,7 +27,8 @@ class Archive(ArchiveCreate):
|
||||
updated_at: datetime | None
|
||||
deleted: bool
|
||||
|
||||
model_config = { "from_attributes": True }
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class SubmitSheet(BaseModel):
|
||||
sheet_name: str | None = None
|
||||
@@ -36,31 +38,70 @@ class SubmitSheet(BaseModel):
|
||||
author_id: str | None = None
|
||||
group_id: str | None = None
|
||||
tags: set[str] | None = set()
|
||||
columns: dict | None = {} # TODO: implement
|
||||
columns: dict | None = {} # TODO: implement
|
||||
|
||||
|
||||
class SubmitManual(BaseModel):
|
||||
result: str # should be a Metadata.to_json()
|
||||
result: str # should be a Metadata.to_json()
|
||||
public: bool = False
|
||||
author_id: str | None = None
|
||||
group_id: str | None = None
|
||||
tags: set[str] | None = set()
|
||||
|
||||
# API REQUESTS BELOW
|
||||
# TODO: replace existing schemas with these
|
||||
|
||||
|
||||
class ArchiveUrl(BaseModel):
|
||||
url: str
|
||||
public: bool = False
|
||||
author_id: str | None
|
||||
group_id: str | None
|
||||
tags: set[str] | None = set()
|
||||
|
||||
# API RESPONSES BELOW
|
||||
|
||||
|
||||
class ArchiveResult(BaseModel):
|
||||
id: str
|
||||
url: str
|
||||
result: dict
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
class TaskResult(Task):
|
||||
status: str
|
||||
result: str
|
||||
|
||||
|
||||
class TaskDelete(Task):
|
||||
deleted: bool
|
||||
|
||||
|
||||
class ActiveUser(BaseModel):
|
||||
active: bool
|
||||
active: bool
|
||||
|
||||
|
||||
class SheetAdd(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
group_id: str
|
||||
frequency: str
|
||||
|
||||
@field_validator('frequency')
|
||||
def validate_frequency(cls, v):
|
||||
valid_frequencies = {"hourly", "daily"}
|
||||
if v not in {"hourly", "daily"}:
|
||||
raise ValueError(f"Invalid frequency: {v}. Must be one of {valid_frequencies}.")
|
||||
return v
|
||||
|
||||
|
||||
class SheetResponse(SheetAdd):
|
||||
author_id: str
|
||||
stats: dict | None
|
||||
last_archived_at: datetime | None
|
||||
created_at: datetime
|
||||
|
||||
@@ -2,23 +2,80 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import exc
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.config import ALLOW_ANY_EMAIL
|
||||
from web.security import get_token_or_user_auth
|
||||
from db import schemas
|
||||
from web.security import token_api_key_auth, get_active_user_auth
|
||||
from db import schemas, crud
|
||||
from db.database import get_db_dependency
|
||||
from worker.main import create_sheet_task
|
||||
|
||||
sheet_router = APIRouter(prefix="/sheet", tags=["Google Spreadsheet operations"])
|
||||
|
||||
|
||||
@sheet_router.post("/archive", status_code=201, summary="Submit a Google Sheet archive request, starts a sheet archiving task.", response_description="task_id for the archiving task.")
|
||||
def archive_sheet(sheet:schemas.SubmitSheet, email = Depends(get_token_or_user_auth)) -> schemas.Task:
|
||||
logger.info(f"SHEET TASK for {sheet=}")
|
||||
if email == ALLOW_ANY_EMAIL:
|
||||
email = sheet.author_id or "api-endpoint"
|
||||
sheet.author_id = email
|
||||
if not sheet.sheet_name and not sheet.sheet_id:
|
||||
raise HTTPException(status_code=422, detail=f"sheet name or id is required")
|
||||
@sheet_router.post("/create", status_code=201, summary="Store a new Google Sheet for regular archiving.")
|
||||
def create_sheet(
|
||||
sheet: schemas.SheetAdd,
|
||||
email=Depends(get_active_user_auth),
|
||||
db: Session = Depends(get_db_dependency),
|
||||
) -> schemas.SheetResponse:
|
||||
user_groups_names = crud.get_user_groups(db, email)
|
||||
|
||||
if sheet.group_id not in user_groups_names:
|
||||
raise HTTPException(status_code=403, detail="User does not have access to this group.")
|
||||
|
||||
if not crud.has_quota_sheet(db, email, user_groups_names):
|
||||
raise HTTPException(status_code=429, detail="User has reached their sheet quota.")
|
||||
|
||||
try:
|
||||
return crud.create_sheet(db, sheet.id, sheet.name, email, sheet.group_id, sheet.frequency)
|
||||
except exc.IntegrityError as e:
|
||||
raise HTTPException(status_code=400, detail="Sheet with this ID already exists.") from e
|
||||
|
||||
|
||||
@sheet_router.get("/mine", status_code=200, summary="Get the authenticated user's Google Sheets.")
|
||||
def get_user_sheets(
|
||||
email=Depends(get_active_user_auth),
|
||||
db: Session = Depends(get_db_dependency)
|
||||
) -> list[schemas.SheetResponse]:
|
||||
return crud.get_user_sheets(db, email)
|
||||
|
||||
|
||||
@sheet_router.delete("/{id}", summary="Delete a Google Sheet by ID.")
|
||||
def delete_sheet(
|
||||
id: str,
|
||||
email=Depends(get_active_user_auth),
|
||||
db: Session = Depends(get_db_dependency),
|
||||
) -> schemas.TaskDelete:
|
||||
return JSONResponse({
|
||||
"id": id,
|
||||
"deleted": crud.delete_sheet(db, id, email)
|
||||
})
|
||||
|
||||
|
||||
@sheet_router.post("/{id}/archive", status_code=201, summary="Trigger an archiving task for a GSheet you own.", response_description="task_id for the archiving task.")
|
||||
def archive_user_sheet(
|
||||
id: str,
|
||||
email=Depends(get_active_user_auth),
|
||||
db: Session = Depends(get_db_dependency),
|
||||
) -> schemas.Task:
|
||||
|
||||
sheet = crud.get_user_sheet(db, email, sheet_id=id)
|
||||
if not sheet:
|
||||
raise HTTPException(status_code=403, detail="No access to this sheet.")
|
||||
|
||||
task = create_sheet_task.delay(schemas.SubmitSheet(sheet_id=id, author_id=email, group=sheet.group_id).model_dump_json())
|
||||
|
||||
return JSONResponse({"id": task.id}, status_code=201)
|
||||
|
||||
|
||||
@sheet_router.post("/archive", status_code=201, summary="Trigger an archiving task for any GSheet with an API token.", response_description="task_id for the archiving task.")
|
||||
def archive_sheet(
|
||||
sheet: schemas.SubmitSheet, #TODO: replace with simpler model
|
||||
auth=Depends(token_api_key_auth)
|
||||
) -> schemas.Task:
|
||||
sheet.author_id = sheet.author_id or "api-endpoint"
|
||||
if not sheet.sheet_id:
|
||||
raise HTTPException(status_code=422, detail=f"sheet id is required")
|
||||
task = create_sheet_task.delay(sheet.model_dump_json())
|
||||
return JSONResponse({"id": task.id}, status_code=201)
|
||||
return JSONResponse({"id": task.id}, status_code=201)
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
"""add sheet_id to archive table
|
||||
|
||||
Revision ID: 89121d2c96d8
|
||||
Revises: fa012ec405b8
|
||||
Create Date: 2024-11-04 11:12:30.237299
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '89121d2c96d8'
|
||||
down_revision = 'fa012ec405b8'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = Inspector.from_engine(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('archives')]
|
||||
|
||||
if 'sheet_id' not in columns:
|
||||
with op.batch_alter_table('archives') as batch_op:
|
||||
batch_op.add_column(sa.Column('sheet_id', sa.String(), nullable=True, default=None))
|
||||
batch_op.create_foreign_key('fk_sheet_id', 'sheets', ['sheet_id'], ['id'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = Inspector.from_engine(conn)
|
||||
foreign_keys = [fk['name'] for fk in inspector.get_foreign_keys('archives')]
|
||||
columns = [col['name'] for col in inspector.get_columns('archives')]
|
||||
|
||||
with op.batch_alter_table('archives') as batch_op:
|
||||
if 'fk_sheet_id' in foreign_keys:
|
||||
batch_op.drop_constraint('fk_sheet_id', type_='foreignkey')
|
||||
|
||||
if 'sheet_id' in columns:
|
||||
batch_op.drop_column('sheet_id')
|
||||
@@ -35,8 +35,11 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('groups', 'description')
|
||||
op.drop_column('groups', 'orchestrator')
|
||||
op.drop_column('groups', 'orchestrator_sheet')
|
||||
op.drop_column('groups', 'permissions')
|
||||
op.drop_column('groups', 'domains')
|
||||
conn = op.get_bind()
|
||||
inspector = Inspector.from_engine(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('groups')]
|
||||
|
||||
column_names = ['description', 'orchestrator', 'orchestrator_sheet', 'permissions', 'domains']
|
||||
for column_name in column_names:
|
||||
if column_name in columns:
|
||||
op.drop_column('groups', column_name)
|
||||
|
||||
@@ -16,6 +16,7 @@ def mock_logger_add():
|
||||
def get_settings():
|
||||
return Settings(_env_file=".env.test")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_settings():
|
||||
with patch('shared.settings.Settings', return_value=Settings(_env_file=".env.test")) as mock_settings:
|
||||
@@ -26,7 +27,7 @@ def mock_settings():
|
||||
def test_db(get_settings: Settings):
|
||||
from db.database import make_engine
|
||||
from db import models
|
||||
|
||||
|
||||
make_engine.cache_clear()
|
||||
engine = make_engine(get_settings.DATABASE_PATH)
|
||||
|
||||
@@ -72,10 +73,10 @@ def client(app):
|
||||
|
||||
@pytest.fixture()
|
||||
def app_with_auth(app):
|
||||
from web.security import get_token_or_user_auth, get_user_auth, token_api_key_auth
|
||||
from web.security import get_token_or_user_auth, get_user_auth, get_active_user_auth
|
||||
app.dependency_overrides[get_token_or_user_auth] = lambda: "rick@example.com"
|
||||
app.dependency_overrides[get_user_auth] = lambda: "morty@example.com"
|
||||
app.dependency_overrides[token_api_key_auth] = lambda: "jerry@example.com"
|
||||
app.dependency_overrides[get_active_user_auth] = lambda: "morty@example.com"
|
||||
return app
|
||||
|
||||
|
||||
@@ -85,6 +86,19 @@ def client_with_auth(app_with_auth):
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def app_with_token(app):
|
||||
from web.security import token_api_key_auth
|
||||
app.dependency_overrides[token_api_key_auth] = lambda: "jerry@example.com"
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client_with_token(app_with_token):
|
||||
client = TestClient(app_with_token)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_no_auth():
|
||||
# reusable code to ensure a method/endpoint combination is unauthorized
|
||||
@@ -92,4 +106,4 @@ def test_no_auth():
|
||||
response = http_method(endpoint)
|
||||
assert response.status_code == 403
|
||||
assert response.json() == {"detail": "Not authenticated"}
|
||||
return no_auth
|
||||
return no_auth
|
||||
|
||||
@@ -101,10 +101,16 @@ def test_favicon(client_with_auth):
|
||||
assert r.headers["content-type"] == "image/vnd.microsoft.icon"
|
||||
|
||||
|
||||
def test_endpoint_test_prometheus_no_auth(client, test_no_auth):
|
||||
test_no_auth(client.get, "/metrics")
|
||||
|
||||
def test_endpoint_test_prometheus_no_user_auth(client_with_auth, test_no_auth):
|
||||
test_no_auth(client_with_auth.get, "/metrics")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prometheus_metrics(test_data, client_with_auth, get_settings):
|
||||
async def test_prometheus_metrics(test_data, client_with_token, get_settings):
|
||||
# before metrics calculation
|
||||
r = client_with_auth.get("/metrics")
|
||||
r = client_with_token.get("/metrics")
|
||||
assert r.status_code == 200
|
||||
assert r.headers["content-type"] == "text/plain; version=0.0.4; charset=utf-8"
|
||||
assert "disk_utilization" in r.text
|
||||
@@ -116,7 +122,7 @@ async def test_prometheus_metrics(test_data, client_with_auth, get_settings):
|
||||
# after metrics calculation
|
||||
from utils.metrics import measure_regular_metrics
|
||||
await measure_regular_metrics(get_settings.DATABASE_PATH, 60 * 60 * 24 * 31 * 12 * 100)
|
||||
r2 = client_with_auth.get("/metrics")
|
||||
r2 = client_with_token.get("/metrics")
|
||||
assert 'disk_utilization{type="used"}' in r2.text
|
||||
assert 'disk_utilization{type="free"}' in r2.text
|
||||
assert 'disk_utilization{type="database"}' in r2.text
|
||||
@@ -130,7 +136,7 @@ async def test_prometheus_metrics(test_data, client_with_auth, get_settings):
|
||||
# 30s window, should not change the gauges nor the total in the counters
|
||||
from utils.metrics import measure_regular_metrics
|
||||
await measure_regular_metrics(get_settings.DATABASE_PATH, 30)
|
||||
r3 = client_with_auth.get("/metrics")
|
||||
r3 = client_with_token.get("/metrics")
|
||||
assert 'database_metrics{query="count_archives"} 100.0' in r3.text
|
||||
assert 'database_metrics{query="count_archive_urls"} 1000.0' in r3.text
|
||||
assert 'database_metrics{query="count_users"} 4.0' in r3.text
|
||||
|
||||
@@ -5,15 +5,19 @@ def test_submit_manual_archive_unauthenticated(client, test_no_auth):
|
||||
test_no_auth(client.post, "/interop/submit-archive")
|
||||
|
||||
|
||||
def test_submit_manual_archive(client_with_auth):
|
||||
def test_submit_manual_archive_not_user_auth(client_with_auth, test_no_auth):
|
||||
test_no_auth(client_with_auth.post, "/interop/submit-archive")
|
||||
|
||||
|
||||
def test_submit_manual_archive(client_with_token):
|
||||
aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": []})
|
||||
|
||||
r = client_with_auth.post("/interop/submit-archive", json={"result": aa_metadata, "public": False, "author_id": "jerry@gmail.com", "group_id": None, "tags": ["test"]})
|
||||
r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": False, "author_id": "jerry@gmail.com", "group_id": None, "tags": ["test"]})
|
||||
assert r.status_code == 201
|
||||
assert "id" in r.json()
|
||||
|
||||
# cannot have the same URL twice
|
||||
aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.com", "http://example.com"]}]})
|
||||
r = client_with_auth.post("/interop/submit-archive", json={"result": aa_metadata, "public": False, "author_id": "jerry@gmail.com", "group_id": None, "tags": ["test"]})
|
||||
r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": False, "author_id": "jerry@gmail.com", "group_id": None, "tags": ["test"]})
|
||||
assert r.status_code == 422
|
||||
assert r.json() == {"detail": "Cannot insert into DB due to integrity error"}
|
||||
|
||||
@@ -1,46 +1,210 @@
|
||||
from datetime import datetime
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from db.schemas import TaskResult
|
||||
|
||||
|
||||
def test_sheet_no_auth(client, test_no_auth):
|
||||
def test_endpoints_no_auth(client, test_no_auth):
|
||||
test_no_auth(client.post, "/sheet/create")
|
||||
test_no_auth(client.get, "/sheet/mine")
|
||||
test_no_auth(client.delete, "/sheet/123-sheet-id")
|
||||
test_no_auth(client.post, "/sheet/123-sheet-id/archive")
|
||||
test_no_auth(client.post, "/sheet/archive")
|
||||
|
||||
|
||||
@patch("worker.main.create_sheet_task.delay", return_value=TaskResult(id="123-456-789", status="PENDING", result=""))
|
||||
def test_sheet_rick(m1, client_with_auth):
|
||||
def test_create_sheet_endpoint(app_with_auth):
|
||||
client_with_auth = TestClient(app_with_auth)
|
||||
good_data = {
|
||||
"id": "123-sheet-id",
|
||||
"name": "Test Sheet",
|
||||
"group_id": "spaceship",
|
||||
"frequency": "daily"
|
||||
}
|
||||
|
||||
response = client_with_auth.post("/sheet/archive", json={"sheet_id": "123-sheet-id"})
|
||||
# with good data
|
||||
response = client_with_auth.post("/sheet/create", json=good_data)
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {'id': '123-456-789'}
|
||||
j = response.json()
|
||||
assert datetime.fromisoformat(j.pop("created_at"))
|
||||
assert datetime.fromisoformat(j.pop("last_archived_at"))
|
||||
assert j.pop("stats") == {}
|
||||
assert j.pop("author_id") == 'morty@example.com'
|
||||
assert j == good_data
|
||||
|
||||
m1.assert_called_once()
|
||||
called_val = m1.call_args.args[0]
|
||||
assert json.loads(called_val) == {"sheet_id": "123-sheet-id", "sheet_name": None, "public": False, "author_id": "rick@example.com", "group_id": None, "tags": [], "columns": {}, "header": 1}
|
||||
# already exists
|
||||
response = client_with_auth.post("/sheet/create", json=good_data)
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {"detail": "Sheet with this ID already exists."}
|
||||
|
||||
# bad frequency
|
||||
bad_data = good_data.copy()
|
||||
bad_data["frequency"] = "every hour"
|
||||
response = client_with_auth.post("/sheet/create", json=bad_data)
|
||||
assert response.status_code == 422
|
||||
assert "Value error, Invalid frequency: every hour. Must be one of" in response.json()["detail"][0]["msg"]
|
||||
|
||||
def test_sheet_missing_sheet_data(client_with_auth):
|
||||
r = client_with_auth.post("/sheet/archive", json={})
|
||||
assert r.status_code == 422
|
||||
assert r.json() == {"detail": "sheet name or id is required"}
|
||||
# bad group
|
||||
bad_data = good_data.copy()
|
||||
bad_data["group_id"] = "not a group"
|
||||
response = client_with_auth.post("/sheet/create", json=bad_data)
|
||||
assert response.status_code == 403
|
||||
assert response.json() == {"detail": "User does not have access to this group."}
|
||||
|
||||
|
||||
@patch("worker.main.create_sheet_task.delay", return_value=TaskResult(id="123-API-789", status="PENDING", result=""))
|
||||
def test_sheet_api(m1, client):
|
||||
|
||||
response = client.post("/sheet/archive", json={"sheet_name": "456-sheet_name-id"}, headers={"Authorization": "Bearer this_is_the_test_api_token"})
|
||||
# bad quota
|
||||
jerry_data = good_data.copy()
|
||||
jerry_data["group_id"] = "animated-characters"
|
||||
jerry_data["id"] = "jerry-sheet-id"
|
||||
from web.security import get_active_user_auth
|
||||
app_with_auth.dependency_overrides[get_active_user_auth] = lambda: "jerry@example.com"
|
||||
client_jerry = TestClient(app_with_auth)
|
||||
response = client_jerry.post("/sheet/create", json=jerry_data)
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {'id': '123-API-789'}
|
||||
|
||||
m1.assert_called_once()
|
||||
called_val = m1.call_args.args[0]
|
||||
assert json.loads(called_val) == {"sheet_name": "456-sheet_name-id", "sheet_id": None, "public": False, "author_id": "api-endpoint", "group_id": None, "tags": [], "columns": {}, "header": 1}
|
||||
response = client_jerry.post("/sheet/create", json=jerry_data)
|
||||
assert response.status_code == 429
|
||||
assert response.json() == {"detail": "User has reached their sheet quota."}
|
||||
|
||||
response = client.post("/sheet/archive", json={"sheet_id": "456-sheet-id", "author_id": "custom-author"}, headers={"Authorization": "Bearer this_is_the_test_api_token"})
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {'id': '123-API-789'}
|
||||
|
||||
assert m1.call_count == 2
|
||||
called_val = m1.call_args.args[0]
|
||||
assert json.loads(called_val) == {"sheet_id": "456-sheet-id", "sheet_name": None, "public": False, "author_id": "custom-author", "group_id": None, "tags": [], "columns": {}, "header": 1}
|
||||
def test_get_user_sheets_endpoint(client_with_auth, db_session):
|
||||
# no data
|
||||
response = client_with_auth.get("/sheet/mine")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
# with data
|
||||
from db import models
|
||||
db_session.add(
|
||||
models.Sheet(id="123", name="Test Sheet 1", author_id="morty@example.com", group_id="spaceship", frequency="hourly")
|
||||
)
|
||||
db_session.commit()
|
||||
db_session.add_all([
|
||||
models.Sheet(id="456", name="Test Sheet 2", author_id="morty@example.com", group_id="interdimensional", frequency="daily"),
|
||||
models.Sheet(id="789", name="Test Sheet 3", author_id="rick@example.com", group_id="interdimensional", frequency="hourly"),
|
||||
])
|
||||
db_session.commit()
|
||||
|
||||
response = client_with_auth.get("/sheet/mine")
|
||||
assert response.status_code == 200
|
||||
r = response.json()
|
||||
assert isinstance(r, list)
|
||||
assert len(r) == 2
|
||||
assert datetime.fromisoformat(r[0].pop("created_at"))
|
||||
assert datetime.fromisoformat(r[0].pop("last_archived_at"))
|
||||
assert datetime.fromisoformat(r[1].pop("created_at"))
|
||||
assert datetime.fromisoformat(r[1].pop("last_archived_at"))
|
||||
assert r[0] == {
|
||||
'id': '123',
|
||||
'author_id': 'morty@example.com',
|
||||
'frequency': 'hourly',
|
||||
'group_id': 'spaceship',
|
||||
'name': 'Test Sheet 1',
|
||||
'stats': {},
|
||||
}
|
||||
assert r[1] == {
|
||||
'id': '456',
|
||||
'author_id': 'morty@example.com',
|
||||
'frequency': 'daily',
|
||||
'group_id': 'interdimensional',
|
||||
'name': 'Test Sheet 2',
|
||||
'stats': {},
|
||||
}
|
||||
|
||||
|
||||
def test_delete_sheet_endpoint(client_with_auth, db_session):
|
||||
# missing sheet
|
||||
response = client_with_auth.delete("/sheet/123-sheet-id")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"id": "123-sheet-id",
|
||||
"deleted": False
|
||||
}
|
||||
|
||||
# add sheets for deletion
|
||||
from db import models
|
||||
db_session.add_all([
|
||||
models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="interdimensional", frequency="daily"),
|
||||
models.Sheet(id="456-sheet-id", name="Test Sheet 2", author_id="rick@example.com", group_id="spaceship", frequency="hourly"),
|
||||
])
|
||||
db_session.commit()
|
||||
|
||||
# morty can delete his
|
||||
response = client_with_auth.delete("/sheet/123-sheet-id")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"id": "123-sheet-id", "deleted": True}
|
||||
# but only once
|
||||
response = client_with_auth.delete("/sheet/123-sheet-id")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"id": "123-sheet-id", "deleted": False}
|
||||
# and not rick's
|
||||
response = client_with_auth.delete("/sheet/456-sheet-id")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"id": "456-sheet-id", "deleted": False}
|
||||
|
||||
|
||||
# def test_archive_user_sheet_endpoint(client_with_auth):
|
||||
# response = client_with_auth.post("/sheet/123-sheet-id/archive")
|
||||
# assert response.status_code == 201
|
||||
# assert "id" in response.json()
|
||||
|
||||
|
||||
class TestArchiveUserSheetEndpoint:
|
||||
def test_token_auth(self, client_with_token, test_no_auth):
|
||||
test_no_auth(client_with_token.post, "/sheet/123-sheet-id/archive")
|
||||
|
||||
def test_missing_data(self, client_with_auth):
|
||||
r = client_with_auth.post("/sheet/123-sheet-id/archive")
|
||||
assert r.status_code == 403
|
||||
assert r.json() == {"detail": "No access to this sheet."}
|
||||
|
||||
def test_no_access(self, client_with_auth, db_session):
|
||||
from db import models
|
||||
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="rick@example.com", group_id="spaceship", frequency="hourly"))
|
||||
db_session.commit()
|
||||
r = client_with_auth.post("/sheet/123-sheet-id/archive")
|
||||
assert r.status_code == 403
|
||||
assert r.json() == {"detail": "No access to this sheet."}
|
||||
|
||||
@patch("worker.main.create_sheet_task.delay", return_value=TaskResult(id="123-taskid", status="PENDING", result=""))
|
||||
def test_normal_flow(self, m1, client_with_auth, db_session):
|
||||
from db import models
|
||||
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="spaceship", frequency="hourly"))
|
||||
db_session.commit()
|
||||
r = client_with_auth.post("/sheet/123-sheet-id/archive")
|
||||
assert r.status_code == 201
|
||||
assert r.json() == {"id": "123-taskid"}
|
||||
m1.assert_called_once()
|
||||
|
||||
|
||||
class TestTokenArchiveEndpoint:
|
||||
|
||||
def test_user_auth(self, client_with_auth, test_no_auth):
|
||||
test_no_auth(client_with_auth.post, "/sheet/archive")
|
||||
|
||||
def test_missing_data(self, client_with_token):
|
||||
r = client_with_token.post("/sheet/archive", json={})
|
||||
assert r.status_code == 422
|
||||
assert r.json() == {"detail": "sheet id is required"}
|
||||
|
||||
@patch("worker.main.create_sheet_task.delay", return_value=TaskResult(id="123-456-789", status="PENDING", result=""))
|
||||
def test_normal_flow(self, m1, client_with_token):
|
||||
|
||||
# minimum data
|
||||
response = client_with_token.post("/sheet/archive", json={"sheet_id": "123-sheet-id"})
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {'id': '123-456-789'}
|
||||
|
||||
m1.assert_called_once()
|
||||
called_val = m1.call_args.args[0]
|
||||
assert json.loads(called_val) == {"sheet_id": "123-sheet-id", "sheet_name": None, "public": False, "author_id": "api-endpoint", "group_id": None, "tags": [], "columns": {}, "header": 1}
|
||||
|
||||
# maximum data
|
||||
response = client_with_token.post("/sheet/archive", json={"sheet_id": "123-sheet-id", "sheet_name": "768-sheet-name", "author_id": "birdman@example.com", "header": 2, "public": True, "group_id": "456-group-id", "tags": ["tag1"], "columns": {"col1": "type1"}})
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {'id': '123-456-789'}
|
||||
|
||||
m1.call_count == 2
|
||||
called_val = m1.call_args.args[0]
|
||||
assert json.loads(called_val) == {"sheet_id": "123-sheet-id", "sheet_name": "768-sheet-name", "public": True, "author_id": "birdman@example.com", "group_id": "456-group-id", "tags": ["tag1"], "columns": {"col1": "type1"}, "header": 2}
|
||||
|
||||
@@ -33,6 +33,7 @@ groups:
|
||||
active_sheets: -1
|
||||
monthly_urls: all
|
||||
monthly_mbs: all
|
||||
alowed_frequency: "hourly"
|
||||
interdimensional:
|
||||
description: "Interdimensional travelers"
|
||||
orchestrator: tests/orchestration.test.yaml
|
||||
@@ -42,12 +43,14 @@ groups:
|
||||
active_sheets: 5
|
||||
monthly_urls: 1000
|
||||
monthly_mbs: 1000
|
||||
alowed_frequency: "hourly"
|
||||
animated-characters:
|
||||
description: "Animated characters"
|
||||
orchestrator: tests/orchestration.test.yaml
|
||||
orchestrator_sheet: tests/orchestration.test.yaml
|
||||
permissions:
|
||||
read: ["animated-characters"]
|
||||
active_sheets: -1
|
||||
monthly_urls: all
|
||||
monthly_mbs: all
|
||||
active_sheets: 1
|
||||
monthly_urls: 2
|
||||
monthly_mbs: 10
|
||||
alowed_frequency: "daily"
|
||||
@@ -4,6 +4,8 @@ from fastapi import HTTPException, status, Depends
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from core.config import ALLOW_ANY_EMAIL
|
||||
from shared.settings import get_settings
|
||||
from db.database import get_db
|
||||
from db import crud
|
||||
|
||||
settings = get_settings()
|
||||
bearer_security = HTTPBearer()
|
||||
@@ -54,6 +56,18 @@ async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bear
|
||||
)
|
||||
|
||||
|
||||
async def get_active_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
|
||||
# validates Bearer token and Active User status
|
||||
try:
|
||||
email = await get_user_auth(credentials)
|
||||
with get_db() as db:
|
||||
if crud.is_active_user(db, email):
|
||||
return email
|
||||
raise HTTPException(status_code=403, detail="User is not active")
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
|
||||
|
||||
def authenticate_user(access_token):
|
||||
# https://cloud.google.com/docs/authentication/token-types#access
|
||||
if type(access_token) != str or len(access_token) < 10: return False, "invalid access_token"
|
||||
@@ -69,7 +83,7 @@ def authenticate_user(access_token):
|
||||
return False, f"email '{j.get('email')}' not verified"
|
||||
if int(j.get("expires_in", -1)) <= 0:
|
||||
return False, "Token expired"
|
||||
return True, j.get('email')
|
||||
return True, j.get('email').lower()
|
||||
except Exception as e:
|
||||
logger.warning(f"AUTH EXCEPTION occurred: {e}")
|
||||
return False, "exception occurred"
|
||||
|
||||
@@ -64,11 +64,13 @@ def create_sheet_task(self, sheet_json: str):
|
||||
sheet.tags.add("gsheet")
|
||||
logger.info(f"SHEET START {sheet=}")
|
||||
|
||||
#TODO: should this check live here?
|
||||
if (em := is_group_invalid_for_user(sheet.public, sheet.group_id, sheet.author_id)):
|
||||
return {"error": em}
|
||||
|
||||
config = Config()
|
||||
# TODO: use choose_orchestrator and overwrite the feeder
|
||||
# TODO: drop sheet_name and use only sheet_id (new endpoints/models)
|
||||
config.parse(use_cli=False, yaml_config_filename=get_settings().SHEET_ORCHESTRATION_YAML, overwrite_configs={"configurations": {"gsheet_feeder": {"sheet": sheet.sheet_name, "sheet_id": sheet.sheet_id, "header": sheet.header}}})
|
||||
orchestrator = ArchivingOrchestrator(config)
|
||||
|
||||
@@ -78,6 +80,7 @@ def create_sheet_task(self, sheet_json: str):
|
||||
logger.error("Got empty result from feeder, an internal error must have occurred.")
|
||||
continue
|
||||
try:
|
||||
#TODO: remove public from sheet in new refactor
|
||||
insert_result_into_db(result, sheet.tags, sheet.public, sheet.group_id, sheet.author_id, models.generate_uuid())
|
||||
stats["archived"] += 1
|
||||
except exc.IntegrityError as e:
|
||||
|
||||
Reference in New Issue
Block a user