WIP decoupling worker/web, cleaning worker code

This commit is contained in:
msramalho
2025-02-08 00:40:35 +00:00
parent 46a5c1a260
commit 9a62f3ff59
19 changed files with 194 additions and 247 deletions

View File

@@ -6,15 +6,15 @@ from fastapi import FastAPI
from contextlib import asynccontextmanager
from fastapi_utils.tasks import repeat_every
from loguru import logger
from sqlalchemy import text
from db import crud, models, schemas
from db.database import get_db, get_db_async, make_engine, wal_checkpoint
from shared.settings import get_settings
from shared.task_messaging import get_celery
from utils.metrics import measure_regular_metrics, redis_subscribe_worker_exceptions
from worker.main import create_sheet_task
from fastapi_mail import FastMail, MessageSchema, MessageType
celery = get_celery()
@asynccontextmanager
async def lifespan(app: FastAPI):
@@ -25,7 +25,7 @@ async def lifespan(app: FastAPI):
models.Base.metadata.create_all(bind=engine)
alembic.config.main(argv=['--raiseerr', 'upgrade', 'head'])
logging.getLogger("uvicorn.access").disabled = True # loguru
asyncio.create_task(redis_subscribe_worker_exceptions(get_settings().REDIS_EXCEPTIONS_CHANNEL, get_settings().CELERY_BROKER_URL))
asyncio.create_task(redis_subscribe_worker_exceptions(get_settings().REDIS_EXCEPTIONS_CHANNEL))
asyncio.create_task(repeat_measure_regular_metrics())
with get_db() as db:
crud.upsert_user_groups(db)
@@ -72,7 +72,9 @@ async def archive_sheets_cronjob(frequency: str, interval: int, current_time_uni
async with get_db_async() as db:
sheets = await crud.get_sheets_by_id_hash(db, frequency, interval, current_time_unit)
for s in sheets:
task = create_sheet_task.apply_async(args=[schemas.SubmitSheet(sheet_id=s.id, author_id=s.author_id, group=s.group_id).model_dump_json()])
task = celery.signature("create_sheet_task", args=[schemas.SubmitSheet(sheet_id=s.id, author_id=s.author_id, group=s.group_id).model_dump_json()]).apply_async()
triggered_jobs.append({"sheet_id": s.id, "task_id": task.id})
logger.info(f"[CRON {frequency.upper()}:{current_time_unit}] Triggered {len(triggered_jobs)} sheet tasks: {triggered_jobs}")

View File

@@ -9,7 +9,6 @@ logger.add("logs/error_logs.log", retention="30 days", level="ERROR")
def log_error(e: Exception, traceback_str: str = None, extra:str = ""):
# EXCEPTION_COUNTER.labels(type(e).__name__).inc()
if not traceback_str: traceback_str = traceback.format_exc()
if extra: extra = f"{extra}\n"
logger.error(f"{extra}{e.__class__.__name__}: {e}\n{traceback_str}")
@@ -21,6 +20,6 @@ async def logging_middleware(request: Request, call_next):
return response
except Exception as e:
from utils.metrics import EXCEPTION_COUNTER
EXCEPTION_COUNTER.labels(type(e).__name__).inc()
EXCEPTION_COUNTER.labels(type=e.__class__.__name__).inc()
log_error(e)
raise e

View File

@@ -100,7 +100,7 @@ def base_query(db: Session):
# --------------- TAG
def create_tag(db: Session, tag: str):
def create_tag(db: Session, tag: str) -> models.Tag:
db_tag = db.query(models.Tag).filter(models.Tag.id == tag).first()
if not db_tag:
db_tag = models.Tag(id=tag)

View File

@@ -48,7 +48,7 @@ async def make_async_engine(database_url: str) -> AsyncEngine:
engine = create_async_engine(database_url, connect_args={"check_same_thread": False})
async with engine.begin() as conn:
await conn.run_sync(lambda sync_conn: sync_conn.execute("PRAGMA journal_mode=WAL;"))
await conn.run_sync(lambda sync_conn: sync_conn.execute(text("PRAGMA journal_mode=WAL;")))
return engine

View File

@@ -11,35 +11,13 @@ class Tag(BaseModel):
model_config = {"from_attributes": True}
__hash__ = object.__hash__
class ArchiveCreate(BaseModel):
id: str | None = None
url: str
result: dict | None = None
public: bool = True
author_id: str | None = None
group_id: str | None = None
tags: set[Tag] | None = set()
rearchive: bool = True
sheet_id: str | None = None
# urls: list = []
class Archive(ArchiveCreate):
created_at: datetime
updated_at: datetime | None
deleted: bool
model_config = {"from_attributes": True}
class SubmitSheet(BaseModel):
sheet_name: str | None = None
sheet_id: str | None = None
header: int = 1
public: bool = False
author_id: str | None = None
group_id: str | None = None
group_id: str | None
tags: set[str] | None = set()
columns: dict | None = {} # TODO: implement
@@ -103,10 +81,25 @@ class SheetResponse(SheetAdd):
class ArchiveTrigger(BaseModel):
author_id: str | None = None
url: Annotated[str, Len(min_length=5)]
public: bool = True
group_id: Annotated[str, Len(min_length=1)] | None = None
tags: set[Tag] | None = set()
public: bool = False
group_id: Annotated[str, Len(min_length=1)] = "default"
tags: set[Tag] | None = None
class ArchiveCreate(ArchiveTrigger):
id: str | None = None
result: dict | None = None
sheet_id: str | None = None
urls: list | None = None
class Archive(ArchiveCreate):
created_at: datetime
updated_at: datetime | None
deleted: bool
model_config = {"from_attributes": True}
class Usage(BaseModel):
monthly_urls: int = 0

View File

@@ -260,6 +260,9 @@ class UserState:
else:
if group_id not in self.permissions: return False
quota = self.permissions[group_id].max_monthly_urls
if quota == -1:
return True
current_month = datetime.now().month
current_year = datetime.now().year
@@ -282,6 +285,9 @@ class UserState:
if group_id not in self.permissions: return False
quota = self.permissions[group_id].max_monthly_mbs
if quota == -1:
return True
current_month = datetime.now().month
current_year = datetime.now().year

View File

@@ -1,5 +0,0 @@
from endpoints.default import default_router
from endpoints.url import url_router
from endpoints.task import task_router
from endpoints.interoperability import interoperability_router
from endpoints.sheet import sheet_router

View File

@@ -6,13 +6,14 @@ from sqlalchemy import exc
from sqlalchemy.orm import Session
from db.user_state import UserState
from shared.task_messaging import get_celery
from web.security import token_api_key_auth, get_user_state
from db import schemas, crud
from db.database import get_db_dependency
from worker.main import create_sheet_task
sheet_router = APIRouter(prefix="/sheet", tags=["Google Spreadsheet operations"])
celery = get_celery()
@sheet_router.post("/create", status_code=201, summary="Store a new Google Sheet for regular archiving.")
def create_sheet(
@@ -73,7 +74,7 @@ def archive_user_sheet(
if not user.can_manually_trigger(sheet.group_id):
raise HTTPException(status_code=429, detail="User cannot manually trigger sheet archiving in this group.")
task = create_sheet_task.delay(schemas.SubmitSheet(sheet_id=id, author_id=user.email, group=sheet.group_id).model_dump_json())
task = celery.signature("create_sheet_task", args=[schemas.SubmitSheet(sheet_id=id, author_id=user.email, group=sheet.group_id).model_dump_json()]).delay()
return JSONResponse({"id": task.id}, status_code=201)
@@ -86,5 +87,5 @@ def archive_sheet(
sheet.author_id = sheet.author_id or "api-endpoint"
if not sheet.sheet_id:
raise HTTPException(status_code=422, detail=f"sheet id is required")
task = create_sheet_task.delay(sheet.model_dump_json())
task = celery.signature("create_sheet_task", args=[sheet.model_dump_json()]).delay()
return JSONResponse({"id": task.id}, status_code=201)

View File

@@ -4,16 +4,17 @@ from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from loguru import logger
from shared.task_messaging import get_celery
from web.security import get_token_or_user_auth
from db import schemas
from core.logging import log_error
from worker.main import celery
from utils.misc import custom_jsonable_encoder
task_router = APIRouter(prefix="/task", tags=["Async task operations"])
celery = get_celery()
@task_router.get("/{task_id}", summary="Check the status of an async task by its id, works for URLs and Sheet tasks.")
def get_status(task_id, email=Depends(get_token_or_user_auth)) -> schemas.TaskResult:

View File

@@ -6,17 +6,18 @@ from datetime import datetime
from loguru import logger
from core.config import ALLOW_ANY_EMAIL
from db.user_state import UserState
from shared.task_messaging import get_celery
from web.security import get_token_or_user_auth, get_user_state
from sqlalchemy.orm import Session
from db import crud, schemas
from db.database import get_db_dependency
from worker.main import create_archive_task
from urllib.parse import urlparse
url_router = APIRouter(prefix="/url", tags=["Single URL operations"])
celery = get_celery()
@url_router.post("/archive", status_code=201, summary="Submit a single URL archive request, starts an archiving task.", response_description="task_id for the archiving task, will match the archive id.")
def archive_url(
@@ -24,6 +25,7 @@ def archive_url(
email=Depends(get_token_or_user_auth),
db: Session = Depends(get_db_dependency)
) -> schemas.Task:
archive.author_id = email
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {archive.url}")
parsed_url = urlparse(archive.url)
@@ -39,15 +41,9 @@ def archive_url(
if not user.has_quota_max_monthly_mbs(archive.group_id):
raise HTTPException(status_code=429, detail="User has reached their monthly MB quota.")
# TODO: deprecate ArchiveCreate
backwards_compatible_archive = schemas.ArchiveCreate(
url=archive.url,
author_id=email,
group_id=archive.group_id,
public=archive.public,
)
archive_create = schemas.ArchiveCreate(**archive.model_dump())
task = create_archive_task.delay(backwards_compatible_archive.model_dump_json())
task = celery.signature("create_archive_task", args=[archive_create.model_dump_json()]).delay()
task_response = schemas.Task(id=task.id)
return JSONResponse(task_response.model_dump(), status_code=201)

View File

@@ -16,6 +16,7 @@ class Settings(BaseSettings):
SHEET_ORCHESTRATION_YAML : str = "secrets/orchestration-sheet.yaml"
# cronjobs
#TODO: disable by default?
CRON_ARCHIVE_SHEETS: bool = False
CRON_DELETE_STALE_SHEETS: bool = True
DELETE_STALE_SHEETS_DAYS: int = 14

View File

@@ -0,0 +1,18 @@
from functools import lru_cache
from celery import Celery
import redis
from shared.settings import get_settings
@lru_cache
def get_celery(name:str="") -> Celery:
return Celery(
name,
broker_url=get_settings().CELERY_BROKER_URL,
result_backend=get_settings().CELERY_RESULT_BACKEND,
)
def get_redis() -> redis.Redis:
return redis.Redis.from_url(get_settings().CELERY_BROKER_URL)

View File

@@ -92,7 +92,7 @@ def client_with_auth(app_with_auth):
@pytest.fixture()
def app_with_token(app):
from web.security import token_api_key_auth,get_token_or_user_auth
from web.security import token_api_key_auth, get_token_or_user_auth
app.dependency_overrides[token_api_key_auth] = lambda: ALLOW_ANY_EMAIL
app.dependency_overrides[get_token_or_user_auth] = lambda: ALLOW_ANY_EMAIL
return app

View File

@@ -1,6 +1,6 @@
from datetime import datetime
import json
from unittest.mock import patch
from unittest.mock import MagicMock, patch
from fastapi.testclient import TestClient
@@ -145,15 +145,21 @@ def test_delete_sheet_endpoint(client_with_auth, db_session):
class TestArchiveUserSheetEndpoint:
@patch("worker.main.create_sheet_task.delay", return_value=TaskResult(id="123-taskid", status="PENDING", result=""))
def test_normal_flow(self, m1, client_with_auth, db_session):
@patch("endpoints.sheet.celery", return_value=MagicMock())
def test_normal_flow(self, m_celery, client_with_auth, db_session):
from db import models
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="spaceship", frequency="hourly"))
db_session.commit()
m_signature = MagicMock()
m_signature.delay.return_value = TaskResult(id="123-taskid", status="PENDING", result="")
m_celery.signature.return_value = m_signature
r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == 201
assert r.json() == {"id": "123-taskid"}
m1.assert_called_once()
m_celery.signature.assert_called_once()
m_signature.delay.assert_called_once()
def test_token_auth(self, client_with_token, test_no_auth):
test_no_auth(client_with_token.post, "/sheet/123-sheet-id/archive")
@@ -198,23 +204,30 @@ class TestTokenArchiveEndpoint:
assert r.status_code == 422
assert r.json() == {"detail": "sheet id is required"}
@patch("worker.main.create_sheet_task.delay", return_value=TaskResult(id="123-456-789", status="PENDING", result=""))
def test_normal_flow(self, m1, client_with_token):
@patch("endpoints.sheet.celery", return_value=MagicMock())
def test_normal_flow(self, m_celery, client_with_token):
m_signature = MagicMock()
m_signature.delay.return_value = TaskResult(id="123-456-789", status="PENDING", result="")
m_celery.signature.return_value = m_signature
# minimum data
response = client_with_token.post("/sheet/archive", json={"sheet_id": "123-sheet-id"})
assert response.status_code == 201
assert response.json() == {'id': '123-456-789'}
m1.assert_called_once()
called_val = m1.call_args.args[0]
assert json.loads(called_val) == {"sheet_id": "123-sheet-id", "sheet_name": None, "public": False, "author_id": "api-endpoint", "group_id": None, "tags": [], "columns": {}, "header": 1}
m_celery.signature.assert_called_once()
m_signature.delay.assert_called_once()
called_val = m_celery.signature.call_args
assert called_val[0][0] == "create_sheet_task"
assert json.loads(called_val[1]['args'][0]) == {"sheet_id": "123-sheet-id", "sheet_name": None, "public": False, "author_id": "api-endpoint", "group_id": None, "tags": [], "columns": {}, "header": 1}
# maximum data
response = client_with_token.post("/sheet/archive", json={"sheet_id": "123-sheet-id", "sheet_name": "768-sheet-name", "author_id": "birdman@example.com", "header": 2, "public": True, "group_id": "456-group-id", "tags": ["tag1"], "columns": {"col1": "type1"}})
assert response.status_code == 201
assert response.json() == {'id': '123-456-789'}
m1.call_count == 2
called_val = m1.call_args.args[0]
assert json.loads(called_val) == {"sheet_id": "123-sheet-id", "sheet_name": "768-sheet-name", "public": True, "author_id": "birdman@example.com", "group_id": "456-group-id", "tags": ["tag1"], "columns": {"col1": "type1"}, "header": 2}
m_celery.signature.call_count == 2
m_signature.delay.call_count == 2
called_val = m_celery.signature.call_args
assert called_val[0][0] == "create_sheet_task"
assert json.loads(called_val[1]['args'][0]) == {"sheet_id": "123-sheet-id", "sheet_name": "768-sheet-name", "public": True, "author_id": "birdman@example.com", "group_id": "456-group-id", "tags": ["tag1"], "columns": {"col1": "type1"}, "header": 2}

View File

@@ -3,13 +3,18 @@ from unittest.mock import MagicMock, patch
from db.schemas import ArchiveCreate, TaskResult
def test_archive_url_unauthenticated(client, test_no_auth):
test_no_auth(client.post, "/url/archive")
@patch("endpoints.url.UserState")
@patch("worker.main.create_archive_task.delay", return_value=TaskResult(id="123-456-789", status="PENDING", result=""))
def test_archive_url(m1, m2, client_with_auth):
@patch("endpoints.url.celery", return_value=MagicMock())
def test_archive_url(m_celery, m2, client_with_auth):
m_signature = MagicMock()
m_signature.delay.return_value = TaskResult(id="123-456-789", status="PENDING", result="")
m_celery.signature.return_value = m_signature
m_user_state = MagicMock()
m2.return_value = m_user_state
@@ -17,7 +22,7 @@ def test_archive_url(m1, m2, client_with_auth):
response = client_with_auth.post("/url/archive", json={"url": "bad"})
assert response.status_code == 422
assert response.json()["detail"][0]["msg"] == 'String should have at least 5 characters'
m1.assert_not_called()
m_celery.signature.assert_not_called()
# url is invalid
response = client_with_auth.post("/url/archive", json={"url": "example.com"})
@@ -30,9 +35,11 @@ def test_archive_url(m1, m2, client_with_auth):
response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
assert response.status_code == 201
assert response.json() == {'id': '123-456-789'}
m1.assert_called_once()
called_val = m1.call_args.args[0]
assert json.loads(called_val) == {"id": None, "url": "https://example.com", "result": None, "public": True, "author_id": "rick@example.com", "group_id": None, "tags": [], "rearchive": True, "sheet_id":None}
m_celery.signature.assert_called_once()
m_signature.delay.assert_called_once()
called_val = m_celery.signature.call_args
assert called_val[0][0] == "create_archive_task"
assert json.loads(called_val[1]['args'][0]) == {"id": None, "url": "https://example.com", "result": None, "public": True, "author_id": "rick@example.com", "group_id": None, "tags": [], "sheet_id": None}
m_user_state.has_quota_max_monthly_urls.assert_called_once()
m_user_state.has_quota_max_monthly_mbs.assert_called_once()
@@ -48,9 +55,10 @@ def test_archive_url(m1, m2, client_with_auth):
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spaceship"})
assert response.status_code == 201
assert response.json() == {'id': '123-456-789'}
assert m1.call_count == 2
called_val = m1.call_args.args[0]
assert json.loads(called_val)["group_id"] == "spaceship"
assert m_celery.signature.call_count == 2
assert m_signature.delay.call_count == 2
called_val = m_celery.signature.call_args
assert json.loads(called_val[1]['args'][0])["group_id"] == "spaceship"
m_user_state.in_group.assert_called_with("spaceship")
# user is over monthly URL quota
@@ -68,6 +76,9 @@ def test_archive_url(m1, m2, client_with_auth):
assert response.status_code == 429
assert response.json()["detail"] == "User has reached their monthly MB quota."
m_user_state.has_quota_max_monthly_mbs.assert_called_with("spacesuit")
assert m_celery.signature.call_count == 2
assert m_signature.delay.call_count == 2
@patch("endpoints.url.UserState")
def test_archive_url_quotas(m1, client_with_auth):
@@ -89,15 +100,25 @@ def test_archive_url_quotas(m1, client_with_auth):
assert response.json()["detail"] == "User has reached their monthly MB quota."
m_user_state.has_quota_max_monthly_mbs.assert_called_once()
@patch("worker.main.create_archive_task.delay", return_value=TaskResult(id="123-456-789", status="PENDING", result=""))
def test_archive_url_with_api_token(m1, client_with_token):
@patch("endpoints.url.celery", return_value=MagicMock())
def test_archive_url_with_api_token(m_celery, client_with_token):
m_signature = MagicMock()
m_signature.delay.return_value = TaskResult(id="123-456-789", status="PENDING", result="")
m_celery.signature.return_value = m_signature
response = client_with_token.post("/url/archive", json={"url": "https://example.com"})
assert response.status_code == 201
assert response.json() == {'id': '123-456-789'}
m_celery.signature.assert_called_once()
m_signature.delay.assert_called_once()
called_val = m_celery.signature.call_args
assert called_val[0][0] == "create_archive_task"
def test_search_by_url_unauthenticated(client, test_no_auth):
test_no_auth(client.get, "/url/search")
def test_search_by_url(client_with_auth, client_with_token, db_session):
# tests the search endpoint, including through some db data for the endpoint params
response = client_with_auth.get("/url/search")
@@ -111,7 +132,7 @@ def test_search_by_url(client_with_auth, client_with_token, db_session):
from db import crud, schemas
for i in range(11):
crud.create_task(db_session, ArchiveCreate(id=f"url-456-{i}", url="https://example.com" if i < 10 else "https://something-else.com", result={}, public=True, author_id="rick@example.com", group_id=None), [], [])
#NB: this insertion is too fast for the ordering to be correct as they are within the same second
# NB: this insertion is too fast for the ordering to be correct as they are within the same second
response = client_with_auth.get("/url/search?url=https://example.com")
assert response.status_code == 200
@@ -142,6 +163,7 @@ def test_search_by_url(client_with_auth, client_with_token, db_session):
assert response.status_code == 200
assert len(response.json()) == 10
@patch("endpoints.url.UserState")
def test_search_no_read_access(mock_user_state, client_with_auth):
mock_user_state.return_value.read = False

View File

@@ -21,7 +21,7 @@ class Test_create_archive_task():
@patch("worker.main.insert_result_into_db")
@patch("worker.main.is_group_invalid_for_user", return_value=None)
@patch("worker.main.choose_orchestrator")
# @patch("worker.main.choose_orchestrator")
@patch("celery.app.task.Task.request")
def test_success(self, m_req, m_choose, m_is_group, m_insert, worker_init, db_session):
from worker.main import create_archive_task
@@ -46,7 +46,7 @@ class Test_create_archive_task():
@patch("worker.main.insert_result_into_db", side_effect=Exception)
@patch("worker.main.is_group_invalid_for_user", return_value=False)
@patch("worker.main.choose_orchestrator")
# @patch("worker.main.choose_orchestrator")
def test_raise_db_error(self, m_choose, m_is_group, m_insert, worker_init):
from worker.main import create_archive_task
mock_orchestrator = self.mock_orchestrator_choice(m_choose)
@@ -123,47 +123,6 @@ class Test_create_sheet_task():
assert db_session.query(models.Archive).filter(models.Archive.url == self.URL).count() == 0
def test_choose_orchestrator(worker_init):
from worker.main import choose_orchestrator
assert choose_orchestrator(None, "rick@example.com").__class__.__name__ == "ArchivingOrchestrator"
@patch("worker.main.get_user_first_group", return_value="does-not-exist")
def test_choose_orchestrator_assertion(worker_init):
from worker.main import choose_orchestrator
with pytest.raises(Exception):
choose_orchestrator(None, "rick@example.com")
@patch("worker.main.read_user_groups")
def test_get_user_first_group(m_read_user_groups, worker_init):
from worker.main import get_user_first_group
m_read_user_groups.return_value = {"users": {}}
assert get_user_first_group("email1") == "default"
m_read_user_groups.return_value = {"users": {"email1": []}}
assert get_user_first_group("email1") == "default"
m_read_user_groups.return_value = {"users": {"email1": ["group1", "group2"]}}
assert get_user_first_group("email1") == "group1"
def test_is_group_invalid_for_user(worker_init, db_session):
from worker.main import is_group_invalid_for_user
from db.crud import upsert_user_groups
upsert_user_groups(db_session)
assert is_group_invalid_for_user(True, "", "") == False
assert is_group_invalid_for_user(False, "", "") == False
assert is_group_invalid_for_user(False, "default", "") == "User is not part of default, no permission"
assert is_group_invalid_for_user(False, "spaceship", "jerry@example.com") == "User jerry@example.com is not part of spaceship, no permission"
assert is_group_invalid_for_user(False, "spaceship", "rick@example.com") == False
def test_get_all_urls(worker_init, db_session):
from worker.main import get_all_urls
from auto_archiver import Metadata

View File

@@ -8,18 +8,19 @@ import redis
from db import crud
from db.database import get_db
from core.logging import log_error
from shared.task_messaging import get_redis
# Custom metrics
EXCEPTION_COUNTER = Counter(
"exceptions",
"Number of times a certain exception has occurred.",
labelnames=["types"]
labelnames=["type"]
)
WORKER_EXCEPTION = Counter(
"worker_exceptions_total",
"Number of times a certain exception has occurred on the worker.",
labelnames=["types", "exception", "task", "traceback"]
labelnames=["type", "exception", "task", "traceback"]
)
DISK_UTILIZATION = Gauge(
"disk_utilization",
@@ -38,16 +39,16 @@ DATABASE_METRICS_COUNTER = Counter(
)
async def redis_subscribe_worker_exceptions(REDIS_EXCEPTIONS_CHANNEL, CELERY_BROKER_URL):
async def redis_subscribe_worker_exceptions(REDIS_EXCEPTIONS_CHANNEL):
# Subscribe to Redis channel and increment the counter for each exception with info on the exception and task
Rdis = redis.Redis.from_url(CELERY_BROKER_URL)
PubSubExceptions = Rdis.pubsub()
Redis = get_redis()
PubSubExceptions = Redis.pubsub()
PubSubExceptions.subscribe(REDIS_EXCEPTIONS_CHANNEL)
while True:
message = PubSubExceptions.get_message()
if message and message["type"] == "message":
data = json.loads(message["data"].decode("utf-8"))
WORKER_EXCEPTION.labels(types=type(data["exception"]).__name__, exception=data["exception"], task=data["task"], traceback=data["traceback"]).inc()
WORKER_EXCEPTION.labels(type=data["type"], exception=data["exception"], task=data["task"], traceback=data["traceback"]).inc()
await asyncio.sleep(1)

View File

@@ -12,7 +12,8 @@ from sqlalchemy.orm import Session
from loguru import logger
from core.logging import logging_middleware, log_error
from worker.main import create_archive_task, create_sheet_task, celery, insert_result_into_db
from shared.task_messaging import get_celery
from worker.main import insert_result_into_db
from db import crud, models, schemas
from web.security import get_user_auth, token_api_key_auth, get_token_or_user_auth
@@ -23,8 +24,13 @@ from shared.settings import get_settings
from auto_archiver import Metadata
from endpoints import default_router, url_router, sheet_router, task_router, interoperability_router
from endpoints.default import default_router
from endpoints.url import url_router
from endpoints.sheet import sheet_router
from endpoints.task import task_router
from endpoints.interoperability import interoperability_router
celery = get_celery()
def app_factory(settings = get_settings()):
app = FastAPI(
@@ -84,7 +90,8 @@ def app_factory(settings = get_settings()):
if type(url) != str or len(url) <= 5:
raise HTTPException(status_code=422, detail=f"Invalid URL received: {url}")
logger.info("creating task")
task = create_archive_task.delay(archive.model_dump_json())
task = celery.signature("create_archive_task", args=[archive.model_dump_json()]).delay()
return JSONResponse({"id": task.id})
@@ -139,7 +146,7 @@ def app_factory(settings = get_settings()):
raise HTTPException(status_code=422, detail=f"sheet name or id is required")
if not crud.is_user_in_group(db, email, sheet.group_id):
raise HTTPException(status_code=403, detail="User does not have access to this group.")
task = create_sheet_task.delay(sheet.model_dump_json())
task = celery.signature("create_sheet_task", args=[sheet.model_dump_json()]).delay()
return JSONResponse({"id": task.id})
@@ -149,7 +156,8 @@ def app_factory(settings = get_settings()):
sheet.author_id = sheet.author_id or "api-endpoint"
if not sheet.sheet_name and not sheet.sheet_id:
raise HTTPException(status_code=422, detail=f"sheet name or id is required")
task = create_sheet_task.delay(sheet.model_dump_json())
task = celery.signature("create_sheet_task", args=[sheet.model_dump_json()]).delay()
return JSONResponse({"id": task.id})
# ----- endpoint to submit data archived elsewhere

View File

@@ -1,78 +1,56 @@
from functools import lru_cache
import traceback, yaml, datetime
from typing import List, Set
from celery import Celery
from celery.signals import task_failure, worker_init
from celery.signals import task_failure
from auto_archiver import Config, ArchivingOrchestrator, Metadata
from auto_archiver.core import Media
from loguru import logger
from db import crud, schemas, models
from db.database import get_db
from shared.task_messaging import get_celery, get_redis
from shared.settings import get_settings
import json
import redis
from sqlalchemy import exc
from core.logging import log_error
settings = get_settings()
celery = get_celery("worker")
Redis = get_redis()
celery = Celery(__name__)
celery.conf.broker_url = settings.CELERY_BROKER_URL
celery.conf.result_backend = settings.CELERY_RESULT_BACKEND
USER_GROUPS_FILENAME = settings.USER_GROUPS_FILENAME
Rdis = redis.Redis.from_url(celery.conf.broker_url)
@celery.task(name="create_archive_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 3})
@celery.task(name="create_archive_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 0})
def create_archive_task(self, archive_json: str):
logger.info(archive_json)
archive = schemas.ArchiveCreate.model_validate_json(archive_json)
logger.info(f"Archiving {archive.url=} {archive.tags=} {archive.public=} {archive.group_id=} {archive.author_id=}")
#TODO: move group checks out of here
invalid = is_group_invalid_for_user(archive.public, archive.group_id, archive.author_id)
if invalid:
raise Exception(invalid) # marks task FAILED, saves the Exception as result
url = archive.url
logger.info(f"{url=} {archive=}")
# call auto-archiver
orchestrator = load_orchestrator(archive.group_id)
result = orchestrator.feed_item(Metadata().set_url(archive.url))
# TODO: re-evaluate if this logic is to be used
if not archive.rearchive:
with get_db() as session:
archives = crud.search_archives_by_url(session, url, archive.author_id, absolute_search=True)
if len(archives):
logger.info(f"Skipping {url=} as it was already archived")
return Metadata.choose_most_complete([a.result for a in archives])
# prepare for DB
assert result, f"UNABLE TO archive: {archive.url}"
archive.id = self.request.id
archive.urls = get_all_urls(result)
archive.result = json.loads(result.to_json())
orchestrator = choose_orchestrator(archive.group_id, archive.author_id)
logger.info(f"Using orchestrator {orchestrator=}")
result = orchestrator.feed_item(Metadata().set_url(url))
insert_result_into_db(archive)
return archive.result.to_dict() # TODO: is return used?
try:
insert_result_into_db(result, archive.tags, archive.public, archive.group_id, archive.author_id, self.request.id)
except Exception as e:
# Log it, then raise again to store the error as the task result
log_error(e)
redis_publish_exception(e, self.name, traceback.format_exc())
raise e
return result.to_dict()
#TODO: refactor how user-groups are loaded and orchestrators chosen
@celery.task(name="create_sheet_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 0})
@celery.task(name="create_sheet_task", bind=True)
def create_sheet_task(self, sheet_json: str):
sheet = schemas.SubmitSheet.model_validate_json(sheet_json)
sheet.tags.add("gsheet")
logger.info(f"SHEET START {sheet=}")
config = Config()
# TODO: use choose_orchestrator and overwrite the feeder
# TODO: drop sheet_name and use only sheet_id (new endpoints/models)
config.parse(use_cli=False, yaml_config_filename=get_settings().SHEET_ORCHESTRATION_YAML, overwrite_configs={"configurations": {"gsheet_feeder": {"sheet": sheet.sheet_name, "sheet_id": sheet.sheet_id, "header": sheet.header}}})
orchestrator = ArchivingOrchestrator(config)
orchestrator = load_orchestrator(sheet.group_id, {"configurations": {"gsheet_feeder": {"sheet": sheet.sheet_name, "sheet_id": sheet.sheet_id, "header": sheet.header}}})
stats = {"archived": 0, "failed": 0, "errors": []}
for result in orchestrator.feed():
@@ -80,8 +58,8 @@ def create_sheet_task(self, sheet_json: str):
logger.error("Got empty result from feeder, an internal error must have occurred.")
continue
try:
#TODO: remove public from sheet in new refactor
#TODO: update the sheets table with the current date if any new archive was done
# TODO: remove public from sheet in new refactor
#TODO: use new insert_result_into_db
insert_result_into_db(result, sheet.tags, sheet.public, sheet.group_id, sheet.author_id, models.generate_uuid(), sheet.sheet_id)
stats["archived"] += 1
except exc.IntegrityError as e:
@@ -97,26 +75,20 @@ def create_sheet_task(self, sheet_json: str):
crud.update_sheet_last_url_archived_at(session, sheet.sheet_id)
logger.info(f"SHEET DONE {sheet=}")
# TODO: use data model
return {"success": True, "sheet": sheet.sheet_name, "sheet_id": sheet.sheet_id, "time": datetime.datetime.now().isoformat(), **stats}
@task_failure.connect(sender=create_sheet_task)
@task_failure.connect(sender=create_archive_task)
def task_failure_notifier(sender, **kwargs):
# automatically capture exceptions in the worker tasks
logger.warning(f"⚠️ worker task failed: {sender.name}")
traceback_msg = "\n".join(traceback.format_list(traceback.extract_tb(kwargs['traceback'])))
logger.warning("😅 From task_failure_notifier ==> Task failed successfully!")
log_error(kwargs['exception'], traceback_msg, f"task_failure: {sender.name}")
redis_publish_exception(kwargs['exception'], sender.name, traceback_msg)
def choose_orchestrator(group, email):
global ORCHESTRATORS
if group not in ORCHESTRATORS: group = get_user_first_group(email)
assert group in ORCHESTRATORS, f"{group=} not in configurations"
logger.info(f"CHOOSE Orchestrator for {group=}, {email=}")
return ArchivingOrchestrator(ORCHESTRATORS.get(group))
def read_user_groups():
# read yaml safely
with open(USER_GROUPS_FILENAME) as inf:
@@ -127,52 +99,28 @@ def read_user_groups():
raise e
def get_user_first_group(email):
user_groups_yaml = read_user_groups()
groups = user_groups_yaml.get("users", {}).get(email, [])
if groups != None and len(groups):
return groups[0]
return "default"
def load_orchestrators():
global ORCHESTRATORS
ORCHESTRATORS = {}
"""
reads the orchestrators key in the config file to load different orchestrators for different groups
"""
user_groups_yaml = read_user_groups()
orchestrators_config = user_groups_yaml.get("orchestrators", {})
assert len(orchestrators_config), f"No orchestrators key found in {USER_GROUPS_FILENAME}. please see the example file"
assert "default" in orchestrators_config, "please include a 'default' orchestrator to be used when the user has no group"
logger.debug(f"Found {len(orchestrators_config)} group orchestrators.")
for group, config_filename in orchestrators_config.items():
config = Config()
config.parse(use_cli=False, yaml_config_filename=config_filename)
ORCHESTRATORS[group] = config
return ORCHESTRATORS
def is_group_invalid_for_user(public: bool, group_id: str, author_id: str):
"""
ensures that, if a group is specified, the user belongs to it.
if public is true the requirement is not needed
returns an error message if invalid, or False if all is good.
"""
if public: return False
if not group_id or len(group_id) == 0: return False
# otherwise group must match
def load_orchestrator(group_id: str, overwrite_configs: dict = {}) -> ArchivingOrchestrator:
with get_db() as session:
if not crud.is_user_in_group(session, author_id, group_id):
logger.error(em := f"User {author_id} is not part of {group_id}, no permission")
return em
return False
orchestrator_fn = crud.get_group(session, group_id).orchestrator
assert orchestrator_fn, f"no orchestrator found for {group_id}"
config = Config()
config.parse(use_cli=False, yaml_config_filename=orchestrator_fn, overwrite_configs=overwrite_configs)
return ArchivingOrchestrator(config)
def insert_result_into_db(result: Metadata, tags: Set[str], public: bool, group_id: str, author_id: str, task_id: str, sheet_id:str="") -> str:
def insert_result_into_db(archive: schemas.ArchiveCreate) -> str:
with get_db() as session:
# create and load user, tags, if needed
crud.create_or_get_user(session, archive.author_id)
db_tags = [crud.create_tag(session, tag) for tag in archive.tags]
# insert everything
db_task = crud.create_task(session, task=archive, tags=db_tags, urls=archive.urls)
logger.debug(f"Added {db_task.id=} to database on {db_task.created_at} ({db_task.author_id})")
return db_task.id
def insert_result_into_db(result: Metadata, tags: Set[str], public: bool, group_id: str, author_id: str, task_id: str, sheet_id: str = "") -> str:
logger.info(f"INSERTING {public=} {group_id=} {author_id=} {tags=} into {task_id}")
assert result, f"UNABLE TO archive: {result.get_url() if result else result}"
with get_db() as session:
@@ -186,7 +134,7 @@ def insert_result_into_db(result: Metadata, tags: Set[str], public: bool, group_
logger.debug(f"Added {db_task.id=} to database on {db_task.created_at} ({db_task.author_id})")
return db_task.id
# TODO: this should live within the auto-archiver
def get_all_urls(result: Metadata) -> List[models.ArchiveUrl]:
db_urls = []
for m in result.media:
@@ -202,6 +150,7 @@ def get_all_urls(result: Metadata) -> List[models.ArchiveUrl]:
return db_urls
# TODO: this should live within the auto-archiver??
def convert_if_media(media):
if isinstance(media, Media): return media
elif isinstance(media, dict):
@@ -214,24 +163,7 @@ def convert_if_media(media):
def redis_publish_exception(exception, task_name, traceback: str = ""):
REDIS_EXCEPTIONS_CHANNEL = settings.REDIS_EXCEPTIONS_CHANNEL
try:
Rdis.publish(REDIS_EXCEPTIONS_CHANNEL, json.dumps({"exception": exception, "task": task_name, "traceback": traceback}, default=str))
exception_data = {"task": task_name, "type": exception.__class__.__name__, "exception": exception, "traceback": traceback}
Redis.publish(REDIS_EXCEPTIONS_CHANNEL, json.dumps(exception_data, default=str))
except Exception as e:
log_error(e, f"[CRITICAL] Could not publish to {REDIS_EXCEPTIONS_CHANNEL}")
@worker_init.connect
def at_start(sender, **kwargs):
global ORCHESTRATORS
ORCHESTRATORS = {}
load_orchestrators()
logger.info("Orchestrators loaded successfully.")
@lru_cache
def get_url_orchestrator(group_name):
with get_db() as db:
group = crud.get_group(db, group_name)
assert group, f"Group {group_name} not found"
# config = Config()
# config.parse(use_cli=False, yaml_config_filename=group.orchestrator_sheet)
# return ArchivingOrchestrator(config)
log_error(e, f"[CRITICAL] Could not publish to {REDIS_EXCEPTIONS_CHANNEL}")