pushing bulk of changes

This commit is contained in:
msramalho
2025-01-22 13:21:16 +00:00
parent 2209b09a9a
commit 9f9bbc9344
21 changed files with 2048 additions and 2250 deletions

View File

@@ -21,6 +21,8 @@ orchestration must be from the console(?)
* turn off VPNs if connection to docker is not working * turn off VPNs if connection to docker is not working
## User management ## User management
TODO: update description and example
- users/domains/groups
Copy [example.user-groups.yaml](src/example.user-groups.yaml) into a new file and set the environment variable `USER_GROUPS_FILENAME` to that filename (defaults to `user-groups.yaml`). Copy [example.user-groups.yaml](src/example.user-groups.yaml) into a new file and set the environment variable `USER_GROUPS_FILENAME` to that filename (defaults to `user-groups.yaml`).
This file contains 2 parts user-groups specifications. Each user can archive URLs publicly, privately, or privately for a group so long as they are declared as part of that group. In the example bellow `email1` has 2 groups while `email3` has none. This file contains 2 parts user-groups specifications. Each user can archive URLs publicly, privately, or privately for a group so long as they are declared as part of that group. In the example bellow `email1` has 2 groups while `email3` has none.

View File

@@ -4,6 +4,7 @@ verify_ssl = true
name = "pypi" name = "pypi"
[packages] [packages]
oscrypto = {git = "https://github.com/wbond/oscrypto.git", ref = "d5f3437ed24257895ae1edd9e503cfb352e635a8"}
aiofiles = "==0.6.0" aiofiles = "==0.6.0"
celery = ">=5.0" celery = ">=5.0"
fastapi = "*" fastapi = "*"

3759
src/Pipfile.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1 +0,0 @@
based on https://fastapi-users.github.io/fastapi-users/10.4/configuration/oauth/

View File

@@ -1,11 +1,12 @@
from collections import defaultdict from collections import defaultdict
from functools import cache from functools import lru_cache
from sqlalchemy.orm import Session, load_only from sqlalchemy.orm import Session, load_only
from sqlalchemy import Column, or_, func from sqlalchemy import Column, or_, func
from loguru import logger from loguru import logger
from datetime import datetime, timedelta from datetime import datetime, timedelta
from core.config import ALLOW_ANY_EMAIL from core.config import ALLOW_ANY_EMAIL
from db.database import get_db
from shared.settings import get_settings from shared.settings import get_settings
from . import models, schemas from . import models, schemas
import yaml import yaml
@@ -23,7 +24,7 @@ def get_archive(db: Session, id: str, email: str):
email = email.lower() email = email.lower()
query = base_query(db).filter(models.Archive.id == id) query = base_query(db).filter(models.Archive.id == id)
if email != ALLOW_ANY_EMAIL: if email != ALLOW_ANY_EMAIL:
groups = get_user_groups(db, email) groups = get_user_groups(email)
query = query.filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups))) query = query.filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups)))
return query.first() return query.first()
@@ -33,7 +34,7 @@ def search_archives_by_url(db: Session, url: str, email: str, skip: int = 0, lim
query = base_query(db) query = base_query(db)
if email != ALLOW_ANY_EMAIL: if email != ALLOW_ANY_EMAIL:
email = email.lower() email = email.lower()
groups = get_user_groups(db, email) groups = get_user_groups(email)
query = query.filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups))) query = query.filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups)))
if absolute_search: if absolute_search:
query = query.filter(models.Archive.url == url) query = query.filter(models.Archive.url == url)
@@ -121,72 +122,37 @@ def is_active_user(db: Session, email: str) -> bool:
return db.query(models.Group).filter(models.Group.domains.contains(domain)).first() is not None return db.query(models.Group).filter(models.Group.domains.contains(domain)).first() is not None
def is_user_in_group(db: Session, group_name: str, email: str) -> models.Group: def is_user_in_group(db: Session, email: str, group_name: str) -> models.Group:
if email == ALLOW_ANY_EMAIL: return True if email == ALLOW_ANY_EMAIL: return True
return len(group_name) and len(email) and group_name in get_user_groups(db, email) return len(group_name) and len(email) and group_name in get_user_groups(email)
#TODO: maybe this can be cached? what about the db session? @lru_cache
def get_user_groups(db: Session, email: str) -> list[str]: def get_user_groups(email: str) -> list[str]:
""" """
given an email retrieves the user groups from the DB and then the email-domain groups from a global variable, the email does not need to belong to an existing user. User does not need to be active. given an email retrieves the user groups from the DB and then the email-domain groups from a global variable, the email does not need to belong to an existing user. User does not need to be active.
""" """
if not email or not len(email) or "@" not in email: return [] if not email or not len(email) or "@" not in email: return []
email = email.lower() email = email.lower()
# get user groups with get_db() as db:
user_groups = db.query(models.association_table_user_groups).filter_by(user_id=email).with_entities(Column("group_id")).all() # get user groups
user_level_groups_names = [g[0] for g in user_groups] user_groups = db.query(models.association_table_user_groups).filter_by(user_id=email).with_entities(Column("group_id")).all()
user_level_groups_names = [g[0] for g in user_groups]
# get domain groups # get domain groups
domain = email.split('@')[1] domain = email.split('@')[1]
domain_level_groups = db.query(models.Group.id).filter(models.Group.domains.contains(domain)).with_entities(Column("id")).all() domain_level_groups = db.query(models.Group.id).filter(models.Group.domains.contains(domain)).with_entities(Column("id")).all()
domain_level_groups_names = [g[0] for g in domain_level_groups] domain_level_groups_names = [g[0] for g in domain_level_groups]
return list(set(user_level_groups_names + domain_level_groups_names)) return list(set(user_level_groups_names + domain_level_groups_names))
# --------------- SHEET
def has_quota_sheet(db: Session, email: str, user_groups_names: list[str]) -> bool:
"""
checks if a user has reached their sheet quota
"""
user_sheets = db.query(models.Sheet).filter(models.Sheet.author_id == email).count()
user_groups = db.query(models.Group).filter(models.Group.id.in_(user_groups_names)).all()
quota = 0
for group in user_groups:
active_sheets = group.permissions.get("active_sheets", 0)
if active_sheets == -1: return True
quota = max(quota, active_sheets)
return user_sheets < quota
def create_sheet(db: Session, sheet_id: str, sheet_name: str, email: str, group_id: str, frequency: str):
db_sheet = models.Sheet(id=sheet_id, name=sheet_name, author_id=email, group_id=group_id, frequency=frequency)
db.add(db_sheet)
db.commit()
db.refresh(db_sheet)
return db_sheet
def get_user_sheets(db: Session, email: str) -> list[models.Sheet]:
return db.query(models.Sheet).filter(models.Sheet.author_id == email).order_by(models.Sheet.last_archived_at.desc()).all()
def get_user_sheet(db: Session, email: str, sheet_id: str) -> models.Sheet:
return db.query(models.Sheet).filter(models.Sheet.author_id == email, models.Sheet.id == sheet_id).first()
def delete_sheet(db: Session, sheet_id: str, email: str) -> bool:
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id, models.Sheet.author_id == email).first()
if db_sheet:
db.delete(db_sheet)
db.commit()
return db_sheet is not None
# --------------- INIT User-Groups # --------------- INIT User-Groups
def get_group(db: Session, group_name: str) -> models.Group:
return db.query(models.Group).filter(models.Group.id == group_name).first()
def create_or_get_user(db: Session, author_id: str, is_active: bool = models.User.is_active.default.arg) -> models.User: def create_or_get_user(db: Session, author_id: str, is_active: bool = models.User.is_active.default.arg) -> models.User:
if type(author_id) == str: author_id = author_id.lower() if type(author_id) == str: author_id = author_id.lower()
@@ -296,3 +262,28 @@ def upsert_user_groups(db: Session):
count_groups = db.query(func.count(models.Group.id)).scalar() count_groups = db.query(func.count(models.Group.id)).scalar()
logger.success(f"[CONFIG] DONE: [users={count_users(db)}, groups={count_groups}, explicit user groups={count_user_groups}].") logger.success(f"[CONFIG] DONE: [users={count_users(db)}, groups={count_groups}, explicit user groups={count_user_groups}].")
# --------------- SHEET
def create_sheet(db: Session, sheet_id: str, sheet_name: str, email: str, group_id: str, frequency: str):
db_sheet = models.Sheet(id=sheet_id, name=sheet_name, author_id=email, group_id=group_id, frequency=frequency)
db.add(db_sheet)
db.commit()
db.refresh(db_sheet)
return db_sheet
def get_user_sheet(db: Session, email: str, sheet_id: str) -> models.Sheet:
return db.query(models.Sheet).filter(models.Sheet.author_id == email, models.Sheet.id == sheet_id).first()
def get_user_sheets(db: Session, email: str) -> list[models.Sheet]:
return db.query(models.Sheet).filter(models.Sheet.author_id == email).order_by(models.Sheet.last_archived_at.desc()).all()
def delete_sheet(db: Session, sheet_id: str, email: str) -> bool:
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id, models.Sheet.author_id == email).first()
if db_sheet:
db.delete(db_sheet)
db.commit()
return db_sheet is not None

View File

@@ -87,7 +87,7 @@ class Group(Base):
description = Column(String, default=None) description = Column(String, default=None)
orchestrator = Column(String, default=None) orchestrator = Column(String, default=None)
orchestrator_sheet = Column(String, default=None) orchestrator_sheet = Column(String, default=None)
permissions = Column(JSON, default=None) permissions = Column(JSON, default={})
domains = Column(JSON, default=[]) domains = Column(JSON, default=[])
archives = relationship("Archive", back_populates="group") archives = relationship("Archive", back_populates="group")

View File

@@ -1,3 +1,5 @@
from typing import Annotated
from annotated_types import Len
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator
from datetime import datetime from datetime import datetime
@@ -105,3 +107,10 @@ class SheetResponse(SheetAdd):
stats: dict | None stats: dict | None
last_archived_at: datetime | None last_archived_at: datetime | None
created_at: datetime created_at: datetime
class ArchiveTrigger(BaseModel):
url: Annotated[str, Len(min_length=5)]
public: bool = True
group_id: Annotated[str, Len(min_length=1)] | None = None
tags: set[Tag] | None = set()

142
src/db/user_state.py Normal file
View File

@@ -0,0 +1,142 @@
import sqlalchemy
from sqlalchemy.orm import Session
from sqlalchemy import func
from db import crud, models
from datetime import datetime
class UserState:
"""
Manage a user's state and permissions
"""
def __init__(self, db: Session, email: str, active=False):
self.db = db
self.email = email
self.active = active
@property
def user_groups_names(self):
if not hasattr(self, '_user_groups_names'):
self._user_groups_names = crud.get_user_groups(self.email)
return self._user_groups_names
@property
def user_groups(self):
if not hasattr(self, '_user_groups'):
self._user_groups = self.db.query(models.Group).filter(
models.Group.id.in_(self.user_groups_names)
).all()
return self._user_groups
@property
def allowed_frequencies(self):
if not hasattr(self, '_allowed_frequencies'):
self._allowed_frequencies = set()
for group in self.user_groups:
if not group.permissions: continue
self._allowed_frequencies.add(group.permissions.get("allowed_frequency", None))
if "hourly" in self._allowed_frequencies:
self._allowed_frequencies.add("daily")
return self._allowed_frequencies
@property
def sheet_quota(self):
"""
infer the user's sheet quota from the groups
-1 means unlimited
"""
if not hasattr(self, '_sheet_quota'):
self._sheet_quota = 0
for group in self.user_groups:
if not group.permissions: continue
active_sheets = group.permissions.get("active_sheets", 0)
if active_sheets == -1:
self._sheet_quota = -1
return self._sheet_quota
self._sheet_quota = max(self._sheet_quota, active_sheets)
return self._sheet_quota
def in_group(self, group_id: str) -> bool:
return group_id in self.user_groups_names
def has_quota_sheet(self) -> bool:
"""
checks if a user has reached their sheet quota
"""
if self.sheet_quota == -1: return True
user_sheets = self.db.query(models.Sheet).filter(models.Sheet.author_id == self.email).count()
return user_sheets < self.sheet_quota
def has_quota_monthly_urls(self) -> bool:
"""
checks if a user has reached their monthly url quota
"""
quota = 0
for group in self.user_groups:
if not group.permissions: continue
monthly_urls = group.permissions.get("monthly_urls", 0)
if monthly_urls == -1: return True
quota = max(quota, monthly_urls)
current_month = datetime.now().month
current_year = datetime.now().year
user_urls = self.db.query(models.Archive).filter(
models.Archive.author_id == self.email,
func.extract('month', models.Archive.created_at) == current_month,
func.extract('year', models.Archive.created_at) == current_year
).count()
return user_urls < quota
def has_quota_monthly_mbs(self) -> bool:
"""
checks if a user has reached their monthly mb quota
"""
quota = 0
for group in self.user_groups:
if not group.permissions: continue
monthly_mbs = group.permissions.get("monthly_mbs", 0)
if monthly_mbs == -1: return True
quota = max(quota, monthly_mbs)
current_month = datetime.now().month
current_year = datetime.now().year
# find and sum all user bytes over this month
user_bytes = self.db.query(models.Archive).filter(
models.Archive.author_id == self.email,
func.extract('month', models.Archive.created_at) == current_month,
func.extract('year', models.Archive.created_at) == current_year
).with_entities(func.coalesce(func.sum(
func.coalesce(
func.cast(
func.json_extract(models.Archive.result, '$.metadata.total_bytes'),
sqlalchemy.Integer
), 0
)
), 0).label('total')).scalar()
# convert bytes to mb
user_mbs = int(user_bytes / 1024 / 1024)
return user_mbs < quota
# def can_manually_trigger(self) -> bool:
# """
# checks if a user is allowed to manually trigger a sheet
# """
# for group in self.user_groups:
# if not group.permissions: continue
# if group.permissions.get("manual_trigger", False):
# return True
# return False
def is_sheet_frequency_allowed(self, frequency: str) -> bool:
"""
checks if a user is allowed to create a sheet with this frequency
"""
return frequency in self.allowed_frequencies

View File

@@ -6,8 +6,9 @@ from sqlalchemy.orm import Session
from core.config import VERSION, BREAKING_CHANGES from core.config import VERSION, BREAKING_CHANGES
from core.logging import log_error from core.logging import log_error
from db import crud, schemas from db import crud, schemas
from db.database import get_db_dependency, get_db from db.database import get_db_dependency
from web.security import get_user_auth, bearer_security from db.user_state import UserState
from web.security import get_user_auth, bearer_security, get_active_user_state
default_router = APIRouter() default_router = APIRouter()
@@ -18,8 +19,7 @@ async def home(request: Request):
status = {"version": VERSION, "breakingChanges": BREAKING_CHANGES} status = {"version": VERSION, "breakingChanges": BREAKING_CHANGES}
try: try:
email = await get_user_auth(await bearer_security(request)) email = await get_user_auth(await bearer_security(request))
with get_db() as db: status["groups"] = crud.get_user_groups(email)
status["groups"] = crud.get_user_groups(db, email)
except HTTPException: pass # not authenticated is fine except HTTPException: pass # not authenticated is fine
except Exception as e: log_error(e) except Exception as e: log_error(e)
return JSONResponse(status) return JSONResponse(status)
@@ -31,13 +31,28 @@ async def health():
@default_router.get("/user/active", summary="Check if the user is active and can use the tool.") @default_router.get("/user/active", summary="Check if the user is active and can use the tool.")
# TODO: reorder db dependencies to after auth
async def active(db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)) -> schemas.ActiveUser: async def active(db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)) -> schemas.ActiveUser:
return {"active": crud.is_active_user(db, email)} return {"active": crud.is_active_user(db, email)}
@default_router.get("/groups") @default_router.get("/groups")
def get_user_groups(db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)) -> list[str]: def get_user_groups(email=Depends(get_user_auth)) -> list[str]:
return crud.get_user_groups(db, email) return crud.get_user_groups(email)
@default_router.get("/permissions")
def get_user_groups(
user: UserState = Depends(get_active_user_state),
) -> list[str]:
return JSONResponse({
"groups": user.user_groups_names,
"allowedFrequencies": list(user.allowed_frequencies),
"sheet_quota": user.sheet_quota,
"monthly_urls": user.monthly_urls,
"monthly_mbs": user.monthly_mbs,
#TODO: should this return
})
@default_router.get('/favicon.ico', include_in_schema=False) @default_router.get('/favicon.ico', include_in_schema=False)

View File

@@ -5,7 +5,8 @@ from fastapi.responses import JSONResponse
from sqlalchemy import exc from sqlalchemy import exc
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from web.security import token_api_key_auth, get_active_user_auth from db.user_state import UserState
from web.security import token_api_key_auth, get_active_user_auth, get_active_user_state
from db import schemas, crud from db import schemas, crud
from db.database import get_db_dependency from db.database import get_db_dependency
from worker.main import create_sheet_task from worker.main import create_sheet_task
@@ -16,19 +17,21 @@ sheet_router = APIRouter(prefix="/sheet", tags=["Google Spreadsheet operations"]
@sheet_router.post("/create", status_code=201, summary="Store a new Google Sheet for regular archiving.") @sheet_router.post("/create", status_code=201, summary="Store a new Google Sheet for regular archiving.")
def create_sheet( def create_sheet(
sheet: schemas.SheetAdd, sheet: schemas.SheetAdd,
email=Depends(get_active_user_auth), user: UserState = Depends(get_active_user_state),
db: Session = Depends(get_db_dependency), db: Session = Depends(get_db_dependency),
) -> schemas.SheetResponse: ) -> schemas.SheetResponse:
user_groups_names = crud.get_user_groups(db, email)
if sheet.group_id not in user_groups_names: if not user.in_group(sheet.group_id):
raise HTTPException(status_code=403, detail="User does not have access to this group.") raise HTTPException(status_code=403, detail="User does not have access to this group.")
if not crud.has_quota_sheet(db, email, user_groups_names): if not user.has_quota_sheet():
raise HTTPException(status_code=429, detail="User has reached their sheet quota.") raise HTTPException(status_code=429, detail="User has reached their sheet quota.")
if not user.is_sheet_frequency_allowed(sheet.frequency):
raise HTTPException(status_code=422, detail=f"Invalid frequency: {sheet.frequency}. Must be one of {user.allowed_frequencies}")
try: try:
return crud.create_sheet(db, sheet.id, sheet.name, email, sheet.group_id, sheet.frequency) return crud.create_sheet(db, sheet.id, sheet.name, user.email, sheet.group_id, sheet.frequency)
except exc.IntegrityError as e: except exc.IntegrityError as e:
raise HTTPException(status_code=400, detail="Sheet with this ID already exists.") from e raise HTTPException(status_code=400, detail="Sheet with this ID already exists.") from e
@@ -56,22 +59,30 @@ def delete_sheet(
@sheet_router.post("/{id}/archive", status_code=201, summary="Trigger an archiving task for a GSheet you own.", response_description="task_id for the archiving task.") @sheet_router.post("/{id}/archive", status_code=201, summary="Trigger an archiving task for a GSheet you own.", response_description="task_id for the archiving task.")
def archive_user_sheet( def archive_user_sheet(
id: str, id: str,
email=Depends(get_active_user_auth), user: UserState = Depends(get_active_user_state),
db: Session = Depends(get_db_dependency), db: Session = Depends(get_db_dependency),
) -> schemas.Task: ) -> schemas.Task:
sheet = crud.get_user_sheet(db, email, sheet_id=id) #TODO: are we enabling manual triggers?
# if not user.can_manually_trigger():
# raise HTTPException(status_code=429, detail="User cannot manually trigger archiving tasks.")
sheet = crud.get_user_sheet(db, user.email, sheet_id=id)
if not sheet: if not sheet:
raise HTTPException(status_code=403, detail="No access to this sheet.") raise HTTPException(status_code=403, detail="No access to this sheet.")
task = create_sheet_task.delay(schemas.SubmitSheet(sheet_id=id, author_id=email, group=sheet.group_id).model_dump_json()) # TODO: what happens if user is taken out of group after sheet is created? this should be checked in a cronjob that notifies the user
if not user.in_group(sheet.group_id):
raise HTTPException(status_code=403, detail="User does not have access to this group.")
task = create_sheet_task.delay(schemas.SubmitSheet(sheet_id=id, author_id=user.email, group=sheet.group_id).model_dump_json())
return JSONResponse({"id": task.id}, status_code=201) return JSONResponse({"id": task.id}, status_code=201)
@sheet_router.post("/archive", status_code=201, summary="Trigger an archiving task for any GSheet with an API token.", response_description="task_id for the archiving task.") @sheet_router.post("/archive", status_code=201, summary="Trigger an archiving task for any GSheet with an API token.", response_description="task_id for the archiving task.")
def archive_sheet( def archive_sheet(
sheet: schemas.SubmitSheet, #TODO: replace with simpler model sheet: schemas.SubmitSheet, # TODO: replace with simpler model
auth=Depends(token_api_key_auth) auth=Depends(token_api_key_auth)
) -> schemas.Task: ) -> schemas.Task:
sheet.author_id = sheet.author_id or "api-endpoint" sheet.author_id = sheet.author_id or "api-endpoint"

View File

@@ -8,7 +8,7 @@ from web.security import get_user_auth, get_token_or_user_auth
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from db import crud, schemas from db import crud, schemas
from db.database import get_db_dependency from db.database import get_db, get_db_dependency
from worker.main import create_archive_task from worker.main import create_archive_task
@@ -16,14 +16,28 @@ url_router = APIRouter(prefix="/url", tags=["Single URL operations"])
@url_router.post("/archive", status_code=201, summary="Submit a single URL archive request, starts an archiving task.", response_description="task_id for the archiving task, will match the archive id.") @url_router.post("/archive", status_code=201, summary="Submit a single URL archive request, starts an archiving task.", response_description="task_id for the archiving task, will match the archive id.")
def archive_url(archive: schemas.ArchiveCreate, email=Depends(get_token_or_user_auth)) -> schemas.Task: def archive_url(
archive.author_id = email archive: schemas.ArchiveTrigger,
url = archive.url email=Depends(get_token_or_user_auth)
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {url}") ) -> schemas.Task:
if type(url) != str or len(url) <= 5: logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {archive.url}")
raise HTTPException(status_code=422, detail=f"Invalid URL received: {url}")
logger.info("creating task") # TODO: implement quota
task = create_archive_task.delay(archive.model_dump_json())
if archive.group_id:
with get_db() as db:
if not crud.is_user_in_group(db, email, archive.group_id):
raise HTTPException(status_code=403, detail="User does not have access to this group.")
# TODO: deprecate ArchiveCreate
backwards_compatible_archive = schemas.ArchiveCreate(
url=archive.url,
author_id=email,
group_id=archive.group_id,
public=archive.public,
)
task = create_archive_task.delay(backwards_compatible_archive.model_dump_json())
task_response = schemas.Task(id=task.id) task_response = schemas.Task(id=task.id)
return JSONResponse(task_response.model_dump(), status_code=201) return JSONResponse(task_response.model_dump(), status_code=201)

View File

@@ -2,6 +2,7 @@ import os
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
import pytest import pytest
from unittest.mock import patch from unittest.mock import patch
from db.user_state import UserState
from shared.settings import Settings from shared.settings import Settings
@@ -27,7 +28,9 @@ def mock_settings():
def test_db(get_settings: Settings): def test_db(get_settings: Settings):
from db.database import make_engine from db.database import make_engine
from db import models from db import models
from db.crud import get_user_groups
get_user_groups.cache_clear()
make_engine.cache_clear() make_engine.cache_clear()
engine = make_engine(get_settings.DATABASE_PATH) engine = make_engine(get_settings.DATABASE_PATH)
@@ -72,11 +75,12 @@ def client(app):
@pytest.fixture() @pytest.fixture()
def app_with_auth(app): def app_with_auth(app, db_session):
from web.security import get_token_or_user_auth, get_user_auth, get_active_user_auth from web.security import get_token_or_user_auth, get_user_auth, get_active_user_auth, get_active_user_state
app.dependency_overrides[get_token_or_user_auth] = lambda: "rick@example.com" app.dependency_overrides[get_token_or_user_auth] = lambda: "rick@example.com"
app.dependency_overrides[get_user_auth] = lambda: "morty@example.com" app.dependency_overrides[get_user_auth] = lambda: "morty@example.com"
app.dependency_overrides[get_active_user_auth] = lambda: "morty@example.com" app.dependency_overrides[get_active_user_auth] = lambda: "morty@example.com"
app.dependency_overrides[get_active_user_state] = lambda: UserState(db_session, "morty@example.com", active=True)
return app return app

View File

@@ -40,6 +40,12 @@ def test_data(db_session):
archive.urls.append(models.ArchiveUrl(url=f"https://example-{i}.com/{j}", key=f"media_{j}")) archive.urls.append(models.ArchiveUrl(url=f"https://example-{i}.com/{j}", key=f"media_{j}"))
db_session.add(archive) db_session.add(archive)
# creates a sheet for each user
for i, email in enumerate(authors):
db_session.add(models.Sheet(id=f"sheet-{i}", name=f"sheet-{i}", author_id=email, group_id=None, frequency="daily"))
if email == "rick@example.com":
db_session.add(models.Sheet(id=f"sheet-{i}-2", name=f"sheet-{i}-2", author_id=email, group_id="spaceship", frequency="hourly"))
db_session.commit() db_session.commit()
assert db_session.query(models.Archive).count() == 100 assert db_session.query(models.Archive).count() == 100
@@ -253,6 +259,7 @@ def test_count_archive_urls(test_data, db_session):
assert crud.count_archives(db_session) == 99 assert crud.count_archives(db_session) == 99
assert crud.count_archive_urls(db_session) == 999 assert crud.count_archive_urls(db_session) == 999
def test_count_users(test_data, db_session): def test_count_users(test_data, db_session):
from db import crud from db import crud
@@ -261,6 +268,7 @@ def test_count_users(test_data, db_session):
db_session.commit() db_session.commit()
assert crud.count_users(db_session) == 3 assert crud.count_users(db_session) == 3
def test_count_by_users_since(test_data, db_session): def test_count_by_users_since(test_data, db_session):
from db import crud from db import crud
@@ -294,6 +302,7 @@ def test_create_tag(db_session):
assert second_tag.id == "tag-102" assert second_tag.id == "tag-102"
assert db_session.query(models.Tag).count() == 2 assert db_session.query(models.Tag).count() == 2
def test_is_active_user(test_data, db_session): def test_is_active_user(test_data, db_session):
from db import crud from db import crud
@@ -329,7 +338,7 @@ def test_is_user_in_group(test_data, db_session):
("jerry@example.com", "spaceship", False), ("jerry@example.com", "spaceship", False),
("jerry@example.com", "interdimensional", False), ("jerry@example.com", "interdimensional", False),
("jerry@example.com", "the-jerrys-club", False), # group not in 'groups' ("jerry@example.com", "the-jerrys-club", False), # group not in 'groups'
("rick@example.com", "animated-characters", True), ("rick@example.com", "animated-characters", True),
("morty@example.com", "animated-characters", True), ("morty@example.com", "animated-characters", True),
@@ -337,7 +346,7 @@ def test_is_user_in_group(test_data, db_session):
("ANYONE@example.com", "animated-characters", True), ("ANYONE@example.com", "animated-characters", True),
("ANYONE@birdy.com", "animated-characters", True), ("ANYONE@birdy.com", "animated-characters", True),
("summer@herself.com", "animated-characters", False), ("summer@herself.com", "animated-characters", False),
("rick@example.com", "", False), ("rick@example.com", "", False),
("", "spaceship", False), ("", "spaceship", False),
@@ -345,7 +354,16 @@ def test_is_user_in_group(test_data, db_session):
] ]
for email, group, expected in test_pairs: for email, group, expected in test_pairs:
print(f"{email} in {group} == {expected}") print(f"{email} in {group} == {expected}")
assert crud.is_user_in_group(db_session, group, email) == expected assert crud.is_user_in_group(db_session, email, group) == expected
def test_get_group(test_data, db_session):
from db import crud
assert crud.get_group(db_session, "spaceship") is not None
assert crud.get_group(db_session, "interdimensional") is not None
assert crud.get_group(db_session, "animated-characters") is not None
assert crud.get_group(db_session, "non-existant!@#!%!") is None
def test_create_or_get_user(test_data, db_session): def test_create_or_get_user(test_data, db_session):
@@ -403,13 +421,12 @@ def test_upsert_group(test_data, db_session):
def test_upsert_user_groups(db_session): def test_upsert_user_groups(db_session):
from db import crud from db import crud
@patch('db.crud.get_settings', new = lambda: bad_setings) @patch('db.crud.get_settings', new=lambda: bad_setings)
def test_missing_yaml(db_session): def test_missing_yaml(db_session):
with pytest.raises(FileNotFoundError): with pytest.raises(FileNotFoundError):
crud.upsert_user_groups(db_session) crud.upsert_user_groups(db_session)
@patch('db.crud.get_settings', new=lambda: bad_setings)
@patch('db.crud.get_settings', new = lambda: bad_setings)
def test_broken_yaml(db_session): def test_broken_yaml(db_session):
with pytest.raises(yaml.YAMLError): with pytest.raises(yaml.YAMLError):
crud.upsert_user_groups(db_session) crud.upsert_user_groups(db_session)
@@ -421,3 +438,53 @@ def test_upsert_user_groups(db_session):
bad_setings.USER_GROUPS_FILENAME = "tests/user-groups.test.broken.yaml" bad_setings.USER_GROUPS_FILENAME = "tests/user-groups.test.broken.yaml"
test_broken_yaml(db_session) test_broken_yaml(db_session)
def test_create_sheet(db_session):
from db import crud
assert db_session.query(models.Sheet).count() == 0
s = crud.create_sheet(db_session, "sheet-id-123", "sheet name", "email@example.com", "group-id", "hourly")
assert s is not None
assert s.id == "sheet-id-123"
assert s.name == "sheet name"
assert s.author_id == "email@example.com"
assert s.group_id == "group-id"
assert s.frequency == "hourly"
assert db_session.query(models.Sheet).count() == 1
# duplicate id
import sqlalchemy
with pytest.raises(sqlalchemy.exc.IntegrityError):
crud.create_sheet(db_session, "sheet-id-123", "I thought this was another sheet", "email", "group-id", "hourly")
def test_get_user_sheet(test_data, db_session):
from db import crud
assert crud.get_user_sheet(db_session, "", "sheet-0") is None
assert crud.get_user_sheet(db_session, "morty@example.com", "sheet-0") is None
assert crud.get_user_sheet(db_session, "rick@example.com", "sheet-0") is not None
assert crud.get_user_sheet(db_session, "rick@example.com", "sheet-0-2") is not None
assert crud.get_user_sheet(db_session, "morty@example.com", "sheet-1") is not None
def test_get_user_sheets(test_data, db_session):
from db import crud
assert len(crud.get_user_sheets(db_session, "")) == 0
rick_sheets = crud.get_user_sheets(db_session, "rick@example.com")
assert len(rick_sheets) == 2
assert [s.id for s in rick_sheets] == ["sheet-0", "sheet-0-2"]
assert len(crud.get_user_sheets(db_session, "morty@example.com")) == 1
def test_delete_sheet(test_data, db_session):
from db import crud
assert crud.delete_sheet(db_session, "sheet-0", "") == False
assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") == True
assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") == False

View File

@@ -15,7 +15,7 @@ def test_endpoints_no_auth(client, test_no_auth):
test_no_auth(client.post, "/sheet/archive") test_no_auth(client.post, "/sheet/archive")
def test_create_sheet_endpoint(app_with_auth): def test_create_sheet_endpoint(app_with_auth, db_session):
client_with_auth = TestClient(app_with_auth) client_with_auth = TestClient(app_with_auth)
good_data = { good_data = {
"id": "123-sheet-id", "id": "123-sheet-id",
@@ -53,13 +53,23 @@ def test_create_sheet_endpoint(app_with_auth):
assert response.status_code == 403 assert response.status_code == 403
assert response.json() == {"detail": "User does not have access to this group."} assert response.json() == {"detail": "User does not have access to this group."}
# bad quota # switch to jerry who's got less quota/permissions
from web.security import get_active_user_state
from db.user_state import UserState
app_with_auth.dependency_overrides[get_active_user_state] = lambda: UserState(db_session, "jerry@example.com", active=True)
client_jerry = TestClient(app_with_auth)
# frequency not allowed
jerry_data = good_data.copy() jerry_data = good_data.copy()
jerry_data["group_id"] = "animated-characters" jerry_data["group_id"] = "animated-characters"
jerry_data["frequency"] = "hourly"
jerry_data["id"] = "jerry-sheet-id" jerry_data["id"] = "jerry-sheet-id"
from web.security import get_active_user_auth response = client_jerry.post("/sheet/create", json=jerry_data)
app_with_auth.dependency_overrides[get_active_user_auth] = lambda: "jerry@example.com" assert response.status_code == 422
client_jerry = TestClient(app_with_auth) assert "Invalid frequency: hourly" in response.json()["detail"]
jerry_data["frequency"] = "daily"
# success for the first sheet, bad quota on second
response = client_jerry.post("/sheet/create", json=jerry_data) response = client_jerry.post("/sheet/create", json=jerry_data)
assert response.status_code == 201 assert response.status_code == 201
@@ -144,12 +154,6 @@ def test_delete_sheet_endpoint(client_with_auth, db_session):
assert response.json() == {"id": "456-sheet-id", "deleted": False} assert response.json() == {"id": "456-sheet-id", "deleted": False}
# def test_archive_user_sheet_endpoint(client_with_auth):
# response = client_with_auth.post("/sheet/123-sheet-id/archive")
# assert response.status_code == 201
# assert "id" in response.json()
class TestArchiveUserSheetEndpoint: class TestArchiveUserSheetEndpoint:
def test_token_auth(self, client_with_token, test_no_auth): def test_token_auth(self, client_with_token, test_no_auth):
test_no_auth(client_with_token.post, "/sheet/123-sheet-id/archive") test_no_auth(client_with_token.post, "/sheet/123-sheet-id/archive")
@@ -177,6 +181,14 @@ class TestArchiveUserSheetEndpoint:
assert r.json() == {"id": "123-taskid"} assert r.json() == {"id": "123-taskid"}
m1.assert_called_once() m1.assert_called_once()
def test_user_not_in_group(self, client_with_auth, db_session):
from db import models
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="interdimensional", frequency="hourly"))
db_session.commit()
r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == 403
assert r.json() == {"detail": "User does not have access to this group."}
class TestTokenArchiveEndpoint: class TestTokenArchiveEndpoint:

View File

@@ -10,11 +10,13 @@ def test_archive_url_unauthenticated(client, test_no_auth):
@patch("worker.main.create_archive_task.delay", return_value=TaskResult(id="123-456-789", status="PENDING", result="")) @patch("worker.main.create_archive_task.delay", return_value=TaskResult(id="123-456-789", status="PENDING", result=""))
def test_archive_url(m1, client_with_auth): def test_archive_url(m1, client_with_auth):
# url is too short
response = client_with_auth.post("/url/archive", json={"url": "bad"}) response = client_with_auth.post("/url/archive", json={"url": "bad"})
assert response.status_code == 422 assert response.status_code == 422
assert response.json() == {'detail': 'Invalid URL received: bad'} assert response.json()["detail"][0]["msg"] == 'String should have at least 5 characters'
m1.assert_not_called() m1.assert_not_called()
# valid request
response = client_with_auth.post("/url/archive", json={"url": "https://example.com"}) response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
assert response.status_code == 201 assert response.status_code == 201
assert response.json() == {'id': '123-456-789'} assert response.json() == {'id': '123-456-789'}
@@ -23,6 +25,20 @@ def test_archive_url(m1, client_with_auth):
called_val = m1.call_args.args[0] called_val = m1.call_args.args[0]
assert json.loads(called_val) == {"id": None, "url": "https://example.com", "result": None, "public": True, "author_id": "rick@example.com", "group_id": None, "tags": [], "rearchive": True} assert json.loads(called_val) == {"id": None, "url": "https://example.com", "result": None, "public": True, "author_id": "rick@example.com", "group_id": None, "tags": [], "rearchive": True}
# user is not in group
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "new-group"})
assert response.status_code == 403
assert response.json()["detail"] == "User does not have access to this group."
# user is in group
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spaceship"})
assert response.status_code == 201
assert response.json() == {'id': '123-456-789'}
assert m1.call_count == 2
called_val = m1.call_args.args[0]
assert json.loads(called_val)["group_id"] == "spaceship"
def test_search_by_url_unauthenticated(client, test_no_auth): def test_search_by_url_unauthenticated(client, test_no_auth):
test_no_auth(client.get, "/url/search") test_no_auth(client.get, "/url/search")

View File

@@ -23,6 +23,8 @@ orchestrators:
interdimensional: tests/orchestration.test.yaml interdimensional: tests/orchestration.test.yaml
default: tests/orchestration.test.yaml default: tests/orchestration.test.yaml
default_orchestrator: tests/orchestration.test.yaml
groups: groups:
spaceship: spaceship:
description: "The spaceship crew" description: "The spaceship crew"
@@ -31,9 +33,9 @@ groups:
permissions: permissions:
read: ["all"] read: ["all"]
active_sheets: -1 active_sheets: -1
monthly_urls: all monthly_urls: -1
monthly_mbs: all monthly_mbs: -1
alowed_frequency: "hourly" allowed_frequency: "hourly"
interdimensional: interdimensional:
description: "Interdimensional travelers" description: "Interdimensional travelers"
orchestrator: tests/orchestration.test.yaml orchestrator: tests/orchestration.test.yaml
@@ -43,7 +45,7 @@ groups:
active_sheets: 5 active_sheets: 5
monthly_urls: 1000 monthly_urls: 1000
monthly_mbs: 1000 monthly_mbs: 1000
alowed_frequency: "hourly" allowed_frequency: "hourly"
animated-characters: animated-characters:
description: "Animated characters" description: "Animated characters"
orchestrator: tests/orchestration.test.yaml orchestrator: tests/orchestration.test.yaml
@@ -53,4 +55,4 @@ groups:
active_sheets: 1 active_sheets: 1
monthly_urls: 2 monthly_urls: 2
monthly_mbs: 10 monthly_mbs: 10
alowed_frequency: "daily" allowed_frequency: "daily"

View File

@@ -122,14 +122,6 @@ class Test_create_sheet_task():
assert db_session.query(models.Archive).filter(models.Archive.url == self.URL).count() == 0 assert db_session.query(models.Archive).filter(models.Archive.url == self.URL).count() == 0
@patch("worker.main.is_group_invalid_for_user", return_value="Access denied")
def test_error_access(self, m_insert, worker_init, db_session):
from worker.main import create_sheet_task
res = create_sheet_task(self.sheet.model_dump_json())
assert "error" in res
assert res["error"] == "Access denied"
def test_choose_orchestrator(worker_init): def test_choose_orchestrator(worker_init):
from worker.main import choose_orchestrator from worker.main import choose_orchestrator

View File

@@ -131,17 +131,19 @@ def app_factory(settings = get_settings()):
@app.post("/sheet", status_code=201, deprecated=True) # DEPRECATED @app.post("/sheet", status_code=201, deprecated=True) # DEPRECATED
def archive_sheet(sheet: schemas.SubmitSheet, email=Depends(get_user_auth)): def archive_sheet(sheet: schemas.SubmitSheet, email=Depends(get_user_auth), db: Session = Depends(get_db_dependency)):
logger.info(f"SHEET TASK for {sheet=}") logger.info(f"SHEET TASK for {sheet=}")
sheet.author_id = email sheet.author_id = email
if not sheet.sheet_name and not sheet.sheet_id: if not sheet.sheet_name and not sheet.sheet_id:
raise HTTPException(status_code=422, detail=f"sheet name or id is required") raise HTTPException(status_code=422, detail=f"sheet name or id is required")
if not crud.is_user_in_group(db, email, sheet.group_id):
raise HTTPException(status_code=403, detail="User does not have access to this group.")
task = create_sheet_task.delay(sheet.model_dump_json()) task = create_sheet_task.delay(sheet.model_dump_json())
return JSONResponse({"id": task.id}) return JSONResponse({"id": task.id})
@app.post("/sheet_service", status_code=201, deprecated=True) # DEPRECATED @app.post("/sheet_service", status_code=201, deprecated=True) # DEPRECATED
def archive_sheet_service(sheet: schemas.SubmitSheet, auth=Depends(token_api_key_auth)): def archive_sheet_service(sheet: schemas.SubmitSheet, auth=Depends(token_api_key_auth), db: Session = Depends(get_db_dependency)):
logger.info(f"SHEET TASK for {sheet=}") logger.info(f"SHEET TASK for {sheet=}")
sheet.author_id = sheet.author_id or "api-endpoint" sheet.author_id = sheet.author_id or "api-endpoint"
if not sheet.sheet_name and not sheet.sheet_id: if not sheet.sheet_name and not sheet.sheet_id:

View File

@@ -6,6 +6,7 @@ from core.config import ALLOW_ANY_EMAIL
from shared.settings import get_settings from shared.settings import get_settings
from db.database import get_db from db.database import get_db
from db import crud from db import crud
from db.user_state import UserState
settings = get_settings() settings = get_settings()
bearer_security = HTTPBearer() bearer_security = HTTPBearer()
@@ -84,3 +85,8 @@ def authenticate_user(access_token):
except Exception as e: except Exception as e:
logger.warning(f"AUTH EXCEPTION occurred: {e}") logger.warning(f"AUTH EXCEPTION occurred: {e}")
return False, "exception occurred" return False, "exception occurred"
def get_active_user_state(email=Depends(get_active_user_auth)):
with get_db() as db:
return UserState(db, email, active=True)

View File

@@ -1,4 +1,5 @@
from functools import lru_cache
import traceback, yaml, datetime import traceback, yaml, datetime
from typing import List, Set from typing import List, Set
@@ -30,6 +31,7 @@ Rdis = redis.Redis.from_url(celery.conf.broker_url)
def create_archive_task(self, archive_json: str): def create_archive_task(self, archive_json: str):
archive = schemas.ArchiveCreate.model_validate_json(archive_json) archive = schemas.ArchiveCreate.model_validate_json(archive_json)
logger.info(f"Archiving {archive.url=} {archive.tags=} {archive.public=} {archive.group_id=} {archive.author_id=}") logger.info(f"Archiving {archive.url=} {archive.tags=} {archive.public=} {archive.group_id=} {archive.author_id=}")
#TODO: move group checks out of here
invalid = is_group_invalid_for_user(archive.public, archive.group_id, archive.author_id) invalid = is_group_invalid_for_user(archive.public, archive.group_id, archive.author_id)
if invalid: if invalid:
raise Exception(invalid) # marks task FAILED, saves the Exception as result raise Exception(invalid) # marks task FAILED, saves the Exception as result
@@ -64,10 +66,6 @@ def create_sheet_task(self, sheet_json: str):
sheet.tags.add("gsheet") sheet.tags.add("gsheet")
logger.info(f"SHEET START {sheet=}") logger.info(f"SHEET START {sheet=}")
#TODO: should this check live here?
if (em := is_group_invalid_for_user(sheet.public, sheet.group_id, sheet.author_id)):
return {"error": em}
config = Config() config = Config()
# TODO: use choose_orchestrator and overwrite the feeder # TODO: use choose_orchestrator and overwrite the feeder
# TODO: drop sheet_name and use only sheet_id (new endpoints/models) # TODO: drop sheet_name and use only sheet_id (new endpoints/models)
@@ -161,7 +159,7 @@ def is_group_invalid_for_user(public: bool, group_id: str, author_id: str):
# otherwise group must match # otherwise group must match
with get_db() as session: with get_db() as session:
if not crud.is_user_in_group(session, group_id, author_id): if not crud.is_user_in_group(session, author_id, group_id):
logger.error(em := f"User {author_id} is not part of {group_id}, no permission") logger.error(em := f"User {author_id} is not part of {group_id}, no permission")
return em return em
return False return False
@@ -220,3 +218,13 @@ def at_start(sender, **kwargs):
ORCHESTRATORS = {} ORCHESTRATORS = {}
load_orchestrators() load_orchestrators()
logger.info("Orchestrators loaded successfully.") logger.info("Orchestrators loaded successfully.")
@lru_cache
def get_url_orchestrator(group_name):
with get_db() as db:
group = crud.get_group(db, group_name)
assert group, f"Group {group_name} not found"
# config = Config()
# config.parse(use_cli=False, yaml_config_filename=group.orchestrator_sheet)
# return ArchivingOrchestrator(config)