introduces group/global usage & permissions, validates in endpoints and tests endpoints

This commit is contained in:
msramalho
2025-02-06 18:41:12 +00:00
parent 2b8c48af1b
commit 5344cc56e7
10 changed files with 252 additions and 52 deletions

View File

@@ -3,7 +3,7 @@ from functools import lru_cache
from sqlalchemy.orm import Session, load_only
from sqlalchemy import Column, or_, func
from loguru import logger
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from core.config import ALLOW_ANY_EMAIL
from db.database import get_db
@@ -51,7 +51,7 @@ def search_archives_by_email(db: Session, email: str, skip: int = 0, limit: int
def create_task(db: Session, task: schemas.ArchiveCreate, tags: list[models.Tag], urls: list[models.ArchiveUrl]):
db_task = models.Archive(id=task.id, url=task.url, result=task.result, public=task.public, author_id=task.author_id, group_id=task.group_id)
db_task = models.Archive(id=task.id, url=task.url, result=task.result, public=task.public, author_id=task.author_id, group_id=task.group_id, sheet_id=task.sheet_id)
db_task.tags = tags
db_task.urls = urls
db.add(db_task)
@@ -246,8 +246,15 @@ def get_user_sheet(db: Session, email: str, sheet_id: str) -> models.Sheet:
def get_user_sheets(db: Session, email: str) -> list[models.Sheet]:
return db.query(models.Sheet).filter(models.Sheet.author_id == email).order_by(models.Sheet.last_archived_at.desc()).all()
return db.query(models.Sheet).filter(models.Sheet.author_id == email).order_by(models.Sheet.last_url_archived_at.desc()).all()
def update_sheet_last_url_archived_at(db: Session, sheet_id: str):
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id).first()
if db_sheet:
db_sheet.last_url_archived_at = datetime.now()
db.commit()
return True
return False
def delete_sheet(db: Session, sheet_id: str, email: str) -> bool:
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id, models.Sheet.author_id == email).first()

View File

@@ -25,16 +25,15 @@ association_table_user_groups = Table(
Column("group_id", ForeignKey("groups.id")),
)
# data model tables
class Archive(Base):
__tablename__ = "archives"
id = Column(String, primary_key=True, index=True)
url = Column(String, index=True)
result = Column(JSON, default=None)
public = Column(Boolean, default=True) # if public=false, access to group and author
public = Column(Boolean, default=True) # if public=false, access by group and author
deleted = Column(Boolean, default=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
@@ -102,8 +101,9 @@ class Sheet(Base):
author_id = Column(String, ForeignKey("users.email"))
group_id = Column(String, ForeignKey("groups.id"), doc="Group ID, user must be in a group to create a sheet.")
frequency = Column(String, default="daily", doc="Frequency of archiving: hourly, daily, weekly.")
# TODO: stats is not needed, is it?
stats = Column(JSON, default={}, doc="Sheet statistics like total links, total rows, ...")
last_archived_at = Column(DateTime(timezone=True), server_default=func.now(), doc="Last time a new link was archived.")
last_url_archived_at = Column(DateTime(timezone=True), server_default=func.now(), doc="Last time a new link was archived.")
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())

View File

@@ -1,6 +1,6 @@
from typing import Annotated
from annotated_types import Len
from pydantic import BaseModel, field_validator
from pydantic import BaseModel
from datetime import datetime
@@ -21,6 +21,7 @@ class ArchiveCreate(BaseModel):
group_id: str | None = None
tags: set[Tag] | None = set()
rearchive: bool = True
sheet_id: str | None = None
# urls: list = []
@@ -97,9 +98,8 @@ class SheetAdd(BaseModel):
class SheetResponse(SheetAdd):
author_id: str
stats: dict | None
last_archived_at: datetime | None
created_at: datetime
last_url_archived_at: datetime | None
class ArchiveTrigger(BaseModel):

View File

@@ -29,6 +29,11 @@ class UserState:
read_public=self.read_public,
archive_url=self.archive_url,
archive_sheet=self.archive_sheet,
# below are relevant only for /url endpoints
max_archive_lifespan_months=self.max_archive_lifespan_months,
max_monthly_urls=self.max_monthly_urls,
max_monthly_mbs=self.max_monthly_mbs,
priority=self.priority
)
for group in self.user_groups:
if not group.permissions: continue
@@ -117,6 +122,34 @@ class UserState:
self._sheet_frequency.update(group.permissions.get("sheet_frequency", None))
return self._sheet_frequency
@property
def max_archive_lifespan_months(self) -> int:
if not hasattr(self, '_max_archive_lifespan_months'):
self._max_archive_lifespan_months = self._helper_for_grouping_max_numerical_permissions("max_archive_lifespan_months")
return self._max_archive_lifespan_months
@property
def max_monthly_urls(self) -> int:
if not hasattr(self, '_max_monthly_urls'):
self._max_monthly_urls = self._helper_for_grouping_max_numerical_permissions("max_monthly_urls")
return self._max_monthly_urls
@property
def max_monthly_mbs(self) -> int:
if not hasattr(self, '_max_monthly_mbs'):
self._max_monthly_mbs = self._helper_for_grouping_max_numerical_permissions("max_monthly_mbs")
return self._max_monthly_mbs
@property
def priority(self) -> str:
if not hasattr(self, '_priority'):
self._priority = "low"
for group in self.user_groups:
if not group.permissions: continue
if group.permissions.get("priority", "low") == "high":
self._priority = "high"
return self._priority
@property
def active(self) -> bool:
"""
@@ -125,34 +158,114 @@ class UserState:
if not hasattr(self, '_active'):
self._active = bool(self.read or self.read_public or self.archive_url or self.archive_sheet)
return self._active
def _helper_for_grouping_max_numerical_permissions(self, permission_name: str) -> int:
"""
Iterates one of the numerical permissions where -1 means no restrictions and returns either -1 or the maximum value, defaults according to GroupPermissions
"""
default = GroupPermissions.model_fields[permission_name].default
max_value = default
for group in self.user_groups:
if not group.permissions: continue
group_value = group.permissions.get(permission_name, default)
if group_value == -1:
max_value = -1
return max_value
max_value = max(max_value, group_value)
return max_value
def in_group(self, group_id: str) -> bool:
return group_id in self.user_groups_names
def usage(self) -> Dict:
"""
returns the monthly quotas for the URLs/MBs and the totals for Sheets
"""
current_month = datetime.now().month
current_year = datetime.now().year
# find and sum all user sheets over this month
user_sheets = self.db.query(
models.Sheet.group_id,
func.count(models.Sheet.id).label('sheet_count')
).filter(models.Sheet.author_id == self.email).group_by(models.Sheet.group_id).all()
sheets_by_group = {sheet.group_id: sheet.sheet_count for sheet in user_sheets}
# find and sum all user urls over this month
urls_by_group = self.db.query(
models.Archive.group_id,
func.count(models.Archive.id).label('url_count'),
func.coalesce(func.sum(
func.coalesce(
func.cast(
func.json_extract(models.Archive.result, '$.metadata.total_bytes'),
sqlalchemy.Integer
), 0
)
), 0).label('total_bytes')
).filter(
models.Archive.author_id == self.email,
func.extract('month', models.Archive.created_at) == current_month,
func.extract('year', models.Archive.created_at) == current_year
).group_by(models.Archive.group_id).all()
# merge the two queries
usage_by_group = {
(url.group_id or ""): {
"monthly_urls": url.url_count,
"monthly_mbs": int(url.total_bytes / 1024 / 1024),
"total_sheets": 0
}
for url in urls_by_group
}
for group_id, sheet_count in sheets_by_group.items():
group_id = group_id or ""
if group_id in usage_by_group:
usage_by_group[group_id]["total_sheets"] = sheet_count
else:
usage_by_group[group_id] = {
"monthly_urls": 0,
"monthly_mbs": 0,
"total_sheets": sheet_count
}
# calculate totals
total_sheets = sum([sheet.sheet_count for sheet in user_sheets])
total_bytes = sum([url.total_bytes for url in urls_by_group])
total_urls = sum([url.url_count for url in urls_by_group])
return {
"total_sheets": total_sheets,
"monthly_urls": total_urls,
"monthly_mbs": int(total_bytes / 1024 / 1024),
"groups": usage_by_group
}
def has_quota_monthly_sheets(self, group_id: str) -> bool:
"""
checks if a user has reached their sheet quota for a given group
"""
if group_id not in self.permissions:
if group_id not in self.permissions:
return False
user_sheets = self.db.query(models.Sheet).filter(models.Sheet.author_id == self.email, models.Sheet.group_id == group_id).count()
sheet_quota = self.permissions[group_id].max_sheets
if sheet_quota == -1:
if sheet_quota == -1:
return True
return user_sheets < sheet_quota
def has_quota_max_monthly_urls(self) -> bool:
def has_quota_max_monthly_urls(self, group_id:str) -> bool:
"""
checks if a user has reached their monthly url quota
checks if a user has reached their monthly url quota for a group, if global then group should be empty string
"""
quota = 0
for group in self.user_groups:
if not group.permissions: continue
max_monthly_urls = group.permissions.get("max_monthly_urls", 0)
if max_monthly_urls == -1: return True
quota = max(quota, max_monthly_urls)
if not group_id:
quota = self.max_monthly_urls
else:
if group_id not in self.permissions: return False
quota = self.permissions[group_id].max_monthly_urls
current_month = datetime.now().month
current_year = datetime.now().year
@@ -164,16 +277,16 @@ class UserState:
return user_urls < quota
def has_quota_max_monthly_mbs(self) -> bool:
def has_quota_max_monthly_mbs(self, group_id:str) -> bool:
"""
checks if a user has reached their monthly mb quota
checks if a user has reached their monthly MBs quota for a group, if global then group should be empty string
"""
quota = 0
for group in self.user_groups:
if not group.permissions: continue
max_monthly_mbs = group.permissions.get("max_monthly_mbs", 0)
if max_monthly_mbs == -1: return True
quota = max(quota, max_monthly_mbs)
if not group_id:
quota = self.max_monthly_mbs
else:
if group_id not in self.permissions: return False
quota = self.permissions[group_id].max_monthly_mbs
current_month = datetime.now().month
current_year = datetime.now().year
@@ -196,20 +309,20 @@ class UserState:
user_mbs = int(user_bytes / 1024 / 1024)
return user_mbs < quota
def can_manually_trigger(self, group_id:str) -> bool:
def can_manually_trigger(self, group_id: str) -> bool:
"""
checks if a user is allowed to manually trigger a sheet
"""
if group_id not in self.permissions:
if group_id not in self.permissions:
return False
return self.permissions[group_id].manually_trigger_sheet
def is_sheet_frequency_allowed(self, group_id:str, frequency: str) -> bool:
def is_sheet_frequency_allowed(self, group_id: str, frequency: str) -> bool:
"""
checks if a user is allowed to create a sheet with this frequency for this group
"""
if group_id not in self.permissions:
if group_id not in self.permissions:
return False
return frequency in self.permissions[group_id].sheet_frequency

View File

@@ -39,13 +39,21 @@ async def active(
return {"active": user.active}
# TODO: test
@default_router.get("/user/permissions", summary="Get the user's global 'all' permissions and the permissions for each group they belong to.")
def get_user_permissions(
user: UserState = Depends(get_user_state),
) -> Dict[str, GroupPermissions]:
return user.permissions
@default_router.get("/user/usage", summary="Get the user's monthly URLs/MBs usage along with the total active sheets, breakdown by group.")
def get_user_usage(
user: UserState = Depends(get_user_state),
):
if not user.active:
raise HTTPException(status_code=403, detail="User is not active.")
return user.usage()
@default_router.get('/favicon.ico', include_in_schema=False)
async def favicon():

View File

@@ -13,6 +13,7 @@ from db import crud, schemas
from db.database import get_db_dependency
from worker.main import create_archive_task
from urllib.parse import urlparse
url_router = APIRouter(prefix="/url", tags=["Single URL operations"])
@@ -25,14 +26,18 @@ def archive_url(
) -> schemas.Task:
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {archive.url}")
parsed_url = urlparse(archive.url)
if not all([parsed_url.scheme, parsed_url.netloc]):
raise HTTPException(status_code=400, detail="Invalid URL received.")
if email != ALLOW_ANY_EMAIL:
user = UserState(db, email)
if not user.has_quota_max_monthly_urls():
raise HTTPException(status_code=429, detail="User has reached their monthly URL quota.")
if not user.has_quota_max_monthly_mbs():
raise HTTPException(status_code=429, detail="User has reached their monthly MB quota.")
if archive.group_id and not user.in_group(archive.group_id):
raise HTTPException(status_code=403, detail="User does not have access to this group.")
if not user.has_quota_max_monthly_urls(archive.group_id):
raise HTTPException(status_code=429, detail="User has reached their monthly URL quota.")
if not user.has_quota_max_monthly_mbs(archive.group_id):
raise HTTPException(status_code=429, detail="User has reached their monthly MB quota.")
# TODO: deprecate ArchiveCreate
backwards_compatible_archive = schemas.ArchiveCreate(

View File

@@ -0,0 +1,32 @@
"""rename sheets last_archived col
Revision ID: 1636724ec4b1
Revises: a23aaf3ae930
Create Date: 2025-02-05 19:19:01.984396
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '1636724ec4b1'
down_revision = 'a23aaf3ae930'
branch_labels = None
depends_on = None
def upgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns('sheets')]
if 'last_archived_at' in columns:
op.alter_column('sheets', 'last_archived_at', new_column_name='last_url_archived_at')
def downgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns('sheets')]
if 'last_url_archived_at' in columns:
op.alter_column('sheets', 'last_url_archived_at', new_column_name='last_archived_at')

View File

@@ -29,8 +29,7 @@ def test_create_sheet_endpoint(app_with_auth, db_session):
assert response.status_code == 201
j = response.json()
assert datetime.fromisoformat(j.pop("created_at"))
assert datetime.fromisoformat(j.pop("last_archived_at"))
assert j.pop("stats") == {}
assert datetime.fromisoformat(j.pop("last_url_archived_at"))
assert j.pop("author_id") == 'morty@example.com'
assert j == good_data
@@ -95,16 +94,15 @@ def test_get_user_sheets_endpoint(client_with_auth, db_session):
assert isinstance(r, list)
assert len(r) == 2
assert datetime.fromisoformat(r[0].pop("created_at"))
assert datetime.fromisoformat(r[0].pop("last_archived_at"))
assert datetime.fromisoformat(r[0].pop("last_url_archived_at"))
assert datetime.fromisoformat(r[1].pop("created_at"))
assert datetime.fromisoformat(r[1].pop("last_archived_at"))
assert datetime.fromisoformat(r[1].pop("last_url_archived_at"))
assert r[0] == {
'id': '123',
'author_id': 'morty@example.com',
'frequency': 'hourly',
'group_id': 'spaceship',
'name': 'Test Sheet 1',
'stats': {},
}
assert r[1] == {
'id': '456',
@@ -112,7 +110,6 @@ def test_get_user_sheets_endpoint(client_with_auth, db_session):
'frequency': 'daily',
'group_id': 'interdimensional',
'name': 'Test Sheet 2',
'stats': {},
}

View File

@@ -7,36 +7,67 @@ def test_archive_url_unauthenticated(client, test_no_auth):
test_no_auth(client.post, "/url/archive")
@patch("endpoints.url.UserState")
@patch("worker.main.create_archive_task.delay", return_value=TaskResult(id="123-456-789", status="PENDING", result=""))
def test_archive_url(m1, client_with_auth):
def test_archive_url(m1, m2, client_with_auth):
m_user_state = MagicMock()
m2.return_value = m_user_state
# url is too short
response = client_with_auth.post("/url/archive", json={"url": "bad"})
assert response.status_code == 422
assert response.json()["detail"][0]["msg"] == 'String should have at least 5 characters'
m1.assert_not_called()
# url is invalid
response = client_with_auth.post("/url/archive", json={"url": "example.com"})
assert response.status_code == 400
assert response.json()["detail"] == "Invalid URL received."
# valid request
m_user_state.has_quota_max_monthly_urls.return_value = True
m_user_state.has_quota_max_monthly_mbs.return_value = True
response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
assert response.status_code == 201
assert response.json() == {'id': '123-456-789'}
m1.assert_called_once()
called_val = m1.call_args.args[0]
assert json.loads(called_val) == {"id": None, "url": "https://example.com", "result": None, "public": True, "author_id": "rick@example.com", "group_id": None, "tags": [], "rearchive": True}
assert json.loads(called_val) == {"id": None, "url": "https://example.com", "result": None, "public": True, "author_id": "rick@example.com", "group_id": None, "tags": [], "rearchive": True, "sheet_id":None}
m_user_state.has_quota_max_monthly_urls.assert_called_once()
m_user_state.has_quota_max_monthly_mbs.assert_called_once()
# user is not in group
m_user_state.in_group.return_value = False
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "new-group"})
assert response.status_code == 403
assert response.json()["detail"] == "User does not have access to this group."
m_user_state.in_group.assert_called_once_with("new-group")
# user is in group
m_user_state.in_group.return_value = True
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spaceship"})
assert response.status_code == 201
assert response.json() == {'id': '123-456-789'}
assert m1.call_count == 2
called_val = m1.call_args.args[0]
assert json.loads(called_val)["group_id"] == "spaceship"
m_user_state.in_group.assert_called_with("spaceship")
# user is over monthly URL quota
m_user_state.has_quota_max_monthly_urls.return_value = False
m_user_state.has_quota_max_monthly_mbs.return_value = True
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spaceship"})
assert response.status_code == 429
assert response.json()["detail"] == "User has reached their monthly URL quota."
m_user_state.has_quota_max_monthly_urls.assert_called_with("spaceship")
# user is over monthly MB quota
m_user_state.has_quota_max_monthly_urls.return_value = True
m_user_state.has_quota_max_monthly_mbs.return_value = False
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spacesuit"})
assert response.status_code == 429
assert response.json()["detail"] == "User has reached their monthly MB quota."
m_user_state.has_quota_max_monthly_mbs.assert_called_with("spacesuit")
@patch("endpoints.url.UserState")
def test_archive_url_quotas(m1, client_with_auth):

View File

@@ -19,6 +19,7 @@ from core.logging import log_error
settings = get_settings()
celery = Celery(__name__)
celery.conf.broker_url = settings.CELERY_BROKER_URL
celery.conf.result_backend = settings.CELERY_RESULT_BACKEND
@@ -48,6 +49,7 @@ def create_archive_task(self, archive_json: str):
return Metadata.choose_most_complete([a.result for a in archives])
orchestrator = choose_orchestrator(archive.group_id, archive.author_id)
logger.info(f"Using orchestrator {orchestrator=}")
result = orchestrator.feed_item(Metadata().set_url(url))
try:
@@ -59,7 +61,7 @@ def create_archive_task(self, archive_json: str):
raise e
return result.to_dict()
#TODO: refactor how user-groups are loaded and orchestrators chosen
@celery.task(name="create_sheet_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 0})
def create_sheet_task(self, sheet_json: str):
sheet = schemas.SubmitSheet.model_validate_json(sheet_json)
@@ -79,7 +81,8 @@ def create_sheet_task(self, sheet_json: str):
continue
try:
#TODO: remove public from sheet in new refactor
insert_result_into_db(result, sheet.tags, sheet.public, sheet.group_id, sheet.author_id, models.generate_uuid())
#TODO: update the sheets table with the current date if any new archive was done
insert_result_into_db(result, sheet.tags, sheet.public, sheet.group_id, sheet.author_id, models.generate_uuid(), sheet.sheet_id)
stats["archived"] += 1
except exc.IntegrityError as e:
logger.warning(f"cached result detected: {e}")
@@ -89,6 +92,10 @@ def create_sheet_task(self, sheet_json: str):
stats["failed"] += 1
stats["errors"].append(str(e))
if stats["archived"] > 0:
with get_db() as session:
crud.update_sheet_last_url_archived_at(session, sheet.sheet_id)
logger.info(f"SHEET DONE {sheet=}")
return {"success": True, "sheet": sheet.sheet_name, "sheet_id": sheet.sheet_id, "time": datetime.datetime.now().isoformat(), **stats}
@@ -165,7 +172,7 @@ def is_group_invalid_for_user(public: bool, group_id: str, author_id: str):
return False
def insert_result_into_db(result: Metadata, tags: Set[str], public: bool, group_id: str, author_id: str, task_id: str) -> str:
def insert_result_into_db(result: Metadata, tags: Set[str], public: bool, group_id: str, author_id: str, task_id: str, sheet_id:str="") -> str:
logger.info(f"INSERTING {public=} {group_id=} {author_id=} {tags=} into {task_id}")
assert result, f"UNABLE TO archive: {result.get_url() if result else result}"
with get_db() as session:
@@ -175,7 +182,7 @@ def insert_result_into_db(result: Metadata, tags: Set[str], public: bool, group_
# create DB TAGs if needed
db_tags = [crud.create_tag(session, tag) for tag in tags]
# insert archive
db_task = crud.create_task(session, task=schemas.ArchiveCreate(id=task_id, url=result.get_url(), result=json.loads(result.to_json()), public=public, author_id=author_id, group_id=group_id), tags=db_tags, urls=get_all_urls(result))
db_task = crud.create_task(session, task=schemas.ArchiveCreate(id=task_id, url=result.get_url(), result=json.loads(result.to_json()), public=public, author_id=author_id, group_id=group_id, sheet_id=sheet_id), tags=db_tags, urls=get_all_urls(result))
logger.debug(f"Added {db_task.id=} to database on {db_task.created_at} ({db_task.author_id})")
return db_task.id