From 809438fbb9f4156b7b879505dbe8849f6fe6c705 Mon Sep 17 00:00:00 2001 From: msramalho <19508417+msramalho@users.noreply.github.com> Date: Tue, 4 Feb 2025 15:40:20 +0000 Subject: [PATCH] introduces user.read_public drops unused endpoints --- src/db/crud.py | 5 +-- src/db/user_state.py | 19 +++++++++-- src/endpoints/url.py | 54 ++++++++++++++++--------------- src/shared/user_groups.py | 4 +-- src/tests/conftest.py | 2 +- src/tests/db/test_crud.py | 9 +++--- src/tests/endpoints/test_url.py | 56 --------------------------------- src/tests/user-groups.test.yaml | 10 +++--- src/tests/web/test_main.py | 4 +-- src/web/security.py | 2 +- 10 files changed, 61 insertions(+), 104 deletions(-) diff --git a/src/db/crud.py b/src/db/crud.py index d09a4c8..d9e27d9 100644 --- a/src/db/crud.py +++ b/src/db/crud.py @@ -22,7 +22,6 @@ def get_limit(user_limit: int): def get_archive(db: Session, id: str, email: str): - email = email.lower() query = base_query(db).filter(models.Archive.id == id) if email != ALLOW_ANY_EMAIL: groups = get_user_groups(email) @@ -34,7 +33,6 @@ def search_archives_by_url(db: Session, url: str, email: str, skip: int = 0, lim # searches for partial URLs, if email is * no ownership filtering happens query = base_query(db) if email != ALLOW_ANY_EMAIL: - email = email.lower() groups = get_user_groups(email) query = query.filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups))) if absolute_search: @@ -49,7 +47,6 @@ def search_archives_by_url(db: Session, url: str, email: str, skip: int = 0, lim def search_archives_by_email(db: Session, email: str, skip: int = 0, limit: int = 100): - email = email.lower() return base_query(db).filter(models.Archive.author_id == email).order_by(models.Archive.created_at.desc()).offset(skip).limit(get_limit(limit)).all() @@ -123,7 +120,6 @@ def get_user_groups(email: str) -> list[str]: given an email retrieves the user groups from the DB and then the email-domain groups from a global variable, the email does not need to belong to an existing user. """ if not email or not len(email) or "@" not in email: return [] - email = email.lower() with get_db() as db: # get user groups @@ -172,6 +168,7 @@ def upsert_group(db: Session, group_name: str, description: str, orchestrator: s def upsert_user(db: Session, email: str): + email = email.lower() db_user = db.query(models.User).filter(models.User.email == email).first() if db_user is None: db_user = models.User(email=email) diff --git a/src/db/user_state.py b/src/db/user_state.py index 28074aa..e3af199 100644 --- a/src/db/user_state.py +++ b/src/db/user_state.py @@ -15,7 +15,7 @@ class UserState: def __init__(self, db: Session, email: str): self.db = db - self.email = email + self.email = email.lower() @property def permissions(self) -> Dict[str, GroupPermissions]: @@ -26,6 +26,7 @@ class UserState: self._permissions = {} self._permissions["all"] = GroupPermissions( read=self.read, + read_public=self.read_public, archive_url=self.archive_url, archive_sheet=self.archive_sheet, ) @@ -65,6 +66,20 @@ class UserState: self._read.update(group.permissions.get("read", [])) return self._read + @property + def read_public(self) -> bool: + """ + Read public permission + """ + if not hasattr(self, '_read_public'): + self._read_public = False + for group in self.user_groups: + if not group.permissions: continue + if group.permissions.get("read_public", False): + self._read_public = True + return self._read_public + return self._read_public + @property def archive_url(self) -> bool: """ @@ -126,7 +141,7 @@ class UserState: A user is active if they can read/archive anything """ if not hasattr(self, '_active'): - self._active = bool(self.read or self.archive_url or self.archive_sheet) + self._active = bool(self.read or self.read_public or self.archive_url or self.archive_sheet) return self._active def in_group(self, group_id: str) -> bool: diff --git a/src/endpoints/url.py b/src/endpoints/url.py index 58cf3c4..d32e6f6 100644 --- a/src/endpoints/url.py +++ b/src/endpoints/url.py @@ -4,11 +4,13 @@ from fastapi.responses import JSONResponse from datetime import datetime from loguru import logger -from web.security import get_user_auth, get_token_or_user_auth +from core.config import ALLOW_ANY_EMAIL +from db.user_state import UserState +from web.security import get_token_or_user_auth, get_user_state from sqlalchemy.orm import Session from db import crud, schemas -from db.database import get_db, get_db_dependency +from db.database import get_db_dependency from worker.main import create_archive_task @@ -18,16 +20,19 @@ url_router = APIRouter(prefix="/url", tags=["Single URL operations"]) @url_router.post("/archive", status_code=201, summary="Submit a single URL archive request, starts an archiving task.", response_description="task_id for the archiving task, will match the archive id.") def archive_url( archive: schemas.ArchiveTrigger, - email=Depends(get_token_or_user_auth) + email=Depends(get_token_or_user_auth), + db: Session = Depends(get_db_dependency) ) -> schemas.Task: logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {archive.url}") - # TODO: implement quota - - if archive.group_id: - with get_db() as db: - if not crud.is_user_in_group(db, email, archive.group_id): - raise HTTPException(status_code=403, detail="User does not have access to this group.") + if email != ALLOW_ANY_EMAIL: + user = UserState(db, email) + if not user.has_quota_max_monthly_urls(): + raise HTTPException(status_code=429, detail="User has reached their monthly URL quota.") + if not user.has_quota_max_monthly_mbs(): + raise HTTPException(status_code=429, detail="User has reached their monthly MB quota.") + if archive.group_id and not user.in_group(archive.group_id): + raise HTTPException(status_code=403, detail="User does not have access to this group.") # TODO: deprecate ArchiveCreate backwards_compatible_archive = schemas.ArchiveCreate( @@ -47,28 +52,25 @@ def search_by_url( url: str, skip: int = 0, limit: int = 25, archived_after: datetime = None, archived_before: datetime = None, db: Session = Depends(get_db_dependency), - email=Depends(get_token_or_user_auth) + email: str = Depends(get_token_or_user_auth) ) -> list[schemas.ArchiveResult]: + + if email != ALLOW_ANY_EMAIL: + user = UserState(db, email) + if not user.read and not user.read_public: + raise HTTPException(status_code=403, detail="User does not have read access.") + return crud.search_archives_by_url(db, url.strip(), email, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before) -@url_router.get("/latest", summary="Fetch latest URL archives for the authenticated user.") -def latest(skip: int = 0, limit: int = 25, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)) -> list[schemas.ArchiveResult]: - return crud.search_archives_by_email(db, email, skip=skip, limit=limit) - -# TODO: find out where/if this is used, tests are also disabled -# @url_router.get("/{id}", summary="Fetch a single URL archive by the associated id.") -# def lookup(id, db: Session = Depends(get_db_dependency), email=Depends(get_token_or_user_auth)) -> schemas.ArchiveResult: -# archive = crud.get_archive(db, id, email) -# if archive is None: -# raise HTTPException(status_code=404, detail="Archive not found") -# return archive - - @url_router.delete("/{id}", summary="Delete a single URL archive by id.") -def delete_task(id, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)) -> schemas.TaskDelete: - logger.info(f"deleting url archive task {id} request by {email}") +def delete_task( + id:str, + user: UserState = Depends(get_user_state), + db: Session = Depends(get_db_dependency) +) -> schemas.TaskDelete: + logger.info(f"deleting url archive task {id} request by {user.email}") return JSONResponse({ "id": id, - "deleted": crud.soft_delete_task(db, id, email) + "deleted": crud.soft_delete_task(db, id, user.email) }) diff --git a/src/shared/user_groups.py b/src/shared/user_groups.py index 12e4836..71a9216 100644 --- a/src/shared/user_groups.py +++ b/src/shared/user_groups.py @@ -32,6 +32,7 @@ class UserGroups: class GroupPermissions(BaseModel): read: Set[str] | bool = Field(default_factory=list) + read_public: bool = False archive_url: bool = False archive_sheet: bool = False sheet_frequency: Set[str] = Field(default_factory=list) @@ -49,8 +50,7 @@ class GroupPermissions(BaseModel): @field_validator('sheet_frequency', mode='before') def validate_sheet_frequency(cls, v): - if not v: - raise ValueError("sheet_frequency should have at least one value.") + if not v: return [] allowed = ["daily", "hourly"] for k in v: if k not in allowed: diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 58ce781..854bd20 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -79,7 +79,7 @@ def app_with_auth(app, db_session): from web.security import get_token_or_user_auth, get_user_auth, get_user_state app.dependency_overrides[get_token_or_user_auth] = lambda: "rick@example.com" app.dependency_overrides[get_user_auth] = lambda: "morty@example.com" - app.dependency_overrides[get_user_state] = lambda: UserState(db_session, "morty@example.com") + app.dependency_overrides[get_user_state] = lambda: UserState(db_session, "MORTY@example.com") return app diff --git a/src/tests/db/test_crud.py b/src/tests/db/test_crud.py index 59b76d8..9517bd2 100644 --- a/src/tests/db/test_crud.py +++ b/src/tests/db/test_crud.py @@ -145,7 +145,6 @@ def test_search_archives_by_email(test_data, db_session): # lower/upper case assert len(crud.search_archives_by_email(db_session, "rick@example.com")) == 34 - assert len(crud.search_archives_by_email(db_session, "RICK@example.com")) == 34 # ALLOW_ANY_EMAIL is not a user assert len(crud.search_archives_by_email(db_session, ALLOW_ANY_EMAIL)) == 0 @@ -314,7 +313,7 @@ def test_is_user_in_group(test_data, db_session): ("rick@example.com", "spaceship", True), ("rick@example.com", "SPACESHIP", False), - ("RICK@example.com", "interdimensional", True), + ("rick@example.com", "interdimensional", True), ("rick@example.com", "animated-characters", True), ("rick@example.com", "the-jerrys-club", False), @@ -329,14 +328,14 @@ def test_is_user_in_group(test_data, db_session): ("rick@example.com", "animated-characters", True), ("morty@example.com", "animated-characters", True), ("jerry@example.com", "animated-characters", True), - ("ANYONE@example.com", "animated-characters", True), - ("ANYONE@birdy.com", "animated-characters", True), + ("anyone@example.com", "animated-characters", True), + ("anyone@birdy.com", "animated-characters", True), ("summer@herself.com", "animated-characters", False), ("rick@example.com", "", False), ("", "spaceship", False), - ("BADEMAILexample.com", "spaceship", False), + ("bademailexample.com", "spaceship", False), ] for email, group, expected in test_pairs: print(f"{email} in {group} == {expected}") diff --git a/src/tests/endpoints/test_url.py b/src/tests/endpoints/test_url.py index b23b07e..282769d 100644 --- a/src/tests/endpoints/test_url.py +++ b/src/tests/endpoints/test_url.py @@ -83,62 +83,6 @@ def test_search_by_url(client_with_auth, db_session): assert len(response.json()) == 10 -def test_latest_unauthenticated(client, test_no_auth): - test_no_auth(client.get, "/url/latest") - - -def test_latest(client_with_auth, db_session): - response = client_with_auth.get("/url/latest") - assert response.status_code == 200 - assert response.json() == [] - - from db import crud, schemas - for i in range(11): - crud.create_task(db_session, ArchiveCreate(id=f"latest-456-{i}", url="https://example.com", result={}, public=True, author_id="morty@example.com" if i < 10 else "rick@example.com", group_id=None), [], []) - #NB: this insertion is too fast for the ordering to be correct as they are within the same second - - # user must exist for /latest to work - crud.create_or_get_user(db_session, "morty@example.com") - - response = client_with_auth.get("/url/latest") - assert response.status_code == 200 - assert len(j := response.json()) == 10 - assert "latest-456-0" in [i["id"] for i in j] - assert "latest-456-9" in [i["id"] for i in j] - assert "latest-456-10" not in [i["id"] for i in j] - assert j[0].keys() == schemas.ArchiveResult.model_fields.keys() - - response = client_with_auth.get("/url/latest?limit=5") - assert response.status_code == 200 - assert len(response.json()) == 5 - - response = client_with_auth.get("/url/latest?skip=5&limit=2") - assert response.status_code == 200 - assert len(response.json()) == 2 - - -# # TODO: find out where/if this is used, tests are also disabled - -# def test_lookup_unauthenticated(client, test_no_auth): -# test_no_auth(client.get, "/url/123-456-789") - -# def test_lookup(client_with_auth, db_session): -# response = client_with_auth.get("/url/lookup-123-456-789") -# assert response.status_code == 404 -# assert response.json() == {"detail": "Archive not found"} - -# from db import crud, schemas -# crud.create_task(db_session, ArchiveCreate(id="lookup-123-456-789", url="https://example.com", result={}, public=True, author_id="rick@example.com", group_id=None), [], []) - -# response = client_with_auth.get("/url/lookup-123-456-789") -# assert response.status_code == 200 -# j = response.json() -# assert j.keys() == schemas.ArchiveResult.model_fields.keys() -# assert j["id"] == "lookup-123-456-789" -# assert j["url"] == "https://example.com" -# assert j["result"] == {} - - def test_delete_task_unauthenticated(client, test_no_auth): test_no_auth(client.delete, "/url/123-456-789") diff --git a/src/tests/user-groups.test.yaml b/src/tests/user-groups.test.yaml index b2abf43..4e33cbd 100644 --- a/src/tests/user-groups.test.yaml +++ b/src/tests/user-groups.test.yaml @@ -75,10 +75,10 @@ groups: permissions: read: [] archive_url: true - archive_sheet: true - sheet_frequency: ["daily"] - max_sheets: 1 + archive_sheet: false + sheet_frequency: [] + max_sheets: 0 max_archive_lifespan_months: 12 - max_monthly_urls: 1 - max_monthly_mbs: 1 + max_monthly_urls: 10 + max_monthly_mbs: 50 priority: "low" \ No newline at end of file diff --git a/src/tests/web/test_main.py b/src/tests/web/test_main.py index e880311..7e3b77e 100644 --- a/src/tests/web/test_main.py +++ b/src/tests/web/test_main.py @@ -17,12 +17,12 @@ def test_alembic(db_session): alembic.config.main(argv=['--raiseerr', 'upgrade', 'head']) alembic.config.main(argv=['--raiseerr', 'downgrade', 'base']) -@patch("endpoints.default.crud.get_user_groups", side_effect=Exception('mocked error')) +@patch("endpoints.default.crud.soft_delete_task", side_effect=Exception('mocked error')) def test_logging_middleware(m1, client_with_auth): from utils.metrics import EXCEPTION_COUNTER assert len(EXCEPTION_COUNTER.collect()[0].samples) == 0 with pytest.raises(Exception, match="mocked error"): - client_with_auth.get("/groups") + client_with_auth.delete("/url/123") # creates one empty and one from above assert len(EXCEPTION_COUNTER.collect()[0].samples) == 2 diff --git a/src/web/security.py b/src/web/security.py index 772fcbd..85ceae4 100644 --- a/src/web/security.py +++ b/src/web/security.py @@ -48,7 +48,7 @@ async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bear # validates the Bearer token in the case that it requires it valid_user, info = authenticate_user(credentials.credentials) if valid_user: - return info + return info.lower() logger.debug(f"TOKEN FAILURE: {valid_user=} {info=}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED,