Format and lint web directory (#67)

This commit is contained in:
Michael Plunkett
2025-03-10 12:45:19 -05:00
committed by GitHub
parent 1ca0ae2fb2
commit b50ca91d89
20 changed files with 761 additions and 309 deletions

View File

@@ -61,18 +61,18 @@ repos:
- --profile=black - --profile=black
- --line-length=80 - --line-length=80
# - repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
# rev: v0.9.7 rev: v0.9.7
# hooks: hooks:
# - id: ruff - id: ruff
# types_or: [python,pyi] types_or: [python,pyi]
# args: args:
# - --fix - --fix
# - --select=B,C,E,F,W,B9 - --select=B,C,E,F,W,B9
# - --line-length=80 - --line-length=80
# - --ignore=E203,E402,E501,E261 - --ignore=E203,E402,E501,E261
# - id: ruff-format - id: ruff-format
# types_or: [ python,pyi] types_or: [ python,pyi]
# args: args:
# - --target-version=py310 - --target-version=py310
# - --line-length=80 - --line-length=80

View File

@@ -14,9 +14,7 @@ config.set_main_option("sqlalchemy.url", get_settings().DATABASE_PATH)
# This line sets up loggers basically. # This line sets up loggers basically.
if config.config_file_name is not None: if config.config_file_name is not None:
# disable_existing_loggers prevents loguru disabling # disable_existing_loggers prevents loguru disabling
fileConfig( fileConfig(config.config_file_name, disable_existing_loggers=False)
config.config_file_name, disable_existing_loggers=False
)
# add your model's MetaData object here # add your model's MetaData object here
# for 'autogenerate' support # for 'autogenerate' support

View File

@@ -20,7 +20,7 @@ def generate_uuid():
return str(uuid.uuid4()) return str(uuid.uuid4())
# many to many association tables # many-to-many association tables
association_table_archive_tags = Table( association_table_archive_tags = Table(
"mtm_archives_tags", "mtm_archives_tags",
Base.metadata, Base.metadata,

View File

@@ -14,8 +14,8 @@ def test_update_sheet_last_url_archived_at(db_session):
assert isinstance(test_sheet.last_url_archived_at, datetime) assert isinstance(test_sheet.last_url_archived_at, datetime)
before = test_sheet.last_url_archived_at before = test_sheet.last_url_archived_at
assert ( assert (
worker_crud.update_sheet_last_url_archived_at(db_session, "sheet-123") worker_crud.update_sheet_last_url_archived_at(db_session, "sheet-123")
is True is True
) )
db_session.refresh(test_sheet) db_session.refresh(test_sheet)
assert isinstance(test_sheet.last_url_archived_at, datetime) assert isinstance(test_sheet.last_url_archived_at, datetime)
@@ -23,10 +23,10 @@ def test_update_sheet_last_url_archived_at(db_session):
# Test non-existent sheet # Test non-existent sheet
assert ( assert (
worker_crud.update_sheet_last_url_archived_at( worker_crud.update_sheet_last_url_archived_at(
db_session, "non-existent-sheet" db_session, "non-existent-sheet"
) )
is False is False
) )
@@ -42,14 +42,14 @@ def test_create_or_get_user(test_data, db_session):
# already exists # already exists
assert ( assert (
u1 := worker_crud.create_or_get_user(db_session, "rick@example.com") u1 := worker_crud.create_or_get_user(db_session, "rick@example.com")
) is not None ) is not None
assert u1.email == "rick@example.com" assert u1.email == "rick@example.com"
# new user # new user
assert ( assert (
u2 := worker_crud.create_or_get_user(db_session, "beth@example.com") u2 := worker_crud.create_or_get_user(db_session, "beth@example.com")
) is not None ) is not None
assert u2.email == "beth@example.com" assert u2.email == "beth@example.com"
assert db_session.query(models.User).count() == 4 assert db_session.query(models.User).count() == 4
@@ -64,8 +64,8 @@ def test_create_tag(db_session):
assert create_tag.id == "tag-101" assert create_tag.id == "tag-101"
assert db_session.query(models.Tag).count() == 1 assert db_session.query(models.Tag).count() == 1
assert ( assert (
db_session.query(models.Tag).filter(models.Tag.id == "tag-101").first() db_session.query(models.Tag).filter(models.Tag.id == "tag-101").first()
== create_tag == create_tag
) )
# same id does not add new db entry # same id does not add new db entry

View File

@@ -712,7 +712,11 @@ async def test_find_by_store_until(async_db_session):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_sheets_by_id_hash(async_db_session): async def test_get_sheets_by_id_hash(async_db_session):
author_emails = ["rick@example.com", "morty@example.com", "jerry@example.com"] author_emails = [
"rick@example.com",
"morty@example.com",
"jerry@example.com",
]
# Add test data # Add test data
sheets = [ sheets = [

View File

@@ -7,6 +7,7 @@ import alembic.config
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from app.web.main import app_factory
from app.web.utils.metrics import EXCEPTION_COUNTER from app.web.utils.metrics import EXCEPTION_COUNTER
@@ -59,7 +60,6 @@ def test_serve_local_archive_logic(get_settings):
try: try:
# modify the settings # modify the settings
get_settings.SERVE_LOCAL_ARCHIVE = "/app/local_archive_test" get_settings.SERVE_LOCAL_ARCHIVE = "/app/local_archive_test"
from app.web.main import app_factory
app = app_factory(get_settings) app = app_factory(get_settings)

View File

@@ -8,7 +8,10 @@ API_DESCRIPTION = """
- You can use this API to archive single URLs or entire Google Sheets. - You can use this API to archive single URLs or entire Google Sheets.
- Once you submit a URL or Sheet for archiving, the API will return a task_id that you can use to check the status of the archiving process. It works asynchronously. - Once you submit a URL or Sheet for archiving, the API will return a task_id that you can use to check the status of the archiving process. It works asynchronously.
""" """
BREAKING_CHANGES = {"minVersion": "0.4.0", "message": "The latest update has breaking changes, please update the extension to the most recent version."} BREAKING_CHANGES = {
"minVersion": "0.4.0",
"message": "The latest update has breaking changes, please update the extension to the most recent version.",
}
# changing this will corrupt the database logic # changing this will corrupt the database logic
ALLOW_ANY_EMAIL = "*" ALLOW_ANY_EMAIL = "*"

View File

@@ -1,15 +1,26 @@
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import lru_cache from typing import Any, Type
from cachetools import LRUCache, cached from cachetools import LRUCache, cached
from cachetools.keys import hashkey from cachetools.keys import hashkey
from loguru import logger from loguru import logger
from sqlalchemy import Column, func, or_, select from sqlalchemy import (
Column,
ColumnElement,
ScalarResult,
false,
func,
not_,
or_,
select,
true,
)
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session, load_only from sqlalchemy.orm import Session, load_only
from app.shared.db import models from app.shared.db import models
from app.shared.db.models import Archive, Group
from app.shared.settings import get_settings from app.shared.settings import get_settings
from app.shared.user_groups import UserGroups from app.shared.user_groups import UserGroups
from app.shared.utils.misc import fnv1a_hash_mod from app.shared.utils.misc import fnv1a_hash_mod
@@ -23,24 +34,48 @@ DATABASE_QUERY_LIMIT = get_settings().DATABASE_QUERY_LIMIT
def get_limit(user_limit: int): def get_limit(user_limit: int):
return max(1, min(user_limit, DATABASE_QUERY_LIMIT)) return max(1, min(user_limit, DATABASE_QUERY_LIMIT))
# --------------- TASK = Archive # --------------- TASK = Archive
def base_query(db: Session): def base_query(db: Session):
# NOTE: load_only is for optimization and not obfuscation, use .with_entities() if needed # NOTE: load_only is for optimization and not obfuscation, use
return db.query(models.Archive)\ # .with_entities() if needed
.filter(models.Archive.deleted == False)\ return (
.options(load_only(models.Archive.id, models.Archive.created_at, models.Archive.url, models.Archive.result, models.Archive.store_until)) db.query(models.Archive)
.filter(not_(models.Archive.deleted))
.options(
load_only(
models.Archive.id,
models.Archive.created_at,
models.Archive.url,
models.Archive.result,
models.Archive.store_until,
)
)
)
def search_archives_by_url(db: Session, url: str, email: str, read_groups: bool | set[str], read_public: bool, skip: int = 0, limit: int = 100, archived_after: datetime = None, archived_before: datetime = None, absolute_search: bool = False) -> list[models.Archive]: def search_archives_by_url(
# searches for partial URLs, if email is * no ownership (or read/read_public) filtering happens db: Session,
url: str,
email: str,
read_groups: bool | set[str],
read_public: bool,
skip: int = 0,
limit: int = 100,
archived_after: datetime = None,
archived_before: datetime = None,
absolute_search: bool = False,
) -> list[Type[Archive]]:
# searches for partial URLs, if email is * no ownership
# (or read/read_public) filtering happens
query = base_query(db) query = base_query(db)
if email != ALLOW_ANY_EMAIL: if email != ALLOW_ANY_EMAIL:
or_filters = [models.Archive.author_id == email] or_filters = [models.Archive.author_id == email]
if read_public: if read_public:
or_filters.append(models.Archive.public == True) or_filters.append(models.Archive.public.is_(true()))
if read_groups == True: if read_groups is True:
or_filters.append(models.Archive.group_id.isnot(None)) or_filters.append(models.Archive.group_id.isnot(None))
else: else:
or_filters.append(models.Archive.group_id.in_(read_groups)) or_filters.append(models.Archive.group_id.in_(read_groups))
@@ -48,21 +83,43 @@ def search_archives_by_url(db: Session, url: str, email: str, read_groups: bool
if absolute_search: if absolute_search:
query = query.filter(models.Archive.url == url) query = query.filter(models.Archive.url == url)
else: else:
query = query.filter(models.Archive.url.like(f'%{url}%')) query = query.filter(models.Archive.url.like(f"%{url}%"))
if archived_after: if archived_after:
query = query.filter(models.Archive.created_at > archived_after) query = query.filter(models.Archive.created_at > archived_after)
if archived_before: if archived_before:
query = query.filter(models.Archive.created_at < archived_before) query = query.filter(models.Archive.created_at < archived_before)
return query.order_by(models.Archive.created_at.desc()).offset(skip).limit(get_limit(limit)).all() return (
query.order_by(models.Archive.created_at.desc())
.offset(skip)
.limit(get_limit(limit))
.all()
)
def search_archives_by_email(db: Session, email: str, skip: int = 0, limit: int = 100): def search_archives_by_email(
return base_query(db).filter(models.Archive.author_id == email).order_by(models.Archive.created_at.desc()).offset(skip).limit(get_limit(limit)).all() db: Session, email: str, skip: int = 0, limit: int = 100
):
return (
base_query(db)
.filter(models.Archive.author_id == email)
.order_by(models.Archive.created_at.desc())
.offset(skip)
.limit(get_limit(limit))
.all()
)
def soft_delete_archive(db: Session, id: str, email: str) -> bool: def soft_delete_archive(db: Session, id: str, email: str) -> bool:
# TODO: implement hard-delete with cronjob that deletes from S3 # TODO: implement hard-delete with cronjob that deletes from S3
db_archive = db.query(models.Archive).filter(models.Archive.id == id, models.Archive.author_id == email, models.Archive.deleted == False).first() db_archive = (
db.query(models.Archive)
.filter(
models.Archive.id == id,
models.Archive.author_id == email,
models.Archive.deleted.is_(false()),
)
.first()
)
if db_archive: if db_archive:
db_archive.deleted = True db_archive.deleted = True
db.commit() db.commit()
@@ -83,22 +140,29 @@ def count_users(db: Session):
def count_by_user_since(db: Session, seconds_delta: int = 15): def count_by_user_since(db: Session, seconds_delta: int = 15):
time_threshold = datetime.now() - timedelta(seconds=seconds_delta) time_threshold = datetime.now() - timedelta(seconds=seconds_delta)
return db.query(models.Archive.author_id, func.count().label('total'))\ return (
.filter(models.Archive.created_at >= time_threshold)\ db.query(models.Archive.author_id, func.count().label("total"))
.group_by(models.Archive.author_id)\ .filter(models.Archive.created_at >= time_threshold)
.order_by(func.count().desc())\ .group_by(models.Archive.author_id)
.limit(500).all() .order_by(func.count().desc())
.limit(500)
.all()
)
async def find_by_store_until(db: AsyncSession, store_until_is_before: datetime) -> list[models.Archive]: async def find_by_store_until(
db: AsyncSession, store_until_is_before: datetime
) -> ScalarResult[Archive]:
res = await db.execute( res = await db.execute(
select(models.Archive) select(models.Archive).filter(
.filter(models.Archive.deleted == False, models.Archive.store_until < store_until_is_before) models.Archive.deleted.is_(false()),
models.Archive.store_until < store_until_is_before,
)
) )
return res.scalars() return res.scalars()
async def soft_delete_expired_archives(db: AsyncSession) -> dict: async def soft_delete_expired_archives(db: AsyncSession) -> int:
to_delete = await find_by_store_until(db, datetime.now()) to_delete = await find_by_store_until(db, datetime.now())
counter = 0 counter = 0
for archive in to_delete: for archive in to_delete:
@@ -106,47 +170,86 @@ async def soft_delete_expired_archives(db: AsyncSession) -> dict:
counter += 1 counter += 1
await db.commit() await db.commit()
return counter return counter
# --------------- TAG # --------------- TAG
async def get_group_priority_async(db: AsyncSession, group_id: str) -> dict: async def get_group_priority_async(db: AsyncSession, group_id: str) -> dict:
db_group = await db.get(models.Group, group_id) db_group = await db.get(models.Group, group_id)
priority = db_group.permissions.get("priority", "low") if db_group else "low" priority = (
db_group.permissions.get("priority", "low") if db_group else "low"
)
return convert_priority_to_queue_dict(priority) return convert_priority_to_queue_dict(priority)
@cached(cache=LRUCache(maxsize=128), key=lambda db, email: hashkey(email)) @cached(cache=LRUCache(maxsize=128), key=lambda db, email: hashkey(email))
def get_user_group_names(db: Session, email: str) -> list[str]: def get_user_group_names(
db: Session, email: str
) -> list[Any] | list[ColumnElement[Any]]:
""" """
given an email retrieves the user groups from the DB and then the email-domain groups from a global variable, the email does not need to belong to an existing user. given an email retrieves the user groups from the DB and then the
email-domain groups from a global variable, the email does not need to
belong to an existing user.
""" """
# TODO: the read: [group1, group2] permissions don't currently work # TODO: the read: [group1, group2] permissions don't currently work
if not email or not len(email) or "@" not in email: return [] if not email or not len(email) or "@" not in email:
return []
# get user groups # get user groups
user_groups = db.query(models.association_table_user_groups).filter_by(user_id=email).with_entities(Column("group_id")).all() user_groups = (
db.query(models.association_table_user_groups)
.filter_by(user_id=email)
.with_entities(Column("group_id"))
.all()
)
user_level_groups_names = [g[0] for g in user_groups] user_level_groups_names = [g[0] for g in user_groups]
# get domain groups # get domain groups
domain = email.split('@')[1] domain = email.split("@")[1]
domain_level_groups = db.query(models.Group.id).filter(models.Group.domains.contains(domain)).with_entities(Column("id")).all() domain_level_groups = (
db.query(models.Group.id)
.filter(models.Group.domains.contains(domain))
.with_entities(Column("id"))
.all()
)
domain_level_groups_names = [g[0] for g in domain_level_groups] domain_level_groups_names = [g[0] for g in domain_level_groups]
return list(set(user_level_groups_names + domain_level_groups_names)) return list(set(user_level_groups_names + domain_level_groups_names))
def get_user_groups_by_name(db: Session, groups: list[str]) -> list[models.Group]: def get_user_groups_by_name(
return db.query(models.Group).filter( db: Session, groups: list[str]
models.Group.id.in_(groups) ) -> list[Type[Group]]:
).all() return db.query(models.Group).filter(models.Group.id.in_(groups)).all()
# --------------- INIT User-Groups # --------------- INIT User-Groups
def upsert_group(db: Session, group_name: str, description: str, orchestrator: str, orchestrator_sheet: str, service_account_email: str, permissions: dict, domains: list) -> models.Group: def upsert_group(
db_group = db.query(models.Group).filter(models.Group.id == group_name).first() db: Session,
group_name: str,
description: str,
orchestrator: str,
orchestrator_sheet: str,
service_account_email: str,
permissions: dict,
domains: list,
) -> models.Group:
db_group = (
db.query(models.Group).filter(models.Group.id == group_name).first()
)
if db_group is None: if db_group is None:
db_group = models.Group(id=group_name, description=description, orchestrator=orchestrator, orchestrator_sheet=orchestrator_sheet, service_account_email=service_account_email, permissions=permissions, domains=domains) db_group = models.Group(
id=group_name,
description=description,
orchestrator=orchestrator,
orchestrator_sheet=orchestrator_sheet,
service_account_email=service_account_email,
permissions=permissions,
domains=domains,
)
db.add(db_group) db.add(db_group)
else: else:
db_group.description = description db_group.description = description
@@ -173,6 +276,7 @@ def upsert_user(db: Session, email: str):
def upsert_user_groups(db: Session): def upsert_user_groups(db: Session):
def display_email_pii(email: str): def display_email_pii(email: str):
return f"'{email[0:3]}...@{email.split('@')[1]}'" return f"'{email[0:3]}...@{email.split('@')[1]}'"
""" """
reads the user_groups yaml file and inserts any new users, groups, reads the user_groups yaml file and inserts any new users, groups,
along with new participation of users in groups along with new participation of users in groups
@@ -196,16 +300,30 @@ def upsert_user_groups(db: Session):
# upsert groups and save a map of groupid -> dbobject # upsert groups and save a map of groupid -> dbobject
for group_id, g in ug.groups.items(): for group_id, g in ug.groups.items():
upsert_group(db, group_id, g.description, g.orchestrator, g.orchestrator_sheet, g.service_account_email, json.loads(g.permissions.model_dump_json()), list(group_domains.get(group_id, []))) upsert_group(
db_groups: dict[str, models.Group] = {g.id: g for g in db.query(models.Group).all()} db,
group_id,
g.description,
g.orchestrator,
g.orchestrator_sheet,
g.service_account_email,
json.loads(g.permissions.model_dump_json()),
list(group_domains.get(group_id, [])),
)
db_groups: dict[str, models.Group] = {
g.id: g for g in db.query(models.Group).all()
}
# integrity checks # integrity checks
for group_in_domains in group_domains: for group_in_domains in group_domains:
if group_in_domains not in db_groups: if group_in_domains not in db_groups:
logger.warning(f"[CONFIG] Group '{group_in_domains}' does not exist in the database: domains setting will not work.") logger.warning(
f"[CONFIG] Group '{group_in_domains}' does not exist in the database: domains setting will not work."
)
# reinsert users in their EXPLICITLY DEFINED groups # reinsert users in their EXPLICITLY DEFINED groups
# domain groups are check live, as there may be new users that are not explicitly registered but belong to a domain # domain groups are check live, as there may be new users that are not
# explicitly registered but belong to a domain
for email, explicit_groups in ug.users.items(): for email, explicit_groups in ug.users.items():
explicit_groups = explicit_groups or [] explicit_groups = explicit_groups or []
logger.info(f"EXPLICIT {display_email_pii(email)} => {explicit_groups}") logger.info(f"EXPLICIT {display_email_pii(email)} => {explicit_groups}")
@@ -215,7 +333,9 @@ def upsert_user_groups(db: Session):
# connect users to groups # connect users to groups
for group_id in explicit_groups: for group_id in explicit_groups:
if group_id not in db_groups: if group_id not in db_groups:
logger.warning(f"[CONFIG] Group {group_id} does not exist in config file, skipping for email={display_email_pii(email)}.") logger.warning(
f"[CONFIG] Group {group_id} does not exist in config file, skipping for email={display_email_pii(email)}."
)
continue continue
db_groups[group_id].users.append(db_user) db_groups[group_id].users.append(db_user)
@@ -223,12 +343,27 @@ def upsert_user_groups(db: Session):
count_user_groups = db.query(models.association_table_user_groups).count() count_user_groups = db.query(models.association_table_user_groups).count()
count_groups = db.query(func.count(models.Group.id)).scalar() count_groups = db.query(func.count(models.Group.id)).scalar()
logger.success(f"[CONFIG] DONE: [users={count_users(db)}, groups={count_groups}, explicit user groups={count_user_groups}].") logger.success(
f"[CONFIG] DONE: [users={count_users(db)}, groups={count_groups}, explicit user groups={count_user_groups}]."
)
# --------------- SHEET # --------------- SHEET
def create_sheet(db: Session, sheet_id: str, name: str, email: str, group_id: str, frequency: str): def create_sheet(
db_sheet = models.Sheet(id=sheet_id, name=name, author_id=email, group_id=group_id, frequency=frequency) db: Session,
sheet_id: str,
name: str,
email: str,
group_id: str,
frequency: str,
):
db_sheet = models.Sheet(
id=sheet_id,
name=name,
author_id=email,
group_id=group_id,
frequency=frequency,
)
db.add(db_sheet) db.add(db_sheet)
db.commit() db.commit()
db.refresh(db_sheet) db.refresh(db_sheet)
@@ -236,20 +371,31 @@ def create_sheet(db: Session, sheet_id: str, name: str, email: str, group_id: st
def get_user_sheet(db: Session, email: str, sheet_id: str) -> models.Sheet: def get_user_sheet(db: Session, email: str, sheet_id: str) -> models.Sheet:
return db.query(models.Sheet).filter(models.Sheet.author_id == email, models.Sheet.id == sheet_id).first() return (
db.query(models.Sheet)
.filter(models.Sheet.author_id == email, models.Sheet.id == sheet_id)
.first()
)
def get_user_sheets(db: Session, email: str) -> list[models.Sheet]: def get_user_sheets(db: Session, email: str) -> list[models.Sheet]:
return db.query(models.Sheet).filter(models.Sheet.author_id == email).order_by(models.Sheet.last_url_archived_at.desc()).all() return (
db.query(models.Sheet)
.filter(models.Sheet.author_id == email)
.order_by(models.Sheet.last_url_archived_at.desc())
.all()
)
async def get_sheets_by_id_hash(db: AsyncSession, frequency: str, modulo: str, id_hash: int) -> list[models.Sheet]: async def get_sheets_by_id_hash(
db: AsyncSession, frequency: str, modulo: str, id_hash: int
) -> list[models.Sheet]:
result = await db.execute( result = await db.execute(
select(models.Sheet).filter(models.Sheet.frequency == frequency) select(models.Sheet).filter(models.Sheet.frequency == frequency)
) )
filtered = [] filtered = []
for sheet in result.scalars(): for sheet in result.scalars():
if fnv1a_hash_mod(sheet.id, modulo) == id_hash: if fnv1a_hash_mod(sheet.id, int(modulo)) == id_hash:
filtered.append(sheet) filtered.append(sheet)
return filtered return filtered
@@ -257,7 +403,9 @@ async def get_sheets_by_id_hash(db: AsyncSession, frequency: str, modulo: str, i
async def delete_stale_sheets(db: AsyncSession, inactivity_days: int) -> dict: async def delete_stale_sheets(db: AsyncSession, inactivity_days: int) -> dict:
time_threshold = datetime.now() - timedelta(days=inactivity_days) time_threshold = datetime.now() - timedelta(days=inactivity_days)
result = await db.execute( result = await db.execute(
select(models.Sheet).filter(models.Sheet.last_url_archived_at < time_threshold) select(models.Sheet).filter(
models.Sheet.last_url_archived_at < time_threshold
)
) )
deleted = defaultdict(list) deleted = defaultdict(list)
for sheet in result.scalars(): for sheet in result.scalars():
@@ -268,7 +416,11 @@ async def delete_stale_sheets(db: AsyncSession, inactivity_days: int) -> dict:
def delete_sheet(db: Session, sheet_id: str, email: str) -> bool: def delete_sheet(db: Session, sheet_id: str, email: str) -> bool:
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id, models.Sheet.author_id == email).first() db_sheet = (
db.query(models.Sheet)
.filter(models.Sheet.id == sheet_id, models.Sheet.author_id == email)
.first()
)
if db_sheet: if db_sheet:
db.delete(db_sheet) db.delete(db_sheet)
db.commit() db.commit()

View File

@@ -1,4 +1,3 @@
from datetime import datetime from datetime import datetime
from typing import Dict, Set from typing import Dict, Set
@@ -21,14 +20,15 @@ class UserState:
def __init__(self, db: Session, email: str): def __init__(self, db: Session, email: str):
self.db = db self.db = db
self.email = email.lower() self.email = email.lower()
self._permissions = {}
@property @property
def permissions(self) -> Dict[str, GroupInfo]: def permissions(self) -> Dict[str, GroupInfo]:
""" """
Returns a dict of all group permissions and a special {"all": read/archive_url/archive_sheet} key Returns a dict of all group permissions and a special
{"all": read/archive_url/archive_sheet} key
""" """
if not hasattr(self, '_permissions'): if not self._permissions:
self._permissions = {}
self._permissions["all"] = GroupInfo( self._permissions["all"] = GroupInfo(
read=self.read, read=self.read,
read_public=self.read_public, read_public=self.read_public,
@@ -38,23 +38,33 @@ class UserState:
max_archive_lifespan_months=self.max_archive_lifespan_months, max_archive_lifespan_months=self.max_archive_lifespan_months,
max_monthly_urls=self.max_monthly_urls, max_monthly_urls=self.max_monthly_urls,
max_monthly_mbs=self.max_monthly_mbs, max_monthly_mbs=self.max_monthly_mbs,
priority=self.priority priority=self.priority,
) )
for group in self.user_groups: for group in self.user_groups:
if not group.permissions: continue if not group.permissions:
self._permissions[group.id] = GroupInfo(**group.permissions, description=group.description, service_account_email=group.service_account_email) continue
self._permissions[group.id] = GroupInfo(
**group.permissions,
description=group.description,
service_account_email=group.service_account_email,
)
return self._permissions return self._permissions
@property @property
def user_groups_names(self): def user_groups_names(self):
if not hasattr(self, '_user_groups_names'): if not hasattr(self, "_user_groups_names"):
self._user_groups_names = crud.get_user_group_names(self.db, self.email) + ["default"] # TODO: Define hidden properties in __init__ method
self._user_groups_names = crud.get_user_group_names(
self.db, self.email
) + ["default"]
return self._user_groups_names return self._user_groups_names
@property @property
def user_groups(self): def user_groups(self):
if not hasattr(self, '_user_groups'): if not hasattr(self, "_user_groups"):
self._user_groups = crud.get_user_groups_by_name(self.db, self.user_groups_names) self._user_groups = crud.get_user_groups_by_name(
self.db, self.user_groups_names
)
return self._user_groups return self._user_groups
@property @property
@@ -62,10 +72,11 @@ class UserState:
""" """
Read can be a list of group names or True, if all can be read. Read can be a list of group names or True, if all can be read.
""" """
if not hasattr(self, '_read'): if not hasattr(self, "_read"):
self._read = set() self._read = set()
for group in self.user_groups: for group in self.user_groups:
if not group.permissions: continue if not group.permissions:
continue
group_read_permissions = group.permissions.get("read", []) group_read_permissions = group.permissions.get("read", [])
if "all" in group_read_permissions: if "all" in group_read_permissions:
self._read = True self._read = True
@@ -79,10 +90,11 @@ class UserState:
""" """
Read public permission Read public permission
""" """
if not hasattr(self, '_read_public'): if not hasattr(self, "_read_public"):
self._read_public = False self._read_public = False
for group in self.user_groups: for group in self.user_groups:
if not group.permissions: continue if not group.permissions:
continue
if group.permissions.get("read_public", False): if group.permissions.get("read_public", False):
self._read_public = True self._read_public = True
return self._read_public return self._read_public
@@ -93,10 +105,11 @@ class UserState:
""" """
Archive URL permission Archive URL permission
""" """
if not hasattr(self, '_archive_url'): if not hasattr(self, "_archive_url"):
self._archive_url = False self._archive_url = False
for group in self.user_groups: for group in self.user_groups:
if not group.permissions: continue if not group.permissions:
continue
if group.permissions.get("archive_url", False): if group.permissions.get("archive_url", False):
self._archive_url = True self._archive_url = True
return self._archive_url return self._archive_url
@@ -107,10 +120,11 @@ class UserState:
""" """
Archive sheet permission Archive sheet permission
""" """
if not hasattr(self, '_archive_sheet'): if not hasattr(self, "_archive_sheet"):
self._archive_sheet = False self._archive_sheet = False
for group in self.user_groups: for group in self.user_groups:
if not group.permissions: continue if not group.permissions:
continue
if group.permissions.get("archive_sheet", False): if group.permissions.get("archive_sheet", False):
self._archive_sheet = True self._archive_sheet = True
return self._archive_sheet return self._archive_sheet
@@ -118,37 +132,53 @@ class UserState:
@property @property
def sheet_frequency(self): def sheet_frequency(self):
if not hasattr(self, '_sheet_frequency'): if not hasattr(self, "_sheet_frequency"):
self._sheet_frequency = set() self._sheet_frequency = set()
for group in self.user_groups: for group in self.user_groups:
if not group.permissions: continue if not group.permissions:
self._sheet_frequency.update(group.permissions.get("sheet_frequency", None)) continue
self._sheet_frequency.update(
group.permissions.get("sheet_frequency", None)
)
return self._sheet_frequency return self._sheet_frequency
@property @property
def max_archive_lifespan_months(self) -> int: def max_archive_lifespan_months(self) -> int:
if not hasattr(self, '_max_archive_lifespan_months'): if not hasattr(self, "_max_archive_lifespan_months"):
self._max_archive_lifespan_months = self._helper_for_grouping_max_numerical_permissions("max_archive_lifespan_months") self._max_archive_lifespan_months = (
self._helper_for_grouping_max_numerical_permissions(
"max_archive_lifespan_months"
)
)
return self._max_archive_lifespan_months return self._max_archive_lifespan_months
@property @property
def max_monthly_urls(self) -> int: def max_monthly_urls(self) -> int:
if not hasattr(self, '_max_monthly_urls'): if not hasattr(self, "_max_monthly_urls"):
self._max_monthly_urls = self._helper_for_grouping_max_numerical_permissions("max_monthly_urls") self._max_monthly_urls = (
self._helper_for_grouping_max_numerical_permissions(
"max_monthly_urls"
)
)
return self._max_monthly_urls return self._max_monthly_urls
@property @property
def max_monthly_mbs(self) -> int: def max_monthly_mbs(self) -> int:
if not hasattr(self, '_max_monthly_mbs'): if not hasattr(self, "_max_monthly_mbs"):
self._max_monthly_mbs = self._helper_for_grouping_max_numerical_permissions("max_monthly_mbs") self._max_monthly_mbs = (
self._helper_for_grouping_max_numerical_permissions(
"max_monthly_mbs"
)
)
return self._max_monthly_mbs return self._max_monthly_mbs
@property @property
def priority(self) -> str: def priority(self) -> str:
if not hasattr(self, '_priority'): if not hasattr(self, "_priority"):
self._priority = "low" self._priority = "low"
for group in self.user_groups: for group in self.user_groups:
if not group.permissions: continue if not group.permissions:
continue
if group.permissions.get("priority", self._priority) == "high": if group.permissions.get("priority", self._priority) == "high":
self._priority = "high" self._priority = "high"
break break
@@ -159,18 +189,28 @@ class UserState:
""" """
A user is active if they can read/archive anything A user is active if they can read/archive anything
""" """
if not hasattr(self, '_active'): if not hasattr(self, "_active"):
self._active = bool(self.read or self.read_public or self.archive_url or self.archive_sheet) self._active = bool(
self.read
or self.read_public
or self.archive_url
or self.archive_sheet
)
return self._active return self._active
def _helper_for_grouping_max_numerical_permissions(self, permission_name: str) -> int: def _helper_for_grouping_max_numerical_permissions(
self, permission_name: str
) -> int:
""" """
Iterates one of the numerical permissions where -1 means no restrictions and returns either -1 or the maximum value, defaults according to GroupPermissions Iterates one of the numerical permissions where -1 means no restrictions
and returns either -1 or the maximum value, defaults according to
GroupPermissions
""" """
default = GroupPermissions.model_fields[permission_name].default default = GroupPermissions.model_fields[permission_name].default
max_value = default max_value = default
for group in self.user_groups: for group in self.user_groups:
if not group.permissions: continue if not group.permissions:
continue
group_value = group.permissions.get(permission_name, default) group_value = group.permissions.get(permission_name, default)
if group_value == -1: if group_value == -1:
max_value = -1 max_value = -1
@@ -181,43 +221,65 @@ class UserState:
def in_group(self, group_id: str) -> bool: def in_group(self, group_id: str) -> bool:
return group_id in self.user_groups_names return group_id in self.user_groups_names
def usage(self) -> Dict: def usage(self) -> UsageResponse:
""" """
returns the monthly quotas for the URLs/MBs and the totals for Sheets Returns the monthly quotas for the URLs/MBs and the totals for Sheets
""" """
current_month = datetime.now().month current_month = datetime.now().month
current_year = datetime.now().year current_year = datetime.now().year
# find and sum all user sheets over this month # find and sum all user sheets over this month
user_sheets = self.db.query( user_sheets = (
models.Sheet.group_id, self.db.query(
func.count(models.Sheet.id).label('sheet_count') models.Sheet.group_id,
).filter(models.Sheet.author_id == self.email).group_by(models.Sheet.group_id).all() func.count(models.Sheet.id).label("sheet_count"),
)
.filter(models.Sheet.author_id == self.email)
.group_by(models.Sheet.group_id)
.all()
)
sheets_by_group = {sheet.group_id: sheet.sheet_count for sheet in user_sheets} sheets_by_group = {
sheet.group_id: sheet.sheet_count for sheet in user_sheets
}
# find and sum all user urls over this month # find and sum all user urls over this month
urls_by_group = self.db.query( urls_by_group = (
models.Archive.group_id, self.db.query(
func.count(models.Archive.id).label('url_count'), models.Archive.group_id,
func.coalesce(func.sum( func.count(models.Archive.id).label("url_count"),
func.coalesce( func.coalesce(
func.cast( func.sum(
func.json_extract(models.Archive.result, '$.metadata.total_bytes'), func.coalesce(
sqlalchemy.Integer func.cast(
), 0 func.json_extract(
) models.Archive.result,
), 0).label('total_bytes') "$.metadata.total_bytes",
).filter( ),
models.Archive.author_id == self.email, sqlalchemy.Integer,
func.extract('month', models.Archive.created_at) == current_month, ),
func.extract('year', models.Archive.created_at) == current_year 0,
).group_by(models.Archive.group_id).all() )
),
0,
).label("total_bytes"),
)
.filter(
models.Archive.author_id == self.email,
func.extract("month", models.Archive.created_at)
== current_month,
func.extract("year", models.Archive.created_at) == current_year,
)
.group_by(models.Archive.group_id)
.all()
)
# merge the two queries # merge the two queries
usage_by_group: Dict[str, Usage] = { usage_by_group: Dict[str, Usage] = {
(url.group_id or ""): (url.group_id or ""): Usage(
Usage(monthly_urls=url.url_count, monthly_mbs=int(url.total_bytes / 1024 / 1024)) monthly_urls=url.url_count,
monthly_mbs=int(url.total_bytes / 1024 / 1024),
)
for url in urls_by_group for url in urls_by_group
} }
for group_id, sheet_count in sheets_by_group.items(): for group_id, sheet_count in sheets_by_group.items():
@@ -236,7 +298,7 @@ class UserState:
monthly_urls=total_urls, monthly_urls=total_urls,
monthly_mbs=int(total_bytes / 1024 / 1024), monthly_mbs=int(total_bytes / 1024 / 1024),
total_sheets=total_sheets, total_sheets=total_sheets,
groups=usage_by_group groups=usage_by_group,
) )
def has_quota_monthly_sheets(self, group_id: str) -> bool: def has_quota_monthly_sheets(self, group_id: str) -> bool:
@@ -246,7 +308,14 @@ class UserState:
if group_id not in self.permissions: if group_id not in self.permissions:
return False return False
user_sheets = self.db.query(models.Sheet).filter(models.Sheet.author_id == self.email, models.Sheet.group_id == group_id).count() user_sheets = (
self.db.query(models.Sheet)
.filter(
models.Sheet.author_id == self.email,
models.Sheet.group_id == group_id,
)
.count()
)
sheet_quota = self.permissions[group_id].max_sheets sheet_quota = self.permissions[group_id].max_sheets
if sheet_quota == -1: if sheet_quota == -1:
@@ -255,13 +324,15 @@ class UserState:
def has_quota_max_monthly_urls(self, group_id: str) -> bool: def has_quota_max_monthly_urls(self, group_id: str) -> bool:
""" """
checks if a user has reached their monthly url quota for a group, if global then group should be empty string Checks if a user has reached their monthly url quota for a group, if
global then group should be empty string
""" """
quota = 0 quota = 0
if not group_id: if not group_id:
quota = self.max_monthly_urls quota = self.max_monthly_urls
else: else:
if group_id not in self.permissions: return False if group_id not in self.permissions:
return False
quota = self.permissions[group_id].max_monthly_urls quota = self.permissions[group_id].max_monthly_urls
if quota == -1: if quota == -1:
@@ -269,24 +340,31 @@ class UserState:
current_month = datetime.now().month current_month = datetime.now().month
current_year = datetime.now().year current_year = datetime.now().year
user_urls = self.db.query(models.Archive).filter( user_urls = (
models.Archive.author_id == self.email, self.db.query(models.Archive)
models.Archive.group_id == group_id, .filter(
func.extract('month', models.Archive.created_at) == current_month, models.Archive.author_id == self.email,
func.extract('year', models.Archive.created_at) == current_year models.Archive.group_id == group_id,
).count() func.extract("month", models.Archive.created_at)
== current_month,
func.extract("year", models.Archive.created_at) == current_year,
)
.count()
)
return user_urls < quota return user_urls < quota
def has_quota_max_monthly_mbs(self, group_id: str) -> bool: def has_quota_max_monthly_mbs(self, group_id: str) -> bool:
""" """
checks if a user has reached their monthly MBs quota for a group, if global then group should be empty string Checks if a user has reached their monthly MBs quota for a group, if
global then group should be empty string
""" """
quota = 0 quota = 0
if not group_id: if not group_id:
quota = self.max_monthly_mbs quota = self.max_monthly_mbs
else: else:
if group_id not in self.permissions: return False if group_id not in self.permissions:
return False
quota = self.permissions[group_id].max_monthly_mbs quota = self.permissions[group_id].max_monthly_mbs
if quota == -1: if quota == -1:
@@ -296,19 +374,34 @@ class UserState:
current_year = datetime.now().year current_year = datetime.now().year
# find and sum all user bytes over this month # find and sum all user bytes over this month
user_bytes = self.db.query(models.Archive).filter( user_bytes = (
models.Archive.author_id == self.email, self.db.query(models.Archive)
models.Archive.group_id == group_id, .filter(
func.extract('month', models.Archive.created_at) == current_month, models.Archive.author_id == self.email,
func.extract('year', models.Archive.created_at) == current_year models.Archive.group_id == group_id,
).with_entities(func.coalesce(func.sum( func.extract("month", models.Archive.created_at)
func.coalesce( == current_month,
func.cast( func.extract("year", models.Archive.created_at) == current_year,
func.json_extract(models.Archive.result, '$.metadata.total_bytes'),
sqlalchemy.Integer
), 0
) )
), 0).label('total')).scalar() .with_entities(
func.coalesce(
func.sum(
func.coalesce(
func.cast(
func.json_extract(
models.Archive.result,
"$.metadata.total_bytes",
),
sqlalchemy.Integer,
),
0,
)
),
0,
).label("total")
)
.scalar()
)
# convert bytes to mb # convert bytes to mb
user_mbs = int(user_bytes / 1024 / 1024) user_mbs = int(user_bytes / 1024 / 1024)
@@ -316,7 +409,7 @@ class UserState:
def can_manually_trigger(self, group_id: str) -> bool: def can_manually_trigger(self, group_id: str) -> bool:
""" """
checks if a user is allowed to manually trigger a sheet Checks if a user is allowed to manually trigger a sheet
""" """
if group_id not in self.permissions: if group_id not in self.permissions:
return False return False
@@ -325,18 +418,21 @@ class UserState:
def is_sheet_frequency_allowed(self, group_id: str, frequency: str) -> bool: def is_sheet_frequency_allowed(self, group_id: str, frequency: str) -> bool:
""" """
checks if a user is allowed to create a sheet with this frequency for this group Checks if a user is allowed to create a sheet with this frequency for
this group
""" """
if group_id not in self.permissions: if group_id not in self.permissions:
return False return False
return frequency in self.permissions[group_id].sheet_frequency return frequency in self.permissions[group_id].sheet_frequency
def priority_group(self, group_id: str) -> str: def priority_group(self, group_id: str) -> dict:
priority = "low" priority = "low"
for group in self.user_groups: for group in self.user_groups:
if group.id != group_id: continue if group.id != group_id:
if not group.permissions: continue continue
if not group.permissions:
continue
priority = group.permissions.get("priority", priority) priority = group.permissions.get("priority", priority)
break break
return convert_priority_to_queue_dict(priority) return convert_priority_to_queue_dict(priority)

View File

@@ -1,4 +1,4 @@
from http import HTTPStatus
from typing import Dict from typing import Dict
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
@@ -15,38 +15,50 @@ default_router = APIRouter()
@default_router.get("/") @default_router.get("/")
async def home(): async def home() -> JSONResponse:
return JSONResponse({"version": VERSION, "breakingChanges": BREAKING_CHANGES}) return JSONResponse(
{"version": VERSION, "breakingChanges": BREAKING_CHANGES}
)
@default_router.get("/health") @default_router.get("/health")
async def health(): async def health() -> JSONResponse:
return JSONResponse({"status": "ok"}) return JSONResponse({"status": "ok"})
@default_router.get("/user/active", summary="Check if the user is active and can use the tool.") @default_router.get(
"/user/active", summary="Check if the user is active and can use the tool."
)
async def active( async def active(
user: UserState = Depends(get_user_state), user: UserState = Depends(get_user_state),
) -> ActiveUser: ) -> ActiveUser:
return {"active": user.active} return ActiveUser(active=user.active)
@default_router.get("/user/permissions", summary="Get the user's global 'all' permissions and the permissions for each group they belong to.") @default_router.get(
"/user/permissions",
summary="Get the user's global 'all' permissions and the permissions for each group they belong to.",
)
def get_user_permissions( def get_user_permissions(
user: UserState = Depends(get_user_state), user: UserState = Depends(get_user_state),
) -> Dict[str, GroupInfo]: ) -> Dict[str, GroupInfo]:
return user.permissions return user.permissions
@default_router.get("/user/usage", summary="Get the user's monthly URLs/MBs usage along with the total active sheets, breakdown by group.")
@default_router.get(
"/user/usage",
summary="Get the user's monthly URLs/MBs usage along with the total active sheets, breakdown by group.",
)
def get_user_usage( def get_user_usage(
user: UserState = Depends(get_user_state), user: UserState = Depends(get_user_state),
) -> UsageResponse: ) -> UsageResponse:
if not user.active: if not user.active:
raise HTTPException(status_code=403, detail="User is not active.") raise HTTPException(
status_code=HTTPStatus.FORBIDDEN, detail="User is not active."
)
return user.usage() return user.usage()
@default_router.get("/favicon.ico", include_in_schema=False)
@default_router.get('/favicon.ico', include_in_schema=False)
async def favicon() -> FileResponse: async def favicon() -> FileResponse:
return FileResponse("app/web/static/favicon.ico") return FileResponse("app/web/static/favicon.ico")

View File

@@ -1,4 +1,5 @@
import json import json
from http import HTTPStatus
import sqlalchemy import sqlalchemy
from auto_archiver.core import Metadata from auto_archiver.core import Metadata
@@ -16,26 +17,39 @@ from app.web.config import ALLOW_ANY_EMAIL
from app.web.security import token_api_key_auth from app.web.security import token_api_key_auth
interoperability_router = APIRouter(prefix="/interop", tags=["Interoperability endpoints."]) interoperability_router = APIRouter(
prefix="/interop", tags=["Interoperability endpoints."]
)
# ----- endpoint to submit data archived elsewhere # ----- endpoint to submit data archived elsewhere
@interoperability_router.post("/submit-archive", status_code=201, summary="Submit a manual archive entry, for data that was archived elsewhere.") @interoperability_router.post(
"/submit-archive",
status_code=HTTPStatus.CREATED,
summary="Submit a manual archive entry, for data that was archived elsewhere.",
)
def submit_manual_archive( def submit_manual_archive(
manual: schemas.SubmitManualArchive, manual: schemas.SubmitManualArchive,
auth=Depends(token_api_key_auth), auth=Depends(token_api_key_auth),
db: Session = Depends(get_db_dependency) db: Session = Depends(get_db_dependency),
): ):
try: try:
result: Metadata = Metadata.from_json(manual.result) result: Metadata = Metadata.from_json(manual.result)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
log_error(e) log_error(e)
raise HTTPException(status_code=422, detail="Invalid JSON in result field.") raise HTTPException(
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
detail="Invalid JSON in result field.",
) from e
manual.author_id = manual.author_id or ALLOW_ANY_EMAIL manual.author_id = manual.author_id or ALLOW_ANY_EMAIL
manual.tags.add("manual") manual.tags.add("manual")
store_until = business_logic.get_store_archive_until_or_never(db, manual.group_id) store_until = business_logic.get_store_archive_until_or_never(
logger.debug(f"[MANUAL ARCHIVE] {manual.author_id} {manual.url} {store_until}") db, manual.group_id
)
logger.debug(
f"[MANUAL ARCHIVE] {manual.author_id} {manual.url} {store_until}"
)
try: try:
archive = schemas.ArchiveCreate( archive = schemas.ArchiveCreate(
@@ -51,8 +65,15 @@ def submit_manual_archive(
) )
db_archive = worker_crud.store_archived_url(db, archive) db_archive = worker_crud.store_archived_url(db, archive)
logger.debug(f"[MANUAL ARCHIVE STORED] {db_archive.author_id} {db_archive.url}") logger.debug(
return JSONResponse({"id": db_archive.id}, status_code=201) f"[MANUAL ARCHIVE STORED] {db_archive.author_id} {db_archive.url}"
)
return JSONResponse(
{"id": db_archive.id}, status_code=HTTPStatus.CREATED
)
except sqlalchemy.exc.IntegrityError as e: except sqlalchemy.exc.IntegrityError as e:
log_error(e) log_error(e)
raise HTTPException(status_code=422, detail=f"Cannot insert into DB due to integrity error, likely duplicate urls.") raise HTTPException(
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
detail="Cannot insert into DB due to integrity error, likely duplicate urls.",
) from e

View File

@@ -1,81 +1,134 @@
from http import HTTPStatus
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from sqlalchemy import exc from sqlalchemy import exc
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.shared import schemas
from app.shared.db.database import get_db_dependency from app.shared.db.database import get_db_dependency
from app.shared.schemas import (
DeleteResponse,
SheetAdd,
SheetResponse,
SubmitSheet,
)
from app.shared.task_messaging import get_celery from app.shared.task_messaging import get_celery
from app.web.db import crud from app.web.db import crud
from app.web.db.user_state import UserState from app.web.db.user_state import UserState
from app.web.security import get_user_state from app.web.security import get_user_state
sheet_router = APIRouter(prefix="/sheet", tags=["Google Spreadsheet operations"]) sheet_router = APIRouter(
prefix="/sheet", tags=["Google Spreadsheet operations"]
)
celery = get_celery() celery = get_celery()
@sheet_router.post("/create", status_code=201, summary="Store a new Google Sheet for regular archiving.")
@sheet_router.post(
"/create",
status_code=HTTPStatus.CREATED,
summary="Store a new Google Sheet for regular archiving.",
)
def create_sheet( def create_sheet(
sheet: schemas.SheetAdd, sheet: SheetAdd,
user: UserState = Depends(get_user_state), user: UserState = Depends(get_user_state),
db: Session = Depends(get_db_dependency), db: Session = Depends(get_db_dependency),
) -> schemas.SheetResponse: ) -> SheetResponse:
if not user.in_group(sheet.group_id): if not user.in_group(sheet.group_id):
raise HTTPException(status_code=403, detail="User does not have access to this group.") raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="User does not have access to this group.",
)
if not user.has_quota_monthly_sheets(sheet.group_id): if not user.has_quota_monthly_sheets(sheet.group_id):
raise HTTPException(status_code=429, detail="User has reached their sheet quota for this group.") raise HTTPException(
status_code=HTTPStatus.TOO_MANY_REQUESTS,
detail="User has reached their sheet quota for this group.",
)
if not user.is_sheet_frequency_allowed(sheet.group_id, sheet.frequency): if not user.is_sheet_frequency_allowed(sheet.group_id, sheet.frequency):
raise HTTPException(status_code=422, detail="Invalid frequency selected for this group.") raise HTTPException(
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
detail="Invalid frequency selected for this group.",
)
try: try:
return crud.create_sheet(db, sheet.id, sheet.name, user.email, sheet.group_id, sheet.frequency) return crud.create_sheet(
db,
sheet.id,
sheet.name,
user.email,
sheet.group_id,
sheet.frequency,
)
except exc.IntegrityError as e: except exc.IntegrityError as e:
raise HTTPException(status_code=400, detail="Sheet with this ID is already being archived.") from e raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="Sheet with this ID is already being archived.",
) from e
@sheet_router.get("/mine", status_code=200, summary="Get the authenticated user's Google Sheets.") @sheet_router.get(
"/mine",
status_code=HTTPStatus.OK,
summary="Get the authenticated user's Google Sheets.",
)
def get_user_sheets( def get_user_sheets(
user: UserState = Depends(get_user_state), user: UserState = Depends(get_user_state),
db: Session = Depends(get_db_dependency) db: Session = Depends(get_db_dependency),
) -> list[schemas.SheetResponse]: ) -> list[SheetResponse]:
return crud.get_user_sheets(db, user.email) return crud.get_user_sheets(db, user.email)
@sheet_router.delete("/{id}", summary="Delete a Google Sheet by ID.") @sheet_router.delete("/{sheet_id}", summary="Delete a Google Sheet by ID.")
def delete_sheet( def delete_sheet(
id: str, sheet_id: str,
user: UserState = Depends(get_user_state), user: UserState = Depends(get_user_state),
db: Session = Depends(get_db_dependency), db: Session = Depends(get_db_dependency),
) -> schemas.DeleteResponse: ) -> DeleteResponse:
return JSONResponse({ return DeleteResponse(
"id": id, id=sheet_id, deleted=crud.delete_sheet(db, sheet_id, user.email)
"deleted": crud.delete_sheet(db, id, user.email) )
})
@sheet_router.post("/{id}/archive", status_code=201, summary="Trigger an archiving task for a GSheet you own.", response_description="task_id for the archiving task.") @sheet_router.post(
"/{sheet_id}/archive",
status_code=HTTPStatus.CREATED,
summary="Trigger an archiving task for a GSheet you own.",
response_description="task_id for the archiving task.",
)
def archive_user_sheet( def archive_user_sheet(
id: str, sheet_id: str,
user: UserState = Depends(get_user_state), user: UserState = Depends(get_user_state),
db: Session = Depends(get_db_dependency), db: Session = Depends(get_db_dependency),
) -> schemas.Task: ) -> JSONResponse:
sheet = crud.get_user_sheet(db, user.email, sheet_id=sheet_id)
sheet = crud.get_user_sheet(db, user.email, sheet_id=id)
if not sheet: if not sheet:
raise HTTPException(status_code=403, detail="No access to this sheet.") raise HTTPException(
status_code=HTTPStatus.FORBIDDEN, detail="No access to this sheet."
)
if not user.in_group(sheet.group_id): if not user.in_group(sheet.group_id):
raise HTTPException(status_code=403, detail="User does not have access to this group.") raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="User does not have access to this group.",
)
if not user.can_manually_trigger(sheet.group_id): if not user.can_manually_trigger(sheet.group_id):
raise HTTPException(status_code=429, detail="User cannot manually trigger sheet archiving in this group.") raise HTTPException(
status_code=HTTPStatus.TOO_MANY_REQUESTS,
detail="User cannot manually trigger sheet archiving in this group.",
)
group_queue = user.priority_group(sheet.group_id) group_queue = user.priority_group(sheet.group_id)
task = celery.signature("create_sheet_task", args=[schemas.SubmitSheet(sheet_id=id, author_id=user.email, group_id=sheet.group_id).model_dump_json()]).apply_async(**group_queue) task = celery.signature(
"create_sheet_task",
args=[
SubmitSheet(
sheet_id=sheet_id, author_id=user.email, group_id=sheet.group_id
).model_dump_json()
],
).apply_async(**group_queue)
return JSONResponse({"id": task.id}, status_code=201) return JSONResponse({"id": task.id}, status_code=HTTPStatus.CREATED)

View File

@@ -14,8 +14,14 @@ task_router = APIRouter(prefix="/task", tags=["Async task operations"])
celery = get_celery() celery = get_celery()
@task_router.get("/{task_id}", summary="Check the status of an async task by its id, works for URLs and Sheet tasks.")
def get_status(task_id, email=Depends(get_token_or_user_auth)) -> schemas.TaskResult: @task_router.get(
"/{task_id}",
summary="Check the status of an async task by its id, works for URLs and Sheet tasks.",
)
def get_status(
task_id, email=Depends(get_token_or_user_auth)
) -> schemas.TaskResult:
task = AsyncResult(task_id, app=celery) task = AsyncResult(task_id, app=celery)
try: try:
if task.status == "FAILURE": if task.status == "FAILURE":
@@ -24,17 +30,17 @@ def get_status(task_id, email=Depends(get_token_or_user_auth)) -> schemas.TaskRe
# https://docs.celeryq.dev/en/stable/_modules/celery/result.html#AsyncResult # https://docs.celeryq.dev/en/stable/_modules/celery/result.html#AsyncResult
raise task.result raise task.result
response = { response = {"id": task_id, "status": task.status, "result": task.result}
"id": task_id, return JSONResponse(
"status": task.status, jsonable_encoder(
"result": task.result response,
} exclude_unset=True,
return JSONResponse(jsonable_encoder(response, exclude_unset=True, custom_encoder={bytes: custom_jsonable_encoder})) custom_encoder={bytes: custom_jsonable_encoder},
)
)
except Exception as e: except Exception as e:
log_error(e) log_error(e)
return JSONResponse({ return JSONResponse(
"id": task_id, {"id": task_id, "status": "FAILURE", "result": {"error": str(e)}}
"status": "FAILURE", )
"result": {"error": str(e)}
})

View File

@@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from http import HTTPStatus
from urllib.parse import urlparse from urllib.parse import urlparse
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
@@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
from app.shared import schemas from app.shared import schemas
from app.shared.db.database import get_db_dependency from app.shared.db.database import get_db_dependency
from app.shared.schemas import DeleteResponse
from app.shared.task_messaging import get_celery from app.shared.task_messaging import get_celery
from app.web.config import ALLOW_ANY_EMAIL from app.web.config import ALLOW_ANY_EMAIL
from app.web.db import crud from app.web.db import crud
@@ -21,65 +22,106 @@ url_router = APIRouter(prefix="/url", tags=["Single URL operations"])
celery = get_celery() celery = get_celery()
@url_router.post("/archive", status_code=201, summary="Submit a single URL archive request, starts an archiving task.", response_description="task_id for the archiving task, will match the archive id.")
@url_router.post(
"/archive",
status_code=HTTPStatus.CREATED,
summary="Submit a single URL archive request, starts an archiving task.",
response_description="task_id for the archiving task, will match the archive id.",
)
def archive_url( def archive_url(
archive: schemas.ArchiveTrigger, archive: schemas.ArchiveTrigger,
email=Depends(get_token_or_user_auth), email=Depends(get_token_or_user_auth),
db: Session = Depends(get_db_dependency) db: Session = Depends(get_db_dependency),
) -> schemas.Task: ) -> JSONResponse:
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {archive.url}") logger.info(
f"new {archive.public=} task for {email=} and {archive.group_id=}: {archive.url}"
)
parsed_url = urlparse(archive.url) parsed_url = urlparse(archive.url)
if not all([parsed_url.scheme, parsed_url.netloc]): if not all([parsed_url.scheme, parsed_url.netloc]):
raise HTTPException(status_code=400, detail="Invalid URL received.") raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, detail="Invalid URL received."
)
archive_create = schemas.ArchiveCreate(**archive.model_dump()) archive_create = schemas.ArchiveCreate(**archive.model_dump())
if email != ALLOW_ANY_EMAIL: if email != ALLOW_ANY_EMAIL:
archive_create.author_id = email archive_create.author_id = email
user = UserState(db, email) user = UserState(db, email)
if archive.group_id and not user.in_group(archive.group_id): if archive.group_id and not user.in_group(archive.group_id):
raise HTTPException(status_code=403, detail="User does not have access to this group.") raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="User does not have access to this group.",
)
if not user.has_quota_max_monthly_urls(archive.group_id): if not user.has_quota_max_monthly_urls(archive.group_id):
raise HTTPException(status_code=429, detail="User has reached their monthly URL quota.") raise HTTPException(
status_code=HTTPStatus.TOO_MANY_REQUESTS,
detail="User has reached their monthly URL quota.",
)
if not user.has_quota_max_monthly_mbs(archive.group_id): if not user.has_quota_max_monthly_mbs(archive.group_id):
raise HTTPException(status_code=429, detail="User has reached their monthly MB quota.") raise HTTPException(
status_code=HTTPStatus.TOO_MANY_REQUESTS,
detail="User has reached their monthly MB quota.",
)
group_queue = user.priority_group(archive_create.group_id) group_queue = user.priority_group(archive_create.group_id)
else: else:
archive_create.author_id = archive.author_id or email archive_create.author_id = archive.author_id or email
group_queue = convert_priority_to_queue_dict("high") group_queue = convert_priority_to_queue_dict("high")
task = celery.signature(
task = celery.signature("create_archive_task", args=[archive_create.model_dump_json()]).apply_async(**group_queue) "create_archive_task", args=[archive_create.model_dump_json()]
).apply_async(**group_queue)
task_response = schemas.Task(id=task.id) task_response = schemas.Task(id=task.id)
return JSONResponse(task_response.model_dump(), status_code=201) return JSONResponse(
task_response.model_dump(), status_code=HTTPStatus.CREATED
)
@url_router.get("/search", summary="Search for archive entries by URL.") @url_router.get("/search", summary="Search for archive entries by URL.")
def search_by_url( def search_by_url(
url: str, skip: int = 0, limit: int = 25, url: str,
archived_after: datetime = None, archived_before: datetime = None, skip: int = 0,
db: Session = Depends(get_db_dependency), limit: int = 25,
email: str = Depends(get_token_or_user_auth) archived_after: datetime = None,
archived_before: datetime = None,
db: Session = Depends(get_db_dependency),
email: str = Depends(get_token_or_user_auth),
) -> list[schemas.ArchiveResult]: ) -> list[schemas.ArchiveResult]:
read_groups, read_public = False, False read_groups, read_public = False, False
if email != ALLOW_ANY_EMAIL: if email != ALLOW_ANY_EMAIL:
user = UserState(db, email) user = UserState(db, email)
if not user.read and not user.read_public: if not user.read and not user.read_public:
raise HTTPException(status_code=403, detail="User does not have read access.") raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="User does not have read access.",
)
read_groups = user.read read_groups = user.read
read_public = user.read_public read_public = user.read_public
return crud.search_archives_by_url(db, url.strip(), email, read_groups, read_public, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before) return crud.search_archives_by_url(
db,
url.strip(),
email,
read_groups,
read_public,
skip=skip,
limit=limit,
archived_after=archived_after,
archived_before=archived_before,
)
@url_router.delete("/{id}", summary="Delete a single URL archive by id.") @url_router.delete(
"/{archive_id}", summary="Delete a single URL archive by id."
)
def delete_archive( def delete_archive(
id:str, archive_id: str,
user: UserState = Depends(get_user_state), user: UserState = Depends(get_user_state),
db: Session = Depends(get_db_dependency) db: Session = Depends(get_db_dependency),
) -> schemas.DeleteResponse: ) -> DeleteResponse:
logger.info(f"deleting url archive task {id} request by {user.email}") logger.info(
return JSONResponse({ f"deleting url archive task {archive_id} request by {user.email}"
"id": id, )
"deleted": crud.soft_delete_archive(db, id, user.email) return DeleteResponse(
}) id=archive_id,
deleted=crud.soft_delete_archive(db, archive_id, user.email),
)

View File

@@ -6,7 +6,7 @@ from fastapi.staticfiles import StaticFiles
from loguru import logger from loguru import logger
from prometheus_fastapi_instrumentator import Instrumentator from prometheus_fastapi_instrumentator import Instrumentator
from app.shared.settings import get_settings from app.shared.settings import Settings, get_settings
from app.shared.task_messaging import get_celery from app.shared.task_messaging import get_celery
from app.web.config import API_DESCRIPTION, VERSION from app.web.config import API_DESCRIPTION, VERSION
from app.web.endpoints.default import default_router from app.web.endpoints.default import default_router
@@ -21,13 +21,22 @@ from app.web.security import token_api_key_auth
celery = get_celery() celery = get_celery()
def app_factory(settings = get_settings()):
def app_factory(settings: Settings = None):
# TODO: Create dev, test, and prod versions of settings that do not have
# TODO: to be passed in as a parameter
if settings is None:
settings = get_settings()
app = FastAPI( app = FastAPI(
title="Auto-Archiver API", title="Auto-Archiver API",
description=API_DESCRIPTION, description=API_DESCRIPTION,
version=VERSION, version=VERSION,
contact={"name": "GitHub", "url": "https://github.com/bellingcat/auto-archiver-api"}, contact={
lifespan=lifespan "name": "GitHub",
"url": "https://github.com/bellingcat/auto-archiver-api",
},
lifespan=lifespan,
) )
app.add_middleware( app.add_middleware(
@@ -46,14 +55,30 @@ def app_factory(settings = get_settings()):
app.include_router(interoperability_router) app.include_router(interoperability_router)
# prometheus exposed in /metrics with authentication # prometheus exposed in /metrics with authentication
Instrumentator(should_group_status_codes=False, excluded_handlers=["/metrics", "/health", "/openapi.json", "/favicon.ico"]).instrument(app).expose(app, dependencies=[Depends(token_api_key_auth)]) Instrumentator(
should_group_status_codes=False,
excluded_handlers=[
"/metrics",
"/health",
"/openapi.json",
"/favicon.ico",
],
).instrument(app).expose(app, dependencies=[Depends(token_api_key_auth)])
if settings.SERVE_LOCAL_ARCHIVE: if settings.SERVE_LOCAL_ARCHIVE:
local_dir = settings.SERVE_LOCAL_ARCHIVE local_dir = settings.SERVE_LOCAL_ARCHIVE
if not os.path.isdir(local_dir) and os.path.isdir(local_dir.replace("/app", ".")): if not os.path.isdir(local_dir) and os.path.isdir(
local_dir.replace("/app", ".")
):
local_dir = local_dir.replace("/app", ".") local_dir = local_dir.replace("/app", ".")
if len(settings.SERVE_LOCAL_ARCHIVE) > 1 and os.path.isdir(local_dir): if len(settings.SERVE_LOCAL_ARCHIVE) > 1 and os.path.isdir(local_dir):
logger.warning(f"MOUNTing local archive, use this in development only {settings.SERVE_LOCAL_ARCHIVE}") logger.warning(
app.mount(settings.SERVE_LOCAL_ARCHIVE, StaticFiles(directory=local_dir), name=settings.SERVE_LOCAL_ARCHIVE) f"MOUNTing local archive, use this in development only {settings.SERVE_LOCAL_ARCHIVE}"
)
app.mount(
settings.SERVE_LOCAL_ARCHIVE,
StaticFiles(directory=local_dir),
name=settings.SERVE_LOCAL_ARCHIVE,
)
return app return app

View File

@@ -1,4 +1,3 @@
import traceback import traceback
from fastapi import Request from fastapi import Request
@@ -11,23 +10,30 @@ from app.web.utils.metrics import EXCEPTION_COUNTER
async def logging_middleware(request: Request, call_next): async def logging_middleware(request: Request, call_next):
try: try:
response = await call_next(request) response = await call_next(request)
#TODO: use Origin to have summary prometheus metrics on where requests come from # TODO: use Origin to have summary prometheus metrics on where requests come from
# origin = request.headers.get("origin") # origin = request.headers.get("origin")
logger.info(f"{request.client.host}:{request.client.port} {request.method} {request.url._url} - HTTP {response.status_code}") logger.info(
f"{request.client.host}:{request.client.port} {request.method} {request.url._url} - HTTP {response.status_code}"
)
return response return response
except Exception as e: except Exception as e:
location = f"{request.method} {request.url._url}" location = f"{request.method} {request.url._url}"
await increase_exceptions_counter(e, location) await increase_exceptions_counter(e, location)
logger.info(f"{request.client.host}:{request.client.port} {location} - {e.__class__.__name__} {e}") logger.info(
f"{request.client.host}:{request.client.port} {location} - {e.__class__.__name__} {e}"
)
raise e raise e
async def increase_exceptions_counter(e: Exception, location:str="cronjob"):
async def increase_exceptions_counter(e: Exception, location: str = "cronjob"):
if location == "cronjob": if location == "cronjob":
try: try:
last_trace = traceback.extract_tb(e.__traceback__)[-1] last_trace = traceback.extract_tb(e.__traceback__)[-1]
_file, _line, func_name, _text = last_trace _file, _line, func_name, _text = last_trace
location = func_name location = func_name
except Exception as e: except Exception as e:
logger.error(f"Unable to get function name from cronjob exception traceback: {e}") logger.error(
f"Unable to get function name from cronjob exception traceback: {e}"
)
EXCEPTION_COUNTER.labels(type=e.__class__.__name__, location=location).inc() EXCEPTION_COUNTER.labels(type=e.__class__.__name__, location=location).inc()
log_error(e) log_error(e)

View File

@@ -1,4 +1,5 @@
import secrets import secrets
from http import HTTPStatus
import requests import requests
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
@@ -16,7 +17,7 @@ settings = get_settings()
bearer_security = HTTPBearer() bearer_security = HTTPBearer()
def secure_compare(token, api_key): def secure_compare(token, api_key) -> bool:
return secrets.compare_digest(token.encode("utf8"), api_key.encode("utf8")) return secrets.compare_digest(token.encode("utf8"), api_key.encode("utf8"))
@@ -24,9 +25,13 @@ def secure_compare(token, api_key):
def api_key_auth(api_key): def api_key_auth(api_key):
assert len(api_key) >= 20, "Invalid API key, must be at least 20 chars" assert len(api_key) >= 20, "Invalid API key, must be at least 20 chars"
async def auth(bearer: HTTPAuthorizationCredentials = Depends(bearer_security), auto_error=True): async def auth(
bearer: HTTPAuthorizationCredentials = Depends(bearer_security),
auto_error=True,
):
is_correct = secure_compare(bearer.credentials, api_key) is_correct = secure_compare(bearer.credentials, api_key)
if is_correct: return True if is_correct:
return True
if auto_error: if auto_error:
raise HTTPException( raise HTTPException(
@@ -38,17 +43,22 @@ def api_key_auth(api_key):
return auth return auth
# --------------------- Token Auth for AA itself to query the API, AA setup tool and Prometheus # --- Token Auth for AA itself to query the API, AA setup tool and Prometheus
token_api_key_auth = api_key_auth(settings.API_BEARER_TOKEN) token_api_key_auth = api_key_auth(settings.API_BEARER_TOKEN)
async def get_token_or_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)): async def get_token_or_user_auth(
credentials: HTTPAuthorizationCredentials = Depends(bearer_security),
):
# tries to use the static API_KEY and defaults to google JWT auth # tries to use the static API_KEY and defaults to google JWT auth
if await token_api_key_auth(credentials, auto_error=False): return ALLOW_ANY_EMAIL if await token_api_key_auth(credentials, auto_error=False):
return ALLOW_ANY_EMAIL
return await get_user_auth(credentials) return await get_user_auth(credentials)
async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)): async def get_user_auth(
credentials: HTTPAuthorizationCredentials = Depends(bearer_security),
):
# validates the Bearer token in the case that it requires it # validates the Bearer token in the case that it requires it
valid_user, info = authenticate_user(credentials.credentials) valid_user, info = authenticate_user(credentials.credentials)
if valid_user: if valid_user:
@@ -61,26 +71,37 @@ async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bear
) )
def authenticate_user(access_token): def authenticate_user(access_token) -> (bool, str):
# https://cloud.google.com/docs/authentication/token-types#access # https://cloud.google.com/docs/authentication/token-types#access
if type(access_token) != str or len(access_token) < 10: return False, "invalid access_token" if not isinstance(access_token, str) or len(access_token) < 10:
r = requests.get("https://oauth2.googleapis.com/tokeninfo", {"access_token": access_token}) return False, "invalid access_token"
if r.status_code != 200: return False, "invalid token" r = requests.get(
"https://oauth2.googleapis.com/tokeninfo",
{"access_token": access_token},
)
if r.status_code != HTTPStatus.OK:
return False, "invalid token"
try: try:
j = r.json() j = r.json()
if j.get("azp") not in settings.CHROME_APP_IDS and j.get("aud") not in settings.CHROME_APP_IDS: if (
return False, f"token does not belong to valid APP_ID" j.get("azp") not in settings.CHROME_APP_IDS
and j.get("aud") not in settings.CHROME_APP_IDS
):
return False, "token does not belong to valid APP_ID"
if j.get("email") in settings.BLOCKED_EMAILS: if j.get("email") in settings.BLOCKED_EMAILS:
return False, f"email '{j.get('email')}' not allowed" return False, f"email '{j.get('email')}' not allowed"
if j.get("email_verified") != "true": if j.get("email_verified") != "true":
return False, f"email '{j.get('email')}' not verified" return False, f"email '{j.get('email')}' not verified"
if int(j.get("expires_in", -1)) <= 0: if int(j.get("expires_in", -1)) <= 0:
return False, "Token expired" return False, "Token expired"
return True, j.get('email').lower() return True, j.get("email").lower()
except Exception as e: except Exception as e:
logger.warning(f"AUTH EXCEPTION occurred: {e}") logger.warning(f"AUTH EXCEPTION occurred: {e}")
return False, "exception occurred" return False, "exception occurred"
def get_user_state(email:str=Depends(get_user_auth), db:Session=Depends(get_db_dependency)): def get_user_state(
email: str = Depends(get_user_auth),
db: Session = Depends(get_db_dependency),
) -> UserState:
return UserState(db, email) return UserState(db, email)

View File

@@ -15,27 +15,25 @@ from app.web.db import crud
EXCEPTION_COUNTER = Counter( EXCEPTION_COUNTER = Counter(
"exceptions", "exceptions",
"Number of times a certain exception has occurred.", "Number of times a certain exception has occurred.",
labelnames=["type", "location"] labelnames=["type", "location"],
) )
WORKER_EXCEPTION = Counter( WORKER_EXCEPTION = Counter(
"worker_exceptions_total", "worker_exceptions_total",
"Number of times a certain exception has occurred on the worker.", "Number of times a certain exception has occurred on the worker.",
labelnames=["type", "exception", "task", "traceback"] labelnames=["type", "exception", "task", "traceback"],
) )
DISK_UTILIZATION = Gauge( DISK_UTILIZATION = Gauge(
"disk_utilization", "disk_utilization", "Disk utilization in GB", labelnames=["type"]
"Disk utilization in GB",
labelnames=["type"]
) )
DATABASE_METRICS = Gauge( DATABASE_METRICS = Gauge(
"database_metrics", "database_metrics",
"Database metric readings at a certain point in time", "Database metric readings at a certain point in time",
labelnames=["query"] labelnames=["query"],
) )
DATABASE_METRICS_COUNTER = Counter( DATABASE_METRICS_COUNTER = Counter(
"database_metrics_counter", "database_metrics_counter",
"Database metrics that increase over time", "Database metrics that increase over time",
labelnames=["query", "user"] labelnames=["query", "user"],
) )
@@ -48,7 +46,12 @@ async def redis_subscribe_worker_exceptions(REDIS_EXCEPTIONS_CHANNEL: str):
message = PubSubExceptions.get_message() message = PubSubExceptions.get_message()
if message and message["type"] == "message": if message and message["type"] == "message":
data = json.loads(message["data"].decode("utf-8")) data = json.loads(message["data"].decode("utf-8"))
WORKER_EXCEPTION.labels(type=data["type"], exception=data["exception"], task=data["task"], traceback=data["traceback"]).inc() WORKER_EXCEPTION.labels(
type=data["type"],
exception=data["exception"],
task=data["task"],
traceback=data["traceback"],
).inc()
await asyncio.sleep(1) await asyncio.sleep(1)
@@ -59,12 +62,19 @@ async def measure_regular_metrics(sqlite_db_url: str, repeat_in_seconds: int):
try: try:
fs = os.stat(sqlite_db_url.replace("sqlite:///", "")) fs = os.stat(sqlite_db_url.replace("sqlite:///", ""))
DISK_UTILIZATION.labels(type="database").set(fs.st_size / (2**30)) DISK_UTILIZATION.labels(type="database").set(fs.st_size / (2**30))
except Exception as e: log_error(e) except Exception as e:
log_error(e)
with get_db() as db: with get_db() as db:
DATABASE_METRICS.labels(query="count_archives").set(crud.count_archives(db)) DATABASE_METRICS.labels(query="count_archives").set(
DATABASE_METRICS.labels(query="count_archive_urls").set(crud.count_archive_urls(db)) crud.count_archives(db)
)
DATABASE_METRICS.labels(query="count_archive_urls").set(
crud.count_archive_urls(db)
)
DATABASE_METRICS.labels(query="count_users").set(crud.count_users(db)) DATABASE_METRICS.labels(query="count_users").set(crud.count_users(db))
for user in crud.count_by_user_since(db, repeat_in_seconds): for user in crud.count_by_user_since(db, repeat_in_seconds):
DATABASE_METRICS_COUNTER.labels(query="count_by_user", user=user.author_id).inc(user.total) DATABASE_METRICS_COUNTER.labels(
query="count_by_user", user=user.author_id
).inc(user.total)

View File

@@ -5,12 +5,12 @@ from fastapi.encoders import jsonable_encoder
def custom_jsonable_encoder(obj): def custom_jsonable_encoder(obj):
if isinstance(obj, bytes): if isinstance(obj, bytes):
return base64.b64encode(obj).decode('utf-8') return base64.b64encode(obj).decode("utf-8")
return jsonable_encoder(obj) return jsonable_encoder(obj)
def convert_priority_to_queue_dict(priority: str) -> dict: def convert_priority_to_queue_dict(priority: str) -> dict:
return { return {
"priority": 0 if priority == "high" else 10, "priority": 0 if priority == "high" else 10,
"queue": f"{priority}_priority" "queue": f"{priority}_priority",
} }

View File

@@ -38,6 +38,9 @@ pythonpath = "."
[tool.coverage.run] [tool.coverage.run]
omit = ["app/migrations/*"] omit = ["app/migrations/*"]
[tool.ruff.lint.flake8-bugbear]
extend-immutable-calls = ["fastapi.Depends", "fastapi.Query"]
[tool.poetry.group.worker.dependencies] [tool.poetry.group.worker.dependencies]
watchdog = ">=6.0.0,<7.0.0" watchdog = ">=6.0.0,<7.0.0"
setuptools = "^75.8.0" setuptools = "^75.8.0"