mirror of
https://github.com/bellingcat/auto-archiver-api.git
synced 2026-06-13 05:58:35 +03:00
pushing bulk of changes
This commit is contained in:
@@ -21,6 +21,8 @@ orchestration must be from the console(?)
|
|||||||
* turn off VPNs if connection to docker is not working
|
* turn off VPNs if connection to docker is not working
|
||||||
|
|
||||||
## User management
|
## User management
|
||||||
|
TODO: update description and example
|
||||||
|
- users/domains/groups
|
||||||
Copy [example.user-groups.yaml](src/example.user-groups.yaml) into a new file and set the environment variable `USER_GROUPS_FILENAME` to that filename (defaults to `user-groups.yaml`).
|
Copy [example.user-groups.yaml](src/example.user-groups.yaml) into a new file and set the environment variable `USER_GROUPS_FILENAME` to that filename (defaults to `user-groups.yaml`).
|
||||||
|
|
||||||
This file contains 2 parts user-groups specifications. Each user can archive URLs publicly, privately, or privately for a group so long as they are declared as part of that group. In the example bellow `email1` has 2 groups while `email3` has none.
|
This file contains 2 parts user-groups specifications. Each user can archive URLs publicly, privately, or privately for a group so long as they are declared as part of that group. In the example bellow `email1` has 2 groups while `email3` has none.
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ verify_ssl = true
|
|||||||
name = "pypi"
|
name = "pypi"
|
||||||
|
|
||||||
[packages]
|
[packages]
|
||||||
|
oscrypto = {git = "https://github.com/wbond/oscrypto.git", ref = "d5f3437ed24257895ae1edd9e503cfb352e635a8"}
|
||||||
aiofiles = "==0.6.0"
|
aiofiles = "==0.6.0"
|
||||||
celery = ">=5.0"
|
celery = ">=5.0"
|
||||||
fastapi = "*"
|
fastapi = "*"
|
||||||
|
|||||||
3759
src/Pipfile.lock
generated
3759
src/Pipfile.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -1 +0,0 @@
|
|||||||
based on https://fastapi-users.github.io/fastapi-users/10.4/configuration/oauth/
|
|
||||||
@@ -1,11 +1,12 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import cache
|
from functools import lru_cache
|
||||||
from sqlalchemy.orm import Session, load_only
|
from sqlalchemy.orm import Session, load_only
|
||||||
from sqlalchemy import Column, or_, func
|
from sqlalchemy import Column, or_, func
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
from core.config import ALLOW_ANY_EMAIL
|
from core.config import ALLOW_ANY_EMAIL
|
||||||
|
from db.database import get_db
|
||||||
from shared.settings import get_settings
|
from shared.settings import get_settings
|
||||||
from . import models, schemas
|
from . import models, schemas
|
||||||
import yaml
|
import yaml
|
||||||
@@ -23,7 +24,7 @@ def get_archive(db: Session, id: str, email: str):
|
|||||||
email = email.lower()
|
email = email.lower()
|
||||||
query = base_query(db).filter(models.Archive.id == id)
|
query = base_query(db).filter(models.Archive.id == id)
|
||||||
if email != ALLOW_ANY_EMAIL:
|
if email != ALLOW_ANY_EMAIL:
|
||||||
groups = get_user_groups(db, email)
|
groups = get_user_groups(email)
|
||||||
query = query.filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups)))
|
query = query.filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups)))
|
||||||
return query.first()
|
return query.first()
|
||||||
|
|
||||||
@@ -33,7 +34,7 @@ def search_archives_by_url(db: Session, url: str, email: str, skip: int = 0, lim
|
|||||||
query = base_query(db)
|
query = base_query(db)
|
||||||
if email != ALLOW_ANY_EMAIL:
|
if email != ALLOW_ANY_EMAIL:
|
||||||
email = email.lower()
|
email = email.lower()
|
||||||
groups = get_user_groups(db, email)
|
groups = get_user_groups(email)
|
||||||
query = query.filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups)))
|
query = query.filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups)))
|
||||||
if absolute_search:
|
if absolute_search:
|
||||||
query = query.filter(models.Archive.url == url)
|
query = query.filter(models.Archive.url == url)
|
||||||
@@ -121,72 +122,37 @@ def is_active_user(db: Session, email: str) -> bool:
|
|||||||
return db.query(models.Group).filter(models.Group.domains.contains(domain)).first() is not None
|
return db.query(models.Group).filter(models.Group.domains.contains(domain)).first() is not None
|
||||||
|
|
||||||
|
|
||||||
def is_user_in_group(db: Session, group_name: str, email: str) -> models.Group:
|
def is_user_in_group(db: Session, email: str, group_name: str) -> models.Group:
|
||||||
if email == ALLOW_ANY_EMAIL: return True
|
if email == ALLOW_ANY_EMAIL: return True
|
||||||
return len(group_name) and len(email) and group_name in get_user_groups(db, email)
|
return len(group_name) and len(email) and group_name in get_user_groups(email)
|
||||||
|
|
||||||
|
|
||||||
#TODO: maybe this can be cached? what about the db session?
|
@lru_cache
|
||||||
def get_user_groups(db: Session, email: str) -> list[str]:
|
def get_user_groups(email: str) -> list[str]:
|
||||||
"""
|
"""
|
||||||
given an email retrieves the user groups from the DB and then the email-domain groups from a global variable, the email does not need to belong to an existing user. User does not need to be active.
|
given an email retrieves the user groups from the DB and then the email-domain groups from a global variable, the email does not need to belong to an existing user. User does not need to be active.
|
||||||
"""
|
"""
|
||||||
if not email or not len(email) or "@" not in email: return []
|
if not email or not len(email) or "@" not in email: return []
|
||||||
email = email.lower()
|
email = email.lower()
|
||||||
|
|
||||||
# get user groups
|
with get_db() as db:
|
||||||
user_groups = db.query(models.association_table_user_groups).filter_by(user_id=email).with_entities(Column("group_id")).all()
|
# get user groups
|
||||||
user_level_groups_names = [g[0] for g in user_groups]
|
user_groups = db.query(models.association_table_user_groups).filter_by(user_id=email).with_entities(Column("group_id")).all()
|
||||||
|
user_level_groups_names = [g[0] for g in user_groups]
|
||||||
|
|
||||||
# get domain groups
|
# get domain groups
|
||||||
domain = email.split('@')[1]
|
domain = email.split('@')[1]
|
||||||
domain_level_groups = db.query(models.Group.id).filter(models.Group.domains.contains(domain)).with_entities(Column("id")).all()
|
domain_level_groups = db.query(models.Group.id).filter(models.Group.domains.contains(domain)).with_entities(Column("id")).all()
|
||||||
domain_level_groups_names = [g[0] for g in domain_level_groups]
|
domain_level_groups_names = [g[0] for g in domain_level_groups]
|
||||||
|
|
||||||
return list(set(user_level_groups_names + domain_level_groups_names))
|
return list(set(user_level_groups_names + domain_level_groups_names))
|
||||||
|
|
||||||
|
|
||||||
# --------------- SHEET
|
|
||||||
|
|
||||||
def has_quota_sheet(db: Session, email: str, user_groups_names: list[str]) -> bool:
|
|
||||||
"""
|
|
||||||
checks if a user has reached their sheet quota
|
|
||||||
"""
|
|
||||||
user_sheets = db.query(models.Sheet).filter(models.Sheet.author_id == email).count()
|
|
||||||
|
|
||||||
user_groups = db.query(models.Group).filter(models.Group.id.in_(user_groups_names)).all()
|
|
||||||
|
|
||||||
quota = 0
|
|
||||||
for group in user_groups:
|
|
||||||
active_sheets = group.permissions.get("active_sheets", 0)
|
|
||||||
if active_sheets == -1: return True
|
|
||||||
quota = max(quota, active_sheets)
|
|
||||||
return user_sheets < quota
|
|
||||||
|
|
||||||
|
|
||||||
def create_sheet(db: Session, sheet_id: str, sheet_name: str, email: str, group_id: str, frequency: str):
|
|
||||||
db_sheet = models.Sheet(id=sheet_id, name=sheet_name, author_id=email, group_id=group_id, frequency=frequency)
|
|
||||||
db.add(db_sheet)
|
|
||||||
db.commit()
|
|
||||||
db.refresh(db_sheet)
|
|
||||||
return db_sheet
|
|
||||||
|
|
||||||
def get_user_sheets(db: Session, email: str) -> list[models.Sheet]:
|
|
||||||
return db.query(models.Sheet).filter(models.Sheet.author_id == email).order_by(models.Sheet.last_archived_at.desc()).all()
|
|
||||||
|
|
||||||
def get_user_sheet(db: Session, email: str, sheet_id: str) -> models.Sheet:
|
|
||||||
return db.query(models.Sheet).filter(models.Sheet.author_id == email, models.Sheet.id == sheet_id).first()
|
|
||||||
|
|
||||||
def delete_sheet(db: Session, sheet_id: str, email: str) -> bool:
|
|
||||||
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id, models.Sheet.author_id == email).first()
|
|
||||||
if db_sheet:
|
|
||||||
db.delete(db_sheet)
|
|
||||||
db.commit()
|
|
||||||
return db_sheet is not None
|
|
||||||
|
|
||||||
|
|
||||||
# --------------- INIT User-Groups
|
# --------------- INIT User-Groups
|
||||||
|
|
||||||
|
def get_group(db: Session, group_name: str) -> models.Group:
|
||||||
|
return db.query(models.Group).filter(models.Group.id == group_name).first()
|
||||||
|
|
||||||
|
|
||||||
def create_or_get_user(db: Session, author_id: str, is_active: bool = models.User.is_active.default.arg) -> models.User:
|
def create_or_get_user(db: Session, author_id: str, is_active: bool = models.User.is_active.default.arg) -> models.User:
|
||||||
if type(author_id) == str: author_id = author_id.lower()
|
if type(author_id) == str: author_id = author_id.lower()
|
||||||
@@ -296,3 +262,28 @@ def upsert_user_groups(db: Session):
|
|||||||
count_groups = db.query(func.count(models.Group.id)).scalar()
|
count_groups = db.query(func.count(models.Group.id)).scalar()
|
||||||
|
|
||||||
logger.success(f"[CONFIG] DONE: [users={count_users(db)}, groups={count_groups}, explicit user groups={count_user_groups}].")
|
logger.success(f"[CONFIG] DONE: [users={count_users(db)}, groups={count_groups}, explicit user groups={count_user_groups}].")
|
||||||
|
|
||||||
|
|
||||||
|
# --------------- SHEET
|
||||||
|
def create_sheet(db: Session, sheet_id: str, sheet_name: str, email: str, group_id: str, frequency: str):
|
||||||
|
db_sheet = models.Sheet(id=sheet_id, name=sheet_name, author_id=email, group_id=group_id, frequency=frequency)
|
||||||
|
db.add(db_sheet)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_sheet)
|
||||||
|
return db_sheet
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_sheet(db: Session, email: str, sheet_id: str) -> models.Sheet:
|
||||||
|
return db.query(models.Sheet).filter(models.Sheet.author_id == email, models.Sheet.id == sheet_id).first()
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_sheets(db: Session, email: str) -> list[models.Sheet]:
|
||||||
|
return db.query(models.Sheet).filter(models.Sheet.author_id == email).order_by(models.Sheet.last_archived_at.desc()).all()
|
||||||
|
|
||||||
|
|
||||||
|
def delete_sheet(db: Session, sheet_id: str, email: str) -> bool:
|
||||||
|
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id, models.Sheet.author_id == email).first()
|
||||||
|
if db_sheet:
|
||||||
|
db.delete(db_sheet)
|
||||||
|
db.commit()
|
||||||
|
return db_sheet is not None
|
||||||
@@ -87,7 +87,7 @@ class Group(Base):
|
|||||||
description = Column(String, default=None)
|
description = Column(String, default=None)
|
||||||
orchestrator = Column(String, default=None)
|
orchestrator = Column(String, default=None)
|
||||||
orchestrator_sheet = Column(String, default=None)
|
orchestrator_sheet = Column(String, default=None)
|
||||||
permissions = Column(JSON, default=None)
|
permissions = Column(JSON, default={})
|
||||||
domains = Column(JSON, default=[])
|
domains = Column(JSON, default=[])
|
||||||
|
|
||||||
archives = relationship("Archive", back_populates="group")
|
archives = relationship("Archive", back_populates="group")
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
from annotated_types import Len
|
||||||
from pydantic import BaseModel, field_validator
|
from pydantic import BaseModel, field_validator
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
@@ -105,3 +107,10 @@ class SheetResponse(SheetAdd):
|
|||||||
stats: dict | None
|
stats: dict | None
|
||||||
last_archived_at: datetime | None
|
last_archived_at: datetime | None
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class ArchiveTrigger(BaseModel):
|
||||||
|
url: Annotated[str, Len(min_length=5)]
|
||||||
|
public: bool = True
|
||||||
|
group_id: Annotated[str, Len(min_length=1)] | None = None
|
||||||
|
tags: set[Tag] | None = set()
|
||||||
|
|||||||
142
src/db/user_state.py
Normal file
142
src/db/user_state.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
|
||||||
|
import sqlalchemy
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy import func
|
||||||
|
from db import crud, models
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class UserState:
|
||||||
|
"""
|
||||||
|
Manage a user's state and permissions
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db: Session, email: str, active=False):
|
||||||
|
self.db = db
|
||||||
|
self.email = email
|
||||||
|
self.active = active
|
||||||
|
|
||||||
|
@property
|
||||||
|
def user_groups_names(self):
|
||||||
|
if not hasattr(self, '_user_groups_names'):
|
||||||
|
self._user_groups_names = crud.get_user_groups(self.email)
|
||||||
|
return self._user_groups_names
|
||||||
|
|
||||||
|
@property
|
||||||
|
def user_groups(self):
|
||||||
|
if not hasattr(self, '_user_groups'):
|
||||||
|
self._user_groups = self.db.query(models.Group).filter(
|
||||||
|
models.Group.id.in_(self.user_groups_names)
|
||||||
|
).all()
|
||||||
|
return self._user_groups
|
||||||
|
|
||||||
|
@property
|
||||||
|
def allowed_frequencies(self):
|
||||||
|
if not hasattr(self, '_allowed_frequencies'):
|
||||||
|
self._allowed_frequencies = set()
|
||||||
|
for group in self.user_groups:
|
||||||
|
if not group.permissions: continue
|
||||||
|
self._allowed_frequencies.add(group.permissions.get("allowed_frequency", None))
|
||||||
|
if "hourly" in self._allowed_frequencies:
|
||||||
|
self._allowed_frequencies.add("daily")
|
||||||
|
return self._allowed_frequencies
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sheet_quota(self):
|
||||||
|
"""
|
||||||
|
infer the user's sheet quota from the groups
|
||||||
|
-1 means unlimited
|
||||||
|
"""
|
||||||
|
if not hasattr(self, '_sheet_quota'):
|
||||||
|
self._sheet_quota = 0
|
||||||
|
for group in self.user_groups:
|
||||||
|
if not group.permissions: continue
|
||||||
|
active_sheets = group.permissions.get("active_sheets", 0)
|
||||||
|
if active_sheets == -1:
|
||||||
|
self._sheet_quota = -1
|
||||||
|
return self._sheet_quota
|
||||||
|
self._sheet_quota = max(self._sheet_quota, active_sheets)
|
||||||
|
|
||||||
|
return self._sheet_quota
|
||||||
|
|
||||||
|
def in_group(self, group_id: str) -> bool:
|
||||||
|
return group_id in self.user_groups_names
|
||||||
|
|
||||||
|
def has_quota_sheet(self) -> bool:
|
||||||
|
"""
|
||||||
|
checks if a user has reached their sheet quota
|
||||||
|
"""
|
||||||
|
if self.sheet_quota == -1: return True
|
||||||
|
|
||||||
|
user_sheets = self.db.query(models.Sheet).filter(models.Sheet.author_id == self.email).count()
|
||||||
|
|
||||||
|
return user_sheets < self.sheet_quota
|
||||||
|
|
||||||
|
def has_quota_monthly_urls(self) -> bool:
|
||||||
|
"""
|
||||||
|
checks if a user has reached their monthly url quota
|
||||||
|
"""
|
||||||
|
quota = 0
|
||||||
|
for group in self.user_groups:
|
||||||
|
if not group.permissions: continue
|
||||||
|
monthly_urls = group.permissions.get("monthly_urls", 0)
|
||||||
|
if monthly_urls == -1: return True
|
||||||
|
quota = max(quota, monthly_urls)
|
||||||
|
|
||||||
|
current_month = datetime.now().month
|
||||||
|
current_year = datetime.now().year
|
||||||
|
user_urls = self.db.query(models.Archive).filter(
|
||||||
|
models.Archive.author_id == self.email,
|
||||||
|
func.extract('month', models.Archive.created_at) == current_month,
|
||||||
|
func.extract('year', models.Archive.created_at) == current_year
|
||||||
|
).count()
|
||||||
|
|
||||||
|
return user_urls < quota
|
||||||
|
|
||||||
|
def has_quota_monthly_mbs(self) -> bool:
|
||||||
|
"""
|
||||||
|
checks if a user has reached their monthly mb quota
|
||||||
|
"""
|
||||||
|
quota = 0
|
||||||
|
for group in self.user_groups:
|
||||||
|
if not group.permissions: continue
|
||||||
|
monthly_mbs = group.permissions.get("monthly_mbs", 0)
|
||||||
|
if monthly_mbs == -1: return True
|
||||||
|
quota = max(quota, monthly_mbs)
|
||||||
|
|
||||||
|
current_month = datetime.now().month
|
||||||
|
current_year = datetime.now().year
|
||||||
|
|
||||||
|
# find and sum all user bytes over this month
|
||||||
|
user_bytes = self.db.query(models.Archive).filter(
|
||||||
|
models.Archive.author_id == self.email,
|
||||||
|
func.extract('month', models.Archive.created_at) == current_month,
|
||||||
|
func.extract('year', models.Archive.created_at) == current_year
|
||||||
|
).with_entities(func.coalesce(func.sum(
|
||||||
|
func.coalesce(
|
||||||
|
func.cast(
|
||||||
|
func.json_extract(models.Archive.result, '$.metadata.total_bytes'),
|
||||||
|
sqlalchemy.Integer
|
||||||
|
), 0
|
||||||
|
)
|
||||||
|
), 0).label('total')).scalar()
|
||||||
|
|
||||||
|
# convert bytes to mb
|
||||||
|
user_mbs = int(user_bytes / 1024 / 1024)
|
||||||
|
return user_mbs < quota
|
||||||
|
|
||||||
|
# def can_manually_trigger(self) -> bool:
|
||||||
|
# """
|
||||||
|
# checks if a user is allowed to manually trigger a sheet
|
||||||
|
# """
|
||||||
|
# for group in self.user_groups:
|
||||||
|
# if not group.permissions: continue
|
||||||
|
# if group.permissions.get("manual_trigger", False):
|
||||||
|
# return True
|
||||||
|
# return False
|
||||||
|
|
||||||
|
def is_sheet_frequency_allowed(self, frequency: str) -> bool:
|
||||||
|
"""
|
||||||
|
checks if a user is allowed to create a sheet with this frequency
|
||||||
|
"""
|
||||||
|
return frequency in self.allowed_frequencies
|
||||||
@@ -6,8 +6,9 @@ from sqlalchemy.orm import Session
|
|||||||
from core.config import VERSION, BREAKING_CHANGES
|
from core.config import VERSION, BREAKING_CHANGES
|
||||||
from core.logging import log_error
|
from core.logging import log_error
|
||||||
from db import crud, schemas
|
from db import crud, schemas
|
||||||
from db.database import get_db_dependency, get_db
|
from db.database import get_db_dependency
|
||||||
from web.security import get_user_auth, bearer_security
|
from db.user_state import UserState
|
||||||
|
from web.security import get_user_auth, bearer_security, get_active_user_state
|
||||||
|
|
||||||
default_router = APIRouter()
|
default_router = APIRouter()
|
||||||
|
|
||||||
@@ -18,8 +19,7 @@ async def home(request: Request):
|
|||||||
status = {"version": VERSION, "breakingChanges": BREAKING_CHANGES}
|
status = {"version": VERSION, "breakingChanges": BREAKING_CHANGES}
|
||||||
try:
|
try:
|
||||||
email = await get_user_auth(await bearer_security(request))
|
email = await get_user_auth(await bearer_security(request))
|
||||||
with get_db() as db:
|
status["groups"] = crud.get_user_groups(email)
|
||||||
status["groups"] = crud.get_user_groups(db, email)
|
|
||||||
except HTTPException: pass # not authenticated is fine
|
except HTTPException: pass # not authenticated is fine
|
||||||
except Exception as e: log_error(e)
|
except Exception as e: log_error(e)
|
||||||
return JSONResponse(status)
|
return JSONResponse(status)
|
||||||
@@ -31,13 +31,28 @@ async def health():
|
|||||||
|
|
||||||
|
|
||||||
@default_router.get("/user/active", summary="Check if the user is active and can use the tool.")
|
@default_router.get("/user/active", summary="Check if the user is active and can use the tool.")
|
||||||
|
# TODO: reorder db dependencies to after auth
|
||||||
async def active(db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)) -> schemas.ActiveUser:
|
async def active(db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)) -> schemas.ActiveUser:
|
||||||
return {"active": crud.is_active_user(db, email)}
|
return {"active": crud.is_active_user(db, email)}
|
||||||
|
|
||||||
|
|
||||||
@default_router.get("/groups")
|
@default_router.get("/groups")
|
||||||
def get_user_groups(db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)) -> list[str]:
|
def get_user_groups(email=Depends(get_user_auth)) -> list[str]:
|
||||||
return crud.get_user_groups(db, email)
|
return crud.get_user_groups(email)
|
||||||
|
|
||||||
|
|
||||||
|
@default_router.get("/permissions")
|
||||||
|
def get_user_groups(
|
||||||
|
user: UserState = Depends(get_active_user_state),
|
||||||
|
) -> list[str]:
|
||||||
|
return JSONResponse({
|
||||||
|
"groups": user.user_groups_names,
|
||||||
|
"allowedFrequencies": list(user.allowed_frequencies),
|
||||||
|
"sheet_quota": user.sheet_quota,
|
||||||
|
"monthly_urls": user.monthly_urls,
|
||||||
|
"monthly_mbs": user.monthly_mbs,
|
||||||
|
#TODO: should this return
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
@default_router.get('/favicon.ico', include_in_schema=False)
|
@default_router.get('/favicon.ico', include_in_schema=False)
|
||||||
|
|||||||
@@ -5,7 +5,8 @@ from fastapi.responses import JSONResponse
|
|||||||
from sqlalchemy import exc
|
from sqlalchemy import exc
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from web.security import token_api_key_auth, get_active_user_auth
|
from db.user_state import UserState
|
||||||
|
from web.security import token_api_key_auth, get_active_user_auth, get_active_user_state
|
||||||
from db import schemas, crud
|
from db import schemas, crud
|
||||||
from db.database import get_db_dependency
|
from db.database import get_db_dependency
|
||||||
from worker.main import create_sheet_task
|
from worker.main import create_sheet_task
|
||||||
@@ -16,19 +17,21 @@ sheet_router = APIRouter(prefix="/sheet", tags=["Google Spreadsheet operations"]
|
|||||||
@sheet_router.post("/create", status_code=201, summary="Store a new Google Sheet for regular archiving.")
|
@sheet_router.post("/create", status_code=201, summary="Store a new Google Sheet for regular archiving.")
|
||||||
def create_sheet(
|
def create_sheet(
|
||||||
sheet: schemas.SheetAdd,
|
sheet: schemas.SheetAdd,
|
||||||
email=Depends(get_active_user_auth),
|
user: UserState = Depends(get_active_user_state),
|
||||||
db: Session = Depends(get_db_dependency),
|
db: Session = Depends(get_db_dependency),
|
||||||
) -> schemas.SheetResponse:
|
) -> schemas.SheetResponse:
|
||||||
user_groups_names = crud.get_user_groups(db, email)
|
|
||||||
|
|
||||||
if sheet.group_id not in user_groups_names:
|
if not user.in_group(sheet.group_id):
|
||||||
raise HTTPException(status_code=403, detail="User does not have access to this group.")
|
raise HTTPException(status_code=403, detail="User does not have access to this group.")
|
||||||
|
|
||||||
if not crud.has_quota_sheet(db, email, user_groups_names):
|
if not user.has_quota_sheet():
|
||||||
raise HTTPException(status_code=429, detail="User has reached their sheet quota.")
|
raise HTTPException(status_code=429, detail="User has reached their sheet quota.")
|
||||||
|
|
||||||
|
if not user.is_sheet_frequency_allowed(sheet.frequency):
|
||||||
|
raise HTTPException(status_code=422, detail=f"Invalid frequency: {sheet.frequency}. Must be one of {user.allowed_frequencies}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return crud.create_sheet(db, sheet.id, sheet.name, email, sheet.group_id, sheet.frequency)
|
return crud.create_sheet(db, sheet.id, sheet.name, user.email, sheet.group_id, sheet.frequency)
|
||||||
except exc.IntegrityError as e:
|
except exc.IntegrityError as e:
|
||||||
raise HTTPException(status_code=400, detail="Sheet with this ID already exists.") from e
|
raise HTTPException(status_code=400, detail="Sheet with this ID already exists.") from e
|
||||||
|
|
||||||
@@ -56,22 +59,30 @@ def delete_sheet(
|
|||||||
@sheet_router.post("/{id}/archive", status_code=201, summary="Trigger an archiving task for a GSheet you own.", response_description="task_id for the archiving task.")
|
@sheet_router.post("/{id}/archive", status_code=201, summary="Trigger an archiving task for a GSheet you own.", response_description="task_id for the archiving task.")
|
||||||
def archive_user_sheet(
|
def archive_user_sheet(
|
||||||
id: str,
|
id: str,
|
||||||
email=Depends(get_active_user_auth),
|
user: UserState = Depends(get_active_user_state),
|
||||||
db: Session = Depends(get_db_dependency),
|
db: Session = Depends(get_db_dependency),
|
||||||
) -> schemas.Task:
|
) -> schemas.Task:
|
||||||
|
|
||||||
sheet = crud.get_user_sheet(db, email, sheet_id=id)
|
#TODO: are we enabling manual triggers?
|
||||||
|
# if not user.can_manually_trigger():
|
||||||
|
# raise HTTPException(status_code=429, detail="User cannot manually trigger archiving tasks.")
|
||||||
|
|
||||||
|
sheet = crud.get_user_sheet(db, user.email, sheet_id=id)
|
||||||
if not sheet:
|
if not sheet:
|
||||||
raise HTTPException(status_code=403, detail="No access to this sheet.")
|
raise HTTPException(status_code=403, detail="No access to this sheet.")
|
||||||
|
|
||||||
task = create_sheet_task.delay(schemas.SubmitSheet(sheet_id=id, author_id=email, group=sheet.group_id).model_dump_json())
|
# TODO: what happens if user is taken out of group after sheet is created? this should be checked in a cronjob that notifies the user
|
||||||
|
if not user.in_group(sheet.group_id):
|
||||||
|
raise HTTPException(status_code=403, detail="User does not have access to this group.")
|
||||||
|
|
||||||
|
task = create_sheet_task.delay(schemas.SubmitSheet(sheet_id=id, author_id=user.email, group=sheet.group_id).model_dump_json())
|
||||||
|
|
||||||
return JSONResponse({"id": task.id}, status_code=201)
|
return JSONResponse({"id": task.id}, status_code=201)
|
||||||
|
|
||||||
|
|
||||||
@sheet_router.post("/archive", status_code=201, summary="Trigger an archiving task for any GSheet with an API token.", response_description="task_id for the archiving task.")
|
@sheet_router.post("/archive", status_code=201, summary="Trigger an archiving task for any GSheet with an API token.", response_description="task_id for the archiving task.")
|
||||||
def archive_sheet(
|
def archive_sheet(
|
||||||
sheet: schemas.SubmitSheet, #TODO: replace with simpler model
|
sheet: schemas.SubmitSheet, # TODO: replace with simpler model
|
||||||
auth=Depends(token_api_key_auth)
|
auth=Depends(token_api_key_auth)
|
||||||
) -> schemas.Task:
|
) -> schemas.Task:
|
||||||
sheet.author_id = sheet.author_id or "api-endpoint"
|
sheet.author_id = sheet.author_id or "api-endpoint"
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from web.security import get_user_auth, get_token_or_user_auth
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from db import crud, schemas
|
from db import crud, schemas
|
||||||
from db.database import get_db_dependency
|
from db.database import get_db, get_db_dependency
|
||||||
|
|
||||||
from worker.main import create_archive_task
|
from worker.main import create_archive_task
|
||||||
|
|
||||||
@@ -16,14 +16,28 @@ url_router = APIRouter(prefix="/url", tags=["Single URL operations"])
|
|||||||
|
|
||||||
|
|
||||||
@url_router.post("/archive", status_code=201, summary="Submit a single URL archive request, starts an archiving task.", response_description="task_id for the archiving task, will match the archive id.")
|
@url_router.post("/archive", status_code=201, summary="Submit a single URL archive request, starts an archiving task.", response_description="task_id for the archiving task, will match the archive id.")
|
||||||
def archive_url(archive: schemas.ArchiveCreate, email=Depends(get_token_or_user_auth)) -> schemas.Task:
|
def archive_url(
|
||||||
archive.author_id = email
|
archive: schemas.ArchiveTrigger,
|
||||||
url = archive.url
|
email=Depends(get_token_or_user_auth)
|
||||||
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {url}")
|
) -> schemas.Task:
|
||||||
if type(url) != str or len(url) <= 5:
|
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {archive.url}")
|
||||||
raise HTTPException(status_code=422, detail=f"Invalid URL received: {url}")
|
|
||||||
logger.info("creating task")
|
# TODO: implement quota
|
||||||
task = create_archive_task.delay(archive.model_dump_json())
|
|
||||||
|
if archive.group_id:
|
||||||
|
with get_db() as db:
|
||||||
|
if not crud.is_user_in_group(db, email, archive.group_id):
|
||||||
|
raise HTTPException(status_code=403, detail="User does not have access to this group.")
|
||||||
|
|
||||||
|
# TODO: deprecate ArchiveCreate
|
||||||
|
backwards_compatible_archive = schemas.ArchiveCreate(
|
||||||
|
url=archive.url,
|
||||||
|
author_id=email,
|
||||||
|
group_id=archive.group_id,
|
||||||
|
public=archive.public,
|
||||||
|
)
|
||||||
|
|
||||||
|
task = create_archive_task.delay(backwards_compatible_archive.model_dump_json())
|
||||||
task_response = schemas.Task(id=task.id)
|
task_response = schemas.Task(id=task.id)
|
||||||
return JSONResponse(task_response.model_dump(), status_code=201)
|
return JSONResponse(task_response.model_dump(), status_code=201)
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import os
|
|||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
from db.user_state import UserState
|
||||||
from shared.settings import Settings
|
from shared.settings import Settings
|
||||||
|
|
||||||
|
|
||||||
@@ -27,7 +28,9 @@ def mock_settings():
|
|||||||
def test_db(get_settings: Settings):
|
def test_db(get_settings: Settings):
|
||||||
from db.database import make_engine
|
from db.database import make_engine
|
||||||
from db import models
|
from db import models
|
||||||
|
from db.crud import get_user_groups
|
||||||
|
|
||||||
|
get_user_groups.cache_clear()
|
||||||
make_engine.cache_clear()
|
make_engine.cache_clear()
|
||||||
engine = make_engine(get_settings.DATABASE_PATH)
|
engine = make_engine(get_settings.DATABASE_PATH)
|
||||||
|
|
||||||
@@ -72,11 +75,12 @@ def client(app):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def app_with_auth(app):
|
def app_with_auth(app, db_session):
|
||||||
from web.security import get_token_or_user_auth, get_user_auth, get_active_user_auth
|
from web.security import get_token_or_user_auth, get_user_auth, get_active_user_auth, get_active_user_state
|
||||||
app.dependency_overrides[get_token_or_user_auth] = lambda: "rick@example.com"
|
app.dependency_overrides[get_token_or_user_auth] = lambda: "rick@example.com"
|
||||||
app.dependency_overrides[get_user_auth] = lambda: "morty@example.com"
|
app.dependency_overrides[get_user_auth] = lambda: "morty@example.com"
|
||||||
app.dependency_overrides[get_active_user_auth] = lambda: "morty@example.com"
|
app.dependency_overrides[get_active_user_auth] = lambda: "morty@example.com"
|
||||||
|
app.dependency_overrides[get_active_user_state] = lambda: UserState(db_session, "morty@example.com", active=True)
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -40,6 +40,12 @@ def test_data(db_session):
|
|||||||
archive.urls.append(models.ArchiveUrl(url=f"https://example-{i}.com/{j}", key=f"media_{j}"))
|
archive.urls.append(models.ArchiveUrl(url=f"https://example-{i}.com/{j}", key=f"media_{j}"))
|
||||||
db_session.add(archive)
|
db_session.add(archive)
|
||||||
|
|
||||||
|
# creates a sheet for each user
|
||||||
|
for i, email in enumerate(authors):
|
||||||
|
db_session.add(models.Sheet(id=f"sheet-{i}", name=f"sheet-{i}", author_id=email, group_id=None, frequency="daily"))
|
||||||
|
if email == "rick@example.com":
|
||||||
|
db_session.add(models.Sheet(id=f"sheet-{i}-2", name=f"sheet-{i}-2", author_id=email, group_id="spaceship", frequency="hourly"))
|
||||||
|
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
assert db_session.query(models.Archive).count() == 100
|
assert db_session.query(models.Archive).count() == 100
|
||||||
@@ -253,6 +259,7 @@ def test_count_archive_urls(test_data, db_session):
|
|||||||
assert crud.count_archives(db_session) == 99
|
assert crud.count_archives(db_session) == 99
|
||||||
assert crud.count_archive_urls(db_session) == 999
|
assert crud.count_archive_urls(db_session) == 999
|
||||||
|
|
||||||
|
|
||||||
def test_count_users(test_data, db_session):
|
def test_count_users(test_data, db_session):
|
||||||
from db import crud
|
from db import crud
|
||||||
|
|
||||||
@@ -261,6 +268,7 @@ def test_count_users(test_data, db_session):
|
|||||||
db_session.commit()
|
db_session.commit()
|
||||||
assert crud.count_users(db_session) == 3
|
assert crud.count_users(db_session) == 3
|
||||||
|
|
||||||
|
|
||||||
def test_count_by_users_since(test_data, db_session):
|
def test_count_by_users_since(test_data, db_session):
|
||||||
from db import crud
|
from db import crud
|
||||||
|
|
||||||
@@ -294,6 +302,7 @@ def test_create_tag(db_session):
|
|||||||
assert second_tag.id == "tag-102"
|
assert second_tag.id == "tag-102"
|
||||||
assert db_session.query(models.Tag).count() == 2
|
assert db_session.query(models.Tag).count() == 2
|
||||||
|
|
||||||
|
|
||||||
def test_is_active_user(test_data, db_session):
|
def test_is_active_user(test_data, db_session):
|
||||||
from db import crud
|
from db import crud
|
||||||
|
|
||||||
@@ -329,7 +338,7 @@ def test_is_user_in_group(test_data, db_session):
|
|||||||
|
|
||||||
("jerry@example.com", "spaceship", False),
|
("jerry@example.com", "spaceship", False),
|
||||||
("jerry@example.com", "interdimensional", False),
|
("jerry@example.com", "interdimensional", False),
|
||||||
("jerry@example.com", "the-jerrys-club", False), # group not in 'groups'
|
("jerry@example.com", "the-jerrys-club", False), # group not in 'groups'
|
||||||
|
|
||||||
("rick@example.com", "animated-characters", True),
|
("rick@example.com", "animated-characters", True),
|
||||||
("morty@example.com", "animated-characters", True),
|
("morty@example.com", "animated-characters", True),
|
||||||
@@ -337,7 +346,7 @@ def test_is_user_in_group(test_data, db_session):
|
|||||||
("ANYONE@example.com", "animated-characters", True),
|
("ANYONE@example.com", "animated-characters", True),
|
||||||
("ANYONE@birdy.com", "animated-characters", True),
|
("ANYONE@birdy.com", "animated-characters", True),
|
||||||
|
|
||||||
("summer@herself.com", "animated-characters", False),
|
("summer@herself.com", "animated-characters", False),
|
||||||
|
|
||||||
("rick@example.com", "", False),
|
("rick@example.com", "", False),
|
||||||
("", "spaceship", False),
|
("", "spaceship", False),
|
||||||
@@ -345,7 +354,16 @@ def test_is_user_in_group(test_data, db_session):
|
|||||||
]
|
]
|
||||||
for email, group, expected in test_pairs:
|
for email, group, expected in test_pairs:
|
||||||
print(f"{email} in {group} == {expected}")
|
print(f"{email} in {group} == {expected}")
|
||||||
assert crud.is_user_in_group(db_session, group, email) == expected
|
assert crud.is_user_in_group(db_session, email, group) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_group(test_data, db_session):
|
||||||
|
from db import crud
|
||||||
|
|
||||||
|
assert crud.get_group(db_session, "spaceship") is not None
|
||||||
|
assert crud.get_group(db_session, "interdimensional") is not None
|
||||||
|
assert crud.get_group(db_session, "animated-characters") is not None
|
||||||
|
assert crud.get_group(db_session, "non-existant!@#!%!") is None
|
||||||
|
|
||||||
|
|
||||||
def test_create_or_get_user(test_data, db_session):
|
def test_create_or_get_user(test_data, db_session):
|
||||||
@@ -403,13 +421,12 @@ def test_upsert_group(test_data, db_session):
|
|||||||
def test_upsert_user_groups(db_session):
|
def test_upsert_user_groups(db_session):
|
||||||
from db import crud
|
from db import crud
|
||||||
|
|
||||||
@patch('db.crud.get_settings', new = lambda: bad_setings)
|
@patch('db.crud.get_settings', new=lambda: bad_setings)
|
||||||
def test_missing_yaml(db_session):
|
def test_missing_yaml(db_session):
|
||||||
with pytest.raises(FileNotFoundError):
|
with pytest.raises(FileNotFoundError):
|
||||||
crud.upsert_user_groups(db_session)
|
crud.upsert_user_groups(db_session)
|
||||||
|
|
||||||
|
@patch('db.crud.get_settings', new=lambda: bad_setings)
|
||||||
@patch('db.crud.get_settings', new = lambda: bad_setings)
|
|
||||||
def test_broken_yaml(db_session):
|
def test_broken_yaml(db_session):
|
||||||
with pytest.raises(yaml.YAMLError):
|
with pytest.raises(yaml.YAMLError):
|
||||||
crud.upsert_user_groups(db_session)
|
crud.upsert_user_groups(db_session)
|
||||||
@@ -421,3 +438,53 @@ def test_upsert_user_groups(db_session):
|
|||||||
|
|
||||||
bad_setings.USER_GROUPS_FILENAME = "tests/user-groups.test.broken.yaml"
|
bad_setings.USER_GROUPS_FILENAME = "tests/user-groups.test.broken.yaml"
|
||||||
test_broken_yaml(db_session)
|
test_broken_yaml(db_session)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_sheet(db_session):
|
||||||
|
from db import crud
|
||||||
|
|
||||||
|
assert db_session.query(models.Sheet).count() == 0
|
||||||
|
|
||||||
|
s = crud.create_sheet(db_session, "sheet-id-123", "sheet name", "email@example.com", "group-id", "hourly")
|
||||||
|
assert s is not None
|
||||||
|
assert s.id == "sheet-id-123"
|
||||||
|
assert s.name == "sheet name"
|
||||||
|
assert s.author_id == "email@example.com"
|
||||||
|
assert s.group_id == "group-id"
|
||||||
|
assert s.frequency == "hourly"
|
||||||
|
|
||||||
|
assert db_session.query(models.Sheet).count() == 1
|
||||||
|
|
||||||
|
# duplicate id
|
||||||
|
import sqlalchemy
|
||||||
|
with pytest.raises(sqlalchemy.exc.IntegrityError):
|
||||||
|
crud.create_sheet(db_session, "sheet-id-123", "I thought this was another sheet", "email", "group-id", "hourly")
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_user_sheet(test_data, db_session):
|
||||||
|
from db import crud
|
||||||
|
|
||||||
|
assert crud.get_user_sheet(db_session, "", "sheet-0") is None
|
||||||
|
assert crud.get_user_sheet(db_session, "morty@example.com", "sheet-0") is None
|
||||||
|
|
||||||
|
assert crud.get_user_sheet(db_session, "rick@example.com", "sheet-0") is not None
|
||||||
|
assert crud.get_user_sheet(db_session, "rick@example.com", "sheet-0-2") is not None
|
||||||
|
assert crud.get_user_sheet(db_session, "morty@example.com", "sheet-1") is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_user_sheets(test_data, db_session):
|
||||||
|
from db import crud
|
||||||
|
|
||||||
|
assert len(crud.get_user_sheets(db_session, "")) == 0
|
||||||
|
rick_sheets = crud.get_user_sheets(db_session, "rick@example.com")
|
||||||
|
assert len(rick_sheets) == 2
|
||||||
|
assert [s.id for s in rick_sheets] == ["sheet-0", "sheet-0-2"]
|
||||||
|
assert len(crud.get_user_sheets(db_session, "morty@example.com")) == 1
|
||||||
|
|
||||||
|
def test_delete_sheet(test_data, db_session):
|
||||||
|
from db import crud
|
||||||
|
|
||||||
|
assert crud.delete_sheet(db_session, "sheet-0", "") == False
|
||||||
|
assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") == True
|
||||||
|
assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") == False
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ def test_endpoints_no_auth(client, test_no_auth):
|
|||||||
test_no_auth(client.post, "/sheet/archive")
|
test_no_auth(client.post, "/sheet/archive")
|
||||||
|
|
||||||
|
|
||||||
def test_create_sheet_endpoint(app_with_auth):
|
def test_create_sheet_endpoint(app_with_auth, db_session):
|
||||||
client_with_auth = TestClient(app_with_auth)
|
client_with_auth = TestClient(app_with_auth)
|
||||||
good_data = {
|
good_data = {
|
||||||
"id": "123-sheet-id",
|
"id": "123-sheet-id",
|
||||||
@@ -53,13 +53,23 @@ def test_create_sheet_endpoint(app_with_auth):
|
|||||||
assert response.status_code == 403
|
assert response.status_code == 403
|
||||||
assert response.json() == {"detail": "User does not have access to this group."}
|
assert response.json() == {"detail": "User does not have access to this group."}
|
||||||
|
|
||||||
# bad quota
|
# switch to jerry who's got less quota/permissions
|
||||||
|
from web.security import get_active_user_state
|
||||||
|
from db.user_state import UserState
|
||||||
|
app_with_auth.dependency_overrides[get_active_user_state] = lambda: UserState(db_session, "jerry@example.com", active=True)
|
||||||
|
client_jerry = TestClient(app_with_auth)
|
||||||
|
|
||||||
|
# frequency not allowed
|
||||||
jerry_data = good_data.copy()
|
jerry_data = good_data.copy()
|
||||||
jerry_data["group_id"] = "animated-characters"
|
jerry_data["group_id"] = "animated-characters"
|
||||||
|
jerry_data["frequency"] = "hourly"
|
||||||
jerry_data["id"] = "jerry-sheet-id"
|
jerry_data["id"] = "jerry-sheet-id"
|
||||||
from web.security import get_active_user_auth
|
response = client_jerry.post("/sheet/create", json=jerry_data)
|
||||||
app_with_auth.dependency_overrides[get_active_user_auth] = lambda: "jerry@example.com"
|
assert response.status_code == 422
|
||||||
client_jerry = TestClient(app_with_auth)
|
assert "Invalid frequency: hourly" in response.json()["detail"]
|
||||||
|
jerry_data["frequency"] = "daily"
|
||||||
|
|
||||||
|
# success for the first sheet, bad quota on second
|
||||||
response = client_jerry.post("/sheet/create", json=jerry_data)
|
response = client_jerry.post("/sheet/create", json=jerry_data)
|
||||||
assert response.status_code == 201
|
assert response.status_code == 201
|
||||||
|
|
||||||
@@ -144,12 +154,6 @@ def test_delete_sheet_endpoint(client_with_auth, db_session):
|
|||||||
assert response.json() == {"id": "456-sheet-id", "deleted": False}
|
assert response.json() == {"id": "456-sheet-id", "deleted": False}
|
||||||
|
|
||||||
|
|
||||||
# def test_archive_user_sheet_endpoint(client_with_auth):
|
|
||||||
# response = client_with_auth.post("/sheet/123-sheet-id/archive")
|
|
||||||
# assert response.status_code == 201
|
|
||||||
# assert "id" in response.json()
|
|
||||||
|
|
||||||
|
|
||||||
class TestArchiveUserSheetEndpoint:
|
class TestArchiveUserSheetEndpoint:
|
||||||
def test_token_auth(self, client_with_token, test_no_auth):
|
def test_token_auth(self, client_with_token, test_no_auth):
|
||||||
test_no_auth(client_with_token.post, "/sheet/123-sheet-id/archive")
|
test_no_auth(client_with_token.post, "/sheet/123-sheet-id/archive")
|
||||||
@@ -177,6 +181,14 @@ class TestArchiveUserSheetEndpoint:
|
|||||||
assert r.json() == {"id": "123-taskid"}
|
assert r.json() == {"id": "123-taskid"}
|
||||||
m1.assert_called_once()
|
m1.assert_called_once()
|
||||||
|
|
||||||
|
def test_user_not_in_group(self, client_with_auth, db_session):
|
||||||
|
from db import models
|
||||||
|
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="interdimensional", frequency="hourly"))
|
||||||
|
db_session.commit()
|
||||||
|
r = client_with_auth.post("/sheet/123-sheet-id/archive")
|
||||||
|
assert r.status_code == 403
|
||||||
|
assert r.json() == {"detail": "User does not have access to this group."}
|
||||||
|
|
||||||
|
|
||||||
class TestTokenArchiveEndpoint:
|
class TestTokenArchiveEndpoint:
|
||||||
|
|
||||||
|
|||||||
@@ -10,11 +10,13 @@ def test_archive_url_unauthenticated(client, test_no_auth):
|
|||||||
|
|
||||||
@patch("worker.main.create_archive_task.delay", return_value=TaskResult(id="123-456-789", status="PENDING", result=""))
|
@patch("worker.main.create_archive_task.delay", return_value=TaskResult(id="123-456-789", status="PENDING", result=""))
|
||||||
def test_archive_url(m1, client_with_auth):
|
def test_archive_url(m1, client_with_auth):
|
||||||
|
# url is too short
|
||||||
response = client_with_auth.post("/url/archive", json={"url": "bad"})
|
response = client_with_auth.post("/url/archive", json={"url": "bad"})
|
||||||
assert response.status_code == 422
|
assert response.status_code == 422
|
||||||
assert response.json() == {'detail': 'Invalid URL received: bad'}
|
assert response.json()["detail"][0]["msg"] == 'String should have at least 5 characters'
|
||||||
m1.assert_not_called()
|
m1.assert_not_called()
|
||||||
|
|
||||||
|
# valid request
|
||||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
|
response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
|
||||||
assert response.status_code == 201
|
assert response.status_code == 201
|
||||||
assert response.json() == {'id': '123-456-789'}
|
assert response.json() == {'id': '123-456-789'}
|
||||||
@@ -23,6 +25,20 @@ def test_archive_url(m1, client_with_auth):
|
|||||||
called_val = m1.call_args.args[0]
|
called_val = m1.call_args.args[0]
|
||||||
assert json.loads(called_val) == {"id": None, "url": "https://example.com", "result": None, "public": True, "author_id": "rick@example.com", "group_id": None, "tags": [], "rearchive": True}
|
assert json.loads(called_val) == {"id": None, "url": "https://example.com", "result": None, "public": True, "author_id": "rick@example.com", "group_id": None, "tags": [], "rearchive": True}
|
||||||
|
|
||||||
|
# user is not in group
|
||||||
|
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "new-group"})
|
||||||
|
assert response.status_code == 403
|
||||||
|
assert response.json()["detail"] == "User does not have access to this group."
|
||||||
|
|
||||||
|
# user is in group
|
||||||
|
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spaceship"})
|
||||||
|
assert response.status_code == 201
|
||||||
|
assert response.json() == {'id': '123-456-789'}
|
||||||
|
|
||||||
|
assert m1.call_count == 2
|
||||||
|
called_val = m1.call_args.args[0]
|
||||||
|
assert json.loads(called_val)["group_id"] == "spaceship"
|
||||||
|
|
||||||
|
|
||||||
def test_search_by_url_unauthenticated(client, test_no_auth):
|
def test_search_by_url_unauthenticated(client, test_no_auth):
|
||||||
test_no_auth(client.get, "/url/search")
|
test_no_auth(client.get, "/url/search")
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ orchestrators:
|
|||||||
interdimensional: tests/orchestration.test.yaml
|
interdimensional: tests/orchestration.test.yaml
|
||||||
default: tests/orchestration.test.yaml
|
default: tests/orchestration.test.yaml
|
||||||
|
|
||||||
|
default_orchestrator: tests/orchestration.test.yaml
|
||||||
|
|
||||||
groups:
|
groups:
|
||||||
spaceship:
|
spaceship:
|
||||||
description: "The spaceship crew"
|
description: "The spaceship crew"
|
||||||
@@ -31,9 +33,9 @@ groups:
|
|||||||
permissions:
|
permissions:
|
||||||
read: ["all"]
|
read: ["all"]
|
||||||
active_sheets: -1
|
active_sheets: -1
|
||||||
monthly_urls: all
|
monthly_urls: -1
|
||||||
monthly_mbs: all
|
monthly_mbs: -1
|
||||||
alowed_frequency: "hourly"
|
allowed_frequency: "hourly"
|
||||||
interdimensional:
|
interdimensional:
|
||||||
description: "Interdimensional travelers"
|
description: "Interdimensional travelers"
|
||||||
orchestrator: tests/orchestration.test.yaml
|
orchestrator: tests/orchestration.test.yaml
|
||||||
@@ -43,7 +45,7 @@ groups:
|
|||||||
active_sheets: 5
|
active_sheets: 5
|
||||||
monthly_urls: 1000
|
monthly_urls: 1000
|
||||||
monthly_mbs: 1000
|
monthly_mbs: 1000
|
||||||
alowed_frequency: "hourly"
|
allowed_frequency: "hourly"
|
||||||
animated-characters:
|
animated-characters:
|
||||||
description: "Animated characters"
|
description: "Animated characters"
|
||||||
orchestrator: tests/orchestration.test.yaml
|
orchestrator: tests/orchestration.test.yaml
|
||||||
@@ -53,4 +55,4 @@ groups:
|
|||||||
active_sheets: 1
|
active_sheets: 1
|
||||||
monthly_urls: 2
|
monthly_urls: 2
|
||||||
monthly_mbs: 10
|
monthly_mbs: 10
|
||||||
alowed_frequency: "daily"
|
allowed_frequency: "daily"
|
||||||
@@ -122,14 +122,6 @@ class Test_create_sheet_task():
|
|||||||
|
|
||||||
assert db_session.query(models.Archive).filter(models.Archive.url == self.URL).count() == 0
|
assert db_session.query(models.Archive).filter(models.Archive.url == self.URL).count() == 0
|
||||||
|
|
||||||
@patch("worker.main.is_group_invalid_for_user", return_value="Access denied")
|
|
||||||
def test_error_access(self, m_insert, worker_init, db_session):
|
|
||||||
from worker.main import create_sheet_task
|
|
||||||
|
|
||||||
res = create_sheet_task(self.sheet.model_dump_json())
|
|
||||||
assert "error" in res
|
|
||||||
assert res["error"] == "Access denied"
|
|
||||||
|
|
||||||
|
|
||||||
def test_choose_orchestrator(worker_init):
|
def test_choose_orchestrator(worker_init):
|
||||||
from worker.main import choose_orchestrator
|
from worker.main import choose_orchestrator
|
||||||
|
|||||||
@@ -131,17 +131,19 @@ def app_factory(settings = get_settings()):
|
|||||||
|
|
||||||
|
|
||||||
@app.post("/sheet", status_code=201, deprecated=True) # DEPRECATED
|
@app.post("/sheet", status_code=201, deprecated=True) # DEPRECATED
|
||||||
def archive_sheet(sheet: schemas.SubmitSheet, email=Depends(get_user_auth)):
|
def archive_sheet(sheet: schemas.SubmitSheet, email=Depends(get_user_auth), db: Session = Depends(get_db_dependency)):
|
||||||
logger.info(f"SHEET TASK for {sheet=}")
|
logger.info(f"SHEET TASK for {sheet=}")
|
||||||
sheet.author_id = email
|
sheet.author_id = email
|
||||||
if not sheet.sheet_name and not sheet.sheet_id:
|
if not sheet.sheet_name and not sheet.sheet_id:
|
||||||
raise HTTPException(status_code=422, detail=f"sheet name or id is required")
|
raise HTTPException(status_code=422, detail=f"sheet name or id is required")
|
||||||
|
if not crud.is_user_in_group(db, email, sheet.group_id):
|
||||||
|
raise HTTPException(status_code=403, detail="User does not have access to this group.")
|
||||||
task = create_sheet_task.delay(sheet.model_dump_json())
|
task = create_sheet_task.delay(sheet.model_dump_json())
|
||||||
return JSONResponse({"id": task.id})
|
return JSONResponse({"id": task.id})
|
||||||
|
|
||||||
|
|
||||||
@app.post("/sheet_service", status_code=201, deprecated=True) # DEPRECATED
|
@app.post("/sheet_service", status_code=201, deprecated=True) # DEPRECATED
|
||||||
def archive_sheet_service(sheet: schemas.SubmitSheet, auth=Depends(token_api_key_auth)):
|
def archive_sheet_service(sheet: schemas.SubmitSheet, auth=Depends(token_api_key_auth), db: Session = Depends(get_db_dependency)):
|
||||||
logger.info(f"SHEET TASK for {sheet=}")
|
logger.info(f"SHEET TASK for {sheet=}")
|
||||||
sheet.author_id = sheet.author_id or "api-endpoint"
|
sheet.author_id = sheet.author_id or "api-endpoint"
|
||||||
if not sheet.sheet_name and not sheet.sheet_id:
|
if not sheet.sheet_name and not sheet.sheet_id:
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from core.config import ALLOW_ANY_EMAIL
|
|||||||
from shared.settings import get_settings
|
from shared.settings import get_settings
|
||||||
from db.database import get_db
|
from db.database import get_db
|
||||||
from db import crud
|
from db import crud
|
||||||
|
from db.user_state import UserState
|
||||||
|
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
bearer_security = HTTPBearer()
|
bearer_security = HTTPBearer()
|
||||||
@@ -84,3 +85,8 @@ def authenticate_user(access_token):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"AUTH EXCEPTION occurred: {e}")
|
logger.warning(f"AUTH EXCEPTION occurred: {e}")
|
||||||
return False, "exception occurred"
|
return False, "exception occurred"
|
||||||
|
|
||||||
|
|
||||||
|
def get_active_user_state(email=Depends(get_active_user_auth)):
|
||||||
|
with get_db() as db:
|
||||||
|
return UserState(db, email, active=True)
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
import traceback, yaml, datetime
|
import traceback, yaml, datetime
|
||||||
from typing import List, Set
|
from typing import List, Set
|
||||||
|
|
||||||
@@ -30,6 +31,7 @@ Rdis = redis.Redis.from_url(celery.conf.broker_url)
|
|||||||
def create_archive_task(self, archive_json: str):
|
def create_archive_task(self, archive_json: str):
|
||||||
archive = schemas.ArchiveCreate.model_validate_json(archive_json)
|
archive = schemas.ArchiveCreate.model_validate_json(archive_json)
|
||||||
logger.info(f"Archiving {archive.url=} {archive.tags=} {archive.public=} {archive.group_id=} {archive.author_id=}")
|
logger.info(f"Archiving {archive.url=} {archive.tags=} {archive.public=} {archive.group_id=} {archive.author_id=}")
|
||||||
|
#TODO: move group checks out of here
|
||||||
invalid = is_group_invalid_for_user(archive.public, archive.group_id, archive.author_id)
|
invalid = is_group_invalid_for_user(archive.public, archive.group_id, archive.author_id)
|
||||||
if invalid:
|
if invalid:
|
||||||
raise Exception(invalid) # marks task FAILED, saves the Exception as result
|
raise Exception(invalid) # marks task FAILED, saves the Exception as result
|
||||||
@@ -64,10 +66,6 @@ def create_sheet_task(self, sheet_json: str):
|
|||||||
sheet.tags.add("gsheet")
|
sheet.tags.add("gsheet")
|
||||||
logger.info(f"SHEET START {sheet=}")
|
logger.info(f"SHEET START {sheet=}")
|
||||||
|
|
||||||
#TODO: should this check live here?
|
|
||||||
if (em := is_group_invalid_for_user(sheet.public, sheet.group_id, sheet.author_id)):
|
|
||||||
return {"error": em}
|
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
||||||
# TODO: use choose_orchestrator and overwrite the feeder
|
# TODO: use choose_orchestrator and overwrite the feeder
|
||||||
# TODO: drop sheet_name and use only sheet_id (new endpoints/models)
|
# TODO: drop sheet_name and use only sheet_id (new endpoints/models)
|
||||||
@@ -161,7 +159,7 @@ def is_group_invalid_for_user(public: bool, group_id: str, author_id: str):
|
|||||||
|
|
||||||
# otherwise group must match
|
# otherwise group must match
|
||||||
with get_db() as session:
|
with get_db() as session:
|
||||||
if not crud.is_user_in_group(session, group_id, author_id):
|
if not crud.is_user_in_group(session, author_id, group_id):
|
||||||
logger.error(em := f"User {author_id} is not part of {group_id}, no permission")
|
logger.error(em := f"User {author_id} is not part of {group_id}, no permission")
|
||||||
return em
|
return em
|
||||||
return False
|
return False
|
||||||
@@ -220,3 +218,13 @@ def at_start(sender, **kwargs):
|
|||||||
ORCHESTRATORS = {}
|
ORCHESTRATORS = {}
|
||||||
load_orchestrators()
|
load_orchestrators()
|
||||||
logger.info("Orchestrators loaded successfully.")
|
logger.info("Orchestrators loaded successfully.")
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_url_orchestrator(group_name):
|
||||||
|
with get_db() as db:
|
||||||
|
group = crud.get_group(db, group_name)
|
||||||
|
assert group, f"Group {group_name} not found"
|
||||||
|
|
||||||
|
# config = Config()
|
||||||
|
# config.parse(use_cli=False, yaml_config_filename=group.orchestrator_sheet)
|
||||||
|
# return ArchivingOrchestrator(config)
|
||||||
Reference in New Issue
Block a user