mirror of
https://github.com/bellingcat/auto-archiver-api.git
synced 2026-06-12 21:48:35 +03:00
minor refactor and user_state tests
This commit is contained in:
@@ -5,15 +5,17 @@ from sqlalchemy import Column, or_, func, select
|
||||
from loguru import logger
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from cachetools import LRUCache, cached
|
||||
from cachetools.keys import hashkey
|
||||
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
from app.shared.db.database import get_db
|
||||
from app.shared.db import models
|
||||
from app.shared.settings import get_settings
|
||||
from app.shared.user_groups import UserGroups
|
||||
from app.shared.utils.misc import fnv1a_hash_mod
|
||||
from app.web.utils.misc import convert_priority_to_queue_dict
|
||||
|
||||
|
||||
DATABASE_QUERY_LIMIT = get_settings().DATABASE_QUERY_LIMIT
|
||||
|
||||
|
||||
@@ -33,7 +35,7 @@ def base_query(db: Session):
|
||||
def get_archive(db: Session, id: str, email: str):
|
||||
query = base_query(db).filter(models.Archive.id == id)
|
||||
if email != ALLOW_ANY_EMAIL:
|
||||
groups = get_user_groups(email)
|
||||
groups = get_user_group_names(db ,email)
|
||||
query = query.filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups)))
|
||||
return query.first()
|
||||
|
||||
@@ -42,7 +44,7 @@ def search_archives_by_url(db: Session, url: str, email: str, skip: int = 0, lim
|
||||
# searches for partial URLs, if email is * no ownership filtering happens
|
||||
query = base_query(db)
|
||||
if email != ALLOW_ANY_EMAIL:
|
||||
groups = get_user_groups(email)
|
||||
groups = get_user_group_names(db, email)
|
||||
query = query.filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups)))
|
||||
if absolute_search:
|
||||
query = query.filter(models.Archive.url == url)
|
||||
@@ -108,38 +110,39 @@ async def soft_delete_expired_archives(db: AsyncSession) -> dict:
|
||||
# --------------- TAG
|
||||
|
||||
|
||||
def is_user_in_group(email: str, group_name: str) -> models.Group:
|
||||
if email == ALLOW_ANY_EMAIL: return True
|
||||
return len(group_name) and len(email) and group_name in get_user_groups(email)
|
||||
|
||||
|
||||
async def get_group_priority_async(db: AsyncSession, group_id: str) -> dict:
|
||||
db_group = await db.get(models.Group, group_id)
|
||||
priority = db_group.permissions.get("priority", "low") if db_group else "low"
|
||||
return convert_priority_to_queue_dict(priority)
|
||||
|
||||
@lru_cache
|
||||
def get_user_groups(email: str) -> list[str]:
|
||||
|
||||
@cached(cache=LRUCache(maxsize=128), key=lambda db, email: hashkey(email))
|
||||
def get_user_group_names(db: Session, email: str) -> list[str]:
|
||||
"""
|
||||
given an email retrieves the user groups from the DB and then the email-domain groups from a global variable, the email does not need to belong to an existing user.
|
||||
"""
|
||||
if not email or not len(email) or "@" not in email: return []
|
||||
|
||||
with get_db() as db:
|
||||
# get user groups
|
||||
user_groups = db.query(models.association_table_user_groups).filter_by(user_id=email).with_entities(Column("group_id")).all()
|
||||
user_level_groups_names = [g[0] for g in user_groups]
|
||||
# get user groups
|
||||
user_groups = db.query(models.association_table_user_groups).filter_by(user_id=email).with_entities(Column("group_id")).all()
|
||||
user_level_groups_names = [g[0] for g in user_groups]
|
||||
|
||||
# get domain groups
|
||||
domain = email.split('@')[1]
|
||||
domain_level_groups = db.query(models.Group.id).filter(models.Group.domains.contains(domain)).with_entities(Column("id")).all()
|
||||
domain_level_groups_names = [g[0] for g in domain_level_groups]
|
||||
# get domain groups
|
||||
domain = email.split('@')[1]
|
||||
domain_level_groups = db.query(models.Group.id).filter(models.Group.domains.contains(domain)).with_entities(Column("id")).all()
|
||||
domain_level_groups_names = [g[0] for g in domain_level_groups]
|
||||
|
||||
return list(set(user_level_groups_names + domain_level_groups_names))
|
||||
return list(set(user_level_groups_names + domain_level_groups_names))
|
||||
|
||||
|
||||
def get_user_groups_by_name(db: Session, groups: list[str]) -> list[models.Group]:
|
||||
return db.query(models.Group).filter(
|
||||
models.Group.id.in_(groups)
|
||||
).all()
|
||||
|
||||
# --------------- INIT User-Groups
|
||||
|
||||
|
||||
def upsert_group(db: Session, group_name: str, description: str, orchestrator: str, orchestrator_sheet: str, service_account_email: str, permissions: dict, domains: list) -> models.Group:
|
||||
db_group = db.query(models.Group).filter(models.Group.id == group_name).first()
|
||||
if db_group is None:
|
||||
|
||||
@@ -47,15 +47,13 @@ class UserState:
|
||||
@property
|
||||
def user_groups_names(self):
|
||||
if not hasattr(self, '_user_groups_names'):
|
||||
self._user_groups_names = crud.get_user_groups(self.email) + ["default"]
|
||||
self._user_groups_names = crud.get_user_group_names(self.db, self.email) + ["default"]
|
||||
return self._user_groups_names
|
||||
|
||||
@property
|
||||
def user_groups(self):
|
||||
if not hasattr(self, '_user_groups'):
|
||||
self._user_groups = self.db.query(models.Group).filter(
|
||||
models.Group.id.in_(self.user_groups_names)
|
||||
).all()
|
||||
self._user_groups = crud.get_user_groups_by_name(self.db, self.user_groups_names)
|
||||
return self._user_groups
|
||||
|
||||
@property
|
||||
@@ -150,8 +148,9 @@ class UserState:
|
||||
self._priority = "low"
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
if group.permissions.get("priority", "low") == "high":
|
||||
if group.permissions.get("priority", self._priority) == "high":
|
||||
self._priority = "high"
|
||||
break
|
||||
return self._priority
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user