mirror of
https://github.com/bellingcat/auto-archiver-api.git
synced 2026-06-08 03:28:35 +03:00
pushing bulk of changes
This commit is contained in:
@@ -21,6 +21,8 @@ orchestration must be from the console(?)
|
||||
* turn off VPNs if connection to docker is not working
|
||||
|
||||
## User management
|
||||
TODO: update description and example
|
||||
- users/domains/groups
|
||||
Copy [example.user-groups.yaml](src/example.user-groups.yaml) into a new file and set the environment variable `USER_GROUPS_FILENAME` to that filename (defaults to `user-groups.yaml`).
|
||||
|
||||
This file contains 2 parts user-groups specifications. Each user can archive URLs publicly, privately, or privately for a group so long as they are declared as part of that group. In the example bellow `email1` has 2 groups while `email3` has none.
|
||||
|
||||
@@ -4,6 +4,7 @@ verify_ssl = true
|
||||
name = "pypi"
|
||||
|
||||
[packages]
|
||||
oscrypto = {git = "https://github.com/wbond/oscrypto.git", ref = "d5f3437ed24257895ae1edd9e503cfb352e635a8"}
|
||||
aiofiles = "==0.6.0"
|
||||
celery = ">=5.0"
|
||||
fastapi = "*"
|
||||
|
||||
3759
src/Pipfile.lock
generated
3759
src/Pipfile.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -1 +0,0 @@
|
||||
based on https://fastapi-users.github.io/fastapi-users/10.4/configuration/oauth/
|
||||
@@ -1,11 +1,12 @@
|
||||
from collections import defaultdict
|
||||
from functools import cache
|
||||
from functools import lru_cache
|
||||
from sqlalchemy.orm import Session, load_only
|
||||
from sqlalchemy import Column, or_, func
|
||||
from loguru import logger
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from core.config import ALLOW_ANY_EMAIL
|
||||
from db.database import get_db
|
||||
from shared.settings import get_settings
|
||||
from . import models, schemas
|
||||
import yaml
|
||||
@@ -23,7 +24,7 @@ def get_archive(db: Session, id: str, email: str):
|
||||
email = email.lower()
|
||||
query = base_query(db).filter(models.Archive.id == id)
|
||||
if email != ALLOW_ANY_EMAIL:
|
||||
groups = get_user_groups(db, email)
|
||||
groups = get_user_groups(email)
|
||||
query = query.filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups)))
|
||||
return query.first()
|
||||
|
||||
@@ -33,7 +34,7 @@ def search_archives_by_url(db: Session, url: str, email: str, skip: int = 0, lim
|
||||
query = base_query(db)
|
||||
if email != ALLOW_ANY_EMAIL:
|
||||
email = email.lower()
|
||||
groups = get_user_groups(db, email)
|
||||
groups = get_user_groups(email)
|
||||
query = query.filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups)))
|
||||
if absolute_search:
|
||||
query = query.filter(models.Archive.url == url)
|
||||
@@ -121,72 +122,37 @@ def is_active_user(db: Session, email: str) -> bool:
|
||||
return db.query(models.Group).filter(models.Group.domains.contains(domain)).first() is not None
|
||||
|
||||
|
||||
def is_user_in_group(db: Session, group_name: str, email: str) -> models.Group:
|
||||
def is_user_in_group(db: Session, email: str, group_name: str) -> models.Group:
|
||||
if email == ALLOW_ANY_EMAIL: return True
|
||||
return len(group_name) and len(email) and group_name in get_user_groups(db, email)
|
||||
return len(group_name) and len(email) and group_name in get_user_groups(email)
|
||||
|
||||
|
||||
#TODO: maybe this can be cached? what about the db session?
|
||||
def get_user_groups(db: Session, email: str) -> list[str]:
|
||||
@lru_cache
|
||||
def get_user_groups(email: str) -> list[str]:
|
||||
"""
|
||||
given an email retrieves the user groups from the DB and then the email-domain groups from a global variable, the email does not need to belong to an existing user. User does not need to be active.
|
||||
"""
|
||||
if not email or not len(email) or "@" not in email: return []
|
||||
email = email.lower()
|
||||
|
||||
# get user groups
|
||||
user_groups = db.query(models.association_table_user_groups).filter_by(user_id=email).with_entities(Column("group_id")).all()
|
||||
user_level_groups_names = [g[0] for g in user_groups]
|
||||
with get_db() as db:
|
||||
# get user groups
|
||||
user_groups = db.query(models.association_table_user_groups).filter_by(user_id=email).with_entities(Column("group_id")).all()
|
||||
user_level_groups_names = [g[0] for g in user_groups]
|
||||
|
||||
# get domain groups
|
||||
domain = email.split('@')[1]
|
||||
domain_level_groups = db.query(models.Group.id).filter(models.Group.domains.contains(domain)).with_entities(Column("id")).all()
|
||||
domain_level_groups_names = [g[0] for g in domain_level_groups]
|
||||
# get domain groups
|
||||
domain = email.split('@')[1]
|
||||
domain_level_groups = db.query(models.Group.id).filter(models.Group.domains.contains(domain)).with_entities(Column("id")).all()
|
||||
domain_level_groups_names = [g[0] for g in domain_level_groups]
|
||||
|
||||
return list(set(user_level_groups_names + domain_level_groups_names))
|
||||
|
||||
|
||||
# --------------- SHEET
|
||||
|
||||
def has_quota_sheet(db: Session, email: str, user_groups_names: list[str]) -> bool:
|
||||
"""
|
||||
checks if a user has reached their sheet quota
|
||||
"""
|
||||
user_sheets = db.query(models.Sheet).filter(models.Sheet.author_id == email).count()
|
||||
|
||||
user_groups = db.query(models.Group).filter(models.Group.id.in_(user_groups_names)).all()
|
||||
|
||||
quota = 0
|
||||
for group in user_groups:
|
||||
active_sheets = group.permissions.get("active_sheets", 0)
|
||||
if active_sheets == -1: return True
|
||||
quota = max(quota, active_sheets)
|
||||
return user_sheets < quota
|
||||
|
||||
|
||||
def create_sheet(db: Session, sheet_id: str, sheet_name: str, email: str, group_id: str, frequency: str):
|
||||
db_sheet = models.Sheet(id=sheet_id, name=sheet_name, author_id=email, group_id=group_id, frequency=frequency)
|
||||
db.add(db_sheet)
|
||||
db.commit()
|
||||
db.refresh(db_sheet)
|
||||
return db_sheet
|
||||
|
||||
def get_user_sheets(db: Session, email: str) -> list[models.Sheet]:
|
||||
return db.query(models.Sheet).filter(models.Sheet.author_id == email).order_by(models.Sheet.last_archived_at.desc()).all()
|
||||
|
||||
def get_user_sheet(db: Session, email: str, sheet_id: str) -> models.Sheet:
|
||||
return db.query(models.Sheet).filter(models.Sheet.author_id == email, models.Sheet.id == sheet_id).first()
|
||||
|
||||
def delete_sheet(db: Session, sheet_id: str, email: str) -> bool:
|
||||
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id, models.Sheet.author_id == email).first()
|
||||
if db_sheet:
|
||||
db.delete(db_sheet)
|
||||
db.commit()
|
||||
return db_sheet is not None
|
||||
return list(set(user_level_groups_names + domain_level_groups_names))
|
||||
|
||||
|
||||
# --------------- INIT User-Groups
|
||||
|
||||
def get_group(db: Session, group_name: str) -> models.Group:
|
||||
return db.query(models.Group).filter(models.Group.id == group_name).first()
|
||||
|
||||
|
||||
def create_or_get_user(db: Session, author_id: str, is_active: bool = models.User.is_active.default.arg) -> models.User:
|
||||
if type(author_id) == str: author_id = author_id.lower()
|
||||
@@ -296,3 +262,28 @@ def upsert_user_groups(db: Session):
|
||||
count_groups = db.query(func.count(models.Group.id)).scalar()
|
||||
|
||||
logger.success(f"[CONFIG] DONE: [users={count_users(db)}, groups={count_groups}, explicit user groups={count_user_groups}].")
|
||||
|
||||
|
||||
# --------------- SHEET
|
||||
def create_sheet(db: Session, sheet_id: str, sheet_name: str, email: str, group_id: str, frequency: str):
|
||||
db_sheet = models.Sheet(id=sheet_id, name=sheet_name, author_id=email, group_id=group_id, frequency=frequency)
|
||||
db.add(db_sheet)
|
||||
db.commit()
|
||||
db.refresh(db_sheet)
|
||||
return db_sheet
|
||||
|
||||
|
||||
def get_user_sheet(db: Session, email: str, sheet_id: str) -> models.Sheet:
|
||||
return db.query(models.Sheet).filter(models.Sheet.author_id == email, models.Sheet.id == sheet_id).first()
|
||||
|
||||
|
||||
def get_user_sheets(db: Session, email: str) -> list[models.Sheet]:
|
||||
return db.query(models.Sheet).filter(models.Sheet.author_id == email).order_by(models.Sheet.last_archived_at.desc()).all()
|
||||
|
||||
|
||||
def delete_sheet(db: Session, sheet_id: str, email: str) -> bool:
|
||||
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id, models.Sheet.author_id == email).first()
|
||||
if db_sheet:
|
||||
db.delete(db_sheet)
|
||||
db.commit()
|
||||
return db_sheet is not None
|
||||
@@ -33,4 +33,4 @@ def get_db():
|
||||
def get_db_dependency():
|
||||
# to use with Depends and ensure proper session closing
|
||||
with get_db() as db:
|
||||
yield db
|
||||
yield db
|
||||
|
||||
@@ -87,7 +87,7 @@ class Group(Base):
|
||||
description = Column(String, default=None)
|
||||
orchestrator = Column(String, default=None)
|
||||
orchestrator_sheet = Column(String, default=None)
|
||||
permissions = Column(JSON, default=None)
|
||||
permissions = Column(JSON, default={})
|
||||
domains = Column(JSON, default=[])
|
||||
|
||||
archives = relationship("Archive", back_populates="group")
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Annotated
|
||||
from annotated_types import Len
|
||||
from pydantic import BaseModel, field_validator
|
||||
from datetime import datetime
|
||||
|
||||
@@ -105,3 +107,10 @@ class SheetResponse(SheetAdd):
|
||||
stats: dict | None
|
||||
last_archived_at: datetime | None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class ArchiveTrigger(BaseModel):
|
||||
url: Annotated[str, Len(min_length=5)]
|
||||
public: bool = True
|
||||
group_id: Annotated[str, Len(min_length=1)] | None = None
|
||||
tags: set[Tag] | None = set()
|
||||
|
||||
142
src/db/user_state.py
Normal file
142
src/db/user_state.py
Normal file
@@ -0,0 +1,142 @@
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
from db import crud, models
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class UserState:
|
||||
"""
|
||||
Manage a user's state and permissions
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session, email: str, active=False):
|
||||
self.db = db
|
||||
self.email = email
|
||||
self.active = active
|
||||
|
||||
@property
|
||||
def user_groups_names(self):
|
||||
if not hasattr(self, '_user_groups_names'):
|
||||
self._user_groups_names = crud.get_user_groups(self.email)
|
||||
return self._user_groups_names
|
||||
|
||||
@property
|
||||
def user_groups(self):
|
||||
if not hasattr(self, '_user_groups'):
|
||||
self._user_groups = self.db.query(models.Group).filter(
|
||||
models.Group.id.in_(self.user_groups_names)
|
||||
).all()
|
||||
return self._user_groups
|
||||
|
||||
@property
|
||||
def allowed_frequencies(self):
|
||||
if not hasattr(self, '_allowed_frequencies'):
|
||||
self._allowed_frequencies = set()
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
self._allowed_frequencies.add(group.permissions.get("allowed_frequency", None))
|
||||
if "hourly" in self._allowed_frequencies:
|
||||
self._allowed_frequencies.add("daily")
|
||||
return self._allowed_frequencies
|
||||
|
||||
@property
|
||||
def sheet_quota(self):
|
||||
"""
|
||||
infer the user's sheet quota from the groups
|
||||
-1 means unlimited
|
||||
"""
|
||||
if not hasattr(self, '_sheet_quota'):
|
||||
self._sheet_quota = 0
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
active_sheets = group.permissions.get("active_sheets", 0)
|
||||
if active_sheets == -1:
|
||||
self._sheet_quota = -1
|
||||
return self._sheet_quota
|
||||
self._sheet_quota = max(self._sheet_quota, active_sheets)
|
||||
|
||||
return self._sheet_quota
|
||||
|
||||
def in_group(self, group_id: str) -> bool:
|
||||
return group_id in self.user_groups_names
|
||||
|
||||
def has_quota_sheet(self) -> bool:
|
||||
"""
|
||||
checks if a user has reached their sheet quota
|
||||
"""
|
||||
if self.sheet_quota == -1: return True
|
||||
|
||||
user_sheets = self.db.query(models.Sheet).filter(models.Sheet.author_id == self.email).count()
|
||||
|
||||
return user_sheets < self.sheet_quota
|
||||
|
||||
def has_quota_monthly_urls(self) -> bool:
|
||||
"""
|
||||
checks if a user has reached their monthly url quota
|
||||
"""
|
||||
quota = 0
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
monthly_urls = group.permissions.get("monthly_urls", 0)
|
||||
if monthly_urls == -1: return True
|
||||
quota = max(quota, monthly_urls)
|
||||
|
||||
current_month = datetime.now().month
|
||||
current_year = datetime.now().year
|
||||
user_urls = self.db.query(models.Archive).filter(
|
||||
models.Archive.author_id == self.email,
|
||||
func.extract('month', models.Archive.created_at) == current_month,
|
||||
func.extract('year', models.Archive.created_at) == current_year
|
||||
).count()
|
||||
|
||||
return user_urls < quota
|
||||
|
||||
def has_quota_monthly_mbs(self) -> bool:
|
||||
"""
|
||||
checks if a user has reached their monthly mb quota
|
||||
"""
|
||||
quota = 0
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
monthly_mbs = group.permissions.get("monthly_mbs", 0)
|
||||
if monthly_mbs == -1: return True
|
||||
quota = max(quota, monthly_mbs)
|
||||
|
||||
current_month = datetime.now().month
|
||||
current_year = datetime.now().year
|
||||
|
||||
# find and sum all user bytes over this month
|
||||
user_bytes = self.db.query(models.Archive).filter(
|
||||
models.Archive.author_id == self.email,
|
||||
func.extract('month', models.Archive.created_at) == current_month,
|
||||
func.extract('year', models.Archive.created_at) == current_year
|
||||
).with_entities(func.coalesce(func.sum(
|
||||
func.coalesce(
|
||||
func.cast(
|
||||
func.json_extract(models.Archive.result, '$.metadata.total_bytes'),
|
||||
sqlalchemy.Integer
|
||||
), 0
|
||||
)
|
||||
), 0).label('total')).scalar()
|
||||
|
||||
# convert bytes to mb
|
||||
user_mbs = int(user_bytes / 1024 / 1024)
|
||||
return user_mbs < quota
|
||||
|
||||
# def can_manually_trigger(self) -> bool:
|
||||
# """
|
||||
# checks if a user is allowed to manually trigger a sheet
|
||||
# """
|
||||
# for group in self.user_groups:
|
||||
# if not group.permissions: continue
|
||||
# if group.permissions.get("manual_trigger", False):
|
||||
# return True
|
||||
# return False
|
||||
|
||||
def is_sheet_frequency_allowed(self, frequency: str) -> bool:
|
||||
"""
|
||||
checks if a user is allowed to create a sheet with this frequency
|
||||
"""
|
||||
return frequency in self.allowed_frequencies
|
||||
@@ -6,8 +6,9 @@ from sqlalchemy.orm import Session
|
||||
from core.config import VERSION, BREAKING_CHANGES
|
||||
from core.logging import log_error
|
||||
from db import crud, schemas
|
||||
from db.database import get_db_dependency, get_db
|
||||
from web.security import get_user_auth, bearer_security
|
||||
from db.database import get_db_dependency
|
||||
from db.user_state import UserState
|
||||
from web.security import get_user_auth, bearer_security, get_active_user_state
|
||||
|
||||
default_router = APIRouter()
|
||||
|
||||
@@ -18,8 +19,7 @@ async def home(request: Request):
|
||||
status = {"version": VERSION, "breakingChanges": BREAKING_CHANGES}
|
||||
try:
|
||||
email = await get_user_auth(await bearer_security(request))
|
||||
with get_db() as db:
|
||||
status["groups"] = crud.get_user_groups(db, email)
|
||||
status["groups"] = crud.get_user_groups(email)
|
||||
except HTTPException: pass # not authenticated is fine
|
||||
except Exception as e: log_error(e)
|
||||
return JSONResponse(status)
|
||||
@@ -31,13 +31,28 @@ async def health():
|
||||
|
||||
|
||||
@default_router.get("/user/active", summary="Check if the user is active and can use the tool.")
|
||||
# TODO: reorder db dependencies to after auth
|
||||
async def active(db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)) -> schemas.ActiveUser:
|
||||
return {"active": crud.is_active_user(db, email)}
|
||||
|
||||
|
||||
@default_router.get("/groups")
|
||||
def get_user_groups(db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)) -> list[str]:
|
||||
return crud.get_user_groups(db, email)
|
||||
def get_user_groups(email=Depends(get_user_auth)) -> list[str]:
|
||||
return crud.get_user_groups(email)
|
||||
|
||||
|
||||
@default_router.get("/permissions")
|
||||
def get_user_groups(
|
||||
user: UserState = Depends(get_active_user_state),
|
||||
) -> list[str]:
|
||||
return JSONResponse({
|
||||
"groups": user.user_groups_names,
|
||||
"allowedFrequencies": list(user.allowed_frequencies),
|
||||
"sheet_quota": user.sheet_quota,
|
||||
"monthly_urls": user.monthly_urls,
|
||||
"monthly_mbs": user.monthly_mbs,
|
||||
#TODO: should this return
|
||||
})
|
||||
|
||||
|
||||
@default_router.get('/favicon.ico', include_in_schema=False)
|
||||
|
||||
@@ -5,7 +5,8 @@ from fastapi.responses import JSONResponse
|
||||
from sqlalchemy import exc
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from web.security import token_api_key_auth, get_active_user_auth
|
||||
from db.user_state import UserState
|
||||
from web.security import token_api_key_auth, get_active_user_auth, get_active_user_state
|
||||
from db import schemas, crud
|
||||
from db.database import get_db_dependency
|
||||
from worker.main import create_sheet_task
|
||||
@@ -16,19 +17,21 @@ sheet_router = APIRouter(prefix="/sheet", tags=["Google Spreadsheet operations"]
|
||||
@sheet_router.post("/create", status_code=201, summary="Store a new Google Sheet for regular archiving.")
|
||||
def create_sheet(
|
||||
sheet: schemas.SheetAdd,
|
||||
email=Depends(get_active_user_auth),
|
||||
user: UserState = Depends(get_active_user_state),
|
||||
db: Session = Depends(get_db_dependency),
|
||||
) -> schemas.SheetResponse:
|
||||
user_groups_names = crud.get_user_groups(db, email)
|
||||
|
||||
if sheet.group_id not in user_groups_names:
|
||||
if not user.in_group(sheet.group_id):
|
||||
raise HTTPException(status_code=403, detail="User does not have access to this group.")
|
||||
|
||||
if not crud.has_quota_sheet(db, email, user_groups_names):
|
||||
if not user.has_quota_sheet():
|
||||
raise HTTPException(status_code=429, detail="User has reached their sheet quota.")
|
||||
|
||||
if not user.is_sheet_frequency_allowed(sheet.frequency):
|
||||
raise HTTPException(status_code=422, detail=f"Invalid frequency: {sheet.frequency}. Must be one of {user.allowed_frequencies}")
|
||||
|
||||
try:
|
||||
return crud.create_sheet(db, sheet.id, sheet.name, email, sheet.group_id, sheet.frequency)
|
||||
return crud.create_sheet(db, sheet.id, sheet.name, user.email, sheet.group_id, sheet.frequency)
|
||||
except exc.IntegrityError as e:
|
||||
raise HTTPException(status_code=400, detail="Sheet with this ID already exists.") from e
|
||||
|
||||
@@ -56,22 +59,30 @@ def delete_sheet(
|
||||
@sheet_router.post("/{id}/archive", status_code=201, summary="Trigger an archiving task for a GSheet you own.", response_description="task_id for the archiving task.")
|
||||
def archive_user_sheet(
|
||||
id: str,
|
||||
email=Depends(get_active_user_auth),
|
||||
user: UserState = Depends(get_active_user_state),
|
||||
db: Session = Depends(get_db_dependency),
|
||||
) -> schemas.Task:
|
||||
|
||||
#TODO: are we enabling manual triggers?
|
||||
# if not user.can_manually_trigger():
|
||||
# raise HTTPException(status_code=429, detail="User cannot manually trigger archiving tasks.")
|
||||
|
||||
sheet = crud.get_user_sheet(db, email, sheet_id=id)
|
||||
sheet = crud.get_user_sheet(db, user.email, sheet_id=id)
|
||||
if not sheet:
|
||||
raise HTTPException(status_code=403, detail="No access to this sheet.")
|
||||
|
||||
task = create_sheet_task.delay(schemas.SubmitSheet(sheet_id=id, author_id=email, group=sheet.group_id).model_dump_json())
|
||||
# TODO: what happens if user is taken out of group after sheet is created? this should be checked in a cronjob that notifies the user
|
||||
if not user.in_group(sheet.group_id):
|
||||
raise HTTPException(status_code=403, detail="User does not have access to this group.")
|
||||
|
||||
task = create_sheet_task.delay(schemas.SubmitSheet(sheet_id=id, author_id=user.email, group=sheet.group_id).model_dump_json())
|
||||
|
||||
return JSONResponse({"id": task.id}, status_code=201)
|
||||
|
||||
|
||||
@sheet_router.post("/archive", status_code=201, summary="Trigger an archiving task for any GSheet with an API token.", response_description="task_id for the archiving task.")
|
||||
def archive_sheet(
|
||||
sheet: schemas.SubmitSheet, #TODO: replace with simpler model
|
||||
sheet: schemas.SubmitSheet, # TODO: replace with simpler model
|
||||
auth=Depends(token_api_key_auth)
|
||||
) -> schemas.Task:
|
||||
sheet.author_id = sheet.author_id or "api-endpoint"
|
||||
|
||||
@@ -8,7 +8,7 @@ from web.security import get_user_auth, get_token_or_user_auth
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from db import crud, schemas
|
||||
from db.database import get_db_dependency
|
||||
from db.database import get_db, get_db_dependency
|
||||
|
||||
from worker.main import create_archive_task
|
||||
|
||||
@@ -16,14 +16,28 @@ url_router = APIRouter(prefix="/url", tags=["Single URL operations"])
|
||||
|
||||
|
||||
@url_router.post("/archive", status_code=201, summary="Submit a single URL archive request, starts an archiving task.", response_description="task_id for the archiving task, will match the archive id.")
|
||||
def archive_url(archive: schemas.ArchiveCreate, email=Depends(get_token_or_user_auth)) -> schemas.Task:
|
||||
archive.author_id = email
|
||||
url = archive.url
|
||||
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {url}")
|
||||
if type(url) != str or len(url) <= 5:
|
||||
raise HTTPException(status_code=422, detail=f"Invalid URL received: {url}")
|
||||
logger.info("creating task")
|
||||
task = create_archive_task.delay(archive.model_dump_json())
|
||||
def archive_url(
|
||||
archive: schemas.ArchiveTrigger,
|
||||
email=Depends(get_token_or_user_auth)
|
||||
) -> schemas.Task:
|
||||
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {archive.url}")
|
||||
|
||||
# TODO: implement quota
|
||||
|
||||
if archive.group_id:
|
||||
with get_db() as db:
|
||||
if not crud.is_user_in_group(db, email, archive.group_id):
|
||||
raise HTTPException(status_code=403, detail="User does not have access to this group.")
|
||||
|
||||
# TODO: deprecate ArchiveCreate
|
||||
backwards_compatible_archive = schemas.ArchiveCreate(
|
||||
url=archive.url,
|
||||
author_id=email,
|
||||
group_id=archive.group_id,
|
||||
public=archive.public,
|
||||
)
|
||||
|
||||
task = create_archive_task.delay(backwards_compatible_archive.model_dump_json())
|
||||
task_response = schemas.Task(id=task.id)
|
||||
return JSONResponse(task_response.model_dump(), status_code=201)
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import os
|
||||
from fastapi.testclient import TestClient
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from db.user_state import UserState
|
||||
from shared.settings import Settings
|
||||
|
||||
|
||||
@@ -27,7 +28,9 @@ def mock_settings():
|
||||
def test_db(get_settings: Settings):
|
||||
from db.database import make_engine
|
||||
from db import models
|
||||
from db.crud import get_user_groups
|
||||
|
||||
get_user_groups.cache_clear()
|
||||
make_engine.cache_clear()
|
||||
engine = make_engine(get_settings.DATABASE_PATH)
|
||||
|
||||
@@ -72,11 +75,12 @@ def client(app):
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def app_with_auth(app):
|
||||
from web.security import get_token_or_user_auth, get_user_auth, get_active_user_auth
|
||||
def app_with_auth(app, db_session):
|
||||
from web.security import get_token_or_user_auth, get_user_auth, get_active_user_auth, get_active_user_state
|
||||
app.dependency_overrides[get_token_or_user_auth] = lambda: "rick@example.com"
|
||||
app.dependency_overrides[get_user_auth] = lambda: "morty@example.com"
|
||||
app.dependency_overrides[get_active_user_auth] = lambda: "morty@example.com"
|
||||
app.dependency_overrides[get_active_user_state] = lambda: UserState(db_session, "morty@example.com", active=True)
|
||||
return app
|
||||
|
||||
|
||||
|
||||
@@ -40,6 +40,12 @@ def test_data(db_session):
|
||||
archive.urls.append(models.ArchiveUrl(url=f"https://example-{i}.com/{j}", key=f"media_{j}"))
|
||||
db_session.add(archive)
|
||||
|
||||
# creates a sheet for each user
|
||||
for i, email in enumerate(authors):
|
||||
db_session.add(models.Sheet(id=f"sheet-{i}", name=f"sheet-{i}", author_id=email, group_id=None, frequency="daily"))
|
||||
if email == "rick@example.com":
|
||||
db_session.add(models.Sheet(id=f"sheet-{i}-2", name=f"sheet-{i}-2", author_id=email, group_id="spaceship", frequency="hourly"))
|
||||
|
||||
db_session.commit()
|
||||
|
||||
assert db_session.query(models.Archive).count() == 100
|
||||
@@ -253,6 +259,7 @@ def test_count_archive_urls(test_data, db_session):
|
||||
assert crud.count_archives(db_session) == 99
|
||||
assert crud.count_archive_urls(db_session) == 999
|
||||
|
||||
|
||||
def test_count_users(test_data, db_session):
|
||||
from db import crud
|
||||
|
||||
@@ -261,6 +268,7 @@ def test_count_users(test_data, db_session):
|
||||
db_session.commit()
|
||||
assert crud.count_users(db_session) == 3
|
||||
|
||||
|
||||
def test_count_by_users_since(test_data, db_session):
|
||||
from db import crud
|
||||
|
||||
@@ -294,6 +302,7 @@ def test_create_tag(db_session):
|
||||
assert second_tag.id == "tag-102"
|
||||
assert db_session.query(models.Tag).count() == 2
|
||||
|
||||
|
||||
def test_is_active_user(test_data, db_session):
|
||||
from db import crud
|
||||
|
||||
@@ -329,7 +338,7 @@ def test_is_user_in_group(test_data, db_session):
|
||||
|
||||
("jerry@example.com", "spaceship", False),
|
||||
("jerry@example.com", "interdimensional", False),
|
||||
("jerry@example.com", "the-jerrys-club", False), # group not in 'groups'
|
||||
("jerry@example.com", "the-jerrys-club", False), # group not in 'groups'
|
||||
|
||||
("rick@example.com", "animated-characters", True),
|
||||
("morty@example.com", "animated-characters", True),
|
||||
@@ -337,7 +346,7 @@ def test_is_user_in_group(test_data, db_session):
|
||||
("ANYONE@example.com", "animated-characters", True),
|
||||
("ANYONE@birdy.com", "animated-characters", True),
|
||||
|
||||
("summer@herself.com", "animated-characters", False),
|
||||
("summer@herself.com", "animated-characters", False),
|
||||
|
||||
("rick@example.com", "", False),
|
||||
("", "spaceship", False),
|
||||
@@ -345,7 +354,16 @@ def test_is_user_in_group(test_data, db_session):
|
||||
]
|
||||
for email, group, expected in test_pairs:
|
||||
print(f"{email} in {group} == {expected}")
|
||||
assert crud.is_user_in_group(db_session, group, email) == expected
|
||||
assert crud.is_user_in_group(db_session, email, group) == expected
|
||||
|
||||
|
||||
def test_get_group(test_data, db_session):
|
||||
from db import crud
|
||||
|
||||
assert crud.get_group(db_session, "spaceship") is not None
|
||||
assert crud.get_group(db_session, "interdimensional") is not None
|
||||
assert crud.get_group(db_session, "animated-characters") is not None
|
||||
assert crud.get_group(db_session, "non-existant!@#!%!") is None
|
||||
|
||||
|
||||
def test_create_or_get_user(test_data, db_session):
|
||||
@@ -403,13 +421,12 @@ def test_upsert_group(test_data, db_session):
|
||||
def test_upsert_user_groups(db_session):
|
||||
from db import crud
|
||||
|
||||
@patch('db.crud.get_settings', new = lambda: bad_setings)
|
||||
@patch('db.crud.get_settings', new=lambda: bad_setings)
|
||||
def test_missing_yaml(db_session):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
crud.upsert_user_groups(db_session)
|
||||
|
||||
|
||||
@patch('db.crud.get_settings', new = lambda: bad_setings)
|
||||
@patch('db.crud.get_settings', new=lambda: bad_setings)
|
||||
def test_broken_yaml(db_session):
|
||||
with pytest.raises(yaml.YAMLError):
|
||||
crud.upsert_user_groups(db_session)
|
||||
@@ -420,4 +437,54 @@ def test_upsert_user_groups(db_session):
|
||||
test_missing_yaml(db_session)
|
||||
|
||||
bad_setings.USER_GROUPS_FILENAME = "tests/user-groups.test.broken.yaml"
|
||||
test_broken_yaml(db_session)
|
||||
test_broken_yaml(db_session)
|
||||
|
||||
|
||||
def test_create_sheet(db_session):
|
||||
from db import crud
|
||||
|
||||
assert db_session.query(models.Sheet).count() == 0
|
||||
|
||||
s = crud.create_sheet(db_session, "sheet-id-123", "sheet name", "email@example.com", "group-id", "hourly")
|
||||
assert s is not None
|
||||
assert s.id == "sheet-id-123"
|
||||
assert s.name == "sheet name"
|
||||
assert s.author_id == "email@example.com"
|
||||
assert s.group_id == "group-id"
|
||||
assert s.frequency == "hourly"
|
||||
|
||||
assert db_session.query(models.Sheet).count() == 1
|
||||
|
||||
# duplicate id
|
||||
import sqlalchemy
|
||||
with pytest.raises(sqlalchemy.exc.IntegrityError):
|
||||
crud.create_sheet(db_session, "sheet-id-123", "I thought this was another sheet", "email", "group-id", "hourly")
|
||||
|
||||
|
||||
def test_get_user_sheet(test_data, db_session):
|
||||
from db import crud
|
||||
|
||||
assert crud.get_user_sheet(db_session, "", "sheet-0") is None
|
||||
assert crud.get_user_sheet(db_session, "morty@example.com", "sheet-0") is None
|
||||
|
||||
assert crud.get_user_sheet(db_session, "rick@example.com", "sheet-0") is not None
|
||||
assert crud.get_user_sheet(db_session, "rick@example.com", "sheet-0-2") is not None
|
||||
assert crud.get_user_sheet(db_session, "morty@example.com", "sheet-1") is not None
|
||||
|
||||
|
||||
def test_get_user_sheets(test_data, db_session):
|
||||
from db import crud
|
||||
|
||||
assert len(crud.get_user_sheets(db_session, "")) == 0
|
||||
rick_sheets = crud.get_user_sheets(db_session, "rick@example.com")
|
||||
assert len(rick_sheets) == 2
|
||||
assert [s.id for s in rick_sheets] == ["sheet-0", "sheet-0-2"]
|
||||
assert len(crud.get_user_sheets(db_session, "morty@example.com")) == 1
|
||||
|
||||
def test_delete_sheet(test_data, db_session):
|
||||
from db import crud
|
||||
|
||||
assert crud.delete_sheet(db_session, "sheet-0", "") == False
|
||||
assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") == True
|
||||
assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") == False
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ def test_endpoints_no_auth(client, test_no_auth):
|
||||
test_no_auth(client.post, "/sheet/archive")
|
||||
|
||||
|
||||
def test_create_sheet_endpoint(app_with_auth):
|
||||
def test_create_sheet_endpoint(app_with_auth, db_session):
|
||||
client_with_auth = TestClient(app_with_auth)
|
||||
good_data = {
|
||||
"id": "123-sheet-id",
|
||||
@@ -53,13 +53,23 @@ def test_create_sheet_endpoint(app_with_auth):
|
||||
assert response.status_code == 403
|
||||
assert response.json() == {"detail": "User does not have access to this group."}
|
||||
|
||||
# bad quota
|
||||
# switch to jerry who's got less quota/permissions
|
||||
from web.security import get_active_user_state
|
||||
from db.user_state import UserState
|
||||
app_with_auth.dependency_overrides[get_active_user_state] = lambda: UserState(db_session, "jerry@example.com", active=True)
|
||||
client_jerry = TestClient(app_with_auth)
|
||||
|
||||
# frequency not allowed
|
||||
jerry_data = good_data.copy()
|
||||
jerry_data["group_id"] = "animated-characters"
|
||||
jerry_data["frequency"] = "hourly"
|
||||
jerry_data["id"] = "jerry-sheet-id"
|
||||
from web.security import get_active_user_auth
|
||||
app_with_auth.dependency_overrides[get_active_user_auth] = lambda: "jerry@example.com"
|
||||
client_jerry = TestClient(app_with_auth)
|
||||
response = client_jerry.post("/sheet/create", json=jerry_data)
|
||||
assert response.status_code == 422
|
||||
assert "Invalid frequency: hourly" in response.json()["detail"]
|
||||
jerry_data["frequency"] = "daily"
|
||||
|
||||
# success for the first sheet, bad quota on second
|
||||
response = client_jerry.post("/sheet/create", json=jerry_data)
|
||||
assert response.status_code == 201
|
||||
|
||||
@@ -144,12 +154,6 @@ def test_delete_sheet_endpoint(client_with_auth, db_session):
|
||||
assert response.json() == {"id": "456-sheet-id", "deleted": False}
|
||||
|
||||
|
||||
# def test_archive_user_sheet_endpoint(client_with_auth):
|
||||
# response = client_with_auth.post("/sheet/123-sheet-id/archive")
|
||||
# assert response.status_code == 201
|
||||
# assert "id" in response.json()
|
||||
|
||||
|
||||
class TestArchiveUserSheetEndpoint:
|
||||
def test_token_auth(self, client_with_token, test_no_auth):
|
||||
test_no_auth(client_with_token.post, "/sheet/123-sheet-id/archive")
|
||||
@@ -177,6 +181,14 @@ class TestArchiveUserSheetEndpoint:
|
||||
assert r.json() == {"id": "123-taskid"}
|
||||
m1.assert_called_once()
|
||||
|
||||
def test_user_not_in_group(self, client_with_auth, db_session):
|
||||
from db import models
|
||||
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="interdimensional", frequency="hourly"))
|
||||
db_session.commit()
|
||||
r = client_with_auth.post("/sheet/123-sheet-id/archive")
|
||||
assert r.status_code == 403
|
||||
assert r.json() == {"detail": "User does not have access to this group."}
|
||||
|
||||
|
||||
class TestTokenArchiveEndpoint:
|
||||
|
||||
|
||||
@@ -10,11 +10,13 @@ def test_archive_url_unauthenticated(client, test_no_auth):
|
||||
|
||||
@patch("worker.main.create_archive_task.delay", return_value=TaskResult(id="123-456-789", status="PENDING", result=""))
|
||||
def test_archive_url(m1, client_with_auth):
|
||||
# url is too short
|
||||
response = client_with_auth.post("/url/archive", json={"url": "bad"})
|
||||
assert response.status_code == 422
|
||||
assert response.json() == {'detail': 'Invalid URL received: bad'}
|
||||
assert response.json()["detail"][0]["msg"] == 'String should have at least 5 characters'
|
||||
m1.assert_not_called()
|
||||
|
||||
# valid request
|
||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {'id': '123-456-789'}
|
||||
@@ -23,6 +25,20 @@ def test_archive_url(m1, client_with_auth):
|
||||
called_val = m1.call_args.args[0]
|
||||
assert json.loads(called_val) == {"id": None, "url": "https://example.com", "result": None, "public": True, "author_id": "rick@example.com", "group_id": None, "tags": [], "rearchive": True}
|
||||
|
||||
# user is not in group
|
||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "new-group"})
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "User does not have access to this group."
|
||||
|
||||
# user is in group
|
||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spaceship"})
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {'id': '123-456-789'}
|
||||
|
||||
assert m1.call_count == 2
|
||||
called_val = m1.call_args.args[0]
|
||||
assert json.loads(called_val)["group_id"] == "spaceship"
|
||||
|
||||
|
||||
def test_search_by_url_unauthenticated(client, test_no_auth):
|
||||
test_no_auth(client.get, "/url/search")
|
||||
|
||||
@@ -23,6 +23,8 @@ orchestrators:
|
||||
interdimensional: tests/orchestration.test.yaml
|
||||
default: tests/orchestration.test.yaml
|
||||
|
||||
default_orchestrator: tests/orchestration.test.yaml
|
||||
|
||||
groups:
|
||||
spaceship:
|
||||
description: "The spaceship crew"
|
||||
@@ -31,9 +33,9 @@ groups:
|
||||
permissions:
|
||||
read: ["all"]
|
||||
active_sheets: -1
|
||||
monthly_urls: all
|
||||
monthly_mbs: all
|
||||
alowed_frequency: "hourly"
|
||||
monthly_urls: -1
|
||||
monthly_mbs: -1
|
||||
allowed_frequency: "hourly"
|
||||
interdimensional:
|
||||
description: "Interdimensional travelers"
|
||||
orchestrator: tests/orchestration.test.yaml
|
||||
@@ -43,7 +45,7 @@ groups:
|
||||
active_sheets: 5
|
||||
monthly_urls: 1000
|
||||
monthly_mbs: 1000
|
||||
alowed_frequency: "hourly"
|
||||
allowed_frequency: "hourly"
|
||||
animated-characters:
|
||||
description: "Animated characters"
|
||||
orchestrator: tests/orchestration.test.yaml
|
||||
@@ -53,4 +55,4 @@ groups:
|
||||
active_sheets: 1
|
||||
monthly_urls: 2
|
||||
monthly_mbs: 10
|
||||
alowed_frequency: "daily"
|
||||
allowed_frequency: "daily"
|
||||
@@ -122,14 +122,6 @@ class Test_create_sheet_task():
|
||||
|
||||
assert db_session.query(models.Archive).filter(models.Archive.url == self.URL).count() == 0
|
||||
|
||||
@patch("worker.main.is_group_invalid_for_user", return_value="Access denied")
|
||||
def test_error_access(self, m_insert, worker_init, db_session):
|
||||
from worker.main import create_sheet_task
|
||||
|
||||
res = create_sheet_task(self.sheet.model_dump_json())
|
||||
assert "error" in res
|
||||
assert res["error"] == "Access denied"
|
||||
|
||||
|
||||
def test_choose_orchestrator(worker_init):
|
||||
from worker.main import choose_orchestrator
|
||||
|
||||
@@ -131,17 +131,19 @@ def app_factory(settings = get_settings()):
|
||||
|
||||
|
||||
@app.post("/sheet", status_code=201, deprecated=True) # DEPRECATED
|
||||
def archive_sheet(sheet: schemas.SubmitSheet, email=Depends(get_user_auth)):
|
||||
def archive_sheet(sheet: schemas.SubmitSheet, email=Depends(get_user_auth), db: Session = Depends(get_db_dependency)):
|
||||
logger.info(f"SHEET TASK for {sheet=}")
|
||||
sheet.author_id = email
|
||||
if not sheet.sheet_name and not sheet.sheet_id:
|
||||
raise HTTPException(status_code=422, detail=f"sheet name or id is required")
|
||||
if not crud.is_user_in_group(db, email, sheet.group_id):
|
||||
raise HTTPException(status_code=403, detail="User does not have access to this group.")
|
||||
task = create_sheet_task.delay(sheet.model_dump_json())
|
||||
return JSONResponse({"id": task.id})
|
||||
|
||||
|
||||
@app.post("/sheet_service", status_code=201, deprecated=True) # DEPRECATED
|
||||
def archive_sheet_service(sheet: schemas.SubmitSheet, auth=Depends(token_api_key_auth)):
|
||||
def archive_sheet_service(sheet: schemas.SubmitSheet, auth=Depends(token_api_key_auth), db: Session = Depends(get_db_dependency)):
|
||||
logger.info(f"SHEET TASK for {sheet=}")
|
||||
sheet.author_id = sheet.author_id or "api-endpoint"
|
||||
if not sheet.sheet_name and not sheet.sheet_id:
|
||||
|
||||
@@ -6,6 +6,7 @@ from core.config import ALLOW_ANY_EMAIL
|
||||
from shared.settings import get_settings
|
||||
from db.database import get_db
|
||||
from db import crud
|
||||
from db.user_state import UserState
|
||||
|
||||
settings = get_settings()
|
||||
bearer_security = HTTPBearer()
|
||||
@@ -84,3 +85,8 @@ def authenticate_user(access_token):
|
||||
except Exception as e:
|
||||
logger.warning(f"AUTH EXCEPTION occurred: {e}")
|
||||
return False, "exception occurred"
|
||||
|
||||
|
||||
def get_active_user_state(email=Depends(get_active_user_auth)):
|
||||
with get_db() as db:
|
||||
return UserState(db, email, active=True)
|
||||
@@ -1,4 +1,5 @@
|
||||
|
||||
from functools import lru_cache
|
||||
import traceback, yaml, datetime
|
||||
from typing import List, Set
|
||||
|
||||
@@ -30,6 +31,7 @@ Rdis = redis.Redis.from_url(celery.conf.broker_url)
|
||||
def create_archive_task(self, archive_json: str):
|
||||
archive = schemas.ArchiveCreate.model_validate_json(archive_json)
|
||||
logger.info(f"Archiving {archive.url=} {archive.tags=} {archive.public=} {archive.group_id=} {archive.author_id=}")
|
||||
#TODO: move group checks out of here
|
||||
invalid = is_group_invalid_for_user(archive.public, archive.group_id, archive.author_id)
|
||||
if invalid:
|
||||
raise Exception(invalid) # marks task FAILED, saves the Exception as result
|
||||
@@ -64,10 +66,6 @@ def create_sheet_task(self, sheet_json: str):
|
||||
sheet.tags.add("gsheet")
|
||||
logger.info(f"SHEET START {sheet=}")
|
||||
|
||||
#TODO: should this check live here?
|
||||
if (em := is_group_invalid_for_user(sheet.public, sheet.group_id, sheet.author_id)):
|
||||
return {"error": em}
|
||||
|
||||
config = Config()
|
||||
# TODO: use choose_orchestrator and overwrite the feeder
|
||||
# TODO: drop sheet_name and use only sheet_id (new endpoints/models)
|
||||
@@ -161,7 +159,7 @@ def is_group_invalid_for_user(public: bool, group_id: str, author_id: str):
|
||||
|
||||
# otherwise group must match
|
||||
with get_db() as session:
|
||||
if not crud.is_user_in_group(session, group_id, author_id):
|
||||
if not crud.is_user_in_group(session, author_id, group_id):
|
||||
logger.error(em := f"User {author_id} is not part of {group_id}, no permission")
|
||||
return em
|
||||
return False
|
||||
@@ -220,3 +218,13 @@ def at_start(sender, **kwargs):
|
||||
ORCHESTRATORS = {}
|
||||
load_orchestrators()
|
||||
logger.info("Orchestrators loaded successfully.")
|
||||
|
||||
@lru_cache
|
||||
def get_url_orchestrator(group_name):
|
||||
with get_db() as db:
|
||||
group = crud.get_group(db, group_name)
|
||||
assert group, f"Group {group_name} not found"
|
||||
|
||||
# config = Config()
|
||||
# config.parse(use_cli=False, yaml_config_filename=group.orchestrator_sheet)
|
||||
# return ArchivingOrchestrator(config)
|
||||
Reference in New Issue
Block a user