pushing bulk of changes

This commit is contained in:
msramalho
2025-01-22 13:21:16 +00:00
parent 2209b09a9a
commit 9f9bbc9344
21 changed files with 2048 additions and 2250 deletions

View File

@@ -21,6 +21,8 @@ orchestration must be from the console(?)
* turn off VPNs if connection to docker is not working
## User management
TODO: update description and example
- users/domains/groups
Copy [example.user-groups.yaml](src/example.user-groups.yaml) into a new file and set the environment variable `USER_GROUPS_FILENAME` to that filename (defaults to `user-groups.yaml`).
This file contains 2 parts user-groups specifications. Each user can archive URLs publicly, privately, or privately for a group so long as they are declared as part of that group. In the example bellow `email1` has 2 groups while `email3` has none.

View File

@@ -4,6 +4,7 @@ verify_ssl = true
name = "pypi"
[packages]
oscrypto = {git = "https://github.com/wbond/oscrypto.git", ref = "d5f3437ed24257895ae1edd9e503cfb352e635a8"}
aiofiles = "==0.6.0"
celery = ">=5.0"
fastapi = "*"

3759
src/Pipfile.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1 +0,0 @@
based on https://fastapi-users.github.io/fastapi-users/10.4/configuration/oauth/

View File

@@ -1,11 +1,12 @@
from collections import defaultdict
from functools import cache
from functools import lru_cache
from sqlalchemy.orm import Session, load_only
from sqlalchemy import Column, or_, func
from loguru import logger
from datetime import datetime, timedelta
from core.config import ALLOW_ANY_EMAIL
from db.database import get_db
from shared.settings import get_settings
from . import models, schemas
import yaml
@@ -23,7 +24,7 @@ def get_archive(db: Session, id: str, email: str):
email = email.lower()
query = base_query(db).filter(models.Archive.id == id)
if email != ALLOW_ANY_EMAIL:
groups = get_user_groups(db, email)
groups = get_user_groups(email)
query = query.filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups)))
return query.first()
@@ -33,7 +34,7 @@ def search_archives_by_url(db: Session, url: str, email: str, skip: int = 0, lim
query = base_query(db)
if email != ALLOW_ANY_EMAIL:
email = email.lower()
groups = get_user_groups(db, email)
groups = get_user_groups(email)
query = query.filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups)))
if absolute_search:
query = query.filter(models.Archive.url == url)
@@ -121,72 +122,37 @@ def is_active_user(db: Session, email: str) -> bool:
return db.query(models.Group).filter(models.Group.domains.contains(domain)).first() is not None
def is_user_in_group(db: Session, group_name: str, email: str) -> models.Group:
def is_user_in_group(db: Session, email: str, group_name: str) -> models.Group:
if email == ALLOW_ANY_EMAIL: return True
return len(group_name) and len(email) and group_name in get_user_groups(db, email)
return len(group_name) and len(email) and group_name in get_user_groups(email)
#TODO: maybe this can be cached? what about the db session?
def get_user_groups(db: Session, email: str) -> list[str]:
@lru_cache
def get_user_groups(email: str) -> list[str]:
"""
given an email retrieves the user groups from the DB and then the email-domain groups from a global variable, the email does not need to belong to an existing user. User does not need to be active.
"""
if not email or not len(email) or "@" not in email: return []
email = email.lower()
# get user groups
user_groups = db.query(models.association_table_user_groups).filter_by(user_id=email).with_entities(Column("group_id")).all()
user_level_groups_names = [g[0] for g in user_groups]
with get_db() as db:
# get user groups
user_groups = db.query(models.association_table_user_groups).filter_by(user_id=email).with_entities(Column("group_id")).all()
user_level_groups_names = [g[0] for g in user_groups]
# get domain groups
domain = email.split('@')[1]
domain_level_groups = db.query(models.Group.id).filter(models.Group.domains.contains(domain)).with_entities(Column("id")).all()
domain_level_groups_names = [g[0] for g in domain_level_groups]
# get domain groups
domain = email.split('@')[1]
domain_level_groups = db.query(models.Group.id).filter(models.Group.domains.contains(domain)).with_entities(Column("id")).all()
domain_level_groups_names = [g[0] for g in domain_level_groups]
return list(set(user_level_groups_names + domain_level_groups_names))
# --------------- SHEET
def has_quota_sheet(db: Session, email: str, user_groups_names: list[str]) -> bool:
"""
checks if a user has reached their sheet quota
"""
user_sheets = db.query(models.Sheet).filter(models.Sheet.author_id == email).count()
user_groups = db.query(models.Group).filter(models.Group.id.in_(user_groups_names)).all()
quota = 0
for group in user_groups:
active_sheets = group.permissions.get("active_sheets", 0)
if active_sheets == -1: return True
quota = max(quota, active_sheets)
return user_sheets < quota
def create_sheet(db: Session, sheet_id: str, sheet_name: str, email: str, group_id: str, frequency: str):
db_sheet = models.Sheet(id=sheet_id, name=sheet_name, author_id=email, group_id=group_id, frequency=frequency)
db.add(db_sheet)
db.commit()
db.refresh(db_sheet)
return db_sheet
def get_user_sheets(db: Session, email: str) -> list[models.Sheet]:
return db.query(models.Sheet).filter(models.Sheet.author_id == email).order_by(models.Sheet.last_archived_at.desc()).all()
def get_user_sheet(db: Session, email: str, sheet_id: str) -> models.Sheet:
return db.query(models.Sheet).filter(models.Sheet.author_id == email, models.Sheet.id == sheet_id).first()
def delete_sheet(db: Session, sheet_id: str, email: str) -> bool:
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id, models.Sheet.author_id == email).first()
if db_sheet:
db.delete(db_sheet)
db.commit()
return db_sheet is not None
return list(set(user_level_groups_names + domain_level_groups_names))
# --------------- INIT User-Groups
def get_group(db: Session, group_name: str) -> models.Group:
return db.query(models.Group).filter(models.Group.id == group_name).first()
def create_or_get_user(db: Session, author_id: str, is_active: bool = models.User.is_active.default.arg) -> models.User:
if type(author_id) == str: author_id = author_id.lower()
@@ -296,3 +262,28 @@ def upsert_user_groups(db: Session):
count_groups = db.query(func.count(models.Group.id)).scalar()
logger.success(f"[CONFIG] DONE: [users={count_users(db)}, groups={count_groups}, explicit user groups={count_user_groups}].")
# --------------- SHEET
def create_sheet(db: Session, sheet_id: str, sheet_name: str, email: str, group_id: str, frequency: str):
db_sheet = models.Sheet(id=sheet_id, name=sheet_name, author_id=email, group_id=group_id, frequency=frequency)
db.add(db_sheet)
db.commit()
db.refresh(db_sheet)
return db_sheet
def get_user_sheet(db: Session, email: str, sheet_id: str) -> models.Sheet:
return db.query(models.Sheet).filter(models.Sheet.author_id == email, models.Sheet.id == sheet_id).first()
def get_user_sheets(db: Session, email: str) -> list[models.Sheet]:
return db.query(models.Sheet).filter(models.Sheet.author_id == email).order_by(models.Sheet.last_archived_at.desc()).all()
def delete_sheet(db: Session, sheet_id: str, email: str) -> bool:
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id, models.Sheet.author_id == email).first()
if db_sheet:
db.delete(db_sheet)
db.commit()
return db_sheet is not None

View File

@@ -33,4 +33,4 @@ def get_db():
def get_db_dependency():
# to use with Depends and ensure proper session closing
with get_db() as db:
yield db
yield db

View File

@@ -87,7 +87,7 @@ class Group(Base):
description = Column(String, default=None)
orchestrator = Column(String, default=None)
orchestrator_sheet = Column(String, default=None)
permissions = Column(JSON, default=None)
permissions = Column(JSON, default={})
domains = Column(JSON, default=[])
archives = relationship("Archive", back_populates="group")

View File

@@ -1,3 +1,5 @@
from typing import Annotated
from annotated_types import Len
from pydantic import BaseModel, field_validator
from datetime import datetime
@@ -105,3 +107,10 @@ class SheetResponse(SheetAdd):
stats: dict | None
last_archived_at: datetime | None
created_at: datetime
class ArchiveTrigger(BaseModel):
url: Annotated[str, Len(min_length=5)]
public: bool = True
group_id: Annotated[str, Len(min_length=1)] | None = None
tags: set[Tag] | None = set()

142
src/db/user_state.py Normal file
View File

@@ -0,0 +1,142 @@
import sqlalchemy
from sqlalchemy.orm import Session
from sqlalchemy import func
from db import crud, models
from datetime import datetime
class UserState:
"""
Manage a user's state and permissions
"""
def __init__(self, db: Session, email: str, active=False):
self.db = db
self.email = email
self.active = active
@property
def user_groups_names(self):
if not hasattr(self, '_user_groups_names'):
self._user_groups_names = crud.get_user_groups(self.email)
return self._user_groups_names
@property
def user_groups(self):
if not hasattr(self, '_user_groups'):
self._user_groups = self.db.query(models.Group).filter(
models.Group.id.in_(self.user_groups_names)
).all()
return self._user_groups
@property
def allowed_frequencies(self):
if not hasattr(self, '_allowed_frequencies'):
self._allowed_frequencies = set()
for group in self.user_groups:
if not group.permissions: continue
self._allowed_frequencies.add(group.permissions.get("allowed_frequency", None))
if "hourly" in self._allowed_frequencies:
self._allowed_frequencies.add("daily")
return self._allowed_frequencies
@property
def sheet_quota(self):
"""
infer the user's sheet quota from the groups
-1 means unlimited
"""
if not hasattr(self, '_sheet_quota'):
self._sheet_quota = 0
for group in self.user_groups:
if not group.permissions: continue
active_sheets = group.permissions.get("active_sheets", 0)
if active_sheets == -1:
self._sheet_quota = -1
return self._sheet_quota
self._sheet_quota = max(self._sheet_quota, active_sheets)
return self._sheet_quota
def in_group(self, group_id: str) -> bool:
return group_id in self.user_groups_names
def has_quota_sheet(self) -> bool:
"""
checks if a user has reached their sheet quota
"""
if self.sheet_quota == -1: return True
user_sheets = self.db.query(models.Sheet).filter(models.Sheet.author_id == self.email).count()
return user_sheets < self.sheet_quota
def has_quota_monthly_urls(self) -> bool:
"""
checks if a user has reached their monthly url quota
"""
quota = 0
for group in self.user_groups:
if not group.permissions: continue
monthly_urls = group.permissions.get("monthly_urls", 0)
if monthly_urls == -1: return True
quota = max(quota, monthly_urls)
current_month = datetime.now().month
current_year = datetime.now().year
user_urls = self.db.query(models.Archive).filter(
models.Archive.author_id == self.email,
func.extract('month', models.Archive.created_at) == current_month,
func.extract('year', models.Archive.created_at) == current_year
).count()
return user_urls < quota
def has_quota_monthly_mbs(self) -> bool:
"""
checks if a user has reached their monthly mb quota
"""
quota = 0
for group in self.user_groups:
if not group.permissions: continue
monthly_mbs = group.permissions.get("monthly_mbs", 0)
if monthly_mbs == -1: return True
quota = max(quota, monthly_mbs)
current_month = datetime.now().month
current_year = datetime.now().year
# find and sum all user bytes over this month
user_bytes = self.db.query(models.Archive).filter(
models.Archive.author_id == self.email,
func.extract('month', models.Archive.created_at) == current_month,
func.extract('year', models.Archive.created_at) == current_year
).with_entities(func.coalesce(func.sum(
func.coalesce(
func.cast(
func.json_extract(models.Archive.result, '$.metadata.total_bytes'),
sqlalchemy.Integer
), 0
)
), 0).label('total')).scalar()
# convert bytes to mb
user_mbs = int(user_bytes / 1024 / 1024)
return user_mbs < quota
# def can_manually_trigger(self) -> bool:
# """
# checks if a user is allowed to manually trigger a sheet
# """
# for group in self.user_groups:
# if not group.permissions: continue
# if group.permissions.get("manual_trigger", False):
# return True
# return False
def is_sheet_frequency_allowed(self, frequency: str) -> bool:
"""
checks if a user is allowed to create a sheet with this frequency
"""
return frequency in self.allowed_frequencies

View File

@@ -6,8 +6,9 @@ from sqlalchemy.orm import Session
from core.config import VERSION, BREAKING_CHANGES
from core.logging import log_error
from db import crud, schemas
from db.database import get_db_dependency, get_db
from web.security import get_user_auth, bearer_security
from db.database import get_db_dependency
from db.user_state import UserState
from web.security import get_user_auth, bearer_security, get_active_user_state
default_router = APIRouter()
@@ -18,8 +19,7 @@ async def home(request: Request):
status = {"version": VERSION, "breakingChanges": BREAKING_CHANGES}
try:
email = await get_user_auth(await bearer_security(request))
with get_db() as db:
status["groups"] = crud.get_user_groups(db, email)
status["groups"] = crud.get_user_groups(email)
except HTTPException: pass # not authenticated is fine
except Exception as e: log_error(e)
return JSONResponse(status)
@@ -31,13 +31,28 @@ async def health():
@default_router.get("/user/active", summary="Check if the user is active and can use the tool.")
# TODO: reorder db dependencies to after auth
async def active(db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)) -> schemas.ActiveUser:
return {"active": crud.is_active_user(db, email)}
@default_router.get("/groups")
def get_user_groups(db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)) -> list[str]:
return crud.get_user_groups(db, email)
def get_user_groups(email=Depends(get_user_auth)) -> list[str]:
return crud.get_user_groups(email)
@default_router.get("/permissions")
def get_user_groups(
user: UserState = Depends(get_active_user_state),
) -> list[str]:
return JSONResponse({
"groups": user.user_groups_names,
"allowedFrequencies": list(user.allowed_frequencies),
"sheet_quota": user.sheet_quota,
"monthly_urls": user.monthly_urls,
"monthly_mbs": user.monthly_mbs,
#TODO: should this return
})
@default_router.get('/favicon.ico', include_in_schema=False)

View File

@@ -5,7 +5,8 @@ from fastapi.responses import JSONResponse
from sqlalchemy import exc
from sqlalchemy.orm import Session
from web.security import token_api_key_auth, get_active_user_auth
from db.user_state import UserState
from web.security import token_api_key_auth, get_active_user_auth, get_active_user_state
from db import schemas, crud
from db.database import get_db_dependency
from worker.main import create_sheet_task
@@ -16,19 +17,21 @@ sheet_router = APIRouter(prefix="/sheet", tags=["Google Spreadsheet operations"]
@sheet_router.post("/create", status_code=201, summary="Store a new Google Sheet for regular archiving.")
def create_sheet(
sheet: schemas.SheetAdd,
email=Depends(get_active_user_auth),
user: UserState = Depends(get_active_user_state),
db: Session = Depends(get_db_dependency),
) -> schemas.SheetResponse:
user_groups_names = crud.get_user_groups(db, email)
if sheet.group_id not in user_groups_names:
if not user.in_group(sheet.group_id):
raise HTTPException(status_code=403, detail="User does not have access to this group.")
if not crud.has_quota_sheet(db, email, user_groups_names):
if not user.has_quota_sheet():
raise HTTPException(status_code=429, detail="User has reached their sheet quota.")
if not user.is_sheet_frequency_allowed(sheet.frequency):
raise HTTPException(status_code=422, detail=f"Invalid frequency: {sheet.frequency}. Must be one of {user.allowed_frequencies}")
try:
return crud.create_sheet(db, sheet.id, sheet.name, email, sheet.group_id, sheet.frequency)
return crud.create_sheet(db, sheet.id, sheet.name, user.email, sheet.group_id, sheet.frequency)
except exc.IntegrityError as e:
raise HTTPException(status_code=400, detail="Sheet with this ID already exists.") from e
@@ -56,22 +59,30 @@ def delete_sheet(
@sheet_router.post("/{id}/archive", status_code=201, summary="Trigger an archiving task for a GSheet you own.", response_description="task_id for the archiving task.")
def archive_user_sheet(
id: str,
email=Depends(get_active_user_auth),
user: UserState = Depends(get_active_user_state),
db: Session = Depends(get_db_dependency),
) -> schemas.Task:
#TODO: are we enabling manual triggers?
# if not user.can_manually_trigger():
# raise HTTPException(status_code=429, detail="User cannot manually trigger archiving tasks.")
sheet = crud.get_user_sheet(db, email, sheet_id=id)
sheet = crud.get_user_sheet(db, user.email, sheet_id=id)
if not sheet:
raise HTTPException(status_code=403, detail="No access to this sheet.")
task = create_sheet_task.delay(schemas.SubmitSheet(sheet_id=id, author_id=email, group=sheet.group_id).model_dump_json())
# TODO: what happens if user is taken out of group after sheet is created? this should be checked in a cronjob that notifies the user
if not user.in_group(sheet.group_id):
raise HTTPException(status_code=403, detail="User does not have access to this group.")
task = create_sheet_task.delay(schemas.SubmitSheet(sheet_id=id, author_id=user.email, group=sheet.group_id).model_dump_json())
return JSONResponse({"id": task.id}, status_code=201)
@sheet_router.post("/archive", status_code=201, summary="Trigger an archiving task for any GSheet with an API token.", response_description="task_id for the archiving task.")
def archive_sheet(
sheet: schemas.SubmitSheet, #TODO: replace with simpler model
sheet: schemas.SubmitSheet, # TODO: replace with simpler model
auth=Depends(token_api_key_auth)
) -> schemas.Task:
sheet.author_id = sheet.author_id or "api-endpoint"

View File

@@ -8,7 +8,7 @@ from web.security import get_user_auth, get_token_or_user_auth
from sqlalchemy.orm import Session
from db import crud, schemas
from db.database import get_db_dependency
from db.database import get_db, get_db_dependency
from worker.main import create_archive_task
@@ -16,14 +16,28 @@ url_router = APIRouter(prefix="/url", tags=["Single URL operations"])
@url_router.post("/archive", status_code=201, summary="Submit a single URL archive request, starts an archiving task.", response_description="task_id for the archiving task, will match the archive id.")
def archive_url(archive: schemas.ArchiveCreate, email=Depends(get_token_or_user_auth)) -> schemas.Task:
archive.author_id = email
url = archive.url
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {url}")
if type(url) != str or len(url) <= 5:
raise HTTPException(status_code=422, detail=f"Invalid URL received: {url}")
logger.info("creating task")
task = create_archive_task.delay(archive.model_dump_json())
def archive_url(
archive: schemas.ArchiveTrigger,
email=Depends(get_token_or_user_auth)
) -> schemas.Task:
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {archive.url}")
# TODO: implement quota
if archive.group_id:
with get_db() as db:
if not crud.is_user_in_group(db, email, archive.group_id):
raise HTTPException(status_code=403, detail="User does not have access to this group.")
# TODO: deprecate ArchiveCreate
backwards_compatible_archive = schemas.ArchiveCreate(
url=archive.url,
author_id=email,
group_id=archive.group_id,
public=archive.public,
)
task = create_archive_task.delay(backwards_compatible_archive.model_dump_json())
task_response = schemas.Task(id=task.id)
return JSONResponse(task_response.model_dump(), status_code=201)

View File

@@ -2,6 +2,7 @@ import os
from fastapi.testclient import TestClient
import pytest
from unittest.mock import patch
from db.user_state import UserState
from shared.settings import Settings
@@ -27,7 +28,9 @@ def mock_settings():
def test_db(get_settings: Settings):
from db.database import make_engine
from db import models
from db.crud import get_user_groups
get_user_groups.cache_clear()
make_engine.cache_clear()
engine = make_engine(get_settings.DATABASE_PATH)
@@ -72,11 +75,12 @@ def client(app):
@pytest.fixture()
def app_with_auth(app):
from web.security import get_token_or_user_auth, get_user_auth, get_active_user_auth
def app_with_auth(app, db_session):
from web.security import get_token_or_user_auth, get_user_auth, get_active_user_auth, get_active_user_state
app.dependency_overrides[get_token_or_user_auth] = lambda: "rick@example.com"
app.dependency_overrides[get_user_auth] = lambda: "morty@example.com"
app.dependency_overrides[get_active_user_auth] = lambda: "morty@example.com"
app.dependency_overrides[get_active_user_state] = lambda: UserState(db_session, "morty@example.com", active=True)
return app

View File

@@ -40,6 +40,12 @@ def test_data(db_session):
archive.urls.append(models.ArchiveUrl(url=f"https://example-{i}.com/{j}", key=f"media_{j}"))
db_session.add(archive)
# creates a sheet for each user
for i, email in enumerate(authors):
db_session.add(models.Sheet(id=f"sheet-{i}", name=f"sheet-{i}", author_id=email, group_id=None, frequency="daily"))
if email == "rick@example.com":
db_session.add(models.Sheet(id=f"sheet-{i}-2", name=f"sheet-{i}-2", author_id=email, group_id="spaceship", frequency="hourly"))
db_session.commit()
assert db_session.query(models.Archive).count() == 100
@@ -253,6 +259,7 @@ def test_count_archive_urls(test_data, db_session):
assert crud.count_archives(db_session) == 99
assert crud.count_archive_urls(db_session) == 999
def test_count_users(test_data, db_session):
from db import crud
@@ -261,6 +268,7 @@ def test_count_users(test_data, db_session):
db_session.commit()
assert crud.count_users(db_session) == 3
def test_count_by_users_since(test_data, db_session):
from db import crud
@@ -294,6 +302,7 @@ def test_create_tag(db_session):
assert second_tag.id == "tag-102"
assert db_session.query(models.Tag).count() == 2
def test_is_active_user(test_data, db_session):
from db import crud
@@ -329,7 +338,7 @@ def test_is_user_in_group(test_data, db_session):
("jerry@example.com", "spaceship", False),
("jerry@example.com", "interdimensional", False),
("jerry@example.com", "the-jerrys-club", False), # group not in 'groups'
("jerry@example.com", "the-jerrys-club", False), # group not in 'groups'
("rick@example.com", "animated-characters", True),
("morty@example.com", "animated-characters", True),
@@ -337,7 +346,7 @@ def test_is_user_in_group(test_data, db_session):
("ANYONE@example.com", "animated-characters", True),
("ANYONE@birdy.com", "animated-characters", True),
("summer@herself.com", "animated-characters", False),
("summer@herself.com", "animated-characters", False),
("rick@example.com", "", False),
("", "spaceship", False),
@@ -345,7 +354,16 @@ def test_is_user_in_group(test_data, db_session):
]
for email, group, expected in test_pairs:
print(f"{email} in {group} == {expected}")
assert crud.is_user_in_group(db_session, group, email) == expected
assert crud.is_user_in_group(db_session, email, group) == expected
def test_get_group(test_data, db_session):
from db import crud
assert crud.get_group(db_session, "spaceship") is not None
assert crud.get_group(db_session, "interdimensional") is not None
assert crud.get_group(db_session, "animated-characters") is not None
assert crud.get_group(db_session, "non-existant!@#!%!") is None
def test_create_or_get_user(test_data, db_session):
@@ -403,13 +421,12 @@ def test_upsert_group(test_data, db_session):
def test_upsert_user_groups(db_session):
from db import crud
@patch('db.crud.get_settings', new = lambda: bad_setings)
@patch('db.crud.get_settings', new=lambda: bad_setings)
def test_missing_yaml(db_session):
with pytest.raises(FileNotFoundError):
crud.upsert_user_groups(db_session)
@patch('db.crud.get_settings', new = lambda: bad_setings)
@patch('db.crud.get_settings', new=lambda: bad_setings)
def test_broken_yaml(db_session):
with pytest.raises(yaml.YAMLError):
crud.upsert_user_groups(db_session)
@@ -420,4 +437,54 @@ def test_upsert_user_groups(db_session):
test_missing_yaml(db_session)
bad_setings.USER_GROUPS_FILENAME = "tests/user-groups.test.broken.yaml"
test_broken_yaml(db_session)
test_broken_yaml(db_session)
def test_create_sheet(db_session):
from db import crud
assert db_session.query(models.Sheet).count() == 0
s = crud.create_sheet(db_session, "sheet-id-123", "sheet name", "email@example.com", "group-id", "hourly")
assert s is not None
assert s.id == "sheet-id-123"
assert s.name == "sheet name"
assert s.author_id == "email@example.com"
assert s.group_id == "group-id"
assert s.frequency == "hourly"
assert db_session.query(models.Sheet).count() == 1
# duplicate id
import sqlalchemy
with pytest.raises(sqlalchemy.exc.IntegrityError):
crud.create_sheet(db_session, "sheet-id-123", "I thought this was another sheet", "email", "group-id", "hourly")
def test_get_user_sheet(test_data, db_session):
from db import crud
assert crud.get_user_sheet(db_session, "", "sheet-0") is None
assert crud.get_user_sheet(db_session, "morty@example.com", "sheet-0") is None
assert crud.get_user_sheet(db_session, "rick@example.com", "sheet-0") is not None
assert crud.get_user_sheet(db_session, "rick@example.com", "sheet-0-2") is not None
assert crud.get_user_sheet(db_session, "morty@example.com", "sheet-1") is not None
def test_get_user_sheets(test_data, db_session):
from db import crud
assert len(crud.get_user_sheets(db_session, "")) == 0
rick_sheets = crud.get_user_sheets(db_session, "rick@example.com")
assert len(rick_sheets) == 2
assert [s.id for s in rick_sheets] == ["sheet-0", "sheet-0-2"]
assert len(crud.get_user_sheets(db_session, "morty@example.com")) == 1
def test_delete_sheet(test_data, db_session):
from db import crud
assert crud.delete_sheet(db_session, "sheet-0", "") == False
assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") == True
assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") == False

View File

@@ -15,7 +15,7 @@ def test_endpoints_no_auth(client, test_no_auth):
test_no_auth(client.post, "/sheet/archive")
def test_create_sheet_endpoint(app_with_auth):
def test_create_sheet_endpoint(app_with_auth, db_session):
client_with_auth = TestClient(app_with_auth)
good_data = {
"id": "123-sheet-id",
@@ -53,13 +53,23 @@ def test_create_sheet_endpoint(app_with_auth):
assert response.status_code == 403
assert response.json() == {"detail": "User does not have access to this group."}
# bad quota
# switch to jerry who's got less quota/permissions
from web.security import get_active_user_state
from db.user_state import UserState
app_with_auth.dependency_overrides[get_active_user_state] = lambda: UserState(db_session, "jerry@example.com", active=True)
client_jerry = TestClient(app_with_auth)
# frequency not allowed
jerry_data = good_data.copy()
jerry_data["group_id"] = "animated-characters"
jerry_data["frequency"] = "hourly"
jerry_data["id"] = "jerry-sheet-id"
from web.security import get_active_user_auth
app_with_auth.dependency_overrides[get_active_user_auth] = lambda: "jerry@example.com"
client_jerry = TestClient(app_with_auth)
response = client_jerry.post("/sheet/create", json=jerry_data)
assert response.status_code == 422
assert "Invalid frequency: hourly" in response.json()["detail"]
jerry_data["frequency"] = "daily"
# success for the first sheet, bad quota on second
response = client_jerry.post("/sheet/create", json=jerry_data)
assert response.status_code == 201
@@ -144,12 +154,6 @@ def test_delete_sheet_endpoint(client_with_auth, db_session):
assert response.json() == {"id": "456-sheet-id", "deleted": False}
# def test_archive_user_sheet_endpoint(client_with_auth):
# response = client_with_auth.post("/sheet/123-sheet-id/archive")
# assert response.status_code == 201
# assert "id" in response.json()
class TestArchiveUserSheetEndpoint:
def test_token_auth(self, client_with_token, test_no_auth):
test_no_auth(client_with_token.post, "/sheet/123-sheet-id/archive")
@@ -177,6 +181,14 @@ class TestArchiveUserSheetEndpoint:
assert r.json() == {"id": "123-taskid"}
m1.assert_called_once()
def test_user_not_in_group(self, client_with_auth, db_session):
from db import models
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="interdimensional", frequency="hourly"))
db_session.commit()
r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == 403
assert r.json() == {"detail": "User does not have access to this group."}
class TestTokenArchiveEndpoint:

View File

@@ -10,11 +10,13 @@ def test_archive_url_unauthenticated(client, test_no_auth):
@patch("worker.main.create_archive_task.delay", return_value=TaskResult(id="123-456-789", status="PENDING", result=""))
def test_archive_url(m1, client_with_auth):
# url is too short
response = client_with_auth.post("/url/archive", json={"url": "bad"})
assert response.status_code == 422
assert response.json() == {'detail': 'Invalid URL received: bad'}
assert response.json()["detail"][0]["msg"] == 'String should have at least 5 characters'
m1.assert_not_called()
# valid request
response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
assert response.status_code == 201
assert response.json() == {'id': '123-456-789'}
@@ -23,6 +25,20 @@ def test_archive_url(m1, client_with_auth):
called_val = m1.call_args.args[0]
assert json.loads(called_val) == {"id": None, "url": "https://example.com", "result": None, "public": True, "author_id": "rick@example.com", "group_id": None, "tags": [], "rearchive": True}
# user is not in group
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "new-group"})
assert response.status_code == 403
assert response.json()["detail"] == "User does not have access to this group."
# user is in group
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spaceship"})
assert response.status_code == 201
assert response.json() == {'id': '123-456-789'}
assert m1.call_count == 2
called_val = m1.call_args.args[0]
assert json.loads(called_val)["group_id"] == "spaceship"
def test_search_by_url_unauthenticated(client, test_no_auth):
test_no_auth(client.get, "/url/search")

View File

@@ -23,6 +23,8 @@ orchestrators:
interdimensional: tests/orchestration.test.yaml
default: tests/orchestration.test.yaml
default_orchestrator: tests/orchestration.test.yaml
groups:
spaceship:
description: "The spaceship crew"
@@ -31,9 +33,9 @@ groups:
permissions:
read: ["all"]
active_sheets: -1
monthly_urls: all
monthly_mbs: all
alowed_frequency: "hourly"
monthly_urls: -1
monthly_mbs: -1
allowed_frequency: "hourly"
interdimensional:
description: "Interdimensional travelers"
orchestrator: tests/orchestration.test.yaml
@@ -43,7 +45,7 @@ groups:
active_sheets: 5
monthly_urls: 1000
monthly_mbs: 1000
alowed_frequency: "hourly"
allowed_frequency: "hourly"
animated-characters:
description: "Animated characters"
orchestrator: tests/orchestration.test.yaml
@@ -53,4 +55,4 @@ groups:
active_sheets: 1
monthly_urls: 2
monthly_mbs: 10
alowed_frequency: "daily"
allowed_frequency: "daily"

View File

@@ -122,14 +122,6 @@ class Test_create_sheet_task():
assert db_session.query(models.Archive).filter(models.Archive.url == self.URL).count() == 0
@patch("worker.main.is_group_invalid_for_user", return_value="Access denied")
def test_error_access(self, m_insert, worker_init, db_session):
from worker.main import create_sheet_task
res = create_sheet_task(self.sheet.model_dump_json())
assert "error" in res
assert res["error"] == "Access denied"
def test_choose_orchestrator(worker_init):
from worker.main import choose_orchestrator

View File

@@ -131,17 +131,19 @@ def app_factory(settings = get_settings()):
@app.post("/sheet", status_code=201, deprecated=True) # DEPRECATED
def archive_sheet(sheet: schemas.SubmitSheet, email=Depends(get_user_auth)):
def archive_sheet(sheet: schemas.SubmitSheet, email=Depends(get_user_auth), db: Session = Depends(get_db_dependency)):
logger.info(f"SHEET TASK for {sheet=}")
sheet.author_id = email
if not sheet.sheet_name and not sheet.sheet_id:
raise HTTPException(status_code=422, detail=f"sheet name or id is required")
if not crud.is_user_in_group(db, email, sheet.group_id):
raise HTTPException(status_code=403, detail="User does not have access to this group.")
task = create_sheet_task.delay(sheet.model_dump_json())
return JSONResponse({"id": task.id})
@app.post("/sheet_service", status_code=201, deprecated=True) # DEPRECATED
def archive_sheet_service(sheet: schemas.SubmitSheet, auth=Depends(token_api_key_auth)):
def archive_sheet_service(sheet: schemas.SubmitSheet, auth=Depends(token_api_key_auth), db: Session = Depends(get_db_dependency)):
logger.info(f"SHEET TASK for {sheet=}")
sheet.author_id = sheet.author_id or "api-endpoint"
if not sheet.sheet_name and not sheet.sheet_id:

View File

@@ -6,6 +6,7 @@ from core.config import ALLOW_ANY_EMAIL
from shared.settings import get_settings
from db.database import get_db
from db import crud
from db.user_state import UserState
settings = get_settings()
bearer_security = HTTPBearer()
@@ -84,3 +85,8 @@ def authenticate_user(access_token):
except Exception as e:
logger.warning(f"AUTH EXCEPTION occurred: {e}")
return False, "exception occurred"
def get_active_user_state(email=Depends(get_active_user_auth)):
with get_db() as db:
return UserState(db, email, active=True)

View File

@@ -1,4 +1,5 @@
from functools import lru_cache
import traceback, yaml, datetime
from typing import List, Set
@@ -30,6 +31,7 @@ Rdis = redis.Redis.from_url(celery.conf.broker_url)
def create_archive_task(self, archive_json: str):
archive = schemas.ArchiveCreate.model_validate_json(archive_json)
logger.info(f"Archiving {archive.url=} {archive.tags=} {archive.public=} {archive.group_id=} {archive.author_id=}")
#TODO: move group checks out of here
invalid = is_group_invalid_for_user(archive.public, archive.group_id, archive.author_id)
if invalid:
raise Exception(invalid) # marks task FAILED, saves the Exception as result
@@ -64,10 +66,6 @@ def create_sheet_task(self, sheet_json: str):
sheet.tags.add("gsheet")
logger.info(f"SHEET START {sheet=}")
#TODO: should this check live here?
if (em := is_group_invalid_for_user(sheet.public, sheet.group_id, sheet.author_id)):
return {"error": em}
config = Config()
# TODO: use choose_orchestrator and overwrite the feeder
# TODO: drop sheet_name and use only sheet_id (new endpoints/models)
@@ -161,7 +159,7 @@ def is_group_invalid_for_user(public: bool, group_id: str, author_id: str):
# otherwise group must match
with get_db() as session:
if not crud.is_user_in_group(session, group_id, author_id):
if not crud.is_user_in_group(session, author_id, group_id):
logger.error(em := f"User {author_id} is not part of {group_id}, no permission")
return em
return False
@@ -220,3 +218,13 @@ def at_start(sender, **kwargs):
ORCHESTRATORS = {}
load_orchestrators()
logger.info("Orchestrators loaded successfully.")
@lru_cache
def get_url_orchestrator(group_name):
with get_db() as db:
group = crud.get_group(db, group_name)
assert group, f"Group {group_name} not found"
# config = Config()
# config.parse(use_cli=False, yaml_config_filename=group.orchestrator_sheet)
# return ArchivingOrchestrator(config)