diff --git a/src/db/crud.py b/src/db/crud.py index 744c711..47e6c09 100644 --- a/src/db/crud.py +++ b/src/db/crud.py @@ -3,9 +3,10 @@ from sqlalchemy.orm import Session, load_only from sqlalchemy import Column, or_ from loguru import logger from . import models, schemas -import yaml +import yaml, os DOMAIN_GROUPS = {} +DOMAIN_GROUPS_LOADED = False # --------------- TASK = Archive @@ -69,11 +70,12 @@ def search_tags(db: Session, tag: str, skip: int = 0, limit: int = 100): return db.query(models.Tag).filter(models.Tag.url.like(f'%{tag}%')).offset(skip).limit(limit).all() -def get_group_for_user(db: Session, group_name: str, email: str) -> models.Group: - return db.query(models.association_table_user_groups).filter_by(user_id=email, group_id=group_name).first() - +def is_user_in_group(db: Session, group_name: str, email: str) -> models.Group: + return len(group_name) and len(email) and group_name in get_user_groups(db, email) def get_user_groups(db: Session, email: str): + global DOMAIN_GROUPS, DOMAIN_GROUPS_LOADED + if not DOMAIN_GROUPS_LOADED: upsert_user_groups(db) # given an email retrieves the user groups from the DB and then the email-domain groups from a global variable groups = db.query(models.association_table_user_groups).filter_by(user_id=email).with_entities(Column("group_id")).all() user_level_groups = [g[0] for g in groups] @@ -104,13 +106,14 @@ def get_group(db: Session, group_name: str) -> models.Group: return db_group -def upsert_user_groups(db: Session, filename: str): - global DOMAIN_GROUPS +def upsert_user_groups(db: Session): + global DOMAIN_GROUPS, DOMAIN_GROUPS_LOADED """ reads the user_groups yaml file and inserts any new users, groups, along with new participation of users in groups """ logger.debug("Updating user-groups configuration.") + filename = os.environ.get("USER_GROUPS_FILENAME", "user-groups.yaml") # read yaml safely with open(filename) as inf: @@ -142,3 +145,4 @@ def upsert_user_groups(db: Session, filename: str): db.commit() count_user_groups = db.query(models.association_table_user_groups).count() logger.success(f"Completed refresh, now: {count_user_groups} user-groups relationships.") + DOMAIN_GROUPS_LOADED = True diff --git a/src/main.py b/src/main.py index 9f68bd3..23ea984 100644 --- a/src/main.py +++ b/src/main.py @@ -77,7 +77,8 @@ def get_user_groups(db: Session = Depends(get_db), email = Depends(get_bearer_au @app.get("/tasks/search-url", response_model=list[schemas.Archive]) def search_by_url(url:str, skip: int = 0, limit: int = 100, db: Session = Depends(get_db), email = Depends(get_bearer_auth)): - return crud.search_tasks_by_url(db, url, email, skip=skip, limit=limit) + #TODO: test strip + return crud.search_tasks_by_url(db, url.strip(), email, skip=skip, limit=limit) @app.get("/tasks/sync", response_model=list[schemas.Archive]) def search(skip: int = 0, limit: int = 100, db: Session = Depends(get_db), email = Depends(get_bearer_auth)): @@ -184,5 +185,4 @@ async def on_startup(): @repeat_every(seconds=60 * 60) # 1 hour async def on_startup(): db: Session = next(get_db()) - USER_GROUPS_FILENAME=os.environ.get("USER_GROUPS_FILENAME", "user-groups.yaml") - crud.upsert_user_groups(db, USER_GROUPS_FILENAME) \ No newline at end of file + crud.upsert_user_groups(db) \ No newline at end of file diff --git a/src/worker.py b/src/worker.py index 790eb2a..53c359b 100644 --- a/src/worker.py +++ b/src/worker.py @@ -137,8 +137,7 @@ def is_group_invalid_for_user(public: bool, group_id: str, author_id: str): if not public and group_id and len(group_id) > 0: # ensure group is valid for user with get_db() as session: - db_group = crud.get_group_for_user(session, group_id, author_id) - if not db_group: + if not crud.is_user_in_group(session, group_id, author_id): logger.error(em := f"User {author_id} is not part of {group_id}, no permission") return em return False