From 9a62f3ff59b683c9603fafa03a04412d03c1aa0b Mon Sep 17 00:00:00 2001 From: msramalho <19508417+msramalho@users.noreply.github.com> Date: Sat, 8 Feb 2025 00:40:35 +0000 Subject: [PATCH] WIP decoupling worker/web, cleaning worker code --- src/core/events.py | 10 +- src/core/logging.py | 3 +- src/db/crud.py | 2 +- src/db/database.py | 2 +- src/db/schemas.py | 45 ++++---- src/db/user_state.py | 6 + src/endpoints/__init__.py | 5 - src/endpoints/sheet.py | 7 +- src/endpoints/task.py | 3 +- src/endpoints/url.py | 14 +-- src/shared/settings.py | 1 + src/shared/task_messaging.py | 18 +++ src/tests/conftest.py | 2 +- src/tests/endpoints/test_sheet.py | 37 ++++-- src/tests/endpoints/test_url.py | 46 ++++++-- src/tests/worker/test_worker_main.py | 45 +------- src/utils/metrics.py | 13 ++- src/web/main.py | 18 ++- src/worker/main.py | 164 ++++++++------------------- 19 files changed, 194 insertions(+), 247 deletions(-) create mode 100644 src/shared/task_messaging.py diff --git a/src/core/events.py b/src/core/events.py index 76dc973..8336f7b 100644 --- a/src/core/events.py +++ b/src/core/events.py @@ -6,15 +6,15 @@ from fastapi import FastAPI from contextlib import asynccontextmanager from fastapi_utils.tasks import repeat_every from loguru import logger -from sqlalchemy import text from db import crud, models, schemas from db.database import get_db, get_db_async, make_engine, wal_checkpoint from shared.settings import get_settings +from shared.task_messaging import get_celery from utils.metrics import measure_regular_metrics, redis_subscribe_worker_exceptions -from worker.main import create_sheet_task from fastapi_mail import FastMail, MessageSchema, MessageType +celery = get_celery() @asynccontextmanager async def lifespan(app: FastAPI): @@ -25,7 +25,7 @@ async def lifespan(app: FastAPI): models.Base.metadata.create_all(bind=engine) alembic.config.main(argv=['--raiseerr', 'upgrade', 'head']) logging.getLogger("uvicorn.access").disabled = True # loguru - asyncio.create_task(redis_subscribe_worker_exceptions(get_settings().REDIS_EXCEPTIONS_CHANNEL, get_settings().CELERY_BROKER_URL)) + asyncio.create_task(redis_subscribe_worker_exceptions(get_settings().REDIS_EXCEPTIONS_CHANNEL)) asyncio.create_task(repeat_measure_regular_metrics()) with get_db() as db: crud.upsert_user_groups(db) @@ -72,7 +72,9 @@ async def archive_sheets_cronjob(frequency: str, interval: int, current_time_uni async with get_db_async() as db: sheets = await crud.get_sheets_by_id_hash(db, frequency, interval, current_time_unit) for s in sheets: - task = create_sheet_task.apply_async(args=[schemas.SubmitSheet(sheet_id=s.id, author_id=s.author_id, group=s.group_id).model_dump_json()]) + + task = celery.signature("create_sheet_task", args=[schemas.SubmitSheet(sheet_id=s.id, author_id=s.author_id, group=s.group_id).model_dump_json()]).apply_async() + triggered_jobs.append({"sheet_id": s.id, "task_id": task.id}) logger.info(f"[CRON {frequency.upper()}:{current_time_unit}] Triggered {len(triggered_jobs)} sheet tasks: {triggered_jobs}") diff --git a/src/core/logging.py b/src/core/logging.py index 5ff03db..fe9f905 100644 --- a/src/core/logging.py +++ b/src/core/logging.py @@ -9,7 +9,6 @@ logger.add("logs/error_logs.log", retention="30 days", level="ERROR") def log_error(e: Exception, traceback_str: str = None, extra:str = ""): - # EXCEPTION_COUNTER.labels(type(e).__name__).inc() if not traceback_str: traceback_str = traceback.format_exc() if extra: extra = f"{extra}\n" logger.error(f"{extra}{e.__class__.__name__}: {e}\n{traceback_str}") @@ -21,6 +20,6 @@ async def logging_middleware(request: Request, call_next): return response except Exception as e: from utils.metrics import EXCEPTION_COUNTER - EXCEPTION_COUNTER.labels(type(e).__name__).inc() + EXCEPTION_COUNTER.labels(type=e.__class__.__name__).inc() log_error(e) raise e \ No newline at end of file diff --git a/src/db/crud.py b/src/db/crud.py index 8e3fb1d..68989b3 100644 --- a/src/db/crud.py +++ b/src/db/crud.py @@ -100,7 +100,7 @@ def base_query(db: Session): # --------------- TAG -def create_tag(db: Session, tag: str): +def create_tag(db: Session, tag: str) -> models.Tag: db_tag = db.query(models.Tag).filter(models.Tag.id == tag).first() if not db_tag: db_tag = models.Tag(id=tag) diff --git a/src/db/database.py b/src/db/database.py index 4555b61..f672b87 100644 --- a/src/db/database.py +++ b/src/db/database.py @@ -48,7 +48,7 @@ async def make_async_engine(database_url: str) -> AsyncEngine: engine = create_async_engine(database_url, connect_args={"check_same_thread": False}) async with engine.begin() as conn: - await conn.run_sync(lambda sync_conn: sync_conn.execute("PRAGMA journal_mode=WAL;")) + await conn.run_sync(lambda sync_conn: sync_conn.execute(text("PRAGMA journal_mode=WAL;"))) return engine diff --git a/src/db/schemas.py b/src/db/schemas.py index 424b9de..6600c63 100644 --- a/src/db/schemas.py +++ b/src/db/schemas.py @@ -11,35 +11,13 @@ class Tag(BaseModel): model_config = {"from_attributes": True} __hash__ = object.__hash__ - -class ArchiveCreate(BaseModel): - id: str | None = None - url: str - result: dict | None = None - public: bool = True - author_id: str | None = None - group_id: str | None = None - tags: set[Tag] | None = set() - rearchive: bool = True - sheet_id: str | None = None - # urls: list = [] - - -class Archive(ArchiveCreate): - created_at: datetime - updated_at: datetime | None - deleted: bool - - model_config = {"from_attributes": True} - - class SubmitSheet(BaseModel): sheet_name: str | None = None sheet_id: str | None = None header: int = 1 public: bool = False author_id: str | None = None - group_id: str | None = None + group_id: str | None tags: set[str] | None = set() columns: dict | None = {} # TODO: implement @@ -103,10 +81,25 @@ class SheetResponse(SheetAdd): class ArchiveTrigger(BaseModel): + author_id: str | None = None url: Annotated[str, Len(min_length=5)] - public: bool = True - group_id: Annotated[str, Len(min_length=1)] | None = None - tags: set[Tag] | None = set() + public: bool = False + group_id: Annotated[str, Len(min_length=1)] = "default" + tags: set[Tag] | None = None + +class ArchiveCreate(ArchiveTrigger): + id: str | None = None + result: dict | None = None + sheet_id: str | None = None + urls: list | None = None + +class Archive(ArchiveCreate): + created_at: datetime + updated_at: datetime | None + deleted: bool + + model_config = {"from_attributes": True} + class Usage(BaseModel): monthly_urls: int = 0 diff --git a/src/db/user_state.py b/src/db/user_state.py index 9ae3135..0091398 100644 --- a/src/db/user_state.py +++ b/src/db/user_state.py @@ -260,6 +260,9 @@ class UserState: else: if group_id not in self.permissions: return False quota = self.permissions[group_id].max_monthly_urls + + if quota == -1: + return True current_month = datetime.now().month current_year = datetime.now().year @@ -282,6 +285,9 @@ class UserState: if group_id not in self.permissions: return False quota = self.permissions[group_id].max_monthly_mbs + if quota == -1: + return True + current_month = datetime.now().month current_year = datetime.now().year diff --git a/src/endpoints/__init__.py b/src/endpoints/__init__.py index 1551fae..e69de29 100644 --- a/src/endpoints/__init__.py +++ b/src/endpoints/__init__.py @@ -1,5 +0,0 @@ -from endpoints.default import default_router -from endpoints.url import url_router -from endpoints.task import task_router -from endpoints.interoperability import interoperability_router -from endpoints.sheet import sheet_router \ No newline at end of file diff --git a/src/endpoints/sheet.py b/src/endpoints/sheet.py index b1a4230..95731f0 100644 --- a/src/endpoints/sheet.py +++ b/src/endpoints/sheet.py @@ -6,13 +6,14 @@ from sqlalchemy import exc from sqlalchemy.orm import Session from db.user_state import UserState +from shared.task_messaging import get_celery from web.security import token_api_key_auth, get_user_state from db import schemas, crud from db.database import get_db_dependency -from worker.main import create_sheet_task sheet_router = APIRouter(prefix="/sheet", tags=["Google Spreadsheet operations"]) +celery = get_celery() @sheet_router.post("/create", status_code=201, summary="Store a new Google Sheet for regular archiving.") def create_sheet( @@ -73,7 +74,7 @@ def archive_user_sheet( if not user.can_manually_trigger(sheet.group_id): raise HTTPException(status_code=429, detail="User cannot manually trigger sheet archiving in this group.") - task = create_sheet_task.delay(schemas.SubmitSheet(sheet_id=id, author_id=user.email, group=sheet.group_id).model_dump_json()) + task = celery.signature("create_sheet_task", args=[schemas.SubmitSheet(sheet_id=id, author_id=user.email, group=sheet.group_id).model_dump_json()]).delay() return JSONResponse({"id": task.id}, status_code=201) @@ -86,5 +87,5 @@ def archive_sheet( sheet.author_id = sheet.author_id or "api-endpoint" if not sheet.sheet_id: raise HTTPException(status_code=422, detail=f"sheet id is required") - task = create_sheet_task.delay(sheet.model_dump_json()) + task = celery.signature("create_sheet_task", args=[sheet.model_dump_json()]).delay() return JSONResponse({"id": task.id}, status_code=201) diff --git a/src/endpoints/task.py b/src/endpoints/task.py index a2250fd..0c7f1e3 100644 --- a/src/endpoints/task.py +++ b/src/endpoints/task.py @@ -4,16 +4,17 @@ from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from loguru import logger +from shared.task_messaging import get_celery from web.security import get_token_or_user_auth from db import schemas from core.logging import log_error -from worker.main import celery from utils.misc import custom_jsonable_encoder task_router = APIRouter(prefix="/task", tags=["Async task operations"]) +celery = get_celery() @task_router.get("/{task_id}", summary="Check the status of an async task by its id, works for URLs and Sheet tasks.") def get_status(task_id, email=Depends(get_token_or_user_auth)) -> schemas.TaskResult: diff --git a/src/endpoints/url.py b/src/endpoints/url.py index 3d0aae1..0c35238 100644 --- a/src/endpoints/url.py +++ b/src/endpoints/url.py @@ -6,17 +6,18 @@ from datetime import datetime from loguru import logger from core.config import ALLOW_ANY_EMAIL from db.user_state import UserState +from shared.task_messaging import get_celery from web.security import get_token_or_user_auth, get_user_state from sqlalchemy.orm import Session from db import crud, schemas from db.database import get_db_dependency -from worker.main import create_archive_task from urllib.parse import urlparse url_router = APIRouter(prefix="/url", tags=["Single URL operations"]) +celery = get_celery() @url_router.post("/archive", status_code=201, summary="Submit a single URL archive request, starts an archiving task.", response_description="task_id for the archiving task, will match the archive id.") def archive_url( @@ -24,6 +25,7 @@ def archive_url( email=Depends(get_token_or_user_auth), db: Session = Depends(get_db_dependency) ) -> schemas.Task: + archive.author_id = email logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {archive.url}") parsed_url = urlparse(archive.url) @@ -39,15 +41,9 @@ def archive_url( if not user.has_quota_max_monthly_mbs(archive.group_id): raise HTTPException(status_code=429, detail="User has reached their monthly MB quota.") - # TODO: deprecate ArchiveCreate - backwards_compatible_archive = schemas.ArchiveCreate( - url=archive.url, - author_id=email, - group_id=archive.group_id, - public=archive.public, - ) + archive_create = schemas.ArchiveCreate(**archive.model_dump()) - task = create_archive_task.delay(backwards_compatible_archive.model_dump_json()) + task = celery.signature("create_archive_task", args=[archive_create.model_dump_json()]).delay() task_response = schemas.Task(id=task.id) return JSONResponse(task_response.model_dump(), status_code=201) diff --git a/src/shared/settings.py b/src/shared/settings.py index 5d4b843..f7383dd 100644 --- a/src/shared/settings.py +++ b/src/shared/settings.py @@ -16,6 +16,7 @@ class Settings(BaseSettings): SHEET_ORCHESTRATION_YAML : str = "secrets/orchestration-sheet.yaml" # cronjobs + #TODO: disable by default? CRON_ARCHIVE_SHEETS: bool = False CRON_DELETE_STALE_SHEETS: bool = True DELETE_STALE_SHEETS_DAYS: int = 14 diff --git a/src/shared/task_messaging.py b/src/shared/task_messaging.py new file mode 100644 index 0000000..4b2e000 --- /dev/null +++ b/src/shared/task_messaging.py @@ -0,0 +1,18 @@ + +from functools import lru_cache +from celery import Celery +import redis + +from shared.settings import get_settings + +@lru_cache +def get_celery(name:str="") -> Celery: + return Celery( + name, + broker_url=get_settings().CELERY_BROKER_URL, + result_backend=get_settings().CELERY_RESULT_BACKEND, + ) + + +def get_redis() -> redis.Redis: + return redis.Redis.from_url(get_settings().CELERY_BROKER_URL) diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 89160c1..33c2886 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -92,7 +92,7 @@ def client_with_auth(app_with_auth): @pytest.fixture() def app_with_token(app): - from web.security import token_api_key_auth,get_token_or_user_auth + from web.security import token_api_key_auth, get_token_or_user_auth app.dependency_overrides[token_api_key_auth] = lambda: ALLOW_ANY_EMAIL app.dependency_overrides[get_token_or_user_auth] = lambda: ALLOW_ANY_EMAIL return app diff --git a/src/tests/endpoints/test_sheet.py b/src/tests/endpoints/test_sheet.py index 129cd2a..81d2cc3 100644 --- a/src/tests/endpoints/test_sheet.py +++ b/src/tests/endpoints/test_sheet.py @@ -1,6 +1,6 @@ from datetime import datetime import json -from unittest.mock import patch +from unittest.mock import MagicMock, patch from fastapi.testclient import TestClient @@ -145,15 +145,21 @@ def test_delete_sheet_endpoint(client_with_auth, db_session): class TestArchiveUserSheetEndpoint: - @patch("worker.main.create_sheet_task.delay", return_value=TaskResult(id="123-taskid", status="PENDING", result="")) - def test_normal_flow(self, m1, client_with_auth, db_session): + @patch("endpoints.sheet.celery", return_value=MagicMock()) + def test_normal_flow(self, m_celery, client_with_auth, db_session): from db import models db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="spaceship", frequency="hourly")) db_session.commit() + + m_signature = MagicMock() + m_signature.delay.return_value = TaskResult(id="123-taskid", status="PENDING", result="") + m_celery.signature.return_value = m_signature + r = client_with_auth.post("/sheet/123-sheet-id/archive") assert r.status_code == 201 assert r.json() == {"id": "123-taskid"} - m1.assert_called_once() + m_celery.signature.assert_called_once() + m_signature.delay.assert_called_once() def test_token_auth(self, client_with_token, test_no_auth): test_no_auth(client_with_token.post, "/sheet/123-sheet-id/archive") @@ -198,23 +204,30 @@ class TestTokenArchiveEndpoint: assert r.status_code == 422 assert r.json() == {"detail": "sheet id is required"} - @patch("worker.main.create_sheet_task.delay", return_value=TaskResult(id="123-456-789", status="PENDING", result="")) - def test_normal_flow(self, m1, client_with_token): + @patch("endpoints.sheet.celery", return_value=MagicMock()) + def test_normal_flow(self, m_celery, client_with_token): + m_signature = MagicMock() + m_signature.delay.return_value = TaskResult(id="123-456-789", status="PENDING", result="") + m_celery.signature.return_value = m_signature # minimum data response = client_with_token.post("/sheet/archive", json={"sheet_id": "123-sheet-id"}) assert response.status_code == 201 assert response.json() == {'id': '123-456-789'} - m1.assert_called_once() - called_val = m1.call_args.args[0] - assert json.loads(called_val) == {"sheet_id": "123-sheet-id", "sheet_name": None, "public": False, "author_id": "api-endpoint", "group_id": None, "tags": [], "columns": {}, "header": 1} + m_celery.signature.assert_called_once() + m_signature.delay.assert_called_once() + called_val = m_celery.signature.call_args + assert called_val[0][0] == "create_sheet_task" + assert json.loads(called_val[1]['args'][0]) == {"sheet_id": "123-sheet-id", "sheet_name": None, "public": False, "author_id": "api-endpoint", "group_id": None, "tags": [], "columns": {}, "header": 1} # maximum data response = client_with_token.post("/sheet/archive", json={"sheet_id": "123-sheet-id", "sheet_name": "768-sheet-name", "author_id": "birdman@example.com", "header": 2, "public": True, "group_id": "456-group-id", "tags": ["tag1"], "columns": {"col1": "type1"}}) assert response.status_code == 201 assert response.json() == {'id': '123-456-789'} - m1.call_count == 2 - called_val = m1.call_args.args[0] - assert json.loads(called_val) == {"sheet_id": "123-sheet-id", "sheet_name": "768-sheet-name", "public": True, "author_id": "birdman@example.com", "group_id": "456-group-id", "tags": ["tag1"], "columns": {"col1": "type1"}, "header": 2} + m_celery.signature.call_count == 2 + m_signature.delay.call_count == 2 + called_val = m_celery.signature.call_args + assert called_val[0][0] == "create_sheet_task" + assert json.loads(called_val[1]['args'][0]) == {"sheet_id": "123-sheet-id", "sheet_name": "768-sheet-name", "public": True, "author_id": "birdman@example.com", "group_id": "456-group-id", "tags": ["tag1"], "columns": {"col1": "type1"}, "header": 2} diff --git a/src/tests/endpoints/test_url.py b/src/tests/endpoints/test_url.py index af198e3..3d85e71 100644 --- a/src/tests/endpoints/test_url.py +++ b/src/tests/endpoints/test_url.py @@ -3,13 +3,18 @@ from unittest.mock import MagicMock, patch from db.schemas import ArchiveCreate, TaskResult + def test_archive_url_unauthenticated(client, test_no_auth): test_no_auth(client.post, "/url/archive") @patch("endpoints.url.UserState") -@patch("worker.main.create_archive_task.delay", return_value=TaskResult(id="123-456-789", status="PENDING", result="")) -def test_archive_url(m1, m2, client_with_auth): +@patch("endpoints.url.celery", return_value=MagicMock()) +def test_archive_url(m_celery, m2, client_with_auth): + m_signature = MagicMock() + m_signature.delay.return_value = TaskResult(id="123-456-789", status="PENDING", result="") + m_celery.signature.return_value = m_signature + m_user_state = MagicMock() m2.return_value = m_user_state @@ -17,7 +22,7 @@ def test_archive_url(m1, m2, client_with_auth): response = client_with_auth.post("/url/archive", json={"url": "bad"}) assert response.status_code == 422 assert response.json()["detail"][0]["msg"] == 'String should have at least 5 characters' - m1.assert_not_called() + m_celery.signature.assert_not_called() # url is invalid response = client_with_auth.post("/url/archive", json={"url": "example.com"}) @@ -30,9 +35,11 @@ def test_archive_url(m1, m2, client_with_auth): response = client_with_auth.post("/url/archive", json={"url": "https://example.com"}) assert response.status_code == 201 assert response.json() == {'id': '123-456-789'} - m1.assert_called_once() - called_val = m1.call_args.args[0] - assert json.loads(called_val) == {"id": None, "url": "https://example.com", "result": None, "public": True, "author_id": "rick@example.com", "group_id": None, "tags": [], "rearchive": True, "sheet_id":None} + m_celery.signature.assert_called_once() + m_signature.delay.assert_called_once() + called_val = m_celery.signature.call_args + assert called_val[0][0] == "create_archive_task" + assert json.loads(called_val[1]['args'][0]) == {"id": None, "url": "https://example.com", "result": None, "public": True, "author_id": "rick@example.com", "group_id": None, "tags": [], "sheet_id": None} m_user_state.has_quota_max_monthly_urls.assert_called_once() m_user_state.has_quota_max_monthly_mbs.assert_called_once() @@ -48,9 +55,10 @@ def test_archive_url(m1, m2, client_with_auth): response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spaceship"}) assert response.status_code == 201 assert response.json() == {'id': '123-456-789'} - assert m1.call_count == 2 - called_val = m1.call_args.args[0] - assert json.loads(called_val)["group_id"] == "spaceship" + assert m_celery.signature.call_count == 2 + assert m_signature.delay.call_count == 2 + called_val = m_celery.signature.call_args + assert json.loads(called_val[1]['args'][0])["group_id"] == "spaceship" m_user_state.in_group.assert_called_with("spaceship") # user is over monthly URL quota @@ -68,6 +76,9 @@ def test_archive_url(m1, m2, client_with_auth): assert response.status_code == 429 assert response.json()["detail"] == "User has reached their monthly MB quota." m_user_state.has_quota_max_monthly_mbs.assert_called_with("spacesuit") + assert m_celery.signature.call_count == 2 + assert m_signature.delay.call_count == 2 + @patch("endpoints.url.UserState") def test_archive_url_quotas(m1, client_with_auth): @@ -89,15 +100,25 @@ def test_archive_url_quotas(m1, client_with_auth): assert response.json()["detail"] == "User has reached their monthly MB quota." m_user_state.has_quota_max_monthly_mbs.assert_called_once() -@patch("worker.main.create_archive_task.delay", return_value=TaskResult(id="123-456-789", status="PENDING", result="")) -def test_archive_url_with_api_token(m1, client_with_token): + +@patch("endpoints.url.celery", return_value=MagicMock()) +def test_archive_url_with_api_token(m_celery, client_with_token): + m_signature = MagicMock() + m_signature.delay.return_value = TaskResult(id="123-456-789", status="PENDING", result="") + m_celery.signature.return_value = m_signature response = client_with_token.post("/url/archive", json={"url": "https://example.com"}) assert response.status_code == 201 assert response.json() == {'id': '123-456-789'} + m_celery.signature.assert_called_once() + m_signature.delay.assert_called_once() + called_val = m_celery.signature.call_args + assert called_val[0][0] == "create_archive_task" + def test_search_by_url_unauthenticated(client, test_no_auth): test_no_auth(client.get, "/url/search") + def test_search_by_url(client_with_auth, client_with_token, db_session): # tests the search endpoint, including through some db data for the endpoint params response = client_with_auth.get("/url/search") @@ -111,7 +132,7 @@ def test_search_by_url(client_with_auth, client_with_token, db_session): from db import crud, schemas for i in range(11): crud.create_task(db_session, ArchiveCreate(id=f"url-456-{i}", url="https://example.com" if i < 10 else "https://something-else.com", result={}, public=True, author_id="rick@example.com", group_id=None), [], []) - #NB: this insertion is too fast for the ordering to be correct as they are within the same second + # NB: this insertion is too fast for the ordering to be correct as they are within the same second response = client_with_auth.get("/url/search?url=https://example.com") assert response.status_code == 200 @@ -142,6 +163,7 @@ def test_search_by_url(client_with_auth, client_with_token, db_session): assert response.status_code == 200 assert len(response.json()) == 10 + @patch("endpoints.url.UserState") def test_search_no_read_access(mock_user_state, client_with_auth): mock_user_state.return_value.read = False diff --git a/src/tests/worker/test_worker_main.py b/src/tests/worker/test_worker_main.py index ffef233..3550173 100644 --- a/src/tests/worker/test_worker_main.py +++ b/src/tests/worker/test_worker_main.py @@ -21,7 +21,7 @@ class Test_create_archive_task(): @patch("worker.main.insert_result_into_db") @patch("worker.main.is_group_invalid_for_user", return_value=None) - @patch("worker.main.choose_orchestrator") + # @patch("worker.main.choose_orchestrator") @patch("celery.app.task.Task.request") def test_success(self, m_req, m_choose, m_is_group, m_insert, worker_init, db_session): from worker.main import create_archive_task @@ -46,7 +46,7 @@ class Test_create_archive_task(): @patch("worker.main.insert_result_into_db", side_effect=Exception) @patch("worker.main.is_group_invalid_for_user", return_value=False) - @patch("worker.main.choose_orchestrator") + # @patch("worker.main.choose_orchestrator") def test_raise_db_error(self, m_choose, m_is_group, m_insert, worker_init): from worker.main import create_archive_task mock_orchestrator = self.mock_orchestrator_choice(m_choose) @@ -123,47 +123,6 @@ class Test_create_sheet_task(): assert db_session.query(models.Archive).filter(models.Archive.url == self.URL).count() == 0 -def test_choose_orchestrator(worker_init): - from worker.main import choose_orchestrator - - assert choose_orchestrator(None, "rick@example.com").__class__.__name__ == "ArchivingOrchestrator" - - -@patch("worker.main.get_user_first_group", return_value="does-not-exist") -def test_choose_orchestrator_assertion(worker_init): - from worker.main import choose_orchestrator - - with pytest.raises(Exception): - choose_orchestrator(None, "rick@example.com") - - -@patch("worker.main.read_user_groups") -def test_get_user_first_group(m_read_user_groups, worker_init): - from worker.main import get_user_first_group - - m_read_user_groups.return_value = {"users": {}} - assert get_user_first_group("email1") == "default" - m_read_user_groups.return_value = {"users": {"email1": []}} - assert get_user_first_group("email1") == "default" - m_read_user_groups.return_value = {"users": {"email1": ["group1", "group2"]}} - assert get_user_first_group("email1") == "group1" - - -def test_is_group_invalid_for_user(worker_init, db_session): - from worker.main import is_group_invalid_for_user - from db.crud import upsert_user_groups - - upsert_user_groups(db_session) - - assert is_group_invalid_for_user(True, "", "") == False - assert is_group_invalid_for_user(False, "", "") == False - - assert is_group_invalid_for_user(False, "default", "") == "User is not part of default, no permission" - assert is_group_invalid_for_user(False, "spaceship", "jerry@example.com") == "User jerry@example.com is not part of spaceship, no permission" - - assert is_group_invalid_for_user(False, "spaceship", "rick@example.com") == False - - def test_get_all_urls(worker_init, db_session): from worker.main import get_all_urls from auto_archiver import Metadata diff --git a/src/utils/metrics.py b/src/utils/metrics.py index 8d513e6..05c0bb0 100644 --- a/src/utils/metrics.py +++ b/src/utils/metrics.py @@ -8,18 +8,19 @@ import redis from db import crud from db.database import get_db from core.logging import log_error +from shared.task_messaging import get_redis # Custom metrics EXCEPTION_COUNTER = Counter( "exceptions", "Number of times a certain exception has occurred.", - labelnames=["types"] + labelnames=["type"] ) WORKER_EXCEPTION = Counter( "worker_exceptions_total", "Number of times a certain exception has occurred on the worker.", - labelnames=["types", "exception", "task", "traceback"] + labelnames=["type", "exception", "task", "traceback"] ) DISK_UTILIZATION = Gauge( "disk_utilization", @@ -38,16 +39,16 @@ DATABASE_METRICS_COUNTER = Counter( ) -async def redis_subscribe_worker_exceptions(REDIS_EXCEPTIONS_CHANNEL, CELERY_BROKER_URL): +async def redis_subscribe_worker_exceptions(REDIS_EXCEPTIONS_CHANNEL): # Subscribe to Redis channel and increment the counter for each exception with info on the exception and task - Rdis = redis.Redis.from_url(CELERY_BROKER_URL) - PubSubExceptions = Rdis.pubsub() + Redis = get_redis() + PubSubExceptions = Redis.pubsub() PubSubExceptions.subscribe(REDIS_EXCEPTIONS_CHANNEL) while True: message = PubSubExceptions.get_message() if message and message["type"] == "message": data = json.loads(message["data"].decode("utf-8")) - WORKER_EXCEPTION.labels(types=type(data["exception"]).__name__, exception=data["exception"], task=data["task"], traceback=data["traceback"]).inc() + WORKER_EXCEPTION.labels(type=data["type"], exception=data["exception"], task=data["task"], traceback=data["traceback"]).inc() await asyncio.sleep(1) diff --git a/src/web/main.py b/src/web/main.py index f2020f0..8b721f9 100644 --- a/src/web/main.py +++ b/src/web/main.py @@ -12,7 +12,8 @@ from sqlalchemy.orm import Session from loguru import logger from core.logging import logging_middleware, log_error -from worker.main import create_archive_task, create_sheet_task, celery, insert_result_into_db +from shared.task_messaging import get_celery +from worker.main import insert_result_into_db from db import crud, models, schemas from web.security import get_user_auth, token_api_key_auth, get_token_or_user_auth @@ -23,8 +24,13 @@ from shared.settings import get_settings from auto_archiver import Metadata -from endpoints import default_router, url_router, sheet_router, task_router, interoperability_router +from endpoints.default import default_router +from endpoints.url import url_router +from endpoints.sheet import sheet_router +from endpoints.task import task_router +from endpoints.interoperability import interoperability_router +celery = get_celery() def app_factory(settings = get_settings()): app = FastAPI( @@ -84,7 +90,8 @@ def app_factory(settings = get_settings()): if type(url) != str or len(url) <= 5: raise HTTPException(status_code=422, detail=f"Invalid URL received: {url}") logger.info("creating task") - task = create_archive_task.delay(archive.model_dump_json()) + + task = celery.signature("create_archive_task", args=[archive.model_dump_json()]).delay() return JSONResponse({"id": task.id}) @@ -139,7 +146,7 @@ def app_factory(settings = get_settings()): raise HTTPException(status_code=422, detail=f"sheet name or id is required") if not crud.is_user_in_group(db, email, sheet.group_id): raise HTTPException(status_code=403, detail="User does not have access to this group.") - task = create_sheet_task.delay(sheet.model_dump_json()) + task = celery.signature("create_sheet_task", args=[sheet.model_dump_json()]).delay() return JSONResponse({"id": task.id}) @@ -149,7 +156,8 @@ def app_factory(settings = get_settings()): sheet.author_id = sheet.author_id or "api-endpoint" if not sheet.sheet_name and not sheet.sheet_id: raise HTTPException(status_code=422, detail=f"sheet name or id is required") - task = create_sheet_task.delay(sheet.model_dump_json()) + + task = celery.signature("create_sheet_task", args=[sheet.model_dump_json()]).delay() return JSONResponse({"id": task.id}) # ----- endpoint to submit data archived elsewhere diff --git a/src/worker/main.py b/src/worker/main.py index 2896a23..27b261a 100644 --- a/src/worker/main.py +++ b/src/worker/main.py @@ -1,78 +1,56 @@ -from functools import lru_cache import traceback, yaml, datetime from typing import List, Set -from celery import Celery -from celery.signals import task_failure, worker_init +from celery.signals import task_failure from auto_archiver import Config, ArchivingOrchestrator, Metadata from auto_archiver.core import Media from loguru import logger from db import crud, schemas, models from db.database import get_db +from shared.task_messaging import get_celery, get_redis from shared.settings import get_settings import json -import redis from sqlalchemy import exc from core.logging import log_error + settings = get_settings() +celery = get_celery("worker") +Redis = get_redis() -celery = Celery(__name__) -celery.conf.broker_url = settings.CELERY_BROKER_URL -celery.conf.result_backend = settings.CELERY_RESULT_BACKEND USER_GROUPS_FILENAME = settings.USER_GROUPS_FILENAME -Rdis = redis.Redis.from_url(celery.conf.broker_url) - -@celery.task(name="create_archive_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 3}) +@celery.task(name="create_archive_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 0}) def create_archive_task(self, archive_json: str): + logger.info(archive_json) archive = schemas.ArchiveCreate.model_validate_json(archive_json) - logger.info(f"Archiving {archive.url=} {archive.tags=} {archive.public=} {archive.group_id=} {archive.author_id=}") - #TODO: move group checks out of here - invalid = is_group_invalid_for_user(archive.public, archive.group_id, archive.author_id) - if invalid: - raise Exception(invalid) # marks task FAILED, saves the Exception as result - url = archive.url - logger.info(f"{url=} {archive=}") + # call auto-archiver + orchestrator = load_orchestrator(archive.group_id) + result = orchestrator.feed_item(Metadata().set_url(archive.url)) - # TODO: re-evaluate if this logic is to be used - if not archive.rearchive: - with get_db() as session: - archives = crud.search_archives_by_url(session, url, archive.author_id, absolute_search=True) - if len(archives): - logger.info(f"Skipping {url=} as it was already archived") - return Metadata.choose_most_complete([a.result for a in archives]) + # prepare for DB + assert result, f"UNABLE TO archive: {archive.url}" + archive.id = self.request.id + archive.urls = get_all_urls(result) + archive.result = json.loads(result.to_json()) - orchestrator = choose_orchestrator(archive.group_id, archive.author_id) - logger.info(f"Using orchestrator {orchestrator=}") - result = orchestrator.feed_item(Metadata().set_url(url)) + insert_result_into_db(archive) + return archive.result.to_dict() # TODO: is return used? - try: - insert_result_into_db(result, archive.tags, archive.public, archive.group_id, archive.author_id, self.request.id) - except Exception as e: - # Log it, then raise again to store the error as the task result - log_error(e) - redis_publish_exception(e, self.name, traceback.format_exc()) - raise e - return result.to_dict() -#TODO: refactor how user-groups are loaded and orchestrators chosen -@celery.task(name="create_sheet_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 0}) +@celery.task(name="create_sheet_task", bind=True) def create_sheet_task(self, sheet_json: str): sheet = schemas.SubmitSheet.model_validate_json(sheet_json) sheet.tags.add("gsheet") logger.info(f"SHEET START {sheet=}") - config = Config() - # TODO: use choose_orchestrator and overwrite the feeder # TODO: drop sheet_name and use only sheet_id (new endpoints/models) - config.parse(use_cli=False, yaml_config_filename=get_settings().SHEET_ORCHESTRATION_YAML, overwrite_configs={"configurations": {"gsheet_feeder": {"sheet": sheet.sheet_name, "sheet_id": sheet.sheet_id, "header": sheet.header}}}) - orchestrator = ArchivingOrchestrator(config) + orchestrator = load_orchestrator(sheet.group_id, {"configurations": {"gsheet_feeder": {"sheet": sheet.sheet_name, "sheet_id": sheet.sheet_id, "header": sheet.header}}}) stats = {"archived": 0, "failed": 0, "errors": []} for result in orchestrator.feed(): @@ -80,8 +58,8 @@ def create_sheet_task(self, sheet_json: str): logger.error("Got empty result from feeder, an internal error must have occurred.") continue try: - #TODO: remove public from sheet in new refactor - #TODO: update the sheets table with the current date if any new archive was done + # TODO: remove public from sheet in new refactor + #TODO: use new insert_result_into_db insert_result_into_db(result, sheet.tags, sheet.public, sheet.group_id, sheet.author_id, models.generate_uuid(), sheet.sheet_id) stats["archived"] += 1 except exc.IntegrityError as e: @@ -97,26 +75,20 @@ def create_sheet_task(self, sheet_json: str): crud.update_sheet_last_url_archived_at(session, sheet.sheet_id) logger.info(f"SHEET DONE {sheet=}") + # TODO: use data model return {"success": True, "sheet": sheet.sheet_name, "sheet_id": sheet.sheet_id, "time": datetime.datetime.now().isoformat(), **stats} @task_failure.connect(sender=create_sheet_task) @task_failure.connect(sender=create_archive_task) def task_failure_notifier(sender, **kwargs): + # automatically capture exceptions in the worker tasks + logger.warning(f"⚠️ worker task failed: {sender.name}") traceback_msg = "\n".join(traceback.format_list(traceback.extract_tb(kwargs['traceback']))) - logger.warning("😅 From task_failure_notifier ==> Task failed successfully!") log_error(kwargs['exception'], traceback_msg, f"task_failure: {sender.name}") redis_publish_exception(kwargs['exception'], sender.name, traceback_msg) -def choose_orchestrator(group, email): - global ORCHESTRATORS - if group not in ORCHESTRATORS: group = get_user_first_group(email) - assert group in ORCHESTRATORS, f"{group=} not in configurations" - logger.info(f"CHOOSE Orchestrator for {group=}, {email=}") - return ArchivingOrchestrator(ORCHESTRATORS.get(group)) - - def read_user_groups(): # read yaml safely with open(USER_GROUPS_FILENAME) as inf: @@ -127,52 +99,28 @@ def read_user_groups(): raise e -def get_user_first_group(email): - user_groups_yaml = read_user_groups() - groups = user_groups_yaml.get("users", {}).get(email, []) - if groups != None and len(groups): - return groups[0] - return "default" - - -def load_orchestrators(): - global ORCHESTRATORS - ORCHESTRATORS = {} - """ - reads the orchestrators key in the config file to load different orchestrators for different groups - """ - user_groups_yaml = read_user_groups() - - orchestrators_config = user_groups_yaml.get("orchestrators", {}) - assert len(orchestrators_config), f"No orchestrators key found in {USER_GROUPS_FILENAME}. please see the example file" - assert "default" in orchestrators_config, "please include a 'default' orchestrator to be used when the user has no group" - logger.debug(f"Found {len(orchestrators_config)} group orchestrators.") - - for group, config_filename in orchestrators_config.items(): - config = Config() - config.parse(use_cli=False, yaml_config_filename=config_filename) - ORCHESTRATORS[group] = config - return ORCHESTRATORS - - -def is_group_invalid_for_user(public: bool, group_id: str, author_id: str): - """ - ensures that, if a group is specified, the user belongs to it. - if public is true the requirement is not needed - returns an error message if invalid, or False if all is good. - """ - if public: return False - if not group_id or len(group_id) == 0: return False - - # otherwise group must match +def load_orchestrator(group_id: str, overwrite_configs: dict = {}) -> ArchivingOrchestrator: with get_db() as session: - if not crud.is_user_in_group(session, author_id, group_id): - logger.error(em := f"User {author_id} is not part of {group_id}, no permission") - return em - return False + orchestrator_fn = crud.get_group(session, group_id).orchestrator + assert orchestrator_fn, f"no orchestrator found for {group_id}" + + config = Config() + config.parse(use_cli=False, yaml_config_filename=orchestrator_fn, overwrite_configs=overwrite_configs) + return ArchivingOrchestrator(config) -def insert_result_into_db(result: Metadata, tags: Set[str], public: bool, group_id: str, author_id: str, task_id: str, sheet_id:str="") -> str: +def insert_result_into_db(archive: schemas.ArchiveCreate) -> str: + with get_db() as session: + # create and load user, tags, if needed + crud.create_or_get_user(session, archive.author_id) + db_tags = [crud.create_tag(session, tag) for tag in archive.tags] + # insert everything + db_task = crud.create_task(session, task=archive, tags=db_tags, urls=archive.urls) + logger.debug(f"Added {db_task.id=} to database on {db_task.created_at} ({db_task.author_id})") + return db_task.id + + +def insert_result_into_db(result: Metadata, tags: Set[str], public: bool, group_id: str, author_id: str, task_id: str, sheet_id: str = "") -> str: logger.info(f"INSERTING {public=} {group_id=} {author_id=} {tags=} into {task_id}") assert result, f"UNABLE TO archive: {result.get_url() if result else result}" with get_db() as session: @@ -186,7 +134,7 @@ def insert_result_into_db(result: Metadata, tags: Set[str], public: bool, group_ logger.debug(f"Added {db_task.id=} to database on {db_task.created_at} ({db_task.author_id})") return db_task.id - +# TODO: this should live within the auto-archiver def get_all_urls(result: Metadata) -> List[models.ArchiveUrl]: db_urls = [] for m in result.media: @@ -202,6 +150,7 @@ def get_all_urls(result: Metadata) -> List[models.ArchiveUrl]: return db_urls +# TODO: this should live within the auto-archiver?? def convert_if_media(media): if isinstance(media, Media): return media elif isinstance(media, dict): @@ -214,24 +163,7 @@ def convert_if_media(media): def redis_publish_exception(exception, task_name, traceback: str = ""): REDIS_EXCEPTIONS_CHANNEL = settings.REDIS_EXCEPTIONS_CHANNEL try: - Rdis.publish(REDIS_EXCEPTIONS_CHANNEL, json.dumps({"exception": exception, "task": task_name, "traceback": traceback}, default=str)) + exception_data = {"task": task_name, "type": exception.__class__.__name__, "exception": exception, "traceback": traceback} + Redis.publish(REDIS_EXCEPTIONS_CHANNEL, json.dumps(exception_data, default=str)) except Exception as e: - log_error(e, f"[CRITICAL] Could not publish to {REDIS_EXCEPTIONS_CHANNEL}") - - -@worker_init.connect -def at_start(sender, **kwargs): - global ORCHESTRATORS - ORCHESTRATORS = {} - load_orchestrators() - logger.info("Orchestrators loaded successfully.") - -@lru_cache -def get_url_orchestrator(group_name): - with get_db() as db: - group = crud.get_group(db, group_name) - assert group, f"Group {group_name} not found" - - # config = Config() - # config.parse(use_cli=False, yaml_config_filename=group.orchestrator_sheet) - # return ArchivingOrchestrator(config) \ No newline at end of file + log_error(e, f"[CRITICAL] Could not publish to {REDIS_EXCEPTIONS_CHANNEL}") \ No newline at end of file