From b50ca91d8966153c3b110c70642b92dd9e4e6973 Mon Sep 17 00:00:00 2001 From: Michael Plunkett <5885605+michplunkett@users.noreply.github.com> Date: Mon, 10 Mar 2025 12:45:19 -0500 Subject: [PATCH] Format and lint `web` directory (#67) --- .pre-commit-config.yaml | 30 +-- app/migrations/env.py | 4 +- app/shared/db/models.py | 2 +- app/tests/shared/db/test_worker_crud.py | 24 +-- app/tests/web/db/test_crud.py | 6 +- app/tests/web/test_main.py | 2 +- app/web/config.py | 5 +- app/web/db/crud.py | 256 +++++++++++++++++----- app/web/db/user_state.py | 272 ++++++++++++++++-------- app/web/endpoints/default.py | 34 ++- app/web/endpoints/interoperability.py | 39 +++- app/web/endpoints/sheet.py | 115 +++++++--- app/web/endpoints/task.py | 32 +-- app/web/endpoints/url.py | 98 ++++++--- app/web/main.py | 41 +++- app/web/middleware.py | 18 +- app/web/security.py | 51 +++-- app/web/utils/metrics.py | 34 +-- app/web/utils/misc.py | 4 +- pyproject.toml | 3 + 20 files changed, 761 insertions(+), 309 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index edb6bcf..6707b21 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -61,18 +61,18 @@ repos: - --profile=black - --line-length=80 -# - repo: https://github.com/astral-sh/ruff-pre-commit -# rev: v0.9.7 -# hooks: -# - id: ruff -# types_or: [python,pyi] -# args: -# - --fix -# - --select=B,C,E,F,W,B9 -# - --line-length=80 -# - --ignore=E203,E402,E501,E261 -# - id: ruff-format -# types_or: [ python,pyi] -# args: -# - --target-version=py310 -# - --line-length=80 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.9.7 + hooks: + - id: ruff + types_or: [python,pyi] + args: + - --fix + - --select=B,C,E,F,W,B9 + - --line-length=80 + - --ignore=E203,E402,E501,E261 + - id: ruff-format + types_or: [ python,pyi] + args: + - --target-version=py310 + - --line-length=80 diff --git a/app/migrations/env.py b/app/migrations/env.py index 54c1f48..fc63a92 100644 --- a/app/migrations/env.py +++ b/app/migrations/env.py @@ -14,9 +14,7 @@ config.set_main_option("sqlalchemy.url", get_settings().DATABASE_PATH) # This line sets up loggers basically. if config.config_file_name is not None: # disable_existing_loggers prevents loguru disabling - fileConfig( - config.config_file_name, disable_existing_loggers=False - ) + fileConfig(config.config_file_name, disable_existing_loggers=False) # add your model's MetaData object here # for 'autogenerate' support diff --git a/app/shared/db/models.py b/app/shared/db/models.py index ca58506..8acedc1 100644 --- a/app/shared/db/models.py +++ b/app/shared/db/models.py @@ -20,7 +20,7 @@ def generate_uuid(): return str(uuid.uuid4()) -# many to many association tables +# many-to-many association tables association_table_archive_tags = Table( "mtm_archives_tags", Base.metadata, diff --git a/app/tests/shared/db/test_worker_crud.py b/app/tests/shared/db/test_worker_crud.py index 2258781..4e5a434 100644 --- a/app/tests/shared/db/test_worker_crud.py +++ b/app/tests/shared/db/test_worker_crud.py @@ -14,8 +14,8 @@ def test_update_sheet_last_url_archived_at(db_session): assert isinstance(test_sheet.last_url_archived_at, datetime) before = test_sheet.last_url_archived_at assert ( - worker_crud.update_sheet_last_url_archived_at(db_session, "sheet-123") - is True + worker_crud.update_sheet_last_url_archived_at(db_session, "sheet-123") + is True ) db_session.refresh(test_sheet) assert isinstance(test_sheet.last_url_archived_at, datetime) @@ -23,10 +23,10 @@ def test_update_sheet_last_url_archived_at(db_session): # Test non-existent sheet assert ( - worker_crud.update_sheet_last_url_archived_at( - db_session, "non-existent-sheet" - ) - is False + worker_crud.update_sheet_last_url_archived_at( + db_session, "non-existent-sheet" + ) + is False ) @@ -42,14 +42,14 @@ def test_create_or_get_user(test_data, db_session): # already exists assert ( - u1 := worker_crud.create_or_get_user(db_session, "rick@example.com") - ) is not None + u1 := worker_crud.create_or_get_user(db_session, "rick@example.com") + ) is not None assert u1.email == "rick@example.com" # new user assert ( - u2 := worker_crud.create_or_get_user(db_session, "beth@example.com") - ) is not None + u2 := worker_crud.create_or_get_user(db_session, "beth@example.com") + ) is not None assert u2.email == "beth@example.com" assert db_session.query(models.User).count() == 4 @@ -64,8 +64,8 @@ def test_create_tag(db_session): assert create_tag.id == "tag-101" assert db_session.query(models.Tag).count() == 1 assert ( - db_session.query(models.Tag).filter(models.Tag.id == "tag-101").first() - == create_tag + db_session.query(models.Tag).filter(models.Tag.id == "tag-101").first() + == create_tag ) # same id does not add new db entry diff --git a/app/tests/web/db/test_crud.py b/app/tests/web/db/test_crud.py index 676aa24..7b569b6 100644 --- a/app/tests/web/db/test_crud.py +++ b/app/tests/web/db/test_crud.py @@ -712,7 +712,11 @@ async def test_find_by_store_until(async_db_session): @pytest.mark.asyncio async def test_get_sheets_by_id_hash(async_db_session): - author_emails = ["rick@example.com", "morty@example.com", "jerry@example.com"] + author_emails = [ + "rick@example.com", + "morty@example.com", + "jerry@example.com", + ] # Add test data sheets = [ diff --git a/app/tests/web/test_main.py b/app/tests/web/test_main.py index eb985a8..1b6e86b 100644 --- a/app/tests/web/test_main.py +++ b/app/tests/web/test_main.py @@ -7,6 +7,7 @@ import alembic.config import pytest from fastapi.testclient import TestClient +from app.web.main import app_factory from app.web.utils.metrics import EXCEPTION_COUNTER @@ -59,7 +60,6 @@ def test_serve_local_archive_logic(get_settings): try: # modify the settings get_settings.SERVE_LOCAL_ARCHIVE = "/app/local_archive_test" - from app.web.main import app_factory app = app_factory(get_settings) diff --git a/app/web/config.py b/app/web/config.py index b795d88..359a9a8 100644 --- a/app/web/config.py +++ b/app/web/config.py @@ -8,7 +8,10 @@ API_DESCRIPTION = """ - You can use this API to archive single URLs or entire Google Sheets. - Once you submit a URL or Sheet for archiving, the API will return a task_id that you can use to check the status of the archiving process. It works asynchronously. """ -BREAKING_CHANGES = {"minVersion": "0.4.0", "message": "The latest update has breaking changes, please update the extension to the most recent version."} +BREAKING_CHANGES = { + "minVersion": "0.4.0", + "message": "The latest update has breaking changes, please update the extension to the most recent version.", +} # changing this will corrupt the database logic ALLOW_ANY_EMAIL = "*" diff --git a/app/web/db/crud.py b/app/web/db/crud.py index b33faa2..308c526 100644 --- a/app/web/db/crud.py +++ b/app/web/db/crud.py @@ -1,15 +1,26 @@ from collections import defaultdict from datetime import datetime, timedelta -from functools import lru_cache +from typing import Any, Type from cachetools import LRUCache, cached from cachetools.keys import hashkey from loguru import logger -from sqlalchemy import Column, func, or_, select +from sqlalchemy import ( + Column, + ColumnElement, + ScalarResult, + false, + func, + not_, + or_, + select, + true, +) from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session, load_only from app.shared.db import models +from app.shared.db.models import Archive, Group from app.shared.settings import get_settings from app.shared.user_groups import UserGroups from app.shared.utils.misc import fnv1a_hash_mod @@ -23,24 +34,48 @@ DATABASE_QUERY_LIMIT = get_settings().DATABASE_QUERY_LIMIT def get_limit(user_limit: int): return max(1, min(user_limit, DATABASE_QUERY_LIMIT)) + # --------------- TASK = Archive def base_query(db: Session): - # NOTE: load_only is for optimization and not obfuscation, use .with_entities() if needed - return db.query(models.Archive)\ - .filter(models.Archive.deleted == False)\ - .options(load_only(models.Archive.id, models.Archive.created_at, models.Archive.url, models.Archive.result, models.Archive.store_until)) + # NOTE: load_only is for optimization and not obfuscation, use + # .with_entities() if needed + return ( + db.query(models.Archive) + .filter(not_(models.Archive.deleted)) + .options( + load_only( + models.Archive.id, + models.Archive.created_at, + models.Archive.url, + models.Archive.result, + models.Archive.store_until, + ) + ) + ) -def search_archives_by_url(db: Session, url: str, email: str, read_groups: bool | set[str], read_public: bool, skip: int = 0, limit: int = 100, archived_after: datetime = None, archived_before: datetime = None, absolute_search: bool = False) -> list[models.Archive]: - # searches for partial URLs, if email is * no ownership (or read/read_public) filtering happens +def search_archives_by_url( + db: Session, + url: str, + email: str, + read_groups: bool | set[str], + read_public: bool, + skip: int = 0, + limit: int = 100, + archived_after: datetime = None, + archived_before: datetime = None, + absolute_search: bool = False, +) -> list[Type[Archive]]: + # searches for partial URLs, if email is * no ownership + # (or read/read_public) filtering happens query = base_query(db) if email != ALLOW_ANY_EMAIL: or_filters = [models.Archive.author_id == email] if read_public: - or_filters.append(models.Archive.public == True) - if read_groups == True: + or_filters.append(models.Archive.public.is_(true())) + if read_groups is True: or_filters.append(models.Archive.group_id.isnot(None)) else: or_filters.append(models.Archive.group_id.in_(read_groups)) @@ -48,21 +83,43 @@ def search_archives_by_url(db: Session, url: str, email: str, read_groups: bool if absolute_search: query = query.filter(models.Archive.url == url) else: - query = query.filter(models.Archive.url.like(f'%{url}%')) + query = query.filter(models.Archive.url.like(f"%{url}%")) if archived_after: query = query.filter(models.Archive.created_at > archived_after) if archived_before: query = query.filter(models.Archive.created_at < archived_before) - return query.order_by(models.Archive.created_at.desc()).offset(skip).limit(get_limit(limit)).all() + return ( + query.order_by(models.Archive.created_at.desc()) + .offset(skip) + .limit(get_limit(limit)) + .all() + ) -def search_archives_by_email(db: Session, email: str, skip: int = 0, limit: int = 100): - return base_query(db).filter(models.Archive.author_id == email).order_by(models.Archive.created_at.desc()).offset(skip).limit(get_limit(limit)).all() +def search_archives_by_email( + db: Session, email: str, skip: int = 0, limit: int = 100 +): + return ( + base_query(db) + .filter(models.Archive.author_id == email) + .order_by(models.Archive.created_at.desc()) + .offset(skip) + .limit(get_limit(limit)) + .all() + ) def soft_delete_archive(db: Session, id: str, email: str) -> bool: # TODO: implement hard-delete with cronjob that deletes from S3 - db_archive = db.query(models.Archive).filter(models.Archive.id == id, models.Archive.author_id == email, models.Archive.deleted == False).first() + db_archive = ( + db.query(models.Archive) + .filter( + models.Archive.id == id, + models.Archive.author_id == email, + models.Archive.deleted.is_(false()), + ) + .first() + ) if db_archive: db_archive.deleted = True db.commit() @@ -83,22 +140,29 @@ def count_users(db: Session): def count_by_user_since(db: Session, seconds_delta: int = 15): time_threshold = datetime.now() - timedelta(seconds=seconds_delta) - return db.query(models.Archive.author_id, func.count().label('total'))\ - .filter(models.Archive.created_at >= time_threshold)\ - .group_by(models.Archive.author_id)\ - .order_by(func.count().desc())\ - .limit(500).all() + return ( + db.query(models.Archive.author_id, func.count().label("total")) + .filter(models.Archive.created_at >= time_threshold) + .group_by(models.Archive.author_id) + .order_by(func.count().desc()) + .limit(500) + .all() + ) -async def find_by_store_until(db: AsyncSession, store_until_is_before: datetime) -> list[models.Archive]: +async def find_by_store_until( + db: AsyncSession, store_until_is_before: datetime +) -> ScalarResult[Archive]: res = await db.execute( - select(models.Archive) - .filter(models.Archive.deleted == False, models.Archive.store_until < store_until_is_before) + select(models.Archive).filter( + models.Archive.deleted.is_(false()), + models.Archive.store_until < store_until_is_before, + ) ) return res.scalars() -async def soft_delete_expired_archives(db: AsyncSession) -> dict: +async def soft_delete_expired_archives(db: AsyncSession) -> int: to_delete = await find_by_store_until(db, datetime.now()) counter = 0 for archive in to_delete: @@ -106,47 +170,86 @@ async def soft_delete_expired_archives(db: AsyncSession) -> dict: counter += 1 await db.commit() return counter + + # --------------- TAG async def get_group_priority_async(db: AsyncSession, group_id: str) -> dict: db_group = await db.get(models.Group, group_id) - priority = db_group.permissions.get("priority", "low") if db_group else "low" + priority = ( + db_group.permissions.get("priority", "low") if db_group else "low" + ) return convert_priority_to_queue_dict(priority) @cached(cache=LRUCache(maxsize=128), key=lambda db, email: hashkey(email)) -def get_user_group_names(db: Session, email: str) -> list[str]: +def get_user_group_names( + db: Session, email: str +) -> list[Any] | list[ColumnElement[Any]]: """ - given an email retrieves the user groups from the DB and then the email-domain groups from a global variable, the email does not need to belong to an existing user. + given an email retrieves the user groups from the DB and then the + email-domain groups from a global variable, the email does not need to + belong to an existing user. """ # TODO: the read: [group1, group2] permissions don't currently work - if not email or not len(email) or "@" not in email: return [] + if not email or not len(email) or "@" not in email: + return [] # get user groups - user_groups = db.query(models.association_table_user_groups).filter_by(user_id=email).with_entities(Column("group_id")).all() + user_groups = ( + db.query(models.association_table_user_groups) + .filter_by(user_id=email) + .with_entities(Column("group_id")) + .all() + ) user_level_groups_names = [g[0] for g in user_groups] # get domain groups - domain = email.split('@')[1] - domain_level_groups = db.query(models.Group.id).filter(models.Group.domains.contains(domain)).with_entities(Column("id")).all() + domain = email.split("@")[1] + domain_level_groups = ( + db.query(models.Group.id) + .filter(models.Group.domains.contains(domain)) + .with_entities(Column("id")) + .all() + ) domain_level_groups_names = [g[0] for g in domain_level_groups] return list(set(user_level_groups_names + domain_level_groups_names)) -def get_user_groups_by_name(db: Session, groups: list[str]) -> list[models.Group]: - return db.query(models.Group).filter( - models.Group.id.in_(groups) - ).all() +def get_user_groups_by_name( + db: Session, groups: list[str] +) -> list[Type[Group]]: + return db.query(models.Group).filter(models.Group.id.in_(groups)).all() + # --------------- INIT User-Groups -def upsert_group(db: Session, group_name: str, description: str, orchestrator: str, orchestrator_sheet: str, service_account_email: str, permissions: dict, domains: list) -> models.Group: - db_group = db.query(models.Group).filter(models.Group.id == group_name).first() +def upsert_group( + db: Session, + group_name: str, + description: str, + orchestrator: str, + orchestrator_sheet: str, + service_account_email: str, + permissions: dict, + domains: list, +) -> models.Group: + db_group = ( + db.query(models.Group).filter(models.Group.id == group_name).first() + ) if db_group is None: - db_group = models.Group(id=group_name, description=description, orchestrator=orchestrator, orchestrator_sheet=orchestrator_sheet, service_account_email=service_account_email, permissions=permissions, domains=domains) + db_group = models.Group( + id=group_name, + description=description, + orchestrator=orchestrator, + orchestrator_sheet=orchestrator_sheet, + service_account_email=service_account_email, + permissions=permissions, + domains=domains, + ) db.add(db_group) else: db_group.description = description @@ -173,6 +276,7 @@ def upsert_user(db: Session, email: str): def upsert_user_groups(db: Session): def display_email_pii(email: str): return f"'{email[0:3]}...@{email.split('@')[1]}'" + """ reads the user_groups yaml file and inserts any new users, groups, along with new participation of users in groups @@ -196,16 +300,30 @@ def upsert_user_groups(db: Session): # upsert groups and save a map of groupid -> dbobject for group_id, g in ug.groups.items(): - upsert_group(db, group_id, g.description, g.orchestrator, g.orchestrator_sheet, g.service_account_email, json.loads(g.permissions.model_dump_json()), list(group_domains.get(group_id, []))) - db_groups: dict[str, models.Group] = {g.id: g for g in db.query(models.Group).all()} + upsert_group( + db, + group_id, + g.description, + g.orchestrator, + g.orchestrator_sheet, + g.service_account_email, + json.loads(g.permissions.model_dump_json()), + list(group_domains.get(group_id, [])), + ) + db_groups: dict[str, models.Group] = { + g.id: g for g in db.query(models.Group).all() + } # integrity checks for group_in_domains in group_domains: if group_in_domains not in db_groups: - logger.warning(f"[CONFIG] Group '{group_in_domains}' does not exist in the database: domains setting will not work.") + logger.warning( + f"[CONFIG] Group '{group_in_domains}' does not exist in the database: domains setting will not work." + ) # reinsert users in their EXPLICITLY DEFINED groups - # domain groups are check live, as there may be new users that are not explicitly registered but belong to a domain + # domain groups are check live, as there may be new users that are not + # explicitly registered but belong to a domain for email, explicit_groups in ug.users.items(): explicit_groups = explicit_groups or [] logger.info(f"EXPLICIT {display_email_pii(email)} => {explicit_groups}") @@ -215,7 +333,9 @@ def upsert_user_groups(db: Session): # connect users to groups for group_id in explicit_groups: if group_id not in db_groups: - logger.warning(f"[CONFIG] Group {group_id} does not exist in config file, skipping for email={display_email_pii(email)}.") + logger.warning( + f"[CONFIG] Group {group_id} does not exist in config file, skipping for email={display_email_pii(email)}." + ) continue db_groups[group_id].users.append(db_user) @@ -223,12 +343,27 @@ def upsert_user_groups(db: Session): count_user_groups = db.query(models.association_table_user_groups).count() count_groups = db.query(func.count(models.Group.id)).scalar() - logger.success(f"[CONFIG] DONE: [users={count_users(db)}, groups={count_groups}, explicit user groups={count_user_groups}].") + logger.success( + f"[CONFIG] DONE: [users={count_users(db)}, groups={count_groups}, explicit user groups={count_user_groups}]." + ) # --------------- SHEET -def create_sheet(db: Session, sheet_id: str, name: str, email: str, group_id: str, frequency: str): - db_sheet = models.Sheet(id=sheet_id, name=name, author_id=email, group_id=group_id, frequency=frequency) +def create_sheet( + db: Session, + sheet_id: str, + name: str, + email: str, + group_id: str, + frequency: str, +): + db_sheet = models.Sheet( + id=sheet_id, + name=name, + author_id=email, + group_id=group_id, + frequency=frequency, + ) db.add(db_sheet) db.commit() db.refresh(db_sheet) @@ -236,20 +371,31 @@ def create_sheet(db: Session, sheet_id: str, name: str, email: str, group_id: st def get_user_sheet(db: Session, email: str, sheet_id: str) -> models.Sheet: - return db.query(models.Sheet).filter(models.Sheet.author_id == email, models.Sheet.id == sheet_id).first() + return ( + db.query(models.Sheet) + .filter(models.Sheet.author_id == email, models.Sheet.id == sheet_id) + .first() + ) def get_user_sheets(db: Session, email: str) -> list[models.Sheet]: - return db.query(models.Sheet).filter(models.Sheet.author_id == email).order_by(models.Sheet.last_url_archived_at.desc()).all() + return ( + db.query(models.Sheet) + .filter(models.Sheet.author_id == email) + .order_by(models.Sheet.last_url_archived_at.desc()) + .all() + ) -async def get_sheets_by_id_hash(db: AsyncSession, frequency: str, modulo: str, id_hash: int) -> list[models.Sheet]: +async def get_sheets_by_id_hash( + db: AsyncSession, frequency: str, modulo: str, id_hash: int +) -> list[models.Sheet]: result = await db.execute( select(models.Sheet).filter(models.Sheet.frequency == frequency) ) filtered = [] for sheet in result.scalars(): - if fnv1a_hash_mod(sheet.id, modulo) == id_hash: + if fnv1a_hash_mod(sheet.id, int(modulo)) == id_hash: filtered.append(sheet) return filtered @@ -257,7 +403,9 @@ async def get_sheets_by_id_hash(db: AsyncSession, frequency: str, modulo: str, i async def delete_stale_sheets(db: AsyncSession, inactivity_days: int) -> dict: time_threshold = datetime.now() - timedelta(days=inactivity_days) result = await db.execute( - select(models.Sheet).filter(models.Sheet.last_url_archived_at < time_threshold) + select(models.Sheet).filter( + models.Sheet.last_url_archived_at < time_threshold + ) ) deleted = defaultdict(list) for sheet in result.scalars(): @@ -268,7 +416,11 @@ async def delete_stale_sheets(db: AsyncSession, inactivity_days: int) -> dict: def delete_sheet(db: Session, sheet_id: str, email: str) -> bool: - db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id, models.Sheet.author_id == email).first() + db_sheet = ( + db.query(models.Sheet) + .filter(models.Sheet.id == sheet_id, models.Sheet.author_id == email) + .first() + ) if db_sheet: db.delete(db_sheet) db.commit() diff --git a/app/web/db/user_state.py b/app/web/db/user_state.py index 384b0b6..67160db 100644 --- a/app/web/db/user_state.py +++ b/app/web/db/user_state.py @@ -1,4 +1,3 @@ - from datetime import datetime from typing import Dict, Set @@ -21,14 +20,15 @@ class UserState: def __init__(self, db: Session, email: str): self.db = db self.email = email.lower() + self._permissions = {} @property def permissions(self) -> Dict[str, GroupInfo]: """ - Returns a dict of all group permissions and a special {"all": read/archive_url/archive_sheet} key + Returns a dict of all group permissions and a special + {"all": read/archive_url/archive_sheet} key """ - if not hasattr(self, '_permissions'): - self._permissions = {} + if not self._permissions: self._permissions["all"] = GroupInfo( read=self.read, read_public=self.read_public, @@ -38,23 +38,33 @@ class UserState: max_archive_lifespan_months=self.max_archive_lifespan_months, max_monthly_urls=self.max_monthly_urls, max_monthly_mbs=self.max_monthly_mbs, - priority=self.priority + priority=self.priority, ) for group in self.user_groups: - if not group.permissions: continue - self._permissions[group.id] = GroupInfo(**group.permissions, description=group.description, service_account_email=group.service_account_email) + if not group.permissions: + continue + self._permissions[group.id] = GroupInfo( + **group.permissions, + description=group.description, + service_account_email=group.service_account_email, + ) return self._permissions @property def user_groups_names(self): - if not hasattr(self, '_user_groups_names'): - self._user_groups_names = crud.get_user_group_names(self.db, self.email) + ["default"] + if not hasattr(self, "_user_groups_names"): + # TODO: Define hidden properties in __init__ method + self._user_groups_names = crud.get_user_group_names( + self.db, self.email + ) + ["default"] return self._user_groups_names @property def user_groups(self): - if not hasattr(self, '_user_groups'): - self._user_groups = crud.get_user_groups_by_name(self.db, self.user_groups_names) + if not hasattr(self, "_user_groups"): + self._user_groups = crud.get_user_groups_by_name( + self.db, self.user_groups_names + ) return self._user_groups @property @@ -62,10 +72,11 @@ class UserState: """ Read can be a list of group names or True, if all can be read. """ - if not hasattr(self, '_read'): + if not hasattr(self, "_read"): self._read = set() for group in self.user_groups: - if not group.permissions: continue + if not group.permissions: + continue group_read_permissions = group.permissions.get("read", []) if "all" in group_read_permissions: self._read = True @@ -79,10 +90,11 @@ class UserState: """ Read public permission """ - if not hasattr(self, '_read_public'): + if not hasattr(self, "_read_public"): self._read_public = False for group in self.user_groups: - if not group.permissions: continue + if not group.permissions: + continue if group.permissions.get("read_public", False): self._read_public = True return self._read_public @@ -93,10 +105,11 @@ class UserState: """ Archive URL permission """ - if not hasattr(self, '_archive_url'): + if not hasattr(self, "_archive_url"): self._archive_url = False for group in self.user_groups: - if not group.permissions: continue + if not group.permissions: + continue if group.permissions.get("archive_url", False): self._archive_url = True return self._archive_url @@ -107,10 +120,11 @@ class UserState: """ Archive sheet permission """ - if not hasattr(self, '_archive_sheet'): + if not hasattr(self, "_archive_sheet"): self._archive_sheet = False for group in self.user_groups: - if not group.permissions: continue + if not group.permissions: + continue if group.permissions.get("archive_sheet", False): self._archive_sheet = True return self._archive_sheet @@ -118,37 +132,53 @@ class UserState: @property def sheet_frequency(self): - if not hasattr(self, '_sheet_frequency'): + if not hasattr(self, "_sheet_frequency"): self._sheet_frequency = set() for group in self.user_groups: - if not group.permissions: continue - self._sheet_frequency.update(group.permissions.get("sheet_frequency", None)) + if not group.permissions: + continue + self._sheet_frequency.update( + group.permissions.get("sheet_frequency", None) + ) return self._sheet_frequency @property def max_archive_lifespan_months(self) -> int: - if not hasattr(self, '_max_archive_lifespan_months'): - self._max_archive_lifespan_months = self._helper_for_grouping_max_numerical_permissions("max_archive_lifespan_months") + if not hasattr(self, "_max_archive_lifespan_months"): + self._max_archive_lifespan_months = ( + self._helper_for_grouping_max_numerical_permissions( + "max_archive_lifespan_months" + ) + ) return self._max_archive_lifespan_months @property def max_monthly_urls(self) -> int: - if not hasattr(self, '_max_monthly_urls'): - self._max_monthly_urls = self._helper_for_grouping_max_numerical_permissions("max_monthly_urls") + if not hasattr(self, "_max_monthly_urls"): + self._max_monthly_urls = ( + self._helper_for_grouping_max_numerical_permissions( + "max_monthly_urls" + ) + ) return self._max_monthly_urls @property def max_monthly_mbs(self) -> int: - if not hasattr(self, '_max_monthly_mbs'): - self._max_monthly_mbs = self._helper_for_grouping_max_numerical_permissions("max_monthly_mbs") + if not hasattr(self, "_max_monthly_mbs"): + self._max_monthly_mbs = ( + self._helper_for_grouping_max_numerical_permissions( + "max_monthly_mbs" + ) + ) return self._max_monthly_mbs @property def priority(self) -> str: - if not hasattr(self, '_priority'): + if not hasattr(self, "_priority"): self._priority = "low" for group in self.user_groups: - if not group.permissions: continue + if not group.permissions: + continue if group.permissions.get("priority", self._priority) == "high": self._priority = "high" break @@ -159,18 +189,28 @@ class UserState: """ A user is active if they can read/archive anything """ - if not hasattr(self, '_active'): - self._active = bool(self.read or self.read_public or self.archive_url or self.archive_sheet) + if not hasattr(self, "_active"): + self._active = bool( + self.read + or self.read_public + or self.archive_url + or self.archive_sheet + ) return self._active - def _helper_for_grouping_max_numerical_permissions(self, permission_name: str) -> int: + def _helper_for_grouping_max_numerical_permissions( + self, permission_name: str + ) -> int: """ - Iterates one of the numerical permissions where -1 means no restrictions and returns either -1 or the maximum value, defaults according to GroupPermissions + Iterates one of the numerical permissions where -1 means no restrictions + and returns either -1 or the maximum value, defaults according to + GroupPermissions """ default = GroupPermissions.model_fields[permission_name].default max_value = default for group in self.user_groups: - if not group.permissions: continue + if not group.permissions: + continue group_value = group.permissions.get(permission_name, default) if group_value == -1: max_value = -1 @@ -181,43 +221,65 @@ class UserState: def in_group(self, group_id: str) -> bool: return group_id in self.user_groups_names - def usage(self) -> Dict: + def usage(self) -> UsageResponse: """ - returns the monthly quotas for the URLs/MBs and the totals for Sheets + Returns the monthly quotas for the URLs/MBs and the totals for Sheets """ current_month = datetime.now().month current_year = datetime.now().year # find and sum all user sheets over this month - user_sheets = self.db.query( - models.Sheet.group_id, - func.count(models.Sheet.id).label('sheet_count') - ).filter(models.Sheet.author_id == self.email).group_by(models.Sheet.group_id).all() + user_sheets = ( + self.db.query( + models.Sheet.group_id, + func.count(models.Sheet.id).label("sheet_count"), + ) + .filter(models.Sheet.author_id == self.email) + .group_by(models.Sheet.group_id) + .all() + ) - sheets_by_group = {sheet.group_id: sheet.sheet_count for sheet in user_sheets} + sheets_by_group = { + sheet.group_id: sheet.sheet_count for sheet in user_sheets + } # find and sum all user urls over this month - urls_by_group = self.db.query( - models.Archive.group_id, - func.count(models.Archive.id).label('url_count'), - func.coalesce(func.sum( + urls_by_group = ( + self.db.query( + models.Archive.group_id, + func.count(models.Archive.id).label("url_count"), func.coalesce( - func.cast( - func.json_extract(models.Archive.result, '$.metadata.total_bytes'), - sqlalchemy.Integer - ), 0 - ) - ), 0).label('total_bytes') - ).filter( - models.Archive.author_id == self.email, - func.extract('month', models.Archive.created_at) == current_month, - func.extract('year', models.Archive.created_at) == current_year - ).group_by(models.Archive.group_id).all() + func.sum( + func.coalesce( + func.cast( + func.json_extract( + models.Archive.result, + "$.metadata.total_bytes", + ), + sqlalchemy.Integer, + ), + 0, + ) + ), + 0, + ).label("total_bytes"), + ) + .filter( + models.Archive.author_id == self.email, + func.extract("month", models.Archive.created_at) + == current_month, + func.extract("year", models.Archive.created_at) == current_year, + ) + .group_by(models.Archive.group_id) + .all() + ) # merge the two queries usage_by_group: Dict[str, Usage] = { - (url.group_id or ""): - Usage(monthly_urls=url.url_count, monthly_mbs=int(url.total_bytes / 1024 / 1024)) + (url.group_id or ""): Usage( + monthly_urls=url.url_count, + monthly_mbs=int(url.total_bytes / 1024 / 1024), + ) for url in urls_by_group } for group_id, sheet_count in sheets_by_group.items(): @@ -236,7 +298,7 @@ class UserState: monthly_urls=total_urls, monthly_mbs=int(total_bytes / 1024 / 1024), total_sheets=total_sheets, - groups=usage_by_group + groups=usage_by_group, ) def has_quota_monthly_sheets(self, group_id: str) -> bool: @@ -246,7 +308,14 @@ class UserState: if group_id not in self.permissions: return False - user_sheets = self.db.query(models.Sheet).filter(models.Sheet.author_id == self.email, models.Sheet.group_id == group_id).count() + user_sheets = ( + self.db.query(models.Sheet) + .filter( + models.Sheet.author_id == self.email, + models.Sheet.group_id == group_id, + ) + .count() + ) sheet_quota = self.permissions[group_id].max_sheets if sheet_quota == -1: @@ -255,13 +324,15 @@ class UserState: def has_quota_max_monthly_urls(self, group_id: str) -> bool: """ - checks if a user has reached their monthly url quota for a group, if global then group should be empty string + Checks if a user has reached their monthly url quota for a group, if + global then group should be empty string """ quota = 0 if not group_id: quota = self.max_monthly_urls else: - if group_id not in self.permissions: return False + if group_id not in self.permissions: + return False quota = self.permissions[group_id].max_monthly_urls if quota == -1: @@ -269,24 +340,31 @@ class UserState: current_month = datetime.now().month current_year = datetime.now().year - user_urls = self.db.query(models.Archive).filter( - models.Archive.author_id == self.email, - models.Archive.group_id == group_id, - func.extract('month', models.Archive.created_at) == current_month, - func.extract('year', models.Archive.created_at) == current_year - ).count() + user_urls = ( + self.db.query(models.Archive) + .filter( + models.Archive.author_id == self.email, + models.Archive.group_id == group_id, + func.extract("month", models.Archive.created_at) + == current_month, + func.extract("year", models.Archive.created_at) == current_year, + ) + .count() + ) return user_urls < quota def has_quota_max_monthly_mbs(self, group_id: str) -> bool: """ - checks if a user has reached their monthly MBs quota for a group, if global then group should be empty string + Checks if a user has reached their monthly MBs quota for a group, if + global then group should be empty string """ quota = 0 if not group_id: quota = self.max_monthly_mbs else: - if group_id not in self.permissions: return False + if group_id not in self.permissions: + return False quota = self.permissions[group_id].max_monthly_mbs if quota == -1: @@ -296,19 +374,34 @@ class UserState: current_year = datetime.now().year # find and sum all user bytes over this month - user_bytes = self.db.query(models.Archive).filter( - models.Archive.author_id == self.email, - models.Archive.group_id == group_id, - func.extract('month', models.Archive.created_at) == current_month, - func.extract('year', models.Archive.created_at) == current_year - ).with_entities(func.coalesce(func.sum( - func.coalesce( - func.cast( - func.json_extract(models.Archive.result, '$.metadata.total_bytes'), - sqlalchemy.Integer - ), 0 + user_bytes = ( + self.db.query(models.Archive) + .filter( + models.Archive.author_id == self.email, + models.Archive.group_id == group_id, + func.extract("month", models.Archive.created_at) + == current_month, + func.extract("year", models.Archive.created_at) == current_year, ) - ), 0).label('total')).scalar() + .with_entities( + func.coalesce( + func.sum( + func.coalesce( + func.cast( + func.json_extract( + models.Archive.result, + "$.metadata.total_bytes", + ), + sqlalchemy.Integer, + ), + 0, + ) + ), + 0, + ).label("total") + ) + .scalar() + ) # convert bytes to mb user_mbs = int(user_bytes / 1024 / 1024) @@ -316,7 +409,7 @@ class UserState: def can_manually_trigger(self, group_id: str) -> bool: """ - checks if a user is allowed to manually trigger a sheet + Checks if a user is allowed to manually trigger a sheet """ if group_id not in self.permissions: return False @@ -325,18 +418,21 @@ class UserState: def is_sheet_frequency_allowed(self, group_id: str, frequency: str) -> bool: """ - checks if a user is allowed to create a sheet with this frequency for this group + Checks if a user is allowed to create a sheet with this frequency for + this group """ if group_id not in self.permissions: return False return frequency in self.permissions[group_id].sheet_frequency - def priority_group(self, group_id: str) -> str: + def priority_group(self, group_id: str) -> dict: priority = "low" for group in self.user_groups: - if group.id != group_id: continue - if not group.permissions: continue + if group.id != group_id: + continue + if not group.permissions: + continue priority = group.permissions.get("priority", priority) break return convert_priority_to_queue_dict(priority) diff --git a/app/web/endpoints/default.py b/app/web/endpoints/default.py index cd23d13..ce3edd2 100644 --- a/app/web/endpoints/default.py +++ b/app/web/endpoints/default.py @@ -1,4 +1,4 @@ - +from http import HTTPStatus from typing import Dict from fastapi import APIRouter, Depends, HTTPException @@ -15,38 +15,50 @@ default_router = APIRouter() @default_router.get("/") -async def home(): - return JSONResponse({"version": VERSION, "breakingChanges": BREAKING_CHANGES}) +async def home() -> JSONResponse: + return JSONResponse( + {"version": VERSION, "breakingChanges": BREAKING_CHANGES} + ) @default_router.get("/health") -async def health(): +async def health() -> JSONResponse: return JSONResponse({"status": "ok"}) -@default_router.get("/user/active", summary="Check if the user is active and can use the tool.") +@default_router.get( + "/user/active", summary="Check if the user is active and can use the tool." +) async def active( user: UserState = Depends(get_user_state), ) -> ActiveUser: - return {"active": user.active} + return ActiveUser(active=user.active) -@default_router.get("/user/permissions", summary="Get the user's global 'all' permissions and the permissions for each group they belong to.") +@default_router.get( + "/user/permissions", + summary="Get the user's global 'all' permissions and the permissions for each group they belong to.", +) def get_user_permissions( user: UserState = Depends(get_user_state), ) -> Dict[str, GroupInfo]: return user.permissions -@default_router.get("/user/usage", summary="Get the user's monthly URLs/MBs usage along with the total active sheets, breakdown by group.") + +@default_router.get( + "/user/usage", + summary="Get the user's monthly URLs/MBs usage along with the total active sheets, breakdown by group.", +) def get_user_usage( user: UserState = Depends(get_user_state), ) -> UsageResponse: if not user.active: - raise HTTPException(status_code=403, detail="User is not active.") + raise HTTPException( + status_code=HTTPStatus.FORBIDDEN, detail="User is not active." + ) return user.usage() - -@default_router.get('/favicon.ico', include_in_schema=False) +@default_router.get("/favicon.ico", include_in_schema=False) async def favicon() -> FileResponse: return FileResponse("app/web/static/favicon.ico") diff --git a/app/web/endpoints/interoperability.py b/app/web/endpoints/interoperability.py index 7892bde..86ea037 100644 --- a/app/web/endpoints/interoperability.py +++ b/app/web/endpoints/interoperability.py @@ -1,4 +1,5 @@ import json +from http import HTTPStatus import sqlalchemy from auto_archiver.core import Metadata @@ -16,26 +17,39 @@ from app.web.config import ALLOW_ANY_EMAIL from app.web.security import token_api_key_auth -interoperability_router = APIRouter(prefix="/interop", tags=["Interoperability endpoints."]) +interoperability_router = APIRouter( + prefix="/interop", tags=["Interoperability endpoints."] +) # ----- endpoint to submit data archived elsewhere -@interoperability_router.post("/submit-archive", status_code=201, summary="Submit a manual archive entry, for data that was archived elsewhere.") +@interoperability_router.post( + "/submit-archive", + status_code=HTTPStatus.CREATED, + summary="Submit a manual archive entry, for data that was archived elsewhere.", +) def submit_manual_archive( manual: schemas.SubmitManualArchive, auth=Depends(token_api_key_auth), - db: Session = Depends(get_db_dependency) + db: Session = Depends(get_db_dependency), ): try: result: Metadata = Metadata.from_json(manual.result) except json.JSONDecodeError as e: log_error(e) - raise HTTPException(status_code=422, detail="Invalid JSON in result field.") + raise HTTPException( + status_code=HTTPStatus.UNPROCESSABLE_ENTITY, + detail="Invalid JSON in result field.", + ) from e manual.author_id = manual.author_id or ALLOW_ANY_EMAIL manual.tags.add("manual") - store_until = business_logic.get_store_archive_until_or_never(db, manual.group_id) - logger.debug(f"[MANUAL ARCHIVE] {manual.author_id} {manual.url} {store_until}") + store_until = business_logic.get_store_archive_until_or_never( + db, manual.group_id + ) + logger.debug( + f"[MANUAL ARCHIVE] {manual.author_id} {manual.url} {store_until}" + ) try: archive = schemas.ArchiveCreate( @@ -51,8 +65,15 @@ def submit_manual_archive( ) db_archive = worker_crud.store_archived_url(db, archive) - logger.debug(f"[MANUAL ARCHIVE STORED] {db_archive.author_id} {db_archive.url}") - return JSONResponse({"id": db_archive.id}, status_code=201) + logger.debug( + f"[MANUAL ARCHIVE STORED] {db_archive.author_id} {db_archive.url}" + ) + return JSONResponse( + {"id": db_archive.id}, status_code=HTTPStatus.CREATED + ) except sqlalchemy.exc.IntegrityError as e: log_error(e) - raise HTTPException(status_code=422, detail=f"Cannot insert into DB due to integrity error, likely duplicate urls.") + raise HTTPException( + status_code=HTTPStatus.UNPROCESSABLE_ENTITY, + detail="Cannot insert into DB due to integrity error, likely duplicate urls.", + ) from e diff --git a/app/web/endpoints/sheet.py b/app/web/endpoints/sheet.py index d8c089a..5247107 100644 --- a/app/web/endpoints/sheet.py +++ b/app/web/endpoints/sheet.py @@ -1,81 +1,134 @@ +from http import HTTPStatus from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import JSONResponse from sqlalchemy import exc from sqlalchemy.orm import Session -from app.shared import schemas from app.shared.db.database import get_db_dependency +from app.shared.schemas import ( + DeleteResponse, + SheetAdd, + SheetResponse, + SubmitSheet, +) from app.shared.task_messaging import get_celery from app.web.db import crud from app.web.db.user_state import UserState from app.web.security import get_user_state -sheet_router = APIRouter(prefix="/sheet", tags=["Google Spreadsheet operations"]) +sheet_router = APIRouter( + prefix="/sheet", tags=["Google Spreadsheet operations"] +) celery = get_celery() -@sheet_router.post("/create", status_code=201, summary="Store a new Google Sheet for regular archiving.") + +@sheet_router.post( + "/create", + status_code=HTTPStatus.CREATED, + summary="Store a new Google Sheet for regular archiving.", +) def create_sheet( - sheet: schemas.SheetAdd, + sheet: SheetAdd, user: UserState = Depends(get_user_state), db: Session = Depends(get_db_dependency), -) -> schemas.SheetResponse: - +) -> SheetResponse: if not user.in_group(sheet.group_id): - raise HTTPException(status_code=403, detail="User does not have access to this group.") + raise HTTPException( + status_code=HTTPStatus.FORBIDDEN, + detail="User does not have access to this group.", + ) if not user.has_quota_monthly_sheets(sheet.group_id): - raise HTTPException(status_code=429, detail="User has reached their sheet quota for this group.") + raise HTTPException( + status_code=HTTPStatus.TOO_MANY_REQUESTS, + detail="User has reached their sheet quota for this group.", + ) if not user.is_sheet_frequency_allowed(sheet.group_id, sheet.frequency): - raise HTTPException(status_code=422, detail="Invalid frequency selected for this group.") + raise HTTPException( + status_code=HTTPStatus.UNPROCESSABLE_ENTITY, + detail="Invalid frequency selected for this group.", + ) try: - return crud.create_sheet(db, sheet.id, sheet.name, user.email, sheet.group_id, sheet.frequency) + return crud.create_sheet( + db, + sheet.id, + sheet.name, + user.email, + sheet.group_id, + sheet.frequency, + ) except exc.IntegrityError as e: - raise HTTPException(status_code=400, detail="Sheet with this ID is already being archived.") from e + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail="Sheet with this ID is already being archived.", + ) from e -@sheet_router.get("/mine", status_code=200, summary="Get the authenticated user's Google Sheets.") +@sheet_router.get( + "/mine", + status_code=HTTPStatus.OK, + summary="Get the authenticated user's Google Sheets.", +) def get_user_sheets( user: UserState = Depends(get_user_state), - db: Session = Depends(get_db_dependency) -) -> list[schemas.SheetResponse]: + db: Session = Depends(get_db_dependency), +) -> list[SheetResponse]: return crud.get_user_sheets(db, user.email) -@sheet_router.delete("/{id}", summary="Delete a Google Sheet by ID.") +@sheet_router.delete("/{sheet_id}", summary="Delete a Google Sheet by ID.") def delete_sheet( - id: str, + sheet_id: str, user: UserState = Depends(get_user_state), db: Session = Depends(get_db_dependency), -) -> schemas.DeleteResponse: - return JSONResponse({ - "id": id, - "deleted": crud.delete_sheet(db, id, user.email) - }) +) -> DeleteResponse: + return DeleteResponse( + id=sheet_id, deleted=crud.delete_sheet(db, sheet_id, user.email) + ) -@sheet_router.post("/{id}/archive", status_code=201, summary="Trigger an archiving task for a GSheet you own.", response_description="task_id for the archiving task.") +@sheet_router.post( + "/{sheet_id}/archive", + status_code=HTTPStatus.CREATED, + summary="Trigger an archiving task for a GSheet you own.", + response_description="task_id for the archiving task.", +) def archive_user_sheet( - id: str, + sheet_id: str, user: UserState = Depends(get_user_state), db: Session = Depends(get_db_dependency), -) -> schemas.Task: - - sheet = crud.get_user_sheet(db, user.email, sheet_id=id) +) -> JSONResponse: + sheet = crud.get_user_sheet(db, user.email, sheet_id=sheet_id) if not sheet: - raise HTTPException(status_code=403, detail="No access to this sheet.") + raise HTTPException( + status_code=HTTPStatus.FORBIDDEN, detail="No access to this sheet." + ) if not user.in_group(sheet.group_id): - raise HTTPException(status_code=403, detail="User does not have access to this group.") + raise HTTPException( + status_code=HTTPStatus.FORBIDDEN, + detail="User does not have access to this group.", + ) if not user.can_manually_trigger(sheet.group_id): - raise HTTPException(status_code=429, detail="User cannot manually trigger sheet archiving in this group.") + raise HTTPException( + status_code=HTTPStatus.TOO_MANY_REQUESTS, + detail="User cannot manually trigger sheet archiving in this group.", + ) group_queue = user.priority_group(sheet.group_id) - task = celery.signature("create_sheet_task", args=[schemas.SubmitSheet(sheet_id=id, author_id=user.email, group_id=sheet.group_id).model_dump_json()]).apply_async(**group_queue) + task = celery.signature( + "create_sheet_task", + args=[ + SubmitSheet( + sheet_id=sheet_id, author_id=user.email, group_id=sheet.group_id + ).model_dump_json() + ], + ).apply_async(**group_queue) - return JSONResponse({"id": task.id}, status_code=201) + return JSONResponse({"id": task.id}, status_code=HTTPStatus.CREATED) diff --git a/app/web/endpoints/task.py b/app/web/endpoints/task.py index 3f2ff94..e9da444 100644 --- a/app/web/endpoints/task.py +++ b/app/web/endpoints/task.py @@ -14,8 +14,14 @@ task_router = APIRouter(prefix="/task", tags=["Async task operations"]) celery = get_celery() -@task_router.get("/{task_id}", summary="Check the status of an async task by its id, works for URLs and Sheet tasks.") -def get_status(task_id, email=Depends(get_token_or_user_auth)) -> schemas.TaskResult: + +@task_router.get( + "/{task_id}", + summary="Check the status of an async task by its id, works for URLs and Sheet tasks.", +) +def get_status( + task_id, email=Depends(get_token_or_user_auth) +) -> schemas.TaskResult: task = AsyncResult(task_id, app=celery) try: if task.status == "FAILURE": @@ -24,17 +30,17 @@ def get_status(task_id, email=Depends(get_token_or_user_auth)) -> schemas.TaskRe # https://docs.celeryq.dev/en/stable/_modules/celery/result.html#AsyncResult raise task.result - response = { - "id": task_id, - "status": task.status, - "result": task.result - } - return JSONResponse(jsonable_encoder(response, exclude_unset=True, custom_encoder={bytes: custom_jsonable_encoder})) + response = {"id": task_id, "status": task.status, "result": task.result} + return JSONResponse( + jsonable_encoder( + response, + exclude_unset=True, + custom_encoder={bytes: custom_jsonable_encoder}, + ) + ) except Exception as e: log_error(e) - return JSONResponse({ - "id": task_id, - "status": "FAILURE", - "result": {"error": str(e)} - }) + return JSONResponse( + {"id": task_id, "status": "FAILURE", "result": {"error": str(e)}} + ) diff --git a/app/web/endpoints/url.py b/app/web/endpoints/url.py index 8307c2d..10b0adc 100644 --- a/app/web/endpoints/url.py +++ b/app/web/endpoints/url.py @@ -1,5 +1,5 @@ - from datetime import datetime +from http import HTTPStatus from urllib.parse import urlparse from fastapi import APIRouter, Depends, HTTPException @@ -9,6 +9,7 @@ from sqlalchemy.orm import Session from app.shared import schemas from app.shared.db.database import get_db_dependency +from app.shared.schemas import DeleteResponse from app.shared.task_messaging import get_celery from app.web.config import ALLOW_ANY_EMAIL from app.web.db import crud @@ -21,65 +22,106 @@ url_router = APIRouter(prefix="/url", tags=["Single URL operations"]) celery = get_celery() -@url_router.post("/archive", status_code=201, summary="Submit a single URL archive request, starts an archiving task.", response_description="task_id for the archiving task, will match the archive id.") + +@url_router.post( + "/archive", + status_code=HTTPStatus.CREATED, + summary="Submit a single URL archive request, starts an archiving task.", + response_description="task_id for the archiving task, will match the archive id.", +) def archive_url( archive: schemas.ArchiveTrigger, email=Depends(get_token_or_user_auth), - db: Session = Depends(get_db_dependency) -) -> schemas.Task: - logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {archive.url}") + db: Session = Depends(get_db_dependency), +) -> JSONResponse: + logger.info( + f"new {archive.public=} task for {email=} and {archive.group_id=}: {archive.url}" + ) parsed_url = urlparse(archive.url) if not all([parsed_url.scheme, parsed_url.netloc]): - raise HTTPException(status_code=400, detail="Invalid URL received.") + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, detail="Invalid URL received." + ) archive_create = schemas.ArchiveCreate(**archive.model_dump()) if email != ALLOW_ANY_EMAIL: archive_create.author_id = email user = UserState(db, email) if archive.group_id and not user.in_group(archive.group_id): - raise HTTPException(status_code=403, detail="User does not have access to this group.") + raise HTTPException( + status_code=HTTPStatus.FORBIDDEN, + detail="User does not have access to this group.", + ) if not user.has_quota_max_monthly_urls(archive.group_id): - raise HTTPException(status_code=429, detail="User has reached their monthly URL quota.") + raise HTTPException( + status_code=HTTPStatus.TOO_MANY_REQUESTS, + detail="User has reached their monthly URL quota.", + ) if not user.has_quota_max_monthly_mbs(archive.group_id): - raise HTTPException(status_code=429, detail="User has reached their monthly MB quota.") + raise HTTPException( + status_code=HTTPStatus.TOO_MANY_REQUESTS, + detail="User has reached their monthly MB quota.", + ) group_queue = user.priority_group(archive_create.group_id) else: archive_create.author_id = archive.author_id or email group_queue = convert_priority_to_queue_dict("high") - - task = celery.signature("create_archive_task", args=[archive_create.model_dump_json()]).apply_async(**group_queue) + task = celery.signature( + "create_archive_task", args=[archive_create.model_dump_json()] + ).apply_async(**group_queue) task_response = schemas.Task(id=task.id) - return JSONResponse(task_response.model_dump(), status_code=201) + return JSONResponse( + task_response.model_dump(), status_code=HTTPStatus.CREATED + ) @url_router.get("/search", summary="Search for archive entries by URL.") def search_by_url( - url: str, skip: int = 0, limit: int = 25, - archived_after: datetime = None, archived_before: datetime = None, - db: Session = Depends(get_db_dependency), - email: str = Depends(get_token_or_user_auth) + url: str, + skip: int = 0, + limit: int = 25, + archived_after: datetime = None, + archived_before: datetime = None, + db: Session = Depends(get_db_dependency), + email: str = Depends(get_token_or_user_auth), ) -> list[schemas.ArchiveResult]: - read_groups, read_public = False, False if email != ALLOW_ANY_EMAIL: user = UserState(db, email) if not user.read and not user.read_public: - raise HTTPException(status_code=403, detail="User does not have read access.") + raise HTTPException( + status_code=HTTPStatus.FORBIDDEN, + detail="User does not have read access.", + ) read_groups = user.read read_public = user.read_public - return crud.search_archives_by_url(db, url.strip(), email, read_groups, read_public, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before) + return crud.search_archives_by_url( + db, + url.strip(), + email, + read_groups, + read_public, + skip=skip, + limit=limit, + archived_after=archived_after, + archived_before=archived_before, + ) -@url_router.delete("/{id}", summary="Delete a single URL archive by id.") +@url_router.delete( + "/{archive_id}", summary="Delete a single URL archive by id." +) def delete_archive( - id:str, + archive_id: str, user: UserState = Depends(get_user_state), - db: Session = Depends(get_db_dependency) -) -> schemas.DeleteResponse: - logger.info(f"deleting url archive task {id} request by {user.email}") - return JSONResponse({ - "id": id, - "deleted": crud.soft_delete_archive(db, id, user.email) - }) + db: Session = Depends(get_db_dependency), +) -> DeleteResponse: + logger.info( + f"deleting url archive task {archive_id} request by {user.email}" + ) + return DeleteResponse( + id=archive_id, + deleted=crud.soft_delete_archive(db, archive_id, user.email), + ) diff --git a/app/web/main.py b/app/web/main.py index 69af5c6..525eff4 100644 --- a/app/web/main.py +++ b/app/web/main.py @@ -6,7 +6,7 @@ from fastapi.staticfiles import StaticFiles from loguru import logger from prometheus_fastapi_instrumentator import Instrumentator -from app.shared.settings import get_settings +from app.shared.settings import Settings, get_settings from app.shared.task_messaging import get_celery from app.web.config import API_DESCRIPTION, VERSION from app.web.endpoints.default import default_router @@ -21,13 +21,22 @@ from app.web.security import token_api_key_auth celery = get_celery() -def app_factory(settings = get_settings()): + +def app_factory(settings: Settings = None): + # TODO: Create dev, test, and prod versions of settings that do not have + # TODO: to be passed in as a parameter + if settings is None: + settings = get_settings() + app = FastAPI( title="Auto-Archiver API", description=API_DESCRIPTION, version=VERSION, - contact={"name": "GitHub", "url": "https://github.com/bellingcat/auto-archiver-api"}, - lifespan=lifespan + contact={ + "name": "GitHub", + "url": "https://github.com/bellingcat/auto-archiver-api", + }, + lifespan=lifespan, ) app.add_middleware( @@ -46,14 +55,30 @@ def app_factory(settings = get_settings()): app.include_router(interoperability_router) # prometheus exposed in /metrics with authentication - Instrumentator(should_group_status_codes=False, excluded_handlers=["/metrics", "/health", "/openapi.json", "/favicon.ico"]).instrument(app).expose(app, dependencies=[Depends(token_api_key_auth)]) + Instrumentator( + should_group_status_codes=False, + excluded_handlers=[ + "/metrics", + "/health", + "/openapi.json", + "/favicon.ico", + ], + ).instrument(app).expose(app, dependencies=[Depends(token_api_key_auth)]) if settings.SERVE_LOCAL_ARCHIVE: local_dir = settings.SERVE_LOCAL_ARCHIVE - if not os.path.isdir(local_dir) and os.path.isdir(local_dir.replace("/app", ".")): + if not os.path.isdir(local_dir) and os.path.isdir( + local_dir.replace("/app", ".") + ): local_dir = local_dir.replace("/app", ".") if len(settings.SERVE_LOCAL_ARCHIVE) > 1 and os.path.isdir(local_dir): - logger.warning(f"MOUNTing local archive, use this in development only {settings.SERVE_LOCAL_ARCHIVE}") - app.mount(settings.SERVE_LOCAL_ARCHIVE, StaticFiles(directory=local_dir), name=settings.SERVE_LOCAL_ARCHIVE) + logger.warning( + f"MOUNTing local archive, use this in development only {settings.SERVE_LOCAL_ARCHIVE}" + ) + app.mount( + settings.SERVE_LOCAL_ARCHIVE, + StaticFiles(directory=local_dir), + name=settings.SERVE_LOCAL_ARCHIVE, + ) return app diff --git a/app/web/middleware.py b/app/web/middleware.py index 5ddca4b..47c07b3 100644 --- a/app/web/middleware.py +++ b/app/web/middleware.py @@ -1,4 +1,3 @@ - import traceback from fastapi import Request @@ -11,23 +10,30 @@ from app.web.utils.metrics import EXCEPTION_COUNTER async def logging_middleware(request: Request, call_next): try: response = await call_next(request) - #TODO: use Origin to have summary prometheus metrics on where requests come from + # TODO: use Origin to have summary prometheus metrics on where requests come from # origin = request.headers.get("origin") - logger.info(f"{request.client.host}:{request.client.port} {request.method} {request.url._url} - HTTP {response.status_code}") + logger.info( + f"{request.client.host}:{request.client.port} {request.method} {request.url._url} - HTTP {response.status_code}" + ) return response except Exception as e: location = f"{request.method} {request.url._url}" await increase_exceptions_counter(e, location) - logger.info(f"{request.client.host}:{request.client.port} {location} - {e.__class__.__name__} {e}") + logger.info( + f"{request.client.host}:{request.client.port} {location} - {e.__class__.__name__} {e}" + ) raise e -async def increase_exceptions_counter(e: Exception, location:str="cronjob"): + +async def increase_exceptions_counter(e: Exception, location: str = "cronjob"): if location == "cronjob": try: last_trace = traceback.extract_tb(e.__traceback__)[-1] _file, _line, func_name, _text = last_trace location = func_name except Exception as e: - logger.error(f"Unable to get function name from cronjob exception traceback: {e}") + logger.error( + f"Unable to get function name from cronjob exception traceback: {e}" + ) EXCEPTION_COUNTER.labels(type=e.__class__.__name__, location=location).inc() log_error(e) diff --git a/app/web/security.py b/app/web/security.py index 494e094..5850ad3 100644 --- a/app/web/security.py +++ b/app/web/security.py @@ -1,4 +1,5 @@ import secrets +from http import HTTPStatus import requests from fastapi import Depends, HTTPException, status @@ -16,7 +17,7 @@ settings = get_settings() bearer_security = HTTPBearer() -def secure_compare(token, api_key): +def secure_compare(token, api_key) -> bool: return secrets.compare_digest(token.encode("utf8"), api_key.encode("utf8")) @@ -24,9 +25,13 @@ def secure_compare(token, api_key): def api_key_auth(api_key): assert len(api_key) >= 20, "Invalid API key, must be at least 20 chars" - async def auth(bearer: HTTPAuthorizationCredentials = Depends(bearer_security), auto_error=True): + async def auth( + bearer: HTTPAuthorizationCredentials = Depends(bearer_security), + auto_error=True, + ): is_correct = secure_compare(bearer.credentials, api_key) - if is_correct: return True + if is_correct: + return True if auto_error: raise HTTPException( @@ -38,17 +43,22 @@ def api_key_auth(api_key): return auth -# --------------------- Token Auth for AA itself to query the API, AA setup tool and Prometheus +# --- Token Auth for AA itself to query the API, AA setup tool and Prometheus token_api_key_auth = api_key_auth(settings.API_BEARER_TOKEN) -async def get_token_or_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)): +async def get_token_or_user_auth( + credentials: HTTPAuthorizationCredentials = Depends(bearer_security), +): # tries to use the static API_KEY and defaults to google JWT auth - if await token_api_key_auth(credentials, auto_error=False): return ALLOW_ANY_EMAIL + if await token_api_key_auth(credentials, auto_error=False): + return ALLOW_ANY_EMAIL return await get_user_auth(credentials) -async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)): +async def get_user_auth( + credentials: HTTPAuthorizationCredentials = Depends(bearer_security), +): # validates the Bearer token in the case that it requires it valid_user, info = authenticate_user(credentials.credentials) if valid_user: @@ -61,26 +71,37 @@ async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bear ) -def authenticate_user(access_token): +def authenticate_user(access_token) -> (bool, str): # https://cloud.google.com/docs/authentication/token-types#access - if type(access_token) != str or len(access_token) < 10: return False, "invalid access_token" - r = requests.get("https://oauth2.googleapis.com/tokeninfo", {"access_token": access_token}) - if r.status_code != 200: return False, "invalid token" + if not isinstance(access_token, str) or len(access_token) < 10: + return False, "invalid access_token" + r = requests.get( + "https://oauth2.googleapis.com/tokeninfo", + {"access_token": access_token}, + ) + if r.status_code != HTTPStatus.OK: + return False, "invalid token" try: j = r.json() - if j.get("azp") not in settings.CHROME_APP_IDS and j.get("aud") not in settings.CHROME_APP_IDS: - return False, f"token does not belong to valid APP_ID" + if ( + j.get("azp") not in settings.CHROME_APP_IDS + and j.get("aud") not in settings.CHROME_APP_IDS + ): + return False, "token does not belong to valid APP_ID" if j.get("email") in settings.BLOCKED_EMAILS: return False, f"email '{j.get('email')}' not allowed" if j.get("email_verified") != "true": return False, f"email '{j.get('email')}' not verified" if int(j.get("expires_in", -1)) <= 0: return False, "Token expired" - return True, j.get('email').lower() + return True, j.get("email").lower() except Exception as e: logger.warning(f"AUTH EXCEPTION occurred: {e}") return False, "exception occurred" -def get_user_state(email:str=Depends(get_user_auth), db:Session=Depends(get_db_dependency)): +def get_user_state( + email: str = Depends(get_user_auth), + db: Session = Depends(get_db_dependency), +) -> UserState: return UserState(db, email) diff --git a/app/web/utils/metrics.py b/app/web/utils/metrics.py index d8026a1..04b496f 100644 --- a/app/web/utils/metrics.py +++ b/app/web/utils/metrics.py @@ -15,27 +15,25 @@ from app.web.db import crud EXCEPTION_COUNTER = Counter( "exceptions", "Number of times a certain exception has occurred.", - labelnames=["type", "location"] + labelnames=["type", "location"], ) WORKER_EXCEPTION = Counter( "worker_exceptions_total", "Number of times a certain exception has occurred on the worker.", - labelnames=["type", "exception", "task", "traceback"] + labelnames=["type", "exception", "task", "traceback"], ) DISK_UTILIZATION = Gauge( - "disk_utilization", - "Disk utilization in GB", - labelnames=["type"] + "disk_utilization", "Disk utilization in GB", labelnames=["type"] ) DATABASE_METRICS = Gauge( "database_metrics", "Database metric readings at a certain point in time", - labelnames=["query"] + labelnames=["query"], ) DATABASE_METRICS_COUNTER = Counter( "database_metrics_counter", "Database metrics that increase over time", - labelnames=["query", "user"] + labelnames=["query", "user"], ) @@ -48,7 +46,12 @@ async def redis_subscribe_worker_exceptions(REDIS_EXCEPTIONS_CHANNEL: str): message = PubSubExceptions.get_message() if message and message["type"] == "message": data = json.loads(message["data"].decode("utf-8")) - WORKER_EXCEPTION.labels(type=data["type"], exception=data["exception"], task=data["task"], traceback=data["traceback"]).inc() + WORKER_EXCEPTION.labels( + type=data["type"], + exception=data["exception"], + task=data["task"], + traceback=data["traceback"], + ).inc() await asyncio.sleep(1) @@ -59,12 +62,19 @@ async def measure_regular_metrics(sqlite_db_url: str, repeat_in_seconds: int): try: fs = os.stat(sqlite_db_url.replace("sqlite:///", "")) DISK_UTILIZATION.labels(type="database").set(fs.st_size / (2**30)) - except Exception as e: log_error(e) + except Exception as e: + log_error(e) with get_db() as db: - DATABASE_METRICS.labels(query="count_archives").set(crud.count_archives(db)) - DATABASE_METRICS.labels(query="count_archive_urls").set(crud.count_archive_urls(db)) + DATABASE_METRICS.labels(query="count_archives").set( + crud.count_archives(db) + ) + DATABASE_METRICS.labels(query="count_archive_urls").set( + crud.count_archive_urls(db) + ) DATABASE_METRICS.labels(query="count_users").set(crud.count_users(db)) for user in crud.count_by_user_since(db, repeat_in_seconds): - DATABASE_METRICS_COUNTER.labels(query="count_by_user", user=user.author_id).inc(user.total) + DATABASE_METRICS_COUNTER.labels( + query="count_by_user", user=user.author_id + ).inc(user.total) diff --git a/app/web/utils/misc.py b/app/web/utils/misc.py index 16a6591..f78ae1e 100644 --- a/app/web/utils/misc.py +++ b/app/web/utils/misc.py @@ -5,12 +5,12 @@ from fastapi.encoders import jsonable_encoder def custom_jsonable_encoder(obj): if isinstance(obj, bytes): - return base64.b64encode(obj).decode('utf-8') + return base64.b64encode(obj).decode("utf-8") return jsonable_encoder(obj) def convert_priority_to_queue_dict(priority: str) -> dict: return { "priority": 0 if priority == "high" else 10, - "queue": f"{priority}_priority" + "queue": f"{priority}_priority", } diff --git a/pyproject.toml b/pyproject.toml index 75f6e63..1a849ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,9 @@ pythonpath = "." [tool.coverage.run] omit = ["app/migrations/*"] +[tool.ruff.lint.flake8-bugbear] +extend-immutable-calls = ["fastapi.Depends", "fastapi.Query"] + [tool.poetry.group.worker.dependencies] watchdog = ">=6.0.0,<7.0.0" setuptools = "^75.8.0"