diff --git a/src/db/crud.py b/src/db/crud.py index 7f9c67f..aae2cd8 100644 --- a/src/db/crud.py +++ b/src/db/crud.py @@ -126,7 +126,8 @@ def is_user_in_group(db: Session, group_name: str, email: str) -> models.Group: return len(group_name) and len(email) and group_name in get_user_groups(db, email) -def get_user_groups(db: Session, email: str): +#TODO: maybe this can be cached? what about the db session? +def get_user_groups(db: Session, email: str) -> list[str]: """ given an email retrieves the user groups from the DB and then the email-domain groups from a global variable, the email does not need to belong to an existing user. User does not need to be active. """ @@ -135,15 +136,53 @@ def get_user_groups(db: Session, email: str): # get user groups user_groups = db.query(models.association_table_user_groups).filter_by(user_id=email).with_entities(Column("group_id")).all() - user_level_groups = [g[0] for g in user_groups] + user_level_groups_names = [g[0] for g in user_groups] # get domain groups domain = email.split('@')[1] domain_level_groups = db.query(models.Group.id).filter(models.Group.domains.contains(domain)).with_entities(Column("id")).all() - domain_level_groups = [g[0] for g in domain_level_groups] + domain_level_groups_names = [g[0] for g in domain_level_groups] - # combine and return - return list(set(user_level_groups + domain_level_groups)) + return list(set(user_level_groups_names + domain_level_groups_names)) + + +# --------------- SHEET + +def has_quota_sheet(db: Session, email: str, user_groups_names: list[str]) -> bool: + """ + checks if a user has reached their sheet quota + """ + user_sheets = db.query(models.Sheet).filter(models.Sheet.author_id == email).count() + + user_groups = db.query(models.Group).filter(models.Group.id.in_(user_groups_names)).all() + + quota = 0 + for group in user_groups: + active_sheets = group.permissions.get("active_sheets", 0) + if active_sheets == -1: return True + quota = max(quota, active_sheets) + return user_sheets < quota + + +def create_sheet(db: Session, sheet_id: str, sheet_name: str, email: str, group_id: str, frequency: str): + db_sheet = models.Sheet(id=sheet_id, name=sheet_name, author_id=email, group_id=group_id, frequency=frequency) + db.add(db_sheet) + db.commit() + db.refresh(db_sheet) + return db_sheet + +def get_user_sheets(db: Session, email: str) -> list[models.Sheet]: + return db.query(models.Sheet).filter(models.Sheet.author_id == email).order_by(models.Sheet.last_archived_at.desc()).all() + +def get_user_sheet(db: Session, email: str, sheet_id: str) -> models.Sheet: + return db.query(models.Sheet).filter(models.Sheet.author_id == email, models.Sheet.id == sheet_id).first() + +def delete_sheet(db: Session, sheet_id: str, email: str) -> bool: + db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id, models.Sheet.author_id == email).first() + if db_sheet: + db.delete(db_sheet) + db.commit() + return db_sheet is not None # --------------- INIT User-Groups @@ -255,5 +294,5 @@ def upsert_user_groups(db: Session): db.commit() count_user_groups = db.query(models.association_table_user_groups).count() count_groups = db.query(func.count(models.Group.id)).scalar() - + logger.success(f"[CONFIG] DONE: [users={count_users(db)}, groups={count_groups}, explicit user groups={count_user_groups}].") diff --git a/src/db/models.py b/src/db/models.py index 193adba..f782588 100644 --- a/src/db/models.py +++ b/src/db/models.py @@ -6,9 +6,11 @@ import uuid Base = declarative_base() + def generate_uuid(): return str(uuid.uuid4()) + # many to many association tables association_table_archive_tags = Table( "mtm_archives_tags", @@ -24,24 +26,29 @@ association_table_user_groups = Table( ) # data model tables + + class Archive(Base): __tablename__ = "archives" id = Column(String, primary_key=True, index=True) url = Column(String, index=True) result = Column(JSON, default=None) - public = Column(Boolean, default=True) # if public=false, access to group and author + public = Column(Boolean, default=True) # if public=false, access to group and author deleted = Column(Boolean, default=False) created_at = Column(DateTime(timezone=True), server_default=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now()) group_id = Column(String, ForeignKey("groups.id"), default=None) author_id = Column(String, ForeignKey("users.email")) + sheet_id = Column(String, ForeignKey("sheets.id"), default=None) tags = relationship("Tag", back_populates="archives", secondary=association_table_archive_tags) group = relationship("Group", back_populates="archives") author = relationship("User", back_populates="archives") urls = relationship("ArchiveUrl", back_populates="archive") + sheet = relationship("Sheet", back_populates="archives") + class ArchiveUrl(Base): __tablename__ = "archive_urls" @@ -61,6 +68,7 @@ class Tag(Base): archives = relationship("Archive", back_populates="tags", secondary=association_table_archive_tags) + class User(Base): __tablename__ = "users" @@ -68,8 +76,10 @@ class User(Base): is_active = Column(Boolean, default=False) archives = relationship("Archive", back_populates="author") + sheets = relationship("Sheet", back_populates="author") groups = relationship("Group", back_populates="users", secondary=association_table_user_groups) + class Group(Base): __tablename__ = "groups" @@ -81,4 +91,23 @@ class Group(Base): domains = Column(JSON, default=[]) archives = relationship("Archive", back_populates="group") - users = relationship("User", back_populates="groups", secondary=association_table_user_groups) \ No newline at end of file + sheets = relationship("Sheet", back_populates="group") + users = relationship("User", back_populates="groups", secondary=association_table_user_groups) + + +class Sheet(Base): + __tablename__ = "sheets" + + id = Column(String, primary_key=True, index=True, doc="Google Sheet ID") + name = Column(String, default=None) + author_id = Column(String, ForeignKey("users.email")) + group_id = Column(String, ForeignKey("groups.id"), doc="Group ID, user must be in a group to create a sheet.") + frequency = Column(String, default="daily", doc="Frequency of archiving: hourly, daily, weekly.") + stats = Column(JSON, default={}, doc="Sheet statistics like total links, total rows, ...") + last_archived_at = Column(DateTime(timezone=True), server_default=func.now(), doc="Last time a new link was archived.") + created_at = Column(DateTime(timezone=True), server_default=func.now()) + updated_at = Column(DateTime(timezone=True), onupdate=func.now()) + + group = relationship("Group", back_populates="sheets") + author = relationship("User", back_populates="sheets") + archives = relationship("Archive", back_populates="sheet") diff --git a/src/db/schemas.py b/src/db/schemas.py index 538609a..a67dac6 100644 --- a/src/db/schemas.py +++ b/src/db/schemas.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, field_validator from datetime import datetime @@ -6,9 +6,10 @@ class Tag(BaseModel): id: str created_at: datetime - model_config = { "from_attributes": True } + model_config = {"from_attributes": True} __hash__ = object.__hash__ + class ArchiveCreate(BaseModel): id: str | None = None url: str @@ -26,7 +27,8 @@ class Archive(ArchiveCreate): updated_at: datetime | None deleted: bool - model_config = { "from_attributes": True } + model_config = {"from_attributes": True} + class SubmitSheet(BaseModel): sheet_name: str | None = None @@ -36,31 +38,70 @@ class SubmitSheet(BaseModel): author_id: str | None = None group_id: str | None = None tags: set[str] | None = set() - columns: dict | None = {} # TODO: implement + columns: dict | None = {} # TODO: implement + class SubmitManual(BaseModel): - result: str # should be a Metadata.to_json() + result: str # should be a Metadata.to_json() public: bool = False author_id: str | None = None group_id: str | None = None tags: set[str] | None = set() +# API REQUESTS BELOW +# TODO: replace existing schemas with these + + +class ArchiveUrl(BaseModel): + url: str + public: bool = False + author_id: str | None + group_id: str | None + tags: set[str] | None = set() + # API RESPONSES BELOW + + class ArchiveResult(BaseModel): id: str url: str result: dict created_at: datetime + class Task(BaseModel): id: str + class TaskResult(Task): status: str result: str + class TaskDelete(Task): deleted: bool + class ActiveUser(BaseModel): - active: bool \ No newline at end of file + active: bool + + +class SheetAdd(BaseModel): + id: str + name: str + group_id: str + frequency: str + + @field_validator('frequency') + def validate_frequency(cls, v): + valid_frequencies = {"hourly", "daily"} + if v not in {"hourly", "daily"}: + raise ValueError(f"Invalid frequency: {v}. Must be one of {valid_frequencies}.") + return v + + +class SheetResponse(SheetAdd): + author_id: str + stats: dict | None + last_archived_at: datetime | None + created_at: datetime diff --git a/src/endpoints/sheet.py b/src/endpoints/sheet.py index 5a32d4b..257ed92 100644 --- a/src/endpoints/sheet.py +++ b/src/endpoints/sheet.py @@ -2,23 +2,80 @@ from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import JSONResponse -from loguru import logger +from sqlalchemy import exc +from sqlalchemy.orm import Session -from core.config import ALLOW_ANY_EMAIL -from web.security import get_token_or_user_auth -from db import schemas +from web.security import token_api_key_auth, get_active_user_auth +from db import schemas, crud +from db.database import get_db_dependency from worker.main import create_sheet_task sheet_router = APIRouter(prefix="/sheet", tags=["Google Spreadsheet operations"]) -@sheet_router.post("/archive", status_code=201, summary="Submit a Google Sheet archive request, starts a sheet archiving task.", response_description="task_id for the archiving task.") -def archive_sheet(sheet:schemas.SubmitSheet, email = Depends(get_token_or_user_auth)) -> schemas.Task: - logger.info(f"SHEET TASK for {sheet=}") - if email == ALLOW_ANY_EMAIL: - email = sheet.author_id or "api-endpoint" - sheet.author_id = email - if not sheet.sheet_name and not sheet.sheet_id: - raise HTTPException(status_code=422, detail=f"sheet name or id is required") +@sheet_router.post("/create", status_code=201, summary="Store a new Google Sheet for regular archiving.") +def create_sheet( + sheet: schemas.SheetAdd, + email=Depends(get_active_user_auth), + db: Session = Depends(get_db_dependency), +) -> schemas.SheetResponse: + user_groups_names = crud.get_user_groups(db, email) + + if sheet.group_id not in user_groups_names: + raise HTTPException(status_code=403, detail="User does not have access to this group.") + + if not crud.has_quota_sheet(db, email, user_groups_names): + raise HTTPException(status_code=429, detail="User has reached their sheet quota.") + + try: + return crud.create_sheet(db, sheet.id, sheet.name, email, sheet.group_id, sheet.frequency) + except exc.IntegrityError as e: + raise HTTPException(status_code=400, detail="Sheet with this ID already exists.") from e + + +@sheet_router.get("/mine", status_code=200, summary="Get the authenticated user's Google Sheets.") +def get_user_sheets( + email=Depends(get_active_user_auth), + db: Session = Depends(get_db_dependency) +) -> list[schemas.SheetResponse]: + return crud.get_user_sheets(db, email) + + +@sheet_router.delete("/{id}", summary="Delete a Google Sheet by ID.") +def delete_sheet( + id: str, + email=Depends(get_active_user_auth), + db: Session = Depends(get_db_dependency), +) -> schemas.TaskDelete: + return JSONResponse({ + "id": id, + "deleted": crud.delete_sheet(db, id, email) + }) + + +@sheet_router.post("/{id}/archive", status_code=201, summary="Trigger an archiving task for a GSheet you own.", response_description="task_id for the archiving task.") +def archive_user_sheet( + id: str, + email=Depends(get_active_user_auth), + db: Session = Depends(get_db_dependency), +) -> schemas.Task: + + sheet = crud.get_user_sheet(db, email, sheet_id=id) + if not sheet: + raise HTTPException(status_code=403, detail="No access to this sheet.") + + task = create_sheet_task.delay(schemas.SubmitSheet(sheet_id=id, author_id=email, group=sheet.group_id).model_dump_json()) + + return JSONResponse({"id": task.id}, status_code=201) + + +@sheet_router.post("/archive", status_code=201, summary="Trigger an archiving task for any GSheet with an API token.", response_description="task_id for the archiving task.") +def archive_sheet( + sheet: schemas.SubmitSheet, #TODO: replace with simpler model + auth=Depends(token_api_key_auth) +) -> schemas.Task: + sheet.author_id = sheet.author_id or "api-endpoint" + if not sheet.sheet_id: + raise HTTPException(status_code=422, detail=f"sheet id is required") task = create_sheet_task.delay(sheet.model_dump_json()) - return JSONResponse({"id": task.id}, status_code=201) \ No newline at end of file + return JSONResponse({"id": task.id}, status_code=201) diff --git a/src/migrations/versions/89121d2c96d8_add_sheet_id_to_archive_table.py b/src/migrations/versions/89121d2c96d8_add_sheet_id_to_archive_table.py new file mode 100644 index 0000000..bbdbee1 --- /dev/null +++ b/src/migrations/versions/89121d2c96d8_add_sheet_id_to_archive_table.py @@ -0,0 +1,42 @@ +"""add sheet_id to archive table + +Revision ID: 89121d2c96d8 +Revises: fa012ec405b8 +Create Date: 2024-11-04 11:12:30.237299 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.engine.reflection import Inspector + + +# revision identifiers, used by Alembic. +revision = '89121d2c96d8' +down_revision = 'fa012ec405b8' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + conn = op.get_bind() + inspector = Inspector.from_engine(conn) + columns = [col['name'] for col in inspector.get_columns('archives')] + + if 'sheet_id' not in columns: + with op.batch_alter_table('archives') as batch_op: + batch_op.add_column(sa.Column('sheet_id', sa.String(), nullable=True, default=None)) + batch_op.create_foreign_key('fk_sheet_id', 'sheets', ['sheet_id'], ['id']) + + +def downgrade() -> None: + conn = op.get_bind() + inspector = Inspector.from_engine(conn) + foreign_keys = [fk['name'] for fk in inspector.get_foreign_keys('archives')] + columns = [col['name'] for col in inspector.get_columns('archives')] + + with op.batch_alter_table('archives') as batch_op: + if 'fk_sheet_id' in foreign_keys: + batch_op.drop_constraint('fk_sheet_id', type_='foreignkey') + + if 'sheet_id' in columns: + batch_op.drop_column('sheet_id') diff --git a/src/migrations/versions/fa012ec405b8_add_columns_to_groups_table.py b/src/migrations/versions/fa012ec405b8_add_columns_to_groups_table.py index da77d41..be94c98 100644 --- a/src/migrations/versions/fa012ec405b8_add_columns_to_groups_table.py +++ b/src/migrations/versions/fa012ec405b8_add_columns_to_groups_table.py @@ -35,8 +35,11 @@ def upgrade() -> None: def downgrade() -> None: - op.drop_column('groups', 'description') - op.drop_column('groups', 'orchestrator') - op.drop_column('groups', 'orchestrator_sheet') - op.drop_column('groups', 'permissions') - op.drop_column('groups', 'domains') + conn = op.get_bind() + inspector = Inspector.from_engine(conn) + columns = [col['name'] for col in inspector.get_columns('groups')] + + column_names = ['description', 'orchestrator', 'orchestrator_sheet', 'permissions', 'domains'] + for column_name in column_names: + if column_name in columns: + op.drop_column('groups', column_name) diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 97c18e8..6188d9f 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -16,6 +16,7 @@ def mock_logger_add(): def get_settings(): return Settings(_env_file=".env.test") + @pytest.fixture(autouse=True) def mock_settings(): with patch('shared.settings.Settings', return_value=Settings(_env_file=".env.test")) as mock_settings: @@ -26,7 +27,7 @@ def mock_settings(): def test_db(get_settings: Settings): from db.database import make_engine from db import models - + make_engine.cache_clear() engine = make_engine(get_settings.DATABASE_PATH) @@ -72,10 +73,10 @@ def client(app): @pytest.fixture() def app_with_auth(app): - from web.security import get_token_or_user_auth, get_user_auth, token_api_key_auth + from web.security import get_token_or_user_auth, get_user_auth, get_active_user_auth app.dependency_overrides[get_token_or_user_auth] = lambda: "rick@example.com" app.dependency_overrides[get_user_auth] = lambda: "morty@example.com" - app.dependency_overrides[token_api_key_auth] = lambda: "jerry@example.com" + app.dependency_overrides[get_active_user_auth] = lambda: "morty@example.com" return app @@ -85,6 +86,19 @@ def client_with_auth(app_with_auth): return client +@pytest.fixture() +def app_with_token(app): + from web.security import token_api_key_auth + app.dependency_overrides[token_api_key_auth] = lambda: "jerry@example.com" + return app + + +@pytest.fixture() +def client_with_token(app_with_token): + client = TestClient(app_with_token) + return client + + @pytest.fixture() def test_no_auth(): # reusable code to ensure a method/endpoint combination is unauthorized @@ -92,4 +106,4 @@ def test_no_auth(): response = http_method(endpoint) assert response.status_code == 403 assert response.json() == {"detail": "Not authenticated"} - return no_auth \ No newline at end of file + return no_auth diff --git a/src/tests/endpoints/test_default.py b/src/tests/endpoints/test_default.py index 6a585e6..9455bcb 100644 --- a/src/tests/endpoints/test_default.py +++ b/src/tests/endpoints/test_default.py @@ -101,10 +101,16 @@ def test_favicon(client_with_auth): assert r.headers["content-type"] == "image/vnd.microsoft.icon" +def test_endpoint_test_prometheus_no_auth(client, test_no_auth): + test_no_auth(client.get, "/metrics") + +def test_endpoint_test_prometheus_no_user_auth(client_with_auth, test_no_auth): + test_no_auth(client_with_auth.get, "/metrics") + @pytest.mark.asyncio -async def test_prometheus_metrics(test_data, client_with_auth, get_settings): +async def test_prometheus_metrics(test_data, client_with_token, get_settings): # before metrics calculation - r = client_with_auth.get("/metrics") + r = client_with_token.get("/metrics") assert r.status_code == 200 assert r.headers["content-type"] == "text/plain; version=0.0.4; charset=utf-8" assert "disk_utilization" in r.text @@ -116,7 +122,7 @@ async def test_prometheus_metrics(test_data, client_with_auth, get_settings): # after metrics calculation from utils.metrics import measure_regular_metrics await measure_regular_metrics(get_settings.DATABASE_PATH, 60 * 60 * 24 * 31 * 12 * 100) - r2 = client_with_auth.get("/metrics") + r2 = client_with_token.get("/metrics") assert 'disk_utilization{type="used"}' in r2.text assert 'disk_utilization{type="free"}' in r2.text assert 'disk_utilization{type="database"}' in r2.text @@ -130,7 +136,7 @@ async def test_prometheus_metrics(test_data, client_with_auth, get_settings): # 30s window, should not change the gauges nor the total in the counters from utils.metrics import measure_regular_metrics await measure_regular_metrics(get_settings.DATABASE_PATH, 30) - r3 = client_with_auth.get("/metrics") + r3 = client_with_token.get("/metrics") assert 'database_metrics{query="count_archives"} 100.0' in r3.text assert 'database_metrics{query="count_archive_urls"} 1000.0' in r3.text assert 'database_metrics{query="count_users"} 4.0' in r3.text diff --git a/src/tests/endpoints/test_interopreability.py b/src/tests/endpoints/test_interopreability.py index 82136f0..8021bfe 100644 --- a/src/tests/endpoints/test_interopreability.py +++ b/src/tests/endpoints/test_interopreability.py @@ -5,15 +5,19 @@ def test_submit_manual_archive_unauthenticated(client, test_no_auth): test_no_auth(client.post, "/interop/submit-archive") -def test_submit_manual_archive(client_with_auth): +def test_submit_manual_archive_not_user_auth(client_with_auth, test_no_auth): + test_no_auth(client_with_auth.post, "/interop/submit-archive") + + +def test_submit_manual_archive(client_with_token): aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": []}) - r = client_with_auth.post("/interop/submit-archive", json={"result": aa_metadata, "public": False, "author_id": "jerry@gmail.com", "group_id": None, "tags": ["test"]}) + r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": False, "author_id": "jerry@gmail.com", "group_id": None, "tags": ["test"]}) assert r.status_code == 201 assert "id" in r.json() # cannot have the same URL twice aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.com", "http://example.com"]}]}) - r = client_with_auth.post("/interop/submit-archive", json={"result": aa_metadata, "public": False, "author_id": "jerry@gmail.com", "group_id": None, "tags": ["test"]}) + r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": False, "author_id": "jerry@gmail.com", "group_id": None, "tags": ["test"]}) assert r.status_code == 422 assert r.json() == {"detail": "Cannot insert into DB due to integrity error"} diff --git a/src/tests/endpoints/test_sheet.py b/src/tests/endpoints/test_sheet.py index f3e2559..ef56361 100644 --- a/src/tests/endpoints/test_sheet.py +++ b/src/tests/endpoints/test_sheet.py @@ -1,46 +1,210 @@ +from datetime import datetime import json from unittest.mock import patch +from fastapi.testclient import TestClient + from db.schemas import TaskResult -def test_sheet_no_auth(client, test_no_auth): +def test_endpoints_no_auth(client, test_no_auth): + test_no_auth(client.post, "/sheet/create") + test_no_auth(client.get, "/sheet/mine") + test_no_auth(client.delete, "/sheet/123-sheet-id") + test_no_auth(client.post, "/sheet/123-sheet-id/archive") test_no_auth(client.post, "/sheet/archive") -@patch("worker.main.create_sheet_task.delay", return_value=TaskResult(id="123-456-789", status="PENDING", result="")) -def test_sheet_rick(m1, client_with_auth): +def test_create_sheet_endpoint(app_with_auth): + client_with_auth = TestClient(app_with_auth) + good_data = { + "id": "123-sheet-id", + "name": "Test Sheet", + "group_id": "spaceship", + "frequency": "daily" + } - response = client_with_auth.post("/sheet/archive", json={"sheet_id": "123-sheet-id"}) + # with good data + response = client_with_auth.post("/sheet/create", json=good_data) assert response.status_code == 201 - assert response.json() == {'id': '123-456-789'} + j = response.json() + assert datetime.fromisoformat(j.pop("created_at")) + assert datetime.fromisoformat(j.pop("last_archived_at")) + assert j.pop("stats") == {} + assert j.pop("author_id") == 'morty@example.com' + assert j == good_data - m1.assert_called_once() - called_val = m1.call_args.args[0] - assert json.loads(called_val) == {"sheet_id": "123-sheet-id", "sheet_name": None, "public": False, "author_id": "rick@example.com", "group_id": None, "tags": [], "columns": {}, "header": 1} + # already exists + response = client_with_auth.post("/sheet/create", json=good_data) + assert response.status_code == 400 + assert response.json() == {"detail": "Sheet with this ID already exists."} + # bad frequency + bad_data = good_data.copy() + bad_data["frequency"] = "every hour" + response = client_with_auth.post("/sheet/create", json=bad_data) + assert response.status_code == 422 + assert "Value error, Invalid frequency: every hour. Must be one of" in response.json()["detail"][0]["msg"] -def test_sheet_missing_sheet_data(client_with_auth): - r = client_with_auth.post("/sheet/archive", json={}) - assert r.status_code == 422 - assert r.json() == {"detail": "sheet name or id is required"} + # bad group + bad_data = good_data.copy() + bad_data["group_id"] = "not a group" + response = client_with_auth.post("/sheet/create", json=bad_data) + assert response.status_code == 403 + assert response.json() == {"detail": "User does not have access to this group."} - -@patch("worker.main.create_sheet_task.delay", return_value=TaskResult(id="123-API-789", status="PENDING", result="")) -def test_sheet_api(m1, client): - - response = client.post("/sheet/archive", json={"sheet_name": "456-sheet_name-id"}, headers={"Authorization": "Bearer this_is_the_test_api_token"}) + # bad quota + jerry_data = good_data.copy() + jerry_data["group_id"] = "animated-characters" + jerry_data["id"] = "jerry-sheet-id" + from web.security import get_active_user_auth + app_with_auth.dependency_overrides[get_active_user_auth] = lambda: "jerry@example.com" + client_jerry = TestClient(app_with_auth) + response = client_jerry.post("/sheet/create", json=jerry_data) assert response.status_code == 201 - assert response.json() == {'id': '123-API-789'} - m1.assert_called_once() - called_val = m1.call_args.args[0] - assert json.loads(called_val) == {"sheet_name": "456-sheet_name-id", "sheet_id": None, "public": False, "author_id": "api-endpoint", "group_id": None, "tags": [], "columns": {}, "header": 1} + response = client_jerry.post("/sheet/create", json=jerry_data) + assert response.status_code == 429 + assert response.json() == {"detail": "User has reached their sheet quota."} - response = client.post("/sheet/archive", json={"sheet_id": "456-sheet-id", "author_id": "custom-author"}, headers={"Authorization": "Bearer this_is_the_test_api_token"}) - assert response.status_code == 201 - assert response.json() == {'id': '123-API-789'} - assert m1.call_count == 2 - called_val = m1.call_args.args[0] - assert json.loads(called_val) == {"sheet_id": "456-sheet-id", "sheet_name": None, "public": False, "author_id": "custom-author", "group_id": None, "tags": [], "columns": {}, "header": 1} +def test_get_user_sheets_endpoint(client_with_auth, db_session): + # no data + response = client_with_auth.get("/sheet/mine") + assert response.status_code == 200 + assert response.json() == [] + + # with data + from db import models + db_session.add( + models.Sheet(id="123", name="Test Sheet 1", author_id="morty@example.com", group_id="spaceship", frequency="hourly") + ) + db_session.commit() + db_session.add_all([ + models.Sheet(id="456", name="Test Sheet 2", author_id="morty@example.com", group_id="interdimensional", frequency="daily"), + models.Sheet(id="789", name="Test Sheet 3", author_id="rick@example.com", group_id="interdimensional", frequency="hourly"), + ]) + db_session.commit() + + response = client_with_auth.get("/sheet/mine") + assert response.status_code == 200 + r = response.json() + assert isinstance(r, list) + assert len(r) == 2 + assert datetime.fromisoformat(r[0].pop("created_at")) + assert datetime.fromisoformat(r[0].pop("last_archived_at")) + assert datetime.fromisoformat(r[1].pop("created_at")) + assert datetime.fromisoformat(r[1].pop("last_archived_at")) + assert r[0] == { + 'id': '123', + 'author_id': 'morty@example.com', + 'frequency': 'hourly', + 'group_id': 'spaceship', + 'name': 'Test Sheet 1', + 'stats': {}, + } + assert r[1] == { + 'id': '456', + 'author_id': 'morty@example.com', + 'frequency': 'daily', + 'group_id': 'interdimensional', + 'name': 'Test Sheet 2', + 'stats': {}, + } + + +def test_delete_sheet_endpoint(client_with_auth, db_session): + # missing sheet + response = client_with_auth.delete("/sheet/123-sheet-id") + assert response.status_code == 200 + assert response.json() == { + "id": "123-sheet-id", + "deleted": False + } + + # add sheets for deletion + from db import models + db_session.add_all([ + models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="interdimensional", frequency="daily"), + models.Sheet(id="456-sheet-id", name="Test Sheet 2", author_id="rick@example.com", group_id="spaceship", frequency="hourly"), + ]) + db_session.commit() + + # morty can delete his + response = client_with_auth.delete("/sheet/123-sheet-id") + assert response.status_code == 200 + assert response.json() == {"id": "123-sheet-id", "deleted": True} + # but only once + response = client_with_auth.delete("/sheet/123-sheet-id") + assert response.status_code == 200 + assert response.json() == {"id": "123-sheet-id", "deleted": False} + # and not rick's + response = client_with_auth.delete("/sheet/456-sheet-id") + assert response.status_code == 200 + assert response.json() == {"id": "456-sheet-id", "deleted": False} + + +# def test_archive_user_sheet_endpoint(client_with_auth): +# response = client_with_auth.post("/sheet/123-sheet-id/archive") +# assert response.status_code == 201 +# assert "id" in response.json() + + +class TestArchiveUserSheetEndpoint: + def test_token_auth(self, client_with_token, test_no_auth): + test_no_auth(client_with_token.post, "/sheet/123-sheet-id/archive") + + def test_missing_data(self, client_with_auth): + r = client_with_auth.post("/sheet/123-sheet-id/archive") + assert r.status_code == 403 + assert r.json() == {"detail": "No access to this sheet."} + + def test_no_access(self, client_with_auth, db_session): + from db import models + db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="rick@example.com", group_id="spaceship", frequency="hourly")) + db_session.commit() + r = client_with_auth.post("/sheet/123-sheet-id/archive") + assert r.status_code == 403 + assert r.json() == {"detail": "No access to this sheet."} + + @patch("worker.main.create_sheet_task.delay", return_value=TaskResult(id="123-taskid", status="PENDING", result="")) + def test_normal_flow(self, m1, client_with_auth, db_session): + from db import models + db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="spaceship", frequency="hourly")) + db_session.commit() + r = client_with_auth.post("/sheet/123-sheet-id/archive") + assert r.status_code == 201 + assert r.json() == {"id": "123-taskid"} + m1.assert_called_once() + + +class TestTokenArchiveEndpoint: + + def test_user_auth(self, client_with_auth, test_no_auth): + test_no_auth(client_with_auth.post, "/sheet/archive") + + def test_missing_data(self, client_with_token): + r = client_with_token.post("/sheet/archive", json={}) + assert r.status_code == 422 + assert r.json() == {"detail": "sheet id is required"} + + @patch("worker.main.create_sheet_task.delay", return_value=TaskResult(id="123-456-789", status="PENDING", result="")) + def test_normal_flow(self, m1, client_with_token): + + # minimum data + response = client_with_token.post("/sheet/archive", json={"sheet_id": "123-sheet-id"}) + assert response.status_code == 201 + assert response.json() == {'id': '123-456-789'} + + m1.assert_called_once() + called_val = m1.call_args.args[0] + assert json.loads(called_val) == {"sheet_id": "123-sheet-id", "sheet_name": None, "public": False, "author_id": "api-endpoint", "group_id": None, "tags": [], "columns": {}, "header": 1} + + # maximum data + response = client_with_token.post("/sheet/archive", json={"sheet_id": "123-sheet-id", "sheet_name": "768-sheet-name", "author_id": "birdman@example.com", "header": 2, "public": True, "group_id": "456-group-id", "tags": ["tag1"], "columns": {"col1": "type1"}}) + assert response.status_code == 201 + assert response.json() == {'id': '123-456-789'} + + m1.call_count == 2 + called_val = m1.call_args.args[0] + assert json.loads(called_val) == {"sheet_id": "123-sheet-id", "sheet_name": "768-sheet-name", "public": True, "author_id": "birdman@example.com", "group_id": "456-group-id", "tags": ["tag1"], "columns": {"col1": "type1"}, "header": 2} diff --git a/src/tests/user-groups.test.yaml b/src/tests/user-groups.test.yaml index 612f09d..aa18c76 100644 --- a/src/tests/user-groups.test.yaml +++ b/src/tests/user-groups.test.yaml @@ -33,6 +33,7 @@ groups: active_sheets: -1 monthly_urls: all monthly_mbs: all + alowed_frequency: "hourly" interdimensional: description: "Interdimensional travelers" orchestrator: tests/orchestration.test.yaml @@ -42,12 +43,14 @@ groups: active_sheets: 5 monthly_urls: 1000 monthly_mbs: 1000 + alowed_frequency: "hourly" animated-characters: description: "Animated characters" orchestrator: tests/orchestration.test.yaml orchestrator_sheet: tests/orchestration.test.yaml permissions: read: ["animated-characters"] - active_sheets: -1 - monthly_urls: all - monthly_mbs: all \ No newline at end of file + active_sheets: 1 + monthly_urls: 2 + monthly_mbs: 10 + alowed_frequency: "daily" \ No newline at end of file diff --git a/src/web/security.py b/src/web/security.py index bfa678a..141cd1b 100644 --- a/src/web/security.py +++ b/src/web/security.py @@ -4,6 +4,8 @@ from fastapi import HTTPException, status, Depends from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from core.config import ALLOW_ANY_EMAIL from shared.settings import get_settings +from db.database import get_db +from db import crud settings = get_settings() bearer_security = HTTPBearer() @@ -54,6 +56,18 @@ async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bear ) +async def get_active_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)): + # validates Bearer token and Active User status + try: + email = await get_user_auth(credentials) + with get_db() as db: + if crud.is_active_user(db, email): + return email + raise HTTPException(status_code=403, detail="User is not active") + except HTTPException as e: + raise e + + def authenticate_user(access_token): # https://cloud.google.com/docs/authentication/token-types#access if type(access_token) != str or len(access_token) < 10: return False, "invalid access_token" @@ -69,7 +83,7 @@ def authenticate_user(access_token): return False, f"email '{j.get('email')}' not verified" if int(j.get("expires_in", -1)) <= 0: return False, "Token expired" - return True, j.get('email') + return True, j.get('email').lower() except Exception as e: logger.warning(f"AUTH EXCEPTION occurred: {e}") return False, "exception occurred" diff --git a/src/worker/main.py b/src/worker/main.py index bde1073..8fe97d5 100644 --- a/src/worker/main.py +++ b/src/worker/main.py @@ -64,11 +64,13 @@ def create_sheet_task(self, sheet_json: str): sheet.tags.add("gsheet") logger.info(f"SHEET START {sheet=}") + #TODO: should this check live here? if (em := is_group_invalid_for_user(sheet.public, sheet.group_id, sheet.author_id)): return {"error": em} config = Config() # TODO: use choose_orchestrator and overwrite the feeder + # TODO: drop sheet_name and use only sheet_id (new endpoints/models) config.parse(use_cli=False, yaml_config_filename=get_settings().SHEET_ORCHESTRATION_YAML, overwrite_configs={"configurations": {"gsheet_feeder": {"sheet": sheet.sheet_name, "sheet_id": sheet.sheet_id, "header": sheet.header}}}) orchestrator = ArchivingOrchestrator(config) @@ -78,6 +80,7 @@ def create_sheet_task(self, sheet_json: str): logger.error("Got empty result from feeder, an internal error must have occurred.") continue try: + #TODO: remove public from sheet in new refactor insert_result_into_db(result, sheet.tags, sheet.public, sheet.group_id, sheet.author_id, models.generate_uuid()) stats["archived"] += 1 except exc.IntegrityError as e: