mirror of
https://github.com/bellingcat/auto-archiver-api.git
synced 2026-06-08 03:28:35 +03:00
introduces group/global usage & permissions, validates in endpoints and tests endpoints
This commit is contained in:
@@ -3,7 +3,7 @@ from functools import lru_cache
|
||||
from sqlalchemy.orm import Session, load_only
|
||||
from sqlalchemy import Column, or_, func
|
||||
from loguru import logger
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from core.config import ALLOW_ANY_EMAIL
|
||||
from db.database import get_db
|
||||
@@ -51,7 +51,7 @@ def search_archives_by_email(db: Session, email: str, skip: int = 0, limit: int
|
||||
|
||||
|
||||
def create_task(db: Session, task: schemas.ArchiveCreate, tags: list[models.Tag], urls: list[models.ArchiveUrl]):
|
||||
db_task = models.Archive(id=task.id, url=task.url, result=task.result, public=task.public, author_id=task.author_id, group_id=task.group_id)
|
||||
db_task = models.Archive(id=task.id, url=task.url, result=task.result, public=task.public, author_id=task.author_id, group_id=task.group_id, sheet_id=task.sheet_id)
|
||||
db_task.tags = tags
|
||||
db_task.urls = urls
|
||||
db.add(db_task)
|
||||
@@ -246,8 +246,15 @@ def get_user_sheet(db: Session, email: str, sheet_id: str) -> models.Sheet:
|
||||
|
||||
|
||||
def get_user_sheets(db: Session, email: str) -> list[models.Sheet]:
|
||||
return db.query(models.Sheet).filter(models.Sheet.author_id == email).order_by(models.Sheet.last_archived_at.desc()).all()
|
||||
return db.query(models.Sheet).filter(models.Sheet.author_id == email).order_by(models.Sheet.last_url_archived_at.desc()).all()
|
||||
|
||||
def update_sheet_last_url_archived_at(db: Session, sheet_id: str):
|
||||
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id).first()
|
||||
if db_sheet:
|
||||
db_sheet.last_url_archived_at = datetime.now()
|
||||
db.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
def delete_sheet(db: Session, sheet_id: str, email: str) -> bool:
|
||||
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id, models.Sheet.author_id == email).first()
|
||||
|
||||
@@ -25,16 +25,15 @@ association_table_user_groups = Table(
|
||||
Column("group_id", ForeignKey("groups.id")),
|
||||
)
|
||||
|
||||
|
||||
# data model tables
|
||||
|
||||
|
||||
class Archive(Base):
|
||||
__tablename__ = "archives"
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
url = Column(String, index=True)
|
||||
result = Column(JSON, default=None)
|
||||
public = Column(Boolean, default=True) # if public=false, access to group and author
|
||||
public = Column(Boolean, default=True) # if public=false, access by group and author
|
||||
deleted = Column(Boolean, default=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
@@ -102,8 +101,9 @@ class Sheet(Base):
|
||||
author_id = Column(String, ForeignKey("users.email"))
|
||||
group_id = Column(String, ForeignKey("groups.id"), doc="Group ID, user must be in a group to create a sheet.")
|
||||
frequency = Column(String, default="daily", doc="Frequency of archiving: hourly, daily, weekly.")
|
||||
# TODO: stats is not needed, is it?
|
||||
stats = Column(JSON, default={}, doc="Sheet statistics like total links, total rows, ...")
|
||||
last_archived_at = Column(DateTime(timezone=True), server_default=func.now(), doc="Last time a new link was archived.")
|
||||
last_url_archived_at = Column(DateTime(timezone=True), server_default=func.now(), doc="Last time a new link was archived.")
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Annotated
|
||||
from annotated_types import Len
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ class ArchiveCreate(BaseModel):
|
||||
group_id: str | None = None
|
||||
tags: set[Tag] | None = set()
|
||||
rearchive: bool = True
|
||||
sheet_id: str | None = None
|
||||
# urls: list = []
|
||||
|
||||
|
||||
@@ -97,9 +98,8 @@ class SheetAdd(BaseModel):
|
||||
|
||||
class SheetResponse(SheetAdd):
|
||||
author_id: str
|
||||
stats: dict | None
|
||||
last_archived_at: datetime | None
|
||||
created_at: datetime
|
||||
last_url_archived_at: datetime | None
|
||||
|
||||
|
||||
class ArchiveTrigger(BaseModel):
|
||||
|
||||
@@ -29,6 +29,11 @@ class UserState:
|
||||
read_public=self.read_public,
|
||||
archive_url=self.archive_url,
|
||||
archive_sheet=self.archive_sheet,
|
||||
# below are relevant only for /url endpoints
|
||||
max_archive_lifespan_months=self.max_archive_lifespan_months,
|
||||
max_monthly_urls=self.max_monthly_urls,
|
||||
max_monthly_mbs=self.max_monthly_mbs,
|
||||
priority=self.priority
|
||||
)
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
@@ -117,6 +122,34 @@ class UserState:
|
||||
self._sheet_frequency.update(group.permissions.get("sheet_frequency", None))
|
||||
return self._sheet_frequency
|
||||
|
||||
@property
|
||||
def max_archive_lifespan_months(self) -> int:
|
||||
if not hasattr(self, '_max_archive_lifespan_months'):
|
||||
self._max_archive_lifespan_months = self._helper_for_grouping_max_numerical_permissions("max_archive_lifespan_months")
|
||||
return self._max_archive_lifespan_months
|
||||
|
||||
@property
|
||||
def max_monthly_urls(self) -> int:
|
||||
if not hasattr(self, '_max_monthly_urls'):
|
||||
self._max_monthly_urls = self._helper_for_grouping_max_numerical_permissions("max_monthly_urls")
|
||||
return self._max_monthly_urls
|
||||
|
||||
@property
|
||||
def max_monthly_mbs(self) -> int:
|
||||
if not hasattr(self, '_max_monthly_mbs'):
|
||||
self._max_monthly_mbs = self._helper_for_grouping_max_numerical_permissions("max_monthly_mbs")
|
||||
return self._max_monthly_mbs
|
||||
|
||||
@property
|
||||
def priority(self) -> str:
|
||||
if not hasattr(self, '_priority'):
|
||||
self._priority = "low"
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
if group.permissions.get("priority", "low") == "high":
|
||||
self._priority = "high"
|
||||
return self._priority
|
||||
|
||||
@property
|
||||
def active(self) -> bool:
|
||||
"""
|
||||
@@ -125,34 +158,114 @@ class UserState:
|
||||
if not hasattr(self, '_active'):
|
||||
self._active = bool(self.read or self.read_public or self.archive_url or self.archive_sheet)
|
||||
return self._active
|
||||
|
||||
def _helper_for_grouping_max_numerical_permissions(self, permission_name: str) -> int:
|
||||
"""
|
||||
Iterates one of the numerical permissions where -1 means no restrictions and returns either -1 or the maximum value, defaults according to GroupPermissions
|
||||
"""
|
||||
default = GroupPermissions.model_fields[permission_name].default
|
||||
max_value = default
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
group_value = group.permissions.get(permission_name, default)
|
||||
if group_value == -1:
|
||||
max_value = -1
|
||||
return max_value
|
||||
max_value = max(max_value, group_value)
|
||||
return max_value
|
||||
|
||||
def in_group(self, group_id: str) -> bool:
|
||||
return group_id in self.user_groups_names
|
||||
|
||||
def usage(self) -> Dict:
|
||||
"""
|
||||
returns the monthly quotas for the URLs/MBs and the totals for Sheets
|
||||
"""
|
||||
current_month = datetime.now().month
|
||||
current_year = datetime.now().year
|
||||
|
||||
# find and sum all user sheets over this month
|
||||
user_sheets = self.db.query(
|
||||
models.Sheet.group_id,
|
||||
func.count(models.Sheet.id).label('sheet_count')
|
||||
).filter(models.Sheet.author_id == self.email).group_by(models.Sheet.group_id).all()
|
||||
|
||||
sheets_by_group = {sheet.group_id: sheet.sheet_count for sheet in user_sheets}
|
||||
|
||||
# find and sum all user urls over this month
|
||||
urls_by_group = self.db.query(
|
||||
models.Archive.group_id,
|
||||
func.count(models.Archive.id).label('url_count'),
|
||||
func.coalesce(func.sum(
|
||||
func.coalesce(
|
||||
func.cast(
|
||||
func.json_extract(models.Archive.result, '$.metadata.total_bytes'),
|
||||
sqlalchemy.Integer
|
||||
), 0
|
||||
)
|
||||
), 0).label('total_bytes')
|
||||
).filter(
|
||||
models.Archive.author_id == self.email,
|
||||
func.extract('month', models.Archive.created_at) == current_month,
|
||||
func.extract('year', models.Archive.created_at) == current_year
|
||||
).group_by(models.Archive.group_id).all()
|
||||
|
||||
# merge the two queries
|
||||
usage_by_group = {
|
||||
(url.group_id or ""): {
|
||||
"monthly_urls": url.url_count,
|
||||
"monthly_mbs": int(url.total_bytes / 1024 / 1024),
|
||||
"total_sheets": 0
|
||||
}
|
||||
for url in urls_by_group
|
||||
}
|
||||
for group_id, sheet_count in sheets_by_group.items():
|
||||
group_id = group_id or ""
|
||||
if group_id in usage_by_group:
|
||||
usage_by_group[group_id]["total_sheets"] = sheet_count
|
||||
else:
|
||||
usage_by_group[group_id] = {
|
||||
"monthly_urls": 0,
|
||||
"monthly_mbs": 0,
|
||||
"total_sheets": sheet_count
|
||||
}
|
||||
|
||||
# calculate totals
|
||||
total_sheets = sum([sheet.sheet_count for sheet in user_sheets])
|
||||
total_bytes = sum([url.total_bytes for url in urls_by_group])
|
||||
total_urls = sum([url.url_count for url in urls_by_group])
|
||||
|
||||
return {
|
||||
"total_sheets": total_sheets,
|
||||
"monthly_urls": total_urls,
|
||||
"monthly_mbs": int(total_bytes / 1024 / 1024),
|
||||
"groups": usage_by_group
|
||||
}
|
||||
|
||||
def has_quota_monthly_sheets(self, group_id: str) -> bool:
|
||||
"""
|
||||
checks if a user has reached their sheet quota for a given group
|
||||
"""
|
||||
if group_id not in self.permissions:
|
||||
if group_id not in self.permissions:
|
||||
return False
|
||||
|
||||
user_sheets = self.db.query(models.Sheet).filter(models.Sheet.author_id == self.email, models.Sheet.group_id == group_id).count()
|
||||
|
||||
|
||||
sheet_quota = self.permissions[group_id].max_sheets
|
||||
if sheet_quota == -1:
|
||||
if sheet_quota == -1:
|
||||
return True
|
||||
return user_sheets < sheet_quota
|
||||
|
||||
def has_quota_max_monthly_urls(self) -> bool:
|
||||
def has_quota_max_monthly_urls(self, group_id:str) -> bool:
|
||||
"""
|
||||
checks if a user has reached their monthly url quota
|
||||
checks if a user has reached their monthly url quota for a group, if global then group should be empty string
|
||||
"""
|
||||
quota = 0
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
max_monthly_urls = group.permissions.get("max_monthly_urls", 0)
|
||||
if max_monthly_urls == -1: return True
|
||||
quota = max(quota, max_monthly_urls)
|
||||
if not group_id:
|
||||
quota = self.max_monthly_urls
|
||||
else:
|
||||
if group_id not in self.permissions: return False
|
||||
quota = self.permissions[group_id].max_monthly_urls
|
||||
|
||||
current_month = datetime.now().month
|
||||
current_year = datetime.now().year
|
||||
@@ -164,16 +277,16 @@ class UserState:
|
||||
|
||||
return user_urls < quota
|
||||
|
||||
def has_quota_max_monthly_mbs(self) -> bool:
|
||||
def has_quota_max_monthly_mbs(self, group_id:str) -> bool:
|
||||
"""
|
||||
checks if a user has reached their monthly mb quota
|
||||
checks if a user has reached their monthly MBs quota for a group, if global then group should be empty string
|
||||
"""
|
||||
quota = 0
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
max_monthly_mbs = group.permissions.get("max_monthly_mbs", 0)
|
||||
if max_monthly_mbs == -1: return True
|
||||
quota = max(quota, max_monthly_mbs)
|
||||
if not group_id:
|
||||
quota = self.max_monthly_mbs
|
||||
else:
|
||||
if group_id not in self.permissions: return False
|
||||
quota = self.permissions[group_id].max_monthly_mbs
|
||||
|
||||
current_month = datetime.now().month
|
||||
current_year = datetime.now().year
|
||||
@@ -196,20 +309,20 @@ class UserState:
|
||||
user_mbs = int(user_bytes / 1024 / 1024)
|
||||
return user_mbs < quota
|
||||
|
||||
def can_manually_trigger(self, group_id:str) -> bool:
|
||||
def can_manually_trigger(self, group_id: str) -> bool:
|
||||
"""
|
||||
checks if a user is allowed to manually trigger a sheet
|
||||
"""
|
||||
if group_id not in self.permissions:
|
||||
if group_id not in self.permissions:
|
||||
return False
|
||||
|
||||
|
||||
return self.permissions[group_id].manually_trigger_sheet
|
||||
|
||||
def is_sheet_frequency_allowed(self, group_id:str, frequency: str) -> bool:
|
||||
def is_sheet_frequency_allowed(self, group_id: str, frequency: str) -> bool:
|
||||
"""
|
||||
checks if a user is allowed to create a sheet with this frequency for this group
|
||||
"""
|
||||
if group_id not in self.permissions:
|
||||
if group_id not in self.permissions:
|
||||
return False
|
||||
|
||||
|
||||
return frequency in self.permissions[group_id].sheet_frequency
|
||||
|
||||
@@ -39,13 +39,21 @@ async def active(
|
||||
return {"active": user.active}
|
||||
|
||||
|
||||
# TODO: test
|
||||
@default_router.get("/user/permissions", summary="Get the user's global 'all' permissions and the permissions for each group they belong to.")
|
||||
def get_user_permissions(
|
||||
user: UserState = Depends(get_user_state),
|
||||
) -> Dict[str, GroupPermissions]:
|
||||
return user.permissions
|
||||
|
||||
@default_router.get("/user/usage", summary="Get the user's monthly URLs/MBs usage along with the total active sheets, breakdown by group.")
|
||||
def get_user_usage(
|
||||
user: UserState = Depends(get_user_state),
|
||||
):
|
||||
if not user.active:
|
||||
raise HTTPException(status_code=403, detail="User is not active.")
|
||||
return user.usage()
|
||||
|
||||
|
||||
|
||||
@default_router.get('/favicon.ico', include_in_schema=False)
|
||||
async def favicon():
|
||||
|
||||
@@ -13,6 +13,7 @@ from db import crud, schemas
|
||||
from db.database import get_db_dependency
|
||||
|
||||
from worker.main import create_archive_task
|
||||
from urllib.parse import urlparse
|
||||
|
||||
url_router = APIRouter(prefix="/url", tags=["Single URL operations"])
|
||||
|
||||
@@ -25,14 +26,18 @@ def archive_url(
|
||||
) -> schemas.Task:
|
||||
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {archive.url}")
|
||||
|
||||
parsed_url = urlparse(archive.url)
|
||||
if not all([parsed_url.scheme, parsed_url.netloc]):
|
||||
raise HTTPException(status_code=400, detail="Invalid URL received.")
|
||||
|
||||
if email != ALLOW_ANY_EMAIL:
|
||||
user = UserState(db, email)
|
||||
if not user.has_quota_max_monthly_urls():
|
||||
raise HTTPException(status_code=429, detail="User has reached their monthly URL quota.")
|
||||
if not user.has_quota_max_monthly_mbs():
|
||||
raise HTTPException(status_code=429, detail="User has reached their monthly MB quota.")
|
||||
if archive.group_id and not user.in_group(archive.group_id):
|
||||
raise HTTPException(status_code=403, detail="User does not have access to this group.")
|
||||
if not user.has_quota_max_monthly_urls(archive.group_id):
|
||||
raise HTTPException(status_code=429, detail="User has reached their monthly URL quota.")
|
||||
if not user.has_quota_max_monthly_mbs(archive.group_id):
|
||||
raise HTTPException(status_code=429, detail="User has reached their monthly MB quota.")
|
||||
|
||||
# TODO: deprecate ArchiveCreate
|
||||
backwards_compatible_archive = schemas.ArchiveCreate(
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
"""rename sheets last_archived col
|
||||
|
||||
Revision ID: 1636724ec4b1
|
||||
Revises: a23aaf3ae930
|
||||
Create Date: 2025-02-05 19:19:01.984396
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '1636724ec4b1'
|
||||
down_revision = 'a23aaf3ae930'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('sheets')]
|
||||
if 'last_archived_at' in columns:
|
||||
op.alter_column('sheets', 'last_archived_at', new_column_name='last_url_archived_at')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('sheets')]
|
||||
if 'last_url_archived_at' in columns:
|
||||
op.alter_column('sheets', 'last_url_archived_at', new_column_name='last_archived_at')
|
||||
@@ -29,8 +29,7 @@ def test_create_sheet_endpoint(app_with_auth, db_session):
|
||||
assert response.status_code == 201
|
||||
j = response.json()
|
||||
assert datetime.fromisoformat(j.pop("created_at"))
|
||||
assert datetime.fromisoformat(j.pop("last_archived_at"))
|
||||
assert j.pop("stats") == {}
|
||||
assert datetime.fromisoformat(j.pop("last_url_archived_at"))
|
||||
assert j.pop("author_id") == 'morty@example.com'
|
||||
assert j == good_data
|
||||
|
||||
@@ -95,16 +94,15 @@ def test_get_user_sheets_endpoint(client_with_auth, db_session):
|
||||
assert isinstance(r, list)
|
||||
assert len(r) == 2
|
||||
assert datetime.fromisoformat(r[0].pop("created_at"))
|
||||
assert datetime.fromisoformat(r[0].pop("last_archived_at"))
|
||||
assert datetime.fromisoformat(r[0].pop("last_url_archived_at"))
|
||||
assert datetime.fromisoformat(r[1].pop("created_at"))
|
||||
assert datetime.fromisoformat(r[1].pop("last_archived_at"))
|
||||
assert datetime.fromisoformat(r[1].pop("last_url_archived_at"))
|
||||
assert r[0] == {
|
||||
'id': '123',
|
||||
'author_id': 'morty@example.com',
|
||||
'frequency': 'hourly',
|
||||
'group_id': 'spaceship',
|
||||
'name': 'Test Sheet 1',
|
||||
'stats': {},
|
||||
}
|
||||
assert r[1] == {
|
||||
'id': '456',
|
||||
@@ -112,7 +110,6 @@ def test_get_user_sheets_endpoint(client_with_auth, db_session):
|
||||
'frequency': 'daily',
|
||||
'group_id': 'interdimensional',
|
||||
'name': 'Test Sheet 2',
|
||||
'stats': {},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -7,36 +7,67 @@ def test_archive_url_unauthenticated(client, test_no_auth):
|
||||
test_no_auth(client.post, "/url/archive")
|
||||
|
||||
|
||||
@patch("endpoints.url.UserState")
|
||||
@patch("worker.main.create_archive_task.delay", return_value=TaskResult(id="123-456-789", status="PENDING", result=""))
|
||||
def test_archive_url(m1, client_with_auth):
|
||||
def test_archive_url(m1, m2, client_with_auth):
|
||||
m_user_state = MagicMock()
|
||||
m2.return_value = m_user_state
|
||||
|
||||
# url is too short
|
||||
response = client_with_auth.post("/url/archive", json={"url": "bad"})
|
||||
assert response.status_code == 422
|
||||
assert response.json()["detail"][0]["msg"] == 'String should have at least 5 characters'
|
||||
m1.assert_not_called()
|
||||
|
||||
# url is invalid
|
||||
response = client_with_auth.post("/url/archive", json={"url": "example.com"})
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "Invalid URL received."
|
||||
|
||||
# valid request
|
||||
m_user_state.has_quota_max_monthly_urls.return_value = True
|
||||
m_user_state.has_quota_max_monthly_mbs.return_value = True
|
||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {'id': '123-456-789'}
|
||||
|
||||
m1.assert_called_once()
|
||||
called_val = m1.call_args.args[0]
|
||||
assert json.loads(called_val) == {"id": None, "url": "https://example.com", "result": None, "public": True, "author_id": "rick@example.com", "group_id": None, "tags": [], "rearchive": True}
|
||||
assert json.loads(called_val) == {"id": None, "url": "https://example.com", "result": None, "public": True, "author_id": "rick@example.com", "group_id": None, "tags": [], "rearchive": True, "sheet_id":None}
|
||||
m_user_state.has_quota_max_monthly_urls.assert_called_once()
|
||||
m_user_state.has_quota_max_monthly_mbs.assert_called_once()
|
||||
|
||||
# user is not in group
|
||||
m_user_state.in_group.return_value = False
|
||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "new-group"})
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "User does not have access to this group."
|
||||
m_user_state.in_group.assert_called_once_with("new-group")
|
||||
|
||||
# user is in group
|
||||
m_user_state.in_group.return_value = True
|
||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spaceship"})
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {'id': '123-456-789'}
|
||||
|
||||
assert m1.call_count == 2
|
||||
called_val = m1.call_args.args[0]
|
||||
assert json.loads(called_val)["group_id"] == "spaceship"
|
||||
m_user_state.in_group.assert_called_with("spaceship")
|
||||
|
||||
# user is over monthly URL quota
|
||||
m_user_state.has_quota_max_monthly_urls.return_value = False
|
||||
m_user_state.has_quota_max_monthly_mbs.return_value = True
|
||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spaceship"})
|
||||
assert response.status_code == 429
|
||||
assert response.json()["detail"] == "User has reached their monthly URL quota."
|
||||
m_user_state.has_quota_max_monthly_urls.assert_called_with("spaceship")
|
||||
|
||||
# user is over monthly MB quota
|
||||
m_user_state.has_quota_max_monthly_urls.return_value = True
|
||||
m_user_state.has_quota_max_monthly_mbs.return_value = False
|
||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spacesuit"})
|
||||
assert response.status_code == 429
|
||||
assert response.json()["detail"] == "User has reached their monthly MB quota."
|
||||
m_user_state.has_quota_max_monthly_mbs.assert_called_with("spacesuit")
|
||||
|
||||
@patch("endpoints.url.UserState")
|
||||
def test_archive_url_quotas(m1, client_with_auth):
|
||||
|
||||
@@ -19,6 +19,7 @@ from core.logging import log_error
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
celery = Celery(__name__)
|
||||
celery.conf.broker_url = settings.CELERY_BROKER_URL
|
||||
celery.conf.result_backend = settings.CELERY_RESULT_BACKEND
|
||||
@@ -48,6 +49,7 @@ def create_archive_task(self, archive_json: str):
|
||||
return Metadata.choose_most_complete([a.result for a in archives])
|
||||
|
||||
orchestrator = choose_orchestrator(archive.group_id, archive.author_id)
|
||||
logger.info(f"Using orchestrator {orchestrator=}")
|
||||
result = orchestrator.feed_item(Metadata().set_url(url))
|
||||
|
||||
try:
|
||||
@@ -59,7 +61,7 @@ def create_archive_task(self, archive_json: str):
|
||||
raise e
|
||||
return result.to_dict()
|
||||
|
||||
|
||||
#TODO: refactor how user-groups are loaded and orchestrators chosen
|
||||
@celery.task(name="create_sheet_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 0})
|
||||
def create_sheet_task(self, sheet_json: str):
|
||||
sheet = schemas.SubmitSheet.model_validate_json(sheet_json)
|
||||
@@ -79,7 +81,8 @@ def create_sheet_task(self, sheet_json: str):
|
||||
continue
|
||||
try:
|
||||
#TODO: remove public from sheet in new refactor
|
||||
insert_result_into_db(result, sheet.tags, sheet.public, sheet.group_id, sheet.author_id, models.generate_uuid())
|
||||
#TODO: update the sheets table with the current date if any new archive was done
|
||||
insert_result_into_db(result, sheet.tags, sheet.public, sheet.group_id, sheet.author_id, models.generate_uuid(), sheet.sheet_id)
|
||||
stats["archived"] += 1
|
||||
except exc.IntegrityError as e:
|
||||
logger.warning(f"cached result detected: {e}")
|
||||
@@ -89,6 +92,10 @@ def create_sheet_task(self, sheet_json: str):
|
||||
stats["failed"] += 1
|
||||
stats["errors"].append(str(e))
|
||||
|
||||
if stats["archived"] > 0:
|
||||
with get_db() as session:
|
||||
crud.update_sheet_last_url_archived_at(session, sheet.sheet_id)
|
||||
|
||||
logger.info(f"SHEET DONE {sheet=}")
|
||||
return {"success": True, "sheet": sheet.sheet_name, "sheet_id": sheet.sheet_id, "time": datetime.datetime.now().isoformat(), **stats}
|
||||
|
||||
@@ -165,7 +172,7 @@ def is_group_invalid_for_user(public: bool, group_id: str, author_id: str):
|
||||
return False
|
||||
|
||||
|
||||
def insert_result_into_db(result: Metadata, tags: Set[str], public: bool, group_id: str, author_id: str, task_id: str) -> str:
|
||||
def insert_result_into_db(result: Metadata, tags: Set[str], public: bool, group_id: str, author_id: str, task_id: str, sheet_id:str="") -> str:
|
||||
logger.info(f"INSERTING {public=} {group_id=} {author_id=} {tags=} into {task_id}")
|
||||
assert result, f"UNABLE TO archive: {result.get_url() if result else result}"
|
||||
with get_db() as session:
|
||||
@@ -175,7 +182,7 @@ def insert_result_into_db(result: Metadata, tags: Set[str], public: bool, group_
|
||||
# create DB TAGs if needed
|
||||
db_tags = [crud.create_tag(session, tag) for tag in tags]
|
||||
# insert archive
|
||||
db_task = crud.create_task(session, task=schemas.ArchiveCreate(id=task_id, url=result.get_url(), result=json.loads(result.to_json()), public=public, author_id=author_id, group_id=group_id), tags=db_tags, urls=get_all_urls(result))
|
||||
db_task = crud.create_task(session, task=schemas.ArchiveCreate(id=task_id, url=result.get_url(), result=json.loads(result.to_json()), public=public, author_id=author_id, group_id=group_id, sheet_id=sheet_id), tags=db_tags, urls=get_all_urls(result))
|
||||
logger.debug(f"Added {db_task.id=} to database on {db_task.created_at} ({db_task.author_id})")
|
||||
return db_task.id
|
||||
|
||||
|
||||
Reference in New Issue
Block a user