minor refactor and user_state tests

This commit is contained in:
msramalho
2025-02-14 23:56:36 +00:00
parent 1a3078055b
commit 1a4508df33
5 changed files with 461 additions and 67 deletions

View File

@@ -5,15 +5,17 @@ from sqlalchemy import Column, or_, func, select
from loguru import logger
from datetime import datetime, timedelta
from sqlalchemy.ext.asyncio import AsyncSession
from cachetools import LRUCache, cached
from cachetools.keys import hashkey
from app.web.config import ALLOW_ANY_EMAIL
from app.shared.db.database import get_db
from app.shared.db import models
from app.shared.settings import get_settings
from app.shared.user_groups import UserGroups
from app.shared.utils.misc import fnv1a_hash_mod
from app.web.utils.misc import convert_priority_to_queue_dict
DATABASE_QUERY_LIMIT = get_settings().DATABASE_QUERY_LIMIT
@@ -33,7 +35,7 @@ def base_query(db: Session):
def get_archive(db: Session, id: str, email: str):
query = base_query(db).filter(models.Archive.id == id)
if email != ALLOW_ANY_EMAIL:
groups = get_user_groups(email)
groups = get_user_group_names(db ,email)
query = query.filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups)))
return query.first()
@@ -42,7 +44,7 @@ def search_archives_by_url(db: Session, url: str, email: str, skip: int = 0, lim
# searches for partial URLs, if email is * no ownership filtering happens
query = base_query(db)
if email != ALLOW_ANY_EMAIL:
groups = get_user_groups(email)
groups = get_user_group_names(db, email)
query = query.filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups)))
if absolute_search:
query = query.filter(models.Archive.url == url)
@@ -108,38 +110,39 @@ async def soft_delete_expired_archives(db: AsyncSession) -> dict:
# --------------- TAG
def is_user_in_group(email: str, group_name: str) -> models.Group:
if email == ALLOW_ANY_EMAIL: return True
return len(group_name) and len(email) and group_name in get_user_groups(email)
async def get_group_priority_async(db: AsyncSession, group_id: str) -> dict:
db_group = await db.get(models.Group, group_id)
priority = db_group.permissions.get("priority", "low") if db_group else "low"
return convert_priority_to_queue_dict(priority)
@lru_cache
def get_user_groups(email: str) -> list[str]:
@cached(cache=LRUCache(maxsize=128), key=lambda db, email: hashkey(email))
def get_user_group_names(db: Session, email: str) -> list[str]:
"""
given an email retrieves the user groups from the DB and then the email-domain groups from a global variable, the email does not need to belong to an existing user.
"""
if not email or not len(email) or "@" not in email: return []
with get_db() as db:
# get user groups
user_groups = db.query(models.association_table_user_groups).filter_by(user_id=email).with_entities(Column("group_id")).all()
user_level_groups_names = [g[0] for g in user_groups]
# get user groups
user_groups = db.query(models.association_table_user_groups).filter_by(user_id=email).with_entities(Column("group_id")).all()
user_level_groups_names = [g[0] for g in user_groups]
# get domain groups
domain = email.split('@')[1]
domain_level_groups = db.query(models.Group.id).filter(models.Group.domains.contains(domain)).with_entities(Column("id")).all()
domain_level_groups_names = [g[0] for g in domain_level_groups]
# get domain groups
domain = email.split('@')[1]
domain_level_groups = db.query(models.Group.id).filter(models.Group.domains.contains(domain)).with_entities(Column("id")).all()
domain_level_groups_names = [g[0] for g in domain_level_groups]
return list(set(user_level_groups_names + domain_level_groups_names))
return list(set(user_level_groups_names + domain_level_groups_names))
def get_user_groups_by_name(db: Session, groups: list[str]) -> list[models.Group]:
return db.query(models.Group).filter(
models.Group.id.in_(groups)
).all()
# --------------- INIT User-Groups
def upsert_group(db: Session, group_name: str, description: str, orchestrator: str, orchestrator_sheet: str, service_account_email: str, permissions: dict, domains: list) -> models.Group:
db_group = db.query(models.Group).filter(models.Group.id == group_name).first()
if db_group is None:

View File

@@ -47,15 +47,13 @@ class UserState:
@property
def user_groups_names(self):
if not hasattr(self, '_user_groups_names'):
self._user_groups_names = crud.get_user_groups(self.email) + ["default"]
self._user_groups_names = crud.get_user_group_names(self.db, self.email) + ["default"]
return self._user_groups_names
@property
def user_groups(self):
if not hasattr(self, '_user_groups'):
self._user_groups = self.db.query(models.Group).filter(
models.Group.id.in_(self.user_groups_names)
).all()
self._user_groups = crud.get_user_groups_by_name(self.db, self.user_groups_names)
return self._user_groups
@property
@@ -150,8 +148,9 @@ class UserState:
self._priority = "low"
for group in self.user_groups:
if not group.permissions: continue
if group.permissions.get("priority", "low") == "high":
if group.permissions.get("priority", self._priority) == "high":
self._priority = "high"
break
return self._priority
@property