diff --git a/src/db/crud.py b/src/db/crud.py index 1368bc0..d09a4c8 100644 --- a/src/db/crud.py +++ b/src/db/crud.py @@ -112,17 +112,6 @@ def create_tag(db: Session, tag: str): return db_tag -def is_active_user(db: Session, email: str) -> bool: - email = email.lower() - if not email or not len(email) or "@" not in email: return False - domain = email.split('@')[1] - - explicitly_active = db.query(models.User).filter(models.User.email == email, models.User.is_active == True).first() is not None - if explicitly_active: return True - - return db.query(models.Group).filter(models.Group.domains.contains(domain)).first() is not None - - def is_user_in_group(db: Session, email: str, group_name: str) -> models.Group: if email == ALLOW_ANY_EMAIL: return True return len(group_name) and len(email) and group_name in get_user_groups(email) @@ -131,7 +120,7 @@ def is_user_in_group(db: Session, email: str, group_name: str) -> models.Group: @lru_cache def get_user_groups(email: str) -> list[str]: """ - given an email retrieves the user groups from the DB and then the email-domain groups from a global variable, the email does not need to belong to an existing user. User does not need to be active. + given an email retrieves the user groups from the DB and then the email-domain groups from a global variable, the email does not need to belong to an existing user. """ if not email or not len(email) or "@" not in email: return [] email = email.lower() @@ -155,11 +144,11 @@ def get_group(db: Session, group_name: str) -> models.Group: return db.query(models.Group).filter(models.Group.id == group_name).first() -def create_or_get_user(db: Session, author_id: str, is_active: bool = models.User.is_active.default.arg) -> models.User: +def create_or_get_user(db: Session, author_id: str) -> models.User: if type(author_id) == str: author_id = author_id.lower() db_user = db.query(models.User).filter(models.User.email == author_id).first() if not db_user: - db_user = models.User(email=author_id, is_active=is_active) + db_user = models.User(email=author_id) db.add(db_user) db.commit() db.refresh(db_user) @@ -182,14 +171,12 @@ def upsert_group(db: Session, group_name: str, description: str, orchestrator: s return db_group -def upsert_user(db: Session, email: str, active: bool): +def upsert_user(db: Session, email: str): db_user = db.query(models.User).filter(models.User.email == email).first() if db_user is None: - db_user = models.User(email=email, is_active=active) + db_user = models.User(email=email) db.add(db_user) - else: - db_user.is_active = active - db.commit() + db.commit() return db_user @@ -208,9 +195,6 @@ def upsert_user_groups(db: Session): # delete all user-groups relationships db.query(models.association_table_user_groups).delete() - # set all users to inactive - db.query(models.User).update({models.User.is_active: False}) - # create a map of group_id -> domains and another of domain -> groups group_domains = defaultdict(set) domain_groups = defaultdict(list) @@ -227,7 +211,7 @@ def upsert_user_groups(db: Session): # integrity checks for group_in_domains in group_domains: if group_in_domains not in db_groups: - logger.error(f"[CONFIG] Group '{group_in_domains}' does not exist in the database: domains setting will not work.") + logger.warning(f"[CONFIG] Group '{group_in_domains}' does not exist in the database: domains setting will not work.") # reinsert users in their EXPLICITLY DEFINED groups # domain groups are check live, as there may be new users that are not explicitly registered but belong to a domain @@ -235,13 +219,12 @@ def upsert_user_groups(db: Session): explicit_groups = explicit_groups or [] logger.info(f"EXPLICIT {display_email_pii(email)} => {explicit_groups}") - # upsert active user - db_user = upsert_user(db, email, active=True) + db_user = upsert_user(db, email) # connect users to groups for group_id in explicit_groups: if group_id not in db_groups: - logger.error(f"[CONFIG] Group {group_id} does not exist in config file, skipping for email={display_email_pii(email)}.") + logger.warning(f"[CONFIG] Group {group_id} does not exist in config file, skipping for email={display_email_pii(email)}.") continue db_groups[group_id].users.append(db_user) diff --git a/src/db/models.py b/src/db/models.py index afadec6..d8b12c8 100644 --- a/src/db/models.py +++ b/src/db/models.py @@ -73,7 +73,6 @@ class User(Base): __tablename__ = "users" email = Column(String, primary_key=True, index=True) - is_active = Column(Boolean, default=False) archives = relationship("Archive", back_populates="author") sheets = relationship("Sheet", back_populates="author") diff --git a/src/db/user_state.py b/src/db/user_state.py index 99c68ac..28074aa 100644 --- a/src/db/user_state.py +++ b/src/db/user_state.py @@ -1,9 +1,11 @@ +from typing import Dict, Set import sqlalchemy from sqlalchemy.orm import Session from sqlalchemy import func from db import crud, models from datetime import datetime +from shared.user_groups import GroupPermissions class UserState: @@ -11,10 +13,26 @@ class UserState: Manage a user's state and permissions """ - def __init__(self, db: Session, email: str, active=False): + def __init__(self, db: Session, email: str): self.db = db self.email = email - self.active = active + + @property + def permissions(self) -> Dict[str, GroupPermissions]: + """ + Returns a dict of all group permissions and a special {"all": read/archive_url/archive_sheet} key + """ + if not hasattr(self, '_permissions'): + self._permissions = {} + self._permissions["all"] = GroupPermissions( + read=self.read, + archive_url=self.archive_url, + archive_sheet=self.archive_sheet, + ) + for group in self.user_groups: + if not group.permissions: continue + self._permissions[group.id] = GroupPermissions(**group.permissions) + return self._permissions @property def user_groups_names(self): @@ -31,7 +49,52 @@ class UserState: return self._user_groups @property - def allowed_frequencies(self): + def read(self) -> Set[str] | bool: + """ + Read can be a list of group names or True, if all can be read. + """ + if not hasattr(self, '_read'): + self._read = set() + for group in self.user_groups: + if not group.permissions: continue + group_read_permissions = group.permissions.get("read", []) + if "all" in group_read_permissions: + self._read = True + return self._read + else: + self._read.update(group.permissions.get("read", [])) + return self._read + + @property + def archive_url(self) -> bool: + """ + Archive URL permission + """ + if not hasattr(self, '_archive_url'): + self._archive_url = False + for group in self.user_groups: + if not group.permissions: continue + if group.permissions.get("archive_url", False): + self._archive_url = True + return self._archive_url + return self._archive_url + + @property + def archive_sheet(self) -> bool: + """ + Archive sheet permission + """ + if not hasattr(self, '_archive_sheet'): + self._archive_sheet = False + for group in self.user_groups: + if not group.permissions: continue + if group.permissions.get("archive_sheet", False): + self._archive_sheet = True + return self._archive_sheet + return self._archive_sheet + + @property + def sheet_frequency(self): if not hasattr(self, '_sheet_frequency'): self._sheet_frequency = set() for group in self.user_groups: @@ -40,22 +103,31 @@ class UserState: return self._sheet_frequency @property - def sheet_quota(self): + def max_sheets(self): """ infer the user's sheet quota from the groups -1 means unlimited """ - if not hasattr(self, '_sheet_quota'): - self._sheet_quota = 0 + if not hasattr(self, '_max_sheets'): + self._max_sheets = 0 for group in self.user_groups: if not group.permissions: continue max_sheets = group.permissions.get("max_sheets", 0) if max_sheets == -1: - self._sheet_quota = -1 - return self._sheet_quota - self._sheet_quota = max(self._sheet_quota, max_sheets) + self._max_sheets = -1 + return self._max_sheets + self._max_sheets = max(self._max_sheets, max_sheets) - return self._sheet_quota + return self._max_sheets + + @property + def active(self) -> bool: + """ + A user is active if they can read/archive anything + """ + if not hasattr(self, '_active'): + self._active = bool(self.read or self.archive_url or self.archive_sheet) + return self._active def in_group(self, group_id: str) -> bool: return group_id in self.user_groups_names @@ -64,11 +136,11 @@ class UserState: """ checks if a user has reached their sheet quota """ - if self.sheet_quota == -1: return True + if self.max_sheets == -1: return True user_sheets = self.db.query(models.Sheet).filter(models.Sheet.author_id == self.email).count() - return user_sheets < self.sheet_quota + return user_sheets < self.max_sheets def has_quota_max_monthly_urls(self) -> bool: """ @@ -137,4 +209,4 @@ class UserState: """ checks if a user is allowed to create a sheet with this frequency """ - return frequency in self.allowed_frequencies + return frequency in self.sheet_frequency diff --git a/src/endpoints/default.py b/src/endpoints/default.py index 62d1174..616ea8c 100644 --- a/src/endpoints/default.py +++ b/src/endpoints/default.py @@ -1,4 +1,5 @@ +from typing import Dict from fastapi import APIRouter, Depends, Request, HTTPException from fastapi.responses import FileResponse, JSONResponse from sqlalchemy.orm import Session @@ -8,7 +9,8 @@ from core.logging import log_error from db import crud, schemas from db.database import get_db_dependency from db.user_state import UserState -from web.security import get_user_auth, bearer_security, get_active_user_state +from web.security import get_user_auth, bearer_security, get_user_state +from shared.user_groups import GroupPermissions default_router = APIRouter() @@ -32,27 +34,22 @@ async def health(): @default_router.get("/user/active", summary="Check if the user is active and can use the tool.") # TODO: reorder db dependencies to after auth -async def active(db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)) -> schemas.ActiveUser: - return {"active": crud.is_active_user(db, email)} +async def active( + user: UserState = Depends(get_user_state), +) -> schemas.ActiveUser: + return {"active": user.active} -@default_router.get("/groups") +@default_router.get("/groups", deprecated=True) # DEPRECATED, only used by extension def get_user_groups(email=Depends(get_user_auth)) -> list[str]: return crud.get_user_groups(email) @default_router.get("/permissions") def get_user_groups( - user: UserState = Depends(get_active_user_state), -) -> list[str]: - return JSONResponse({ - "groups": user.user_groups_names, - "allowedFrequencies": list(user.allowed_frequencies), - "sheet_quota": user.sheet_quota, - "max_monthly_urls": user.max_monthly_urls, #TODO - "max_monthly_mbs": user.max_monthly_mbs, # TODO - #TODO: should this return - }) + user: UserState = Depends(get_user_state), +) -> Dict[str, GroupPermissions]: + return user.permissions @default_router.get('/favicon.ico', include_in_schema=False) diff --git a/src/endpoints/sheet.py b/src/endpoints/sheet.py index 263966a..07d3ee3 100644 --- a/src/endpoints/sheet.py +++ b/src/endpoints/sheet.py @@ -6,7 +6,7 @@ from sqlalchemy import exc from sqlalchemy.orm import Session from db.user_state import UserState -from web.security import token_api_key_auth, get_active_user_auth, get_active_user_state +from web.security import token_api_key_auth, get_user_auth, get_user_state from db import schemas, crud from db.database import get_db_dependency from worker.main import create_sheet_task @@ -17,7 +17,7 @@ sheet_router = APIRouter(prefix="/sheet", tags=["Google Spreadsheet operations"] @sheet_router.post("/create", status_code=201, summary="Store a new Google Sheet for regular archiving.") def create_sheet( sheet: schemas.SheetAdd, - user: UserState = Depends(get_active_user_state), + user: UserState = Depends(get_user_state), db: Session = Depends(get_db_dependency), ) -> schemas.SheetResponse: @@ -28,7 +28,7 @@ def create_sheet( raise HTTPException(status_code=429, detail="User has reached their sheet quota.") if not user.is_sheet_frequency_allowed(sheet.frequency): - raise HTTPException(status_code=422, detail=f"Invalid frequency: {sheet.frequency}. Must be one of {user.allowed_frequencies}") + raise HTTPException(status_code=422, detail=f"Invalid frequency: {sheet.frequency}. Must be one of {user.sheet_frequency}") try: return crud.create_sheet(db, sheet.id, sheet.name, user.email, sheet.group_id, sheet.frequency) @@ -38,7 +38,7 @@ def create_sheet( @sheet_router.get("/mine", status_code=200, summary="Get the authenticated user's Google Sheets.") def get_user_sheets( - email=Depends(get_active_user_auth), + email=Depends(get_user_auth), db: Session = Depends(get_db_dependency) ) -> list[schemas.SheetResponse]: return crud.get_user_sheets(db, email) @@ -47,7 +47,7 @@ def get_user_sheets( @sheet_router.delete("/{id}", summary="Delete a Google Sheet by ID.") def delete_sheet( id: str, - email=Depends(get_active_user_auth), + email=Depends(get_user_auth), db: Session = Depends(get_db_dependency), ) -> schemas.TaskDelete: return JSONResponse({ @@ -59,7 +59,7 @@ def delete_sheet( @sheet_router.post("/{id}/archive", status_code=201, summary="Trigger an archiving task for a GSheet you own.", response_description="task_id for the archiving task.") def archive_user_sheet( id: str, - user: UserState = Depends(get_active_user_state), + user: UserState = Depends(get_user_state), db: Session = Depends(get_db_dependency), ) -> schemas.Task: diff --git a/src/migrations/versions/a23aaf3ae930_drop_active_column.py b/src/migrations/versions/a23aaf3ae930_drop_active_column.py new file mode 100644 index 0000000..912f408 --- /dev/null +++ b/src/migrations/versions/a23aaf3ae930_drop_active_column.py @@ -0,0 +1,34 @@ +"""drop active column + +Revision ID: a23aaf3ae930 +Revises: 89121d2c96d8 +Create Date: 2025-02-04 12:19:20.753570 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'a23aaf3ae930' +down_revision = '89121d2c96d8' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + conn = op.get_bind() + inspector = sa.inspect(conn) + columns = [col['name'] for col in inspector.get_columns('users')] + + if 'is_active' in columns: + op.drop_column('users', 'is_active') + + +def downgrade() -> None: + conn = op.get_bind() + inspector = sa.inspect(conn) + columns = [col['name'] for col in inspector.get_columns('users')] + + if 'is_active' not in columns: + op.add_column('users', sa.Column('is_active', sa.Boolean(), nullable=False, server_default=sa.false())) diff --git a/src/shared/user_groups.py b/src/shared/user_groups.py index d4ee02f..12e4836 100644 --- a/src/shared/user_groups.py +++ b/src/shared/user_groups.py @@ -31,7 +31,7 @@ class UserGroups: class GroupPermissions(BaseModel): - read: Set[str] = Field(default_factory=list) + read: Set[str] | bool = Field(default_factory=list) archive_url: bool = False archive_sheet: bool = False sheet_frequency: Set[str] = Field(default_factory=list) diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 8091062..58ce781 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -76,11 +76,10 @@ def client(app): @pytest.fixture() def app_with_auth(app, db_session): - from web.security import get_token_or_user_auth, get_user_auth, get_active_user_auth, get_active_user_state + from web.security import get_token_or_user_auth, get_user_auth, get_user_state app.dependency_overrides[get_token_or_user_auth] = lambda: "rick@example.com" app.dependency_overrides[get_user_auth] = lambda: "morty@example.com" - app.dependency_overrides[get_active_user_auth] = lambda: "morty@example.com" - app.dependency_overrides[get_active_user_state] = lambda: UserState(db_session, "morty@example.com", active=True) + app.dependency_overrides[get_user_state] = lambda: UserState(db_session, "morty@example.com") return app diff --git a/src/tests/db/test_crud.py b/src/tests/db/test_crud.py index 79c6894..59b76d8 100644 --- a/src/tests/db/test_crud.py +++ b/src/tests/db/test_crud.py @@ -303,20 +303,6 @@ def test_create_tag(db_session): assert db_session.query(models.Tag).count() == 2 -def test_is_active_user(test_data, db_session): - from db import crud - - assert crud.is_active_user(db_session, "") == False - assert crud.is_active_user(db_session, "example.com") == False - assert crud.is_active_user(db_session, "unknown@example.com") == True - assert crud.is_active_user(db_session, "ANYONE@example.com") == True - assert crud.is_active_user(db_session, "ANYONE@birdy.com") == True - assert crud.is_active_user(db_session, "rick@example.com") == True - assert crud.is_active_user(db_session, "RICK@example.com") == True - assert crud.is_active_user(db_session, "summer@herself.com") == False - assert crud.is_active_user(db_session, "rick@not-in-groups.com") == False - - def test_is_user_in_group(test_data, db_session): from db import crud from core.config import ALLOW_ANY_EMAIL @@ -363,7 +349,7 @@ def test_get_group(test_data, db_session): assert crud.get_group(db_session, "spaceship") is not None assert crud.get_group(db_session, "interdimensional") is not None assert crud.get_group(db_session, "animated-characters") is not None - assert crud.get_group(db_session, "non-existant!@#!%!") is None + assert crud.get_group(db_session, "non-existent!@#!%!") is None def test_create_or_get_user(test_data, db_session): @@ -374,19 +360,12 @@ def test_create_or_get_user(test_data, db_session): # already exists assert (u1 := crud.create_or_get_user(db_session, "rick@example.com")) is not None assert u1.email == "rick@example.com" - assert u1.is_active == True - # new active - assert (u2 := crud.create_or_get_user(db_session, "beth@example.com", is_active=True)) is not None + # new user + assert (u2 := crud.create_or_get_user(db_session, "beth@example.com")) is not None assert u2.email == "beth@example.com" - assert u2.is_active == True - # new not active - assert (u3 := crud.create_or_get_user(db_session, "not-active@example.com")) is not None - assert u3.email == "not-active@example.com" - assert u3.is_active == False - - assert db_session.query(models.User).count() == 5 + assert db_session.query(models.User).count() == 4 def test_upsert_group(test_data, db_session): diff --git a/src/tests/endpoints/test_sheet.py b/src/tests/endpoints/test_sheet.py index e9949a2..71df69d 100644 --- a/src/tests/endpoints/test_sheet.py +++ b/src/tests/endpoints/test_sheet.py @@ -54,9 +54,9 @@ def test_create_sheet_endpoint(app_with_auth, db_session): assert response.json() == {"detail": "User does not have access to this group."} # switch to jerry who's got less quota/permissions - from web.security import get_active_user_state + from web.security import get_user_state from db.user_state import UserState - app_with_auth.dependency_overrides[get_active_user_state] = lambda: UserState(db_session, "jerry@example.com", active=True) + app_with_auth.dependency_overrides[get_user_state] = lambda: UserState(db_session, "jerry@example.com") client_jerry = TestClient(app_with_auth) # frequency not allowed diff --git a/src/tests/web/test_security.py b/src/tests/web/test_security.py index f82874c..c7427d1 100644 --- a/src/tests/web/test_security.py +++ b/src/tests/web/test_security.py @@ -40,24 +40,6 @@ async def test_get_user_auth(m1): assert await get_user_auth(good_user) == "summer@example.com" -@patch("web.security.authenticate_user", return_value=(True, "summer@example.com")) -@pytest.mark.asyncio -async def test_get_active_user_auth_inactive(m1, db_session): - from web.security import get_active_user_auth - - # inactive at first - creds = HTTPAuthorizationCredentials(scheme="ipsum", credentials="valid-and-good") - with pytest.raises(HTTPException): - await get_active_user_auth(creds) - - from db import models - db_session.add(models.User(email="summer@example.com", is_active=True)) - db_session.commit() - assert await get_active_user_auth(creds) == "summer@example.com" - - - - @patch("web.security.secure_compare", return_value=False) @pytest.mark.asyncio async def test_token_api_key_auth_exception(m1): diff --git a/src/web/security.py b/src/web/security.py index cbb4cae..772fcbd 100644 --- a/src/web/security.py +++ b/src/web/security.py @@ -57,15 +57,6 @@ async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bear ) -async def get_active_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)): - # validates Bearer token and Active User status - email = await get_user_auth(credentials) - with get_db() as db: - if crud.is_active_user(db, email): - return email - raise HTTPException(status_code=403, detail="User is not active") - - def authenticate_user(access_token): # https://cloud.google.com/docs/authentication/token-types#access if type(access_token) != str or len(access_token) < 10: return False, "invalid access_token" @@ -87,6 +78,6 @@ def authenticate_user(access_token): return False, "exception occurred" -def get_active_user_state(email=Depends(get_active_user_auth)): +def get_user_state(email=Depends(get_user_auth)): with get_db() as db: - return UserState(db, email, active=True) \ No newline at end of file + return UserState(db, email) \ No newline at end of file