mirror of
https://github.com/bellingcat/auto-archiver-api.git
synced 2026-06-12 05:28:34 +03:00
major refactor of structure for worker V web: docker/app/secrets/envs/...
This commit is contained in:
0
app/shared/db/__init__.py
Normal file
0
app/shared/db/__init__.py
Normal file
314
app/shared/db/crud.py
Normal file
314
app/shared/db/crud.py
Normal file
@@ -0,0 +1,314 @@
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache
|
||||
from sqlalchemy.orm import Session, load_only
|
||||
from sqlalchemy import Column, or_, func, select
|
||||
from loguru import logger
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.shared.config import ALLOW_ANY_EMAIL
|
||||
from app.shared.db.database import get_db
|
||||
from app.shared.db import models
|
||||
from app.shared import schemas
|
||||
from app.shared.settings import get_settings
|
||||
from app.shared.user_groups import UserGroups
|
||||
from app.shared.utils.misc import fnv1a_hash_mod
|
||||
|
||||
DATABASE_QUERY_LIMIT = get_settings().DATABASE_QUERY_LIMIT
|
||||
|
||||
|
||||
def get_limit(user_limit: int):
|
||||
return max(1, min(user_limit, DATABASE_QUERY_LIMIT))
|
||||
|
||||
# --------------- TASK = Archive
|
||||
|
||||
def base_query(db: Session):
|
||||
# NOTE: load_only is for optimization and not obfuscation, use .with_entities() if needed
|
||||
return db.query(models.Archive)\
|
||||
.filter(models.Archive.deleted == False)\
|
||||
.options(load_only(models.Archive.id, models.Archive.created_at, models.Archive.url, models.Archive.result, models.Archive.store_until))
|
||||
|
||||
def get_archive(db: Session, id: str, email: str):
|
||||
query = base_query(db).filter(models.Archive.id == id)
|
||||
if email != ALLOW_ANY_EMAIL:
|
||||
groups = get_user_groups(email)
|
||||
query = query.filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups)))
|
||||
return query.first()
|
||||
|
||||
def search_archives_by_url(db: Session, url: str, email: str, skip: int = 0, limit: int = 100, archived_after: datetime = None, archived_before: datetime = None, absolute_search: bool = False)-> list[models.Archive]:
|
||||
# searches for partial URLs, if email is * no ownership filtering happens
|
||||
query = base_query(db)
|
||||
if email != ALLOW_ANY_EMAIL:
|
||||
groups = get_user_groups(email)
|
||||
query = query.filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups)))
|
||||
if absolute_search:
|
||||
query = query.filter(models.Archive.url == url)
|
||||
else:
|
||||
query = query.filter(models.Archive.url.like(f'%{url}%'))
|
||||
if archived_after:
|
||||
query = query.filter(models.Archive.created_at > archived_after)
|
||||
if archived_before:
|
||||
query = query.filter(models.Archive.created_at < archived_before)
|
||||
return query.order_by(models.Archive.created_at.desc()).offset(skip).limit(get_limit(limit)).all()
|
||||
|
||||
|
||||
def search_archives_by_email(db: Session, email: str, skip: int = 0, limit: int = 100):
|
||||
return base_query(db).filter(models.Archive.author_id == email).order_by(models.Archive.created_at.desc()).offset(skip).limit(get_limit(limit)).all()
|
||||
|
||||
#TODO: rename task to archive
|
||||
def create_task(db: Session, task: schemas.ArchiveCreate, tags: list[models.Tag], urls: list[models.ArchiveUrl]) -> models.Archive:
|
||||
db_task = models.Archive(id=task.id, url=task.url, result=task.result, public=task.public, author_id=task.author_id, group_id=task.group_id, sheet_id=task.sheet_id, store_until=task.store_until)
|
||||
db_task.tags = tags
|
||||
db_task.urls = urls
|
||||
db.add(db_task)
|
||||
db.commit()
|
||||
db.refresh(db_task)
|
||||
return db_task
|
||||
|
||||
|
||||
def soft_delete_task(db: Session, task_id: str, email: str) -> bool:
|
||||
# TODO: implement hard-delete with cronjob that deletes from S3
|
||||
db_task = db.query(models.Archive).filter(models.Archive.id == task_id, models.Archive.author_id == email, models.Archive.deleted == False).first()
|
||||
if db_task:
|
||||
db_task.deleted = True
|
||||
db.commit()
|
||||
return db_task is not None
|
||||
|
||||
|
||||
def count_archives(db: Session):
|
||||
return db.query(func.count(models.Archive.id)).scalar()
|
||||
|
||||
|
||||
def count_archive_urls(db: Session):
|
||||
return db.query(func.count(models.ArchiveUrl.url)).scalar()
|
||||
|
||||
|
||||
def count_users(db: Session):
|
||||
return db.query(func.count(models.User.email)).scalar()
|
||||
|
||||
|
||||
def count_by_user_since(db: Session, seconds_delta: int = 15):
|
||||
time_threshold = datetime.now() - timedelta(seconds=seconds_delta)
|
||||
return db.query(models.Archive.author_id, func.count().label('total'))\
|
||||
.filter(models.Archive.created_at >= time_threshold)\
|
||||
.group_by(models.Archive.author_id)\
|
||||
.order_by(func.count().desc())\
|
||||
.limit(500).all()
|
||||
|
||||
async def find_by_store_until(db: AsyncSession, store_until_is_before:datetime) -> dict:
|
||||
res = await db.execute(
|
||||
select(models.Archive)
|
||||
.filter(models.Archive.deleted ==False, models.Archive.store_until < store_until_is_before)
|
||||
)
|
||||
return res.scalars()
|
||||
|
||||
async def soft_delete_expired_archives(db: AsyncSession) -> dict:
|
||||
to_delete = await find_by_store_until(db, datetime.now())
|
||||
counter = 0
|
||||
for archive in to_delete:
|
||||
archive.deleted = True
|
||||
counter += 1
|
||||
await db.commit()
|
||||
return counter
|
||||
# --------------- TAG
|
||||
|
||||
|
||||
def create_tag(db: Session, tag: str) -> models.Tag:
|
||||
db_tag = db.query(models.Tag).filter(models.Tag.id == tag).first()
|
||||
if not db_tag:
|
||||
db_tag = models.Tag(id=tag)
|
||||
db.add(db_tag)
|
||||
db.commit()
|
||||
db.refresh(db_tag)
|
||||
return db_tag
|
||||
|
||||
|
||||
def is_user_in_group(db: Session, email: str, group_name: str) -> models.Group:
|
||||
if email == ALLOW_ANY_EMAIL: return True
|
||||
return len(group_name) and len(email) and group_name in get_user_groups(email)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_user_groups(email: str) -> list[str]:
|
||||
"""
|
||||
given an email retrieves the user groups from the DB and then the email-domain groups from a global variable, the email does not need to belong to an existing user.
|
||||
"""
|
||||
if not email or not len(email) or "@" not in email: return []
|
||||
|
||||
with get_db() as db:
|
||||
# get user groups
|
||||
user_groups = db.query(models.association_table_user_groups).filter_by(user_id=email).with_entities(Column("group_id")).all()
|
||||
user_level_groups_names = [g[0] for g in user_groups]
|
||||
|
||||
# get domain groups
|
||||
domain = email.split('@')[1]
|
||||
domain_level_groups = db.query(models.Group.id).filter(models.Group.domains.contains(domain)).with_entities(Column("id")).all()
|
||||
domain_level_groups_names = [g[0] for g in domain_level_groups]
|
||||
|
||||
return list(set(user_level_groups_names + domain_level_groups_names))
|
||||
|
||||
|
||||
# --------------- INIT User-Groups
|
||||
|
||||
def get_group(db: Session, group_name: str) -> models.Group:
|
||||
return db.query(models.Group).filter(models.Group.id == group_name).first()
|
||||
|
||||
|
||||
def create_or_get_user(db: Session, author_id: str) -> models.User:
|
||||
if type(author_id) == str: author_id = author_id.lower()
|
||||
db_user = db.query(models.User).filter(models.User.email == author_id).first()
|
||||
if not db_user:
|
||||
db_user = models.User(email=author_id)
|
||||
db.add(db_user)
|
||||
db.commit()
|
||||
db.refresh(db_user)
|
||||
return db_user
|
||||
|
||||
|
||||
def upsert_group(db: Session, group_name: str, description: str, orchestrator: str, orchestrator_sheet: str, permissions: dict, domains: list) -> models.Group:
|
||||
db_group = db.query(models.Group).filter(models.Group.id == group_name).first()
|
||||
if db_group is None:
|
||||
db_group = models.Group(id=group_name, description=description, orchestrator=orchestrator, orchestrator_sheet=orchestrator_sheet, permissions=permissions, domains=domains)
|
||||
db.add(db_group)
|
||||
else:
|
||||
db_group.description = description
|
||||
db_group.orchestrator = orchestrator
|
||||
db_group.orchestrator_sheet = orchestrator_sheet
|
||||
db_group.permissions = permissions
|
||||
db_group.domains = domains
|
||||
db.commit()
|
||||
db.refresh(db_group)
|
||||
return db_group
|
||||
|
||||
|
||||
def upsert_user(db: Session, email: str):
|
||||
email = email.lower()
|
||||
db_user = db.query(models.User).filter(models.User.email == email).first()
|
||||
if db_user is None:
|
||||
db_user = models.User(email=email)
|
||||
db.add(db_user)
|
||||
db.commit()
|
||||
return db_user
|
||||
|
||||
|
||||
def upsert_user_groups(db: Session):
|
||||
def display_email_pii(email: str):
|
||||
return f"'{email[0:3]}...@{email.split('@')[1]}'"
|
||||
"""
|
||||
reads the user_groups yaml file and inserts any new users, groups,
|
||||
along with new participation of users in groups
|
||||
"""
|
||||
logger.debug("Updating user-groups configuration.")
|
||||
filename = get_settings().USER_GROUPS_FILENAME
|
||||
|
||||
ug = UserGroups(filename)
|
||||
|
||||
# delete all user-groups relationships
|
||||
db.query(models.association_table_user_groups).delete()
|
||||
|
||||
# create a map of group_id -> domains and another of domain -> groups
|
||||
group_domains = defaultdict(set)
|
||||
domain_groups = defaultdict(list)
|
||||
for domain, explicit_groups in ug.domains.items():
|
||||
domain_groups[domain] = list(set(explicit_groups))
|
||||
for group in explicit_groups:
|
||||
group_domains[group].add(domain)
|
||||
import json
|
||||
# upsert groups and save a map of groupid -> dbobject
|
||||
for group_id, g in ug.groups.items():
|
||||
upsert_group(db, group_id, g.description, g.orchestrator, g.orchestrator_sheet, json.loads(g.permissions.model_dump_json()), list(group_domains.get(group_id, [])))
|
||||
db_groups: dict[str, models.Group] = {g.id: g for g in db.query(models.Group).all()}
|
||||
|
||||
# integrity checks
|
||||
for group_in_domains in group_domains:
|
||||
if group_in_domains not in db_groups:
|
||||
logger.warning(f"[CONFIG] Group '{group_in_domains}' does not exist in the database: domains setting will not work.")
|
||||
|
||||
# reinsert users in their EXPLICITLY DEFINED groups
|
||||
# domain groups are check live, as there may be new users that are not explicitly registered but belong to a domain
|
||||
for email, explicit_groups in ug.users.items():
|
||||
explicit_groups = explicit_groups or []
|
||||
logger.info(f"EXPLICIT {display_email_pii(email)} => {explicit_groups}")
|
||||
|
||||
db_user = upsert_user(db, email)
|
||||
|
||||
# connect users to groups
|
||||
for group_id in explicit_groups:
|
||||
if group_id not in db_groups:
|
||||
logger.warning(f"[CONFIG] Group {group_id} does not exist in config file, skipping for email={display_email_pii(email)}.")
|
||||
continue
|
||||
db_groups[group_id].users.append(db_user)
|
||||
|
||||
db.commit()
|
||||
count_user_groups = db.query(models.association_table_user_groups).count()
|
||||
count_groups = db.query(func.count(models.Group.id)).scalar()
|
||||
|
||||
logger.success(f"[CONFIG] DONE: [users={count_users(db)}, groups={count_groups}, explicit user groups={count_user_groups}].")
|
||||
|
||||
|
||||
# --------------- SHEET
|
||||
def create_sheet(db: Session, sheet_id: str, name: str, email: str, group_id: str, frequency: str):
|
||||
db_sheet = models.Sheet(id=sheet_id, name=name, author_id=email, group_id=group_id, frequency=frequency)
|
||||
db.add(db_sheet)
|
||||
db.commit()
|
||||
db.refresh(db_sheet)
|
||||
return db_sheet
|
||||
|
||||
|
||||
def get_user_sheet(db: Session, email: str, sheet_id: str) -> models.Sheet:
|
||||
return db.query(models.Sheet).filter(models.Sheet.author_id == email, models.Sheet.id == sheet_id).first()
|
||||
|
||||
|
||||
def get_user_sheets(db: Session, email: str) -> list[models.Sheet]:
|
||||
return db.query(models.Sheet).filter(models.Sheet.author_id == email).order_by(models.Sheet.last_url_archived_at.desc()).all()
|
||||
|
||||
|
||||
async def get_sheets_by_id_hash(db: AsyncSession, frequency: str, modulo: str, id_hash: str) -> list[models.Sheet]:
|
||||
result = await db.execute(
|
||||
select(models.Sheet).filter(models.Sheet.frequency == frequency)
|
||||
)
|
||||
filtered = []
|
||||
for sheet in result.scalars():
|
||||
if fnv1a_hash_mod(sheet.id, modulo) == id_hash:
|
||||
filtered.append(sheet)
|
||||
return filtered
|
||||
|
||||
async def delete_stale_sheets(db: AsyncSession, inactivity_days: int) -> dict:
|
||||
time_threshold = datetime.now() - timedelta(days=inactivity_days)
|
||||
result = await db.execute(
|
||||
select(models.Sheet).filter(models.Sheet.last_url_archived_at < time_threshold)
|
||||
)
|
||||
deleted = defaultdict(list)
|
||||
for sheet in result.scalars():
|
||||
await db.delete(sheet)
|
||||
deleted[sheet.author_id].append(sheet)
|
||||
await db.commit()
|
||||
return dict(deleted)
|
||||
|
||||
def update_sheet_last_url_archived_at(db: Session, sheet_id: str):
|
||||
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id).first()
|
||||
if db_sheet:
|
||||
db_sheet.last_url_archived_at = datetime.now()
|
||||
db.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def delete_sheet(db: Session, sheet_id: str, email: str) -> bool:
|
||||
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id, models.Sheet.author_id == email).first()
|
||||
if db_sheet:
|
||||
db.delete(db_sheet)
|
||||
db.commit()
|
||||
return db_sheet is not None
|
||||
|
||||
|
||||
#--- Celery worker tasks
|
||||
|
||||
|
||||
def store_archived_url(db: Session, archive: schemas.ArchiveCreate) -> models.Archive:
|
||||
# create and load user, tags, if needed
|
||||
create_or_get_user(db, archive.author_id)
|
||||
db_tags = [create_tag(db, tag) for tag in archive.tags]
|
||||
# insert everything
|
||||
db_task = create_task(db, task=archive, tags=db_tags, urls=archive.urls)
|
||||
return db_task
|
||||
67
app/shared/db/database.py
Normal file
67
app/shared/db/database.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from functools import lru_cache
|
||||
from sqlalchemy import Engine, create_engine, event, text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, AsyncEngine, async_sessionmaker
|
||||
|
||||
from app.shared.settings import get_settings
|
||||
|
||||
|
||||
@lru_cache
|
||||
def make_engine(database_url: str):
|
||||
engine = create_engine(database_url, connect_args={"check_same_thread": False})
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def set_sqlite_pragma(conn, _) -> None:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.close()
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
def make_session_local(engine: Engine):
|
||||
session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
return session_local
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db():
|
||||
session = make_session_local(make_engine(get_settings().DATABASE_PATH))()
|
||||
try: yield session
|
||||
finally: session.close()
|
||||
|
||||
|
||||
def get_db_dependency():
|
||||
# to use with Depends and ensure proper session closing
|
||||
with get_db() as db:
|
||||
yield db
|
||||
|
||||
def wal_checkpoint():
|
||||
# WAL checkpointing, make sure the .sqlite file receives the latest changes
|
||||
# to be called at startup as it halts writes
|
||||
with get_db() as db:
|
||||
db.execute(text("PRAGMA wal_checkpoint(TRUNCATE)"))
|
||||
|
||||
|
||||
# ASYNC connections
|
||||
async def make_async_engine(database_url: str) -> AsyncEngine:
|
||||
engine = create_async_engine(database_url, connect_args={"check_same_thread": False})
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(lambda sync_conn: sync_conn.execute(text("PRAGMA journal_mode=WAL;")))
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
async def make_async_session_local(engine: AsyncEngine) -> AsyncSession:
|
||||
return async_sessionmaker(engine, expire_on_commit=False, autoflush=False, autocommit=False)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db_async():
|
||||
engine = await make_async_engine(get_settings().ASYNC_DATABASE_PATH)
|
||||
async_session = await make_async_session_local(engine)
|
||||
async with async_session() as session:
|
||||
try: yield session
|
||||
finally: await engine.dispose()
|
||||
113
app/shared/db/models.py
Normal file
113
app/shared/db/models.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from sqlalchemy import Column, String, JSON, DateTime, Boolean, Table, ForeignKey
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship, declarative_base
|
||||
import uuid
|
||||
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def generate_uuid():
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
# many to many association tables
|
||||
association_table_archive_tags = Table(
|
||||
"mtm_archives_tags",
|
||||
Base.metadata,
|
||||
Column("archive_id", ForeignKey("archives.id")),
|
||||
Column("tag_id", ForeignKey("tags.id")),
|
||||
)
|
||||
association_table_user_groups = Table(
|
||||
"mtm_users_groups",
|
||||
Base.metadata,
|
||||
Column("user_id", ForeignKey("users.email")),
|
||||
Column("group_id", ForeignKey("groups.id")),
|
||||
)
|
||||
|
||||
|
||||
# data model tables
|
||||
class Archive(Base):
|
||||
__tablename__ = "archives"
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
url = Column(String, index=True)
|
||||
result = Column(JSON, default=None)
|
||||
public = Column(Boolean, default=True) # if public=false, access by group and author
|
||||
deleted = Column(Boolean, default=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
store_until = Column(DateTime(timezone=True), default=None)
|
||||
|
||||
group_id = Column(String, ForeignKey("groups.id"), default=None)
|
||||
author_id = Column(String, ForeignKey("users.email"))
|
||||
sheet_id = Column(String, ForeignKey("sheets.id"), default=None)
|
||||
|
||||
tags = relationship("Tag", back_populates="archives", secondary=association_table_archive_tags)
|
||||
group = relationship("Group", back_populates="archives")
|
||||
author = relationship("User", back_populates="archives")
|
||||
urls = relationship("ArchiveUrl", back_populates="archive")
|
||||
sheet = relationship("Sheet", back_populates="archives")
|
||||
|
||||
|
||||
class ArchiveUrl(Base):
|
||||
__tablename__ = "archive_urls"
|
||||
|
||||
url = Column(String, primary_key=True, index=True)
|
||||
archive_id = Column(String, ForeignKey("archives.id"), primary_key=True)
|
||||
key = Column(String, default=None)
|
||||
|
||||
archive = relationship("Archive", back_populates="urls")
|
||||
|
||||
|
||||
class Tag(Base):
|
||||
__tablename__ = "tags"
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
archives = relationship("Archive", back_populates="tags", secondary=association_table_archive_tags)
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
email = Column(String, primary_key=True, index=True)
|
||||
|
||||
archives = relationship("Archive", back_populates="author")
|
||||
sheets = relationship("Sheet", back_populates="author")
|
||||
groups = relationship("Group", back_populates="users", secondary=association_table_user_groups)
|
||||
|
||||
|
||||
class Group(Base):
|
||||
__tablename__ = "groups"
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
description = Column(String, default=None)
|
||||
orchestrator = Column(String, default=None)
|
||||
orchestrator_sheet = Column(String, default=None)
|
||||
permissions = Column(JSON, default={})
|
||||
domains = Column(JSON, default=[])
|
||||
|
||||
archives = relationship("Archive", back_populates="group")
|
||||
sheets = relationship("Sheet", back_populates="group")
|
||||
users = relationship("User", back_populates="groups", secondary=association_table_user_groups)
|
||||
|
||||
|
||||
class Sheet(Base):
|
||||
__tablename__ = "sheets"
|
||||
|
||||
id = Column(String, primary_key=True, index=True, doc="Google Sheet ID")
|
||||
name = Column(String, default=None)
|
||||
author_id = Column(String, ForeignKey("users.email"))
|
||||
group_id = Column(String, ForeignKey("groups.id"), doc="Group ID, user must be in a group to create a sheet.")
|
||||
frequency = Column(String, default="daily", doc="Frequency of archiving: hourly, daily, weekly.")
|
||||
# TODO: stats is not needed, is it?
|
||||
stats = Column(JSON, default={}, doc="Sheet statistics like total links, total rows, ...")
|
||||
last_url_archived_at = Column(DateTime(timezone=True), server_default=func.now(), doc="Last time a new link was archived.")
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
group = relationship("Group", back_populates="sheets")
|
||||
author = relationship("User", back_populates="sheets")
|
||||
archives = relationship("Archive", back_populates="sheet")
|
||||
328
app/shared/db/user_state.py
Normal file
328
app/shared/db/user_state.py
Normal file
@@ -0,0 +1,328 @@
|
||||
|
||||
from typing import Dict, Set
|
||||
import sqlalchemy
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
from datetime import datetime
|
||||
|
||||
from app.shared.db import crud, models
|
||||
from app.shared.user_groups import GroupInfo, GroupPermissions
|
||||
from app.shared.schemas import Usage, UsageResponse
|
||||
|
||||
class UserState:
|
||||
"""
|
||||
Manage a user's state and permissions
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session, email: str):
|
||||
self.db = db
|
||||
self.email = email.lower()
|
||||
|
||||
@property
|
||||
def permissions(self) -> Dict[str, GroupInfo]:
|
||||
"""
|
||||
Returns a dict of all group permissions and a special {"all": read/archive_url/archive_sheet} key
|
||||
"""
|
||||
if not hasattr(self, '_permissions'):
|
||||
self._permissions = {}
|
||||
self._permissions["all"] = GroupInfo(
|
||||
read=self.read,
|
||||
read_public=self.read_public,
|
||||
archive_url=self.archive_url,
|
||||
archive_sheet=self.archive_sheet,
|
||||
# below are relevant only for /url endpoints
|
||||
max_archive_lifespan_months=self.max_archive_lifespan_months,
|
||||
max_monthly_urls=self.max_monthly_urls,
|
||||
max_monthly_mbs=self.max_monthly_mbs,
|
||||
priority=self.priority
|
||||
)
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
self._permissions[group.id] = GroupInfo(**group.permissions, description=group.description)
|
||||
return self._permissions
|
||||
|
||||
@property
|
||||
def user_groups_names(self):
|
||||
if not hasattr(self, '_user_groups_names'):
|
||||
self._user_groups_names = crud.get_user_groups(self.email) + ["default"]
|
||||
return self._user_groups_names
|
||||
|
||||
@property
|
||||
def user_groups(self):
|
||||
if not hasattr(self, '_user_groups'):
|
||||
self._user_groups = self.db.query(models.Group).filter(
|
||||
models.Group.id.in_(self.user_groups_names)
|
||||
).all()
|
||||
return self._user_groups
|
||||
|
||||
@property
|
||||
def read(self) -> Set[str] | bool:
|
||||
"""
|
||||
Read can be a list of group names or True, if all can be read.
|
||||
"""
|
||||
if not hasattr(self, '_read'):
|
||||
self._read = set()
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
group_read_permissions = group.permissions.get("read", [])
|
||||
if "all" in group_read_permissions:
|
||||
self._read = True
|
||||
return self._read
|
||||
else:
|
||||
self._read.update(group.permissions.get("read", []))
|
||||
return self._read
|
||||
|
||||
@property
|
||||
def read_public(self) -> bool:
|
||||
"""
|
||||
Read public permission
|
||||
"""
|
||||
if not hasattr(self, '_read_public'):
|
||||
self._read_public = False
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
if group.permissions.get("read_public", False):
|
||||
self._read_public = True
|
||||
return self._read_public
|
||||
return self._read_public
|
||||
|
||||
@property
|
||||
def archive_url(self) -> bool:
|
||||
"""
|
||||
Archive URL permission
|
||||
"""
|
||||
if not hasattr(self, '_archive_url'):
|
||||
self._archive_url = False
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
if group.permissions.get("archive_url", False):
|
||||
self._archive_url = True
|
||||
return self._archive_url
|
||||
return self._archive_url
|
||||
|
||||
@property
|
||||
def archive_sheet(self) -> bool:
|
||||
"""
|
||||
Archive sheet permission
|
||||
"""
|
||||
if not hasattr(self, '_archive_sheet'):
|
||||
self._archive_sheet = False
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
if group.permissions.get("archive_sheet", False):
|
||||
self._archive_sheet = True
|
||||
return self._archive_sheet
|
||||
return self._archive_sheet
|
||||
|
||||
@property
|
||||
def sheet_frequency(self):
|
||||
if not hasattr(self, '_sheet_frequency'):
|
||||
self._sheet_frequency = set()
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
self._sheet_frequency.update(group.permissions.get("sheet_frequency", None))
|
||||
return self._sheet_frequency
|
||||
|
||||
@property
|
||||
def max_archive_lifespan_months(self) -> int:
|
||||
if not hasattr(self, '_max_archive_lifespan_months'):
|
||||
self._max_archive_lifespan_months = self._helper_for_grouping_max_numerical_permissions("max_archive_lifespan_months")
|
||||
return self._max_archive_lifespan_months
|
||||
|
||||
@property
|
||||
def max_monthly_urls(self) -> int:
|
||||
if not hasattr(self, '_max_monthly_urls'):
|
||||
self._max_monthly_urls = self._helper_for_grouping_max_numerical_permissions("max_monthly_urls")
|
||||
return self._max_monthly_urls
|
||||
|
||||
@property
|
||||
def max_monthly_mbs(self) -> int:
|
||||
if not hasattr(self, '_max_monthly_mbs'):
|
||||
self._max_monthly_mbs = self._helper_for_grouping_max_numerical_permissions("max_monthly_mbs")
|
||||
return self._max_monthly_mbs
|
||||
|
||||
@property
|
||||
def priority(self) -> str:
|
||||
if not hasattr(self, '_priority'):
|
||||
self._priority = "low"
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
if group.permissions.get("priority", "low") == "high":
|
||||
self._priority = "high"
|
||||
return self._priority
|
||||
|
||||
@property
|
||||
def active(self) -> bool:
|
||||
"""
|
||||
A user is active if they can read/archive anything
|
||||
"""
|
||||
if not hasattr(self, '_active'):
|
||||
self._active = bool(self.read or self.read_public or self.archive_url or self.archive_sheet)
|
||||
return self._active
|
||||
|
||||
def _helper_for_grouping_max_numerical_permissions(self, permission_name: str) -> int:
|
||||
"""
|
||||
Iterates one of the numerical permissions where -1 means no restrictions and returns either -1 or the maximum value, defaults according to GroupPermissions
|
||||
"""
|
||||
default = GroupPermissions.model_fields[permission_name].default
|
||||
max_value = default
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
group_value = group.permissions.get(permission_name, default)
|
||||
if group_value == -1:
|
||||
max_value = -1
|
||||
return max_value
|
||||
max_value = max(max_value, group_value)
|
||||
return max_value
|
||||
|
||||
def in_group(self, group_id: str) -> bool:
|
||||
return group_id in self.user_groups_names
|
||||
|
||||
def usage(self) -> Dict:
|
||||
"""
|
||||
returns the monthly quotas for the URLs/MBs and the totals for Sheets
|
||||
"""
|
||||
current_month = datetime.now().month
|
||||
current_year = datetime.now().year
|
||||
|
||||
# find and sum all user sheets over this month
|
||||
user_sheets = self.db.query(
|
||||
models.Sheet.group_id,
|
||||
func.count(models.Sheet.id).label('sheet_count')
|
||||
).filter(models.Sheet.author_id == self.email).group_by(models.Sheet.group_id).all()
|
||||
|
||||
sheets_by_group = {sheet.group_id: sheet.sheet_count for sheet in user_sheets}
|
||||
|
||||
# find and sum all user urls over this month
|
||||
urls_by_group = self.db.query(
|
||||
models.Archive.group_id,
|
||||
func.count(models.Archive.id).label('url_count'),
|
||||
func.coalesce(func.sum(
|
||||
func.coalesce(
|
||||
func.cast(
|
||||
func.json_extract(models.Archive.result, '$.metadata.total_bytes'),
|
||||
sqlalchemy.Integer
|
||||
), 0
|
||||
)
|
||||
), 0).label('total_bytes')
|
||||
).filter(
|
||||
models.Archive.author_id == self.email,
|
||||
func.extract('month', models.Archive.created_at) == current_month,
|
||||
func.extract('year', models.Archive.created_at) == current_year
|
||||
).group_by(models.Archive.group_id).all()
|
||||
|
||||
# merge the two queries
|
||||
usage_by_group: Dict[str, Usage] = {
|
||||
(url.group_id or ""):
|
||||
Usage(monthly_urls=url.url_count, monthly_mbs=int(url.total_bytes / 1024 / 1024))
|
||||
for url in urls_by_group
|
||||
}
|
||||
for group_id, sheet_count in sheets_by_group.items():
|
||||
group_id = group_id or ""
|
||||
if group_id in usage_by_group:
|
||||
usage_by_group[group_id].total_sheets = sheet_count
|
||||
else:
|
||||
usage_by_group[group_id] = Usage(total_sheets=sheet_count)
|
||||
|
||||
# calculate totals
|
||||
total_sheets = sum([sheet.sheet_count for sheet in user_sheets])
|
||||
total_bytes = sum([url.total_bytes for url in urls_by_group])
|
||||
total_urls = sum([url.url_count for url in urls_by_group])
|
||||
|
||||
return UsageResponse(
|
||||
monthly_urls=total_urls,
|
||||
monthly_mbs=int(total_bytes / 1024 / 1024),
|
||||
total_sheets=total_sheets,
|
||||
groups=usage_by_group
|
||||
)
|
||||
|
||||
def has_quota_monthly_sheets(self, group_id: str) -> bool:
|
||||
"""
|
||||
checks if a user has reached their sheet quota for a given group
|
||||
"""
|
||||
if group_id not in self.permissions:
|
||||
return False
|
||||
|
||||
user_sheets = self.db.query(models.Sheet).filter(models.Sheet.author_id == self.email, models.Sheet.group_id == group_id).count()
|
||||
|
||||
sheet_quota = self.permissions[group_id].max_sheets
|
||||
if sheet_quota == -1:
|
||||
return True
|
||||
return user_sheets < sheet_quota
|
||||
|
||||
def has_quota_max_monthly_urls(self, group_id: str) -> bool:
|
||||
"""
|
||||
checks if a user has reached their monthly url quota for a group, if global then group should be empty string
|
||||
"""
|
||||
quota = 0
|
||||
if not group_id:
|
||||
quota = self.max_monthly_urls
|
||||
else:
|
||||
if group_id not in self.permissions: return False
|
||||
quota = self.permissions[group_id].max_monthly_urls
|
||||
|
||||
if quota == -1:
|
||||
return True
|
||||
|
||||
current_month = datetime.now().month
|
||||
current_year = datetime.now().year
|
||||
user_urls = self.db.query(models.Archive).filter(
|
||||
models.Archive.author_id == self.email,
|
||||
func.extract('month', models.Archive.created_at) == current_month,
|
||||
func.extract('year', models.Archive.created_at) == current_year
|
||||
).count()
|
||||
|
||||
return user_urls < quota
|
||||
|
||||
def has_quota_max_monthly_mbs(self, group_id: str) -> bool:
|
||||
"""
|
||||
checks if a user has reached their monthly MBs quota for a group, if global then group should be empty string
|
||||
"""
|
||||
quota = 0
|
||||
if not group_id:
|
||||
quota = self.max_monthly_mbs
|
||||
else:
|
||||
if group_id not in self.permissions: return False
|
||||
quota = self.permissions[group_id].max_monthly_mbs
|
||||
|
||||
if quota == -1:
|
||||
return True
|
||||
|
||||
current_month = datetime.now().month
|
||||
current_year = datetime.now().year
|
||||
|
||||
# find and sum all user bytes over this month
|
||||
user_bytes = self.db.query(models.Archive).filter(
|
||||
models.Archive.author_id == self.email,
|
||||
func.extract('month', models.Archive.created_at) == current_month,
|
||||
func.extract('year', models.Archive.created_at) == current_year
|
||||
).with_entities(func.coalesce(func.sum(
|
||||
func.coalesce(
|
||||
func.cast(
|
||||
func.json_extract(models.Archive.result, '$.metadata.total_bytes'),
|
||||
sqlalchemy.Integer
|
||||
), 0
|
||||
)
|
||||
), 0).label('total')).scalar()
|
||||
|
||||
# convert bytes to mb
|
||||
user_mbs = int(user_bytes / 1024 / 1024)
|
||||
return user_mbs < quota
|
||||
|
||||
def can_manually_trigger(self, group_id: str) -> bool:
|
||||
"""
|
||||
checks if a user is allowed to manually trigger a sheet
|
||||
"""
|
||||
if group_id not in self.permissions:
|
||||
return False
|
||||
|
||||
return self.permissions[group_id].manually_trigger_sheet
|
||||
|
||||
def is_sheet_frequency_allowed(self, group_id: str, frequency: str) -> bool:
|
||||
"""
|
||||
checks if a user is allowed to create a sheet with this frequency for this group
|
||||
"""
|
||||
if group_id not in self.permissions:
|
||||
return False
|
||||
|
||||
return frequency in self.permissions[group_id].sheet_frequency
|
||||
Reference in New Issue
Block a user