WIP decoupling worker/web, cleaning worker code

This commit is contained in:
msramalho
2025-02-08 00:40:35 +00:00
parent 46a5c1a260
commit 9a62f3ff59
19 changed files with 194 additions and 247 deletions

View File

@@ -6,15 +6,15 @@ from fastapi import FastAPI
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi_utils.tasks import repeat_every from fastapi_utils.tasks import repeat_every
from loguru import logger from loguru import logger
from sqlalchemy import text
from db import crud, models, schemas from db import crud, models, schemas
from db.database import get_db, get_db_async, make_engine, wal_checkpoint from db.database import get_db, get_db_async, make_engine, wal_checkpoint
from shared.settings import get_settings from shared.settings import get_settings
from shared.task_messaging import get_celery
from utils.metrics import measure_regular_metrics, redis_subscribe_worker_exceptions from utils.metrics import measure_regular_metrics, redis_subscribe_worker_exceptions
from worker.main import create_sheet_task
from fastapi_mail import FastMail, MessageSchema, MessageType from fastapi_mail import FastMail, MessageSchema, MessageType
celery = get_celery()
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
@@ -25,7 +25,7 @@ async def lifespan(app: FastAPI):
models.Base.metadata.create_all(bind=engine) models.Base.metadata.create_all(bind=engine)
alembic.config.main(argv=['--raiseerr', 'upgrade', 'head']) alembic.config.main(argv=['--raiseerr', 'upgrade', 'head'])
logging.getLogger("uvicorn.access").disabled = True # loguru logging.getLogger("uvicorn.access").disabled = True # loguru
asyncio.create_task(redis_subscribe_worker_exceptions(get_settings().REDIS_EXCEPTIONS_CHANNEL, get_settings().CELERY_BROKER_URL)) asyncio.create_task(redis_subscribe_worker_exceptions(get_settings().REDIS_EXCEPTIONS_CHANNEL))
asyncio.create_task(repeat_measure_regular_metrics()) asyncio.create_task(repeat_measure_regular_metrics())
with get_db() as db: with get_db() as db:
crud.upsert_user_groups(db) crud.upsert_user_groups(db)
@@ -72,7 +72,9 @@ async def archive_sheets_cronjob(frequency: str, interval: int, current_time_uni
async with get_db_async() as db: async with get_db_async() as db:
sheets = await crud.get_sheets_by_id_hash(db, frequency, interval, current_time_unit) sheets = await crud.get_sheets_by_id_hash(db, frequency, interval, current_time_unit)
for s in sheets: for s in sheets:
task = create_sheet_task.apply_async(args=[schemas.SubmitSheet(sheet_id=s.id, author_id=s.author_id, group=s.group_id).model_dump_json()])
task = celery.signature("create_sheet_task", args=[schemas.SubmitSheet(sheet_id=s.id, author_id=s.author_id, group=s.group_id).model_dump_json()]).apply_async()
triggered_jobs.append({"sheet_id": s.id, "task_id": task.id}) triggered_jobs.append({"sheet_id": s.id, "task_id": task.id})
logger.info(f"[CRON {frequency.upper()}:{current_time_unit}] Triggered {len(triggered_jobs)} sheet tasks: {triggered_jobs}") logger.info(f"[CRON {frequency.upper()}:{current_time_unit}] Triggered {len(triggered_jobs)} sheet tasks: {triggered_jobs}")

View File

@@ -9,7 +9,6 @@ logger.add("logs/error_logs.log", retention="30 days", level="ERROR")
def log_error(e: Exception, traceback_str: str = None, extra:str = ""): def log_error(e: Exception, traceback_str: str = None, extra:str = ""):
# EXCEPTION_COUNTER.labels(type(e).__name__).inc()
if not traceback_str: traceback_str = traceback.format_exc() if not traceback_str: traceback_str = traceback.format_exc()
if extra: extra = f"{extra}\n" if extra: extra = f"{extra}\n"
logger.error(f"{extra}{e.__class__.__name__}: {e}\n{traceback_str}") logger.error(f"{extra}{e.__class__.__name__}: {e}\n{traceback_str}")
@@ -21,6 +20,6 @@ async def logging_middleware(request: Request, call_next):
return response return response
except Exception as e: except Exception as e:
from utils.metrics import EXCEPTION_COUNTER from utils.metrics import EXCEPTION_COUNTER
EXCEPTION_COUNTER.labels(type(e).__name__).inc() EXCEPTION_COUNTER.labels(type=e.__class__.__name__).inc()
log_error(e) log_error(e)
raise e raise e

View File

@@ -100,7 +100,7 @@ def base_query(db: Session):
# --------------- TAG # --------------- TAG
def create_tag(db: Session, tag: str): def create_tag(db: Session, tag: str) -> models.Tag:
db_tag = db.query(models.Tag).filter(models.Tag.id == tag).first() db_tag = db.query(models.Tag).filter(models.Tag.id == tag).first()
if not db_tag: if not db_tag:
db_tag = models.Tag(id=tag) db_tag = models.Tag(id=tag)

View File

@@ -48,7 +48,7 @@ async def make_async_engine(database_url: str) -> AsyncEngine:
engine = create_async_engine(database_url, connect_args={"check_same_thread": False}) engine = create_async_engine(database_url, connect_args={"check_same_thread": False})
async with engine.begin() as conn: async with engine.begin() as conn:
await conn.run_sync(lambda sync_conn: sync_conn.execute("PRAGMA journal_mode=WAL;")) await conn.run_sync(lambda sync_conn: sync_conn.execute(text("PRAGMA journal_mode=WAL;")))
return engine return engine

View File

@@ -11,35 +11,13 @@ class Tag(BaseModel):
model_config = {"from_attributes": True} model_config = {"from_attributes": True}
__hash__ = object.__hash__ __hash__ = object.__hash__
class ArchiveCreate(BaseModel):
id: str | None = None
url: str
result: dict | None = None
public: bool = True
author_id: str | None = None
group_id: str | None = None
tags: set[Tag] | None = set()
rearchive: bool = True
sheet_id: str | None = None
# urls: list = []
class Archive(ArchiveCreate):
created_at: datetime
updated_at: datetime | None
deleted: bool
model_config = {"from_attributes": True}
class SubmitSheet(BaseModel): class SubmitSheet(BaseModel):
sheet_name: str | None = None sheet_name: str | None = None
sheet_id: str | None = None sheet_id: str | None = None
header: int = 1 header: int = 1
public: bool = False public: bool = False
author_id: str | None = None author_id: str | None = None
group_id: str | None = None group_id: str | None
tags: set[str] | None = set() tags: set[str] | None = set()
columns: dict | None = {} # TODO: implement columns: dict | None = {} # TODO: implement
@@ -103,10 +81,25 @@ class SheetResponse(SheetAdd):
class ArchiveTrigger(BaseModel): class ArchiveTrigger(BaseModel):
author_id: str | None = None
url: Annotated[str, Len(min_length=5)] url: Annotated[str, Len(min_length=5)]
public: bool = True public: bool = False
group_id: Annotated[str, Len(min_length=1)] | None = None group_id: Annotated[str, Len(min_length=1)] = "default"
tags: set[Tag] | None = set() tags: set[Tag] | None = None
class ArchiveCreate(ArchiveTrigger):
id: str | None = None
result: dict | None = None
sheet_id: str | None = None
urls: list | None = None
class Archive(ArchiveCreate):
created_at: datetime
updated_at: datetime | None
deleted: bool
model_config = {"from_attributes": True}
class Usage(BaseModel): class Usage(BaseModel):
monthly_urls: int = 0 monthly_urls: int = 0

View File

@@ -261,6 +261,9 @@ class UserState:
if group_id not in self.permissions: return False if group_id not in self.permissions: return False
quota = self.permissions[group_id].max_monthly_urls quota = self.permissions[group_id].max_monthly_urls
if quota == -1:
return True
current_month = datetime.now().month current_month = datetime.now().month
current_year = datetime.now().year current_year = datetime.now().year
user_urls = self.db.query(models.Archive).filter( user_urls = self.db.query(models.Archive).filter(
@@ -282,6 +285,9 @@ class UserState:
if group_id not in self.permissions: return False if group_id not in self.permissions: return False
quota = self.permissions[group_id].max_monthly_mbs quota = self.permissions[group_id].max_monthly_mbs
if quota == -1:
return True
current_month = datetime.now().month current_month = datetime.now().month
current_year = datetime.now().year current_year = datetime.now().year

View File

@@ -1,5 +0,0 @@
from endpoints.default import default_router
from endpoints.url import url_router
from endpoints.task import task_router
from endpoints.interoperability import interoperability_router
from endpoints.sheet import sheet_router

View File

@@ -6,13 +6,14 @@ from sqlalchemy import exc
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from db.user_state import UserState from db.user_state import UserState
from shared.task_messaging import get_celery
from web.security import token_api_key_auth, get_user_state from web.security import token_api_key_auth, get_user_state
from db import schemas, crud from db import schemas, crud
from db.database import get_db_dependency from db.database import get_db_dependency
from worker.main import create_sheet_task
sheet_router = APIRouter(prefix="/sheet", tags=["Google Spreadsheet operations"]) sheet_router = APIRouter(prefix="/sheet", tags=["Google Spreadsheet operations"])
celery = get_celery()
@sheet_router.post("/create", status_code=201, summary="Store a new Google Sheet for regular archiving.") @sheet_router.post("/create", status_code=201, summary="Store a new Google Sheet for regular archiving.")
def create_sheet( def create_sheet(
@@ -73,7 +74,7 @@ def archive_user_sheet(
if not user.can_manually_trigger(sheet.group_id): if not user.can_manually_trigger(sheet.group_id):
raise HTTPException(status_code=429, detail="User cannot manually trigger sheet archiving in this group.") raise HTTPException(status_code=429, detail="User cannot manually trigger sheet archiving in this group.")
task = create_sheet_task.delay(schemas.SubmitSheet(sheet_id=id, author_id=user.email, group=sheet.group_id).model_dump_json()) task = celery.signature("create_sheet_task", args=[schemas.SubmitSheet(sheet_id=id, author_id=user.email, group=sheet.group_id).model_dump_json()]).delay()
return JSONResponse({"id": task.id}, status_code=201) return JSONResponse({"id": task.id}, status_code=201)
@@ -86,5 +87,5 @@ def archive_sheet(
sheet.author_id = sheet.author_id or "api-endpoint" sheet.author_id = sheet.author_id or "api-endpoint"
if not sheet.sheet_id: if not sheet.sheet_id:
raise HTTPException(status_code=422, detail=f"sheet id is required") raise HTTPException(status_code=422, detail=f"sheet id is required")
task = create_sheet_task.delay(sheet.model_dump_json()) task = celery.signature("create_sheet_task", args=[sheet.model_dump_json()]).delay()
return JSONResponse({"id": task.id}, status_code=201) return JSONResponse({"id": task.id}, status_code=201)

View File

@@ -4,16 +4,17 @@ from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from loguru import logger from loguru import logger
from shared.task_messaging import get_celery
from web.security import get_token_or_user_auth from web.security import get_token_or_user_auth
from db import schemas from db import schemas
from core.logging import log_error from core.logging import log_error
from worker.main import celery
from utils.misc import custom_jsonable_encoder from utils.misc import custom_jsonable_encoder
task_router = APIRouter(prefix="/task", tags=["Async task operations"]) task_router = APIRouter(prefix="/task", tags=["Async task operations"])
celery = get_celery()
@task_router.get("/{task_id}", summary="Check the status of an async task by its id, works for URLs and Sheet tasks.") @task_router.get("/{task_id}", summary="Check the status of an async task by its id, works for URLs and Sheet tasks.")
def get_status(task_id, email=Depends(get_token_or_user_auth)) -> schemas.TaskResult: def get_status(task_id, email=Depends(get_token_or_user_auth)) -> schemas.TaskResult:

View File

@@ -6,17 +6,18 @@ from datetime import datetime
from loguru import logger from loguru import logger
from core.config import ALLOW_ANY_EMAIL from core.config import ALLOW_ANY_EMAIL
from db.user_state import UserState from db.user_state import UserState
from shared.task_messaging import get_celery
from web.security import get_token_or_user_auth, get_user_state from web.security import get_token_or_user_auth, get_user_state
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from db import crud, schemas from db import crud, schemas
from db.database import get_db_dependency from db.database import get_db_dependency
from worker.main import create_archive_task
from urllib.parse import urlparse from urllib.parse import urlparse
url_router = APIRouter(prefix="/url", tags=["Single URL operations"]) url_router = APIRouter(prefix="/url", tags=["Single URL operations"])
celery = get_celery()
@url_router.post("/archive", status_code=201, summary="Submit a single URL archive request, starts an archiving task.", response_description="task_id for the archiving task, will match the archive id.") @url_router.post("/archive", status_code=201, summary="Submit a single URL archive request, starts an archiving task.", response_description="task_id for the archiving task, will match the archive id.")
def archive_url( def archive_url(
@@ -24,6 +25,7 @@ def archive_url(
email=Depends(get_token_or_user_auth), email=Depends(get_token_or_user_auth),
db: Session = Depends(get_db_dependency) db: Session = Depends(get_db_dependency)
) -> schemas.Task: ) -> schemas.Task:
archive.author_id = email
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {archive.url}") logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {archive.url}")
parsed_url = urlparse(archive.url) parsed_url = urlparse(archive.url)
@@ -39,15 +41,9 @@ def archive_url(
if not user.has_quota_max_monthly_mbs(archive.group_id): if not user.has_quota_max_monthly_mbs(archive.group_id):
raise HTTPException(status_code=429, detail="User has reached their monthly MB quota.") raise HTTPException(status_code=429, detail="User has reached their monthly MB quota.")
# TODO: deprecate ArchiveCreate archive_create = schemas.ArchiveCreate(**archive.model_dump())
backwards_compatible_archive = schemas.ArchiveCreate(
url=archive.url,
author_id=email,
group_id=archive.group_id,
public=archive.public,
)
task = create_archive_task.delay(backwards_compatible_archive.model_dump_json()) task = celery.signature("create_archive_task", args=[archive_create.model_dump_json()]).delay()
task_response = schemas.Task(id=task.id) task_response = schemas.Task(id=task.id)
return JSONResponse(task_response.model_dump(), status_code=201) return JSONResponse(task_response.model_dump(), status_code=201)

View File

@@ -16,6 +16,7 @@ class Settings(BaseSettings):
SHEET_ORCHESTRATION_YAML : str = "secrets/orchestration-sheet.yaml" SHEET_ORCHESTRATION_YAML : str = "secrets/orchestration-sheet.yaml"
# cronjobs # cronjobs
#TODO: disable by default?
CRON_ARCHIVE_SHEETS: bool = False CRON_ARCHIVE_SHEETS: bool = False
CRON_DELETE_STALE_SHEETS: bool = True CRON_DELETE_STALE_SHEETS: bool = True
DELETE_STALE_SHEETS_DAYS: int = 14 DELETE_STALE_SHEETS_DAYS: int = 14

View File

@@ -0,0 +1,18 @@
from functools import lru_cache
from celery import Celery
import redis
from shared.settings import get_settings
@lru_cache
def get_celery(name:str="") -> Celery:
return Celery(
name,
broker_url=get_settings().CELERY_BROKER_URL,
result_backend=get_settings().CELERY_RESULT_BACKEND,
)
def get_redis() -> redis.Redis:
return redis.Redis.from_url(get_settings().CELERY_BROKER_URL)

View File

@@ -92,7 +92,7 @@ def client_with_auth(app_with_auth):
@pytest.fixture() @pytest.fixture()
def app_with_token(app): def app_with_token(app):
from web.security import token_api_key_auth,get_token_or_user_auth from web.security import token_api_key_auth, get_token_or_user_auth
app.dependency_overrides[token_api_key_auth] = lambda: ALLOW_ANY_EMAIL app.dependency_overrides[token_api_key_auth] = lambda: ALLOW_ANY_EMAIL
app.dependency_overrides[get_token_or_user_auth] = lambda: ALLOW_ANY_EMAIL app.dependency_overrides[get_token_or_user_auth] = lambda: ALLOW_ANY_EMAIL
return app return app

View File

@@ -1,6 +1,6 @@
from datetime import datetime from datetime import datetime
import json import json
from unittest.mock import patch from unittest.mock import MagicMock, patch
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@@ -145,15 +145,21 @@ def test_delete_sheet_endpoint(client_with_auth, db_session):
class TestArchiveUserSheetEndpoint: class TestArchiveUserSheetEndpoint:
@patch("worker.main.create_sheet_task.delay", return_value=TaskResult(id="123-taskid", status="PENDING", result="")) @patch("endpoints.sheet.celery", return_value=MagicMock())
def test_normal_flow(self, m1, client_with_auth, db_session): def test_normal_flow(self, m_celery, client_with_auth, db_session):
from db import models from db import models
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="spaceship", frequency="hourly")) db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="spaceship", frequency="hourly"))
db_session.commit() db_session.commit()
m_signature = MagicMock()
m_signature.delay.return_value = TaskResult(id="123-taskid", status="PENDING", result="")
m_celery.signature.return_value = m_signature
r = client_with_auth.post("/sheet/123-sheet-id/archive") r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == 201 assert r.status_code == 201
assert r.json() == {"id": "123-taskid"} assert r.json() == {"id": "123-taskid"}
m1.assert_called_once() m_celery.signature.assert_called_once()
m_signature.delay.assert_called_once()
def test_token_auth(self, client_with_token, test_no_auth): def test_token_auth(self, client_with_token, test_no_auth):
test_no_auth(client_with_token.post, "/sheet/123-sheet-id/archive") test_no_auth(client_with_token.post, "/sheet/123-sheet-id/archive")
@@ -198,23 +204,30 @@ class TestTokenArchiveEndpoint:
assert r.status_code == 422 assert r.status_code == 422
assert r.json() == {"detail": "sheet id is required"} assert r.json() == {"detail": "sheet id is required"}
@patch("worker.main.create_sheet_task.delay", return_value=TaskResult(id="123-456-789", status="PENDING", result="")) @patch("endpoints.sheet.celery", return_value=MagicMock())
def test_normal_flow(self, m1, client_with_token): def test_normal_flow(self, m_celery, client_with_token):
m_signature = MagicMock()
m_signature.delay.return_value = TaskResult(id="123-456-789", status="PENDING", result="")
m_celery.signature.return_value = m_signature
# minimum data # minimum data
response = client_with_token.post("/sheet/archive", json={"sheet_id": "123-sheet-id"}) response = client_with_token.post("/sheet/archive", json={"sheet_id": "123-sheet-id"})
assert response.status_code == 201 assert response.status_code == 201
assert response.json() == {'id': '123-456-789'} assert response.json() == {'id': '123-456-789'}
m1.assert_called_once() m_celery.signature.assert_called_once()
called_val = m1.call_args.args[0] m_signature.delay.assert_called_once()
assert json.loads(called_val) == {"sheet_id": "123-sheet-id", "sheet_name": None, "public": False, "author_id": "api-endpoint", "group_id": None, "tags": [], "columns": {}, "header": 1} called_val = m_celery.signature.call_args
assert called_val[0][0] == "create_sheet_task"
assert json.loads(called_val[1]['args'][0]) == {"sheet_id": "123-sheet-id", "sheet_name": None, "public": False, "author_id": "api-endpoint", "group_id": None, "tags": [], "columns": {}, "header": 1}
# maximum data # maximum data
response = client_with_token.post("/sheet/archive", json={"sheet_id": "123-sheet-id", "sheet_name": "768-sheet-name", "author_id": "birdman@example.com", "header": 2, "public": True, "group_id": "456-group-id", "tags": ["tag1"], "columns": {"col1": "type1"}}) response = client_with_token.post("/sheet/archive", json={"sheet_id": "123-sheet-id", "sheet_name": "768-sheet-name", "author_id": "birdman@example.com", "header": 2, "public": True, "group_id": "456-group-id", "tags": ["tag1"], "columns": {"col1": "type1"}})
assert response.status_code == 201 assert response.status_code == 201
assert response.json() == {'id': '123-456-789'} assert response.json() == {'id': '123-456-789'}
m1.call_count == 2 m_celery.signature.call_count == 2
called_val = m1.call_args.args[0] m_signature.delay.call_count == 2
assert json.loads(called_val) == {"sheet_id": "123-sheet-id", "sheet_name": "768-sheet-name", "public": True, "author_id": "birdman@example.com", "group_id": "456-group-id", "tags": ["tag1"], "columns": {"col1": "type1"}, "header": 2} called_val = m_celery.signature.call_args
assert called_val[0][0] == "create_sheet_task"
assert json.loads(called_val[1]['args'][0]) == {"sheet_id": "123-sheet-id", "sheet_name": "768-sheet-name", "public": True, "author_id": "birdman@example.com", "group_id": "456-group-id", "tags": ["tag1"], "columns": {"col1": "type1"}, "header": 2}

View File

@@ -3,13 +3,18 @@ from unittest.mock import MagicMock, patch
from db.schemas import ArchiveCreate, TaskResult from db.schemas import ArchiveCreate, TaskResult
def test_archive_url_unauthenticated(client, test_no_auth): def test_archive_url_unauthenticated(client, test_no_auth):
test_no_auth(client.post, "/url/archive") test_no_auth(client.post, "/url/archive")
@patch("endpoints.url.UserState") @patch("endpoints.url.UserState")
@patch("worker.main.create_archive_task.delay", return_value=TaskResult(id="123-456-789", status="PENDING", result="")) @patch("endpoints.url.celery", return_value=MagicMock())
def test_archive_url(m1, m2, client_with_auth): def test_archive_url(m_celery, m2, client_with_auth):
m_signature = MagicMock()
m_signature.delay.return_value = TaskResult(id="123-456-789", status="PENDING", result="")
m_celery.signature.return_value = m_signature
m_user_state = MagicMock() m_user_state = MagicMock()
m2.return_value = m_user_state m2.return_value = m_user_state
@@ -17,7 +22,7 @@ def test_archive_url(m1, m2, client_with_auth):
response = client_with_auth.post("/url/archive", json={"url": "bad"}) response = client_with_auth.post("/url/archive", json={"url": "bad"})
assert response.status_code == 422 assert response.status_code == 422
assert response.json()["detail"][0]["msg"] == 'String should have at least 5 characters' assert response.json()["detail"][0]["msg"] == 'String should have at least 5 characters'
m1.assert_not_called() m_celery.signature.assert_not_called()
# url is invalid # url is invalid
response = client_with_auth.post("/url/archive", json={"url": "example.com"}) response = client_with_auth.post("/url/archive", json={"url": "example.com"})
@@ -30,9 +35,11 @@ def test_archive_url(m1, m2, client_with_auth):
response = client_with_auth.post("/url/archive", json={"url": "https://example.com"}) response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
assert response.status_code == 201 assert response.status_code == 201
assert response.json() == {'id': '123-456-789'} assert response.json() == {'id': '123-456-789'}
m1.assert_called_once() m_celery.signature.assert_called_once()
called_val = m1.call_args.args[0] m_signature.delay.assert_called_once()
assert json.loads(called_val) == {"id": None, "url": "https://example.com", "result": None, "public": True, "author_id": "rick@example.com", "group_id": None, "tags": [], "rearchive": True, "sheet_id":None} called_val = m_celery.signature.call_args
assert called_val[0][0] == "create_archive_task"
assert json.loads(called_val[1]['args'][0]) == {"id": None, "url": "https://example.com", "result": None, "public": True, "author_id": "rick@example.com", "group_id": None, "tags": [], "sheet_id": None}
m_user_state.has_quota_max_monthly_urls.assert_called_once() m_user_state.has_quota_max_monthly_urls.assert_called_once()
m_user_state.has_quota_max_monthly_mbs.assert_called_once() m_user_state.has_quota_max_monthly_mbs.assert_called_once()
@@ -48,9 +55,10 @@ def test_archive_url(m1, m2, client_with_auth):
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spaceship"}) response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spaceship"})
assert response.status_code == 201 assert response.status_code == 201
assert response.json() == {'id': '123-456-789'} assert response.json() == {'id': '123-456-789'}
assert m1.call_count == 2 assert m_celery.signature.call_count == 2
called_val = m1.call_args.args[0] assert m_signature.delay.call_count == 2
assert json.loads(called_val)["group_id"] == "spaceship" called_val = m_celery.signature.call_args
assert json.loads(called_val[1]['args'][0])["group_id"] == "spaceship"
m_user_state.in_group.assert_called_with("spaceship") m_user_state.in_group.assert_called_with("spaceship")
# user is over monthly URL quota # user is over monthly URL quota
@@ -68,6 +76,9 @@ def test_archive_url(m1, m2, client_with_auth):
assert response.status_code == 429 assert response.status_code == 429
assert response.json()["detail"] == "User has reached their monthly MB quota." assert response.json()["detail"] == "User has reached their monthly MB quota."
m_user_state.has_quota_max_monthly_mbs.assert_called_with("spacesuit") m_user_state.has_quota_max_monthly_mbs.assert_called_with("spacesuit")
assert m_celery.signature.call_count == 2
assert m_signature.delay.call_count == 2
@patch("endpoints.url.UserState") @patch("endpoints.url.UserState")
def test_archive_url_quotas(m1, client_with_auth): def test_archive_url_quotas(m1, client_with_auth):
@@ -89,15 +100,25 @@ def test_archive_url_quotas(m1, client_with_auth):
assert response.json()["detail"] == "User has reached their monthly MB quota." assert response.json()["detail"] == "User has reached their monthly MB quota."
m_user_state.has_quota_max_monthly_mbs.assert_called_once() m_user_state.has_quota_max_monthly_mbs.assert_called_once()
@patch("worker.main.create_archive_task.delay", return_value=TaskResult(id="123-456-789", status="PENDING", result=""))
def test_archive_url_with_api_token(m1, client_with_token): @patch("endpoints.url.celery", return_value=MagicMock())
def test_archive_url_with_api_token(m_celery, client_with_token):
m_signature = MagicMock()
m_signature.delay.return_value = TaskResult(id="123-456-789", status="PENDING", result="")
m_celery.signature.return_value = m_signature
response = client_with_token.post("/url/archive", json={"url": "https://example.com"}) response = client_with_token.post("/url/archive", json={"url": "https://example.com"})
assert response.status_code == 201 assert response.status_code == 201
assert response.json() == {'id': '123-456-789'} assert response.json() == {'id': '123-456-789'}
m_celery.signature.assert_called_once()
m_signature.delay.assert_called_once()
called_val = m_celery.signature.call_args
assert called_val[0][0] == "create_archive_task"
def test_search_by_url_unauthenticated(client, test_no_auth): def test_search_by_url_unauthenticated(client, test_no_auth):
test_no_auth(client.get, "/url/search") test_no_auth(client.get, "/url/search")
def test_search_by_url(client_with_auth, client_with_token, db_session): def test_search_by_url(client_with_auth, client_with_token, db_session):
# tests the search endpoint, including through some db data for the endpoint params # tests the search endpoint, including through some db data for the endpoint params
response = client_with_auth.get("/url/search") response = client_with_auth.get("/url/search")
@@ -111,7 +132,7 @@ def test_search_by_url(client_with_auth, client_with_token, db_session):
from db import crud, schemas from db import crud, schemas
for i in range(11): for i in range(11):
crud.create_task(db_session, ArchiveCreate(id=f"url-456-{i}", url="https://example.com" if i < 10 else "https://something-else.com", result={}, public=True, author_id="rick@example.com", group_id=None), [], []) crud.create_task(db_session, ArchiveCreate(id=f"url-456-{i}", url="https://example.com" if i < 10 else "https://something-else.com", result={}, public=True, author_id="rick@example.com", group_id=None), [], [])
#NB: this insertion is too fast for the ordering to be correct as they are within the same second # NB: this insertion is too fast for the ordering to be correct as they are within the same second
response = client_with_auth.get("/url/search?url=https://example.com") response = client_with_auth.get("/url/search?url=https://example.com")
assert response.status_code == 200 assert response.status_code == 200
@@ -142,6 +163,7 @@ def test_search_by_url(client_with_auth, client_with_token, db_session):
assert response.status_code == 200 assert response.status_code == 200
assert len(response.json()) == 10 assert len(response.json()) == 10
@patch("endpoints.url.UserState") @patch("endpoints.url.UserState")
def test_search_no_read_access(mock_user_state, client_with_auth): def test_search_no_read_access(mock_user_state, client_with_auth):
mock_user_state.return_value.read = False mock_user_state.return_value.read = False

View File

@@ -21,7 +21,7 @@ class Test_create_archive_task():
@patch("worker.main.insert_result_into_db") @patch("worker.main.insert_result_into_db")
@patch("worker.main.is_group_invalid_for_user", return_value=None) @patch("worker.main.is_group_invalid_for_user", return_value=None)
@patch("worker.main.choose_orchestrator") # @patch("worker.main.choose_orchestrator")
@patch("celery.app.task.Task.request") @patch("celery.app.task.Task.request")
def test_success(self, m_req, m_choose, m_is_group, m_insert, worker_init, db_session): def test_success(self, m_req, m_choose, m_is_group, m_insert, worker_init, db_session):
from worker.main import create_archive_task from worker.main import create_archive_task
@@ -46,7 +46,7 @@ class Test_create_archive_task():
@patch("worker.main.insert_result_into_db", side_effect=Exception) @patch("worker.main.insert_result_into_db", side_effect=Exception)
@patch("worker.main.is_group_invalid_for_user", return_value=False) @patch("worker.main.is_group_invalid_for_user", return_value=False)
@patch("worker.main.choose_orchestrator") # @patch("worker.main.choose_orchestrator")
def test_raise_db_error(self, m_choose, m_is_group, m_insert, worker_init): def test_raise_db_error(self, m_choose, m_is_group, m_insert, worker_init):
from worker.main import create_archive_task from worker.main import create_archive_task
mock_orchestrator = self.mock_orchestrator_choice(m_choose) mock_orchestrator = self.mock_orchestrator_choice(m_choose)
@@ -123,47 +123,6 @@ class Test_create_sheet_task():
assert db_session.query(models.Archive).filter(models.Archive.url == self.URL).count() == 0 assert db_session.query(models.Archive).filter(models.Archive.url == self.URL).count() == 0
def test_choose_orchestrator(worker_init):
from worker.main import choose_orchestrator
assert choose_orchestrator(None, "rick@example.com").__class__.__name__ == "ArchivingOrchestrator"
@patch("worker.main.get_user_first_group", return_value="does-not-exist")
def test_choose_orchestrator_assertion(worker_init):
from worker.main import choose_orchestrator
with pytest.raises(Exception):
choose_orchestrator(None, "rick@example.com")
@patch("worker.main.read_user_groups")
def test_get_user_first_group(m_read_user_groups, worker_init):
from worker.main import get_user_first_group
m_read_user_groups.return_value = {"users": {}}
assert get_user_first_group("email1") == "default"
m_read_user_groups.return_value = {"users": {"email1": []}}
assert get_user_first_group("email1") == "default"
m_read_user_groups.return_value = {"users": {"email1": ["group1", "group2"]}}
assert get_user_first_group("email1") == "group1"
def test_is_group_invalid_for_user(worker_init, db_session):
from worker.main import is_group_invalid_for_user
from db.crud import upsert_user_groups
upsert_user_groups(db_session)
assert is_group_invalid_for_user(True, "", "") == False
assert is_group_invalid_for_user(False, "", "") == False
assert is_group_invalid_for_user(False, "default", "") == "User is not part of default, no permission"
assert is_group_invalid_for_user(False, "spaceship", "jerry@example.com") == "User jerry@example.com is not part of spaceship, no permission"
assert is_group_invalid_for_user(False, "spaceship", "rick@example.com") == False
def test_get_all_urls(worker_init, db_session): def test_get_all_urls(worker_init, db_session):
from worker.main import get_all_urls from worker.main import get_all_urls
from auto_archiver import Metadata from auto_archiver import Metadata

View File

@@ -8,18 +8,19 @@ import redis
from db import crud from db import crud
from db.database import get_db from db.database import get_db
from core.logging import log_error from core.logging import log_error
from shared.task_messaging import get_redis
# Custom metrics # Custom metrics
EXCEPTION_COUNTER = Counter( EXCEPTION_COUNTER = Counter(
"exceptions", "exceptions",
"Number of times a certain exception has occurred.", "Number of times a certain exception has occurred.",
labelnames=["types"] labelnames=["type"]
) )
WORKER_EXCEPTION = Counter( WORKER_EXCEPTION = Counter(
"worker_exceptions_total", "worker_exceptions_total",
"Number of times a certain exception has occurred on the worker.", "Number of times a certain exception has occurred on the worker.",
labelnames=["types", "exception", "task", "traceback"] labelnames=["type", "exception", "task", "traceback"]
) )
DISK_UTILIZATION = Gauge( DISK_UTILIZATION = Gauge(
"disk_utilization", "disk_utilization",
@@ -38,16 +39,16 @@ DATABASE_METRICS_COUNTER = Counter(
) )
async def redis_subscribe_worker_exceptions(REDIS_EXCEPTIONS_CHANNEL, CELERY_BROKER_URL): async def redis_subscribe_worker_exceptions(REDIS_EXCEPTIONS_CHANNEL):
# Subscribe to Redis channel and increment the counter for each exception with info on the exception and task # Subscribe to Redis channel and increment the counter for each exception with info on the exception and task
Rdis = redis.Redis.from_url(CELERY_BROKER_URL) Redis = get_redis()
PubSubExceptions = Rdis.pubsub() PubSubExceptions = Redis.pubsub()
PubSubExceptions.subscribe(REDIS_EXCEPTIONS_CHANNEL) PubSubExceptions.subscribe(REDIS_EXCEPTIONS_CHANNEL)
while True: while True:
message = PubSubExceptions.get_message() message = PubSubExceptions.get_message()
if message and message["type"] == "message": if message and message["type"] == "message":
data = json.loads(message["data"].decode("utf-8")) data = json.loads(message["data"].decode("utf-8"))
WORKER_EXCEPTION.labels(types=type(data["exception"]).__name__, exception=data["exception"], task=data["task"], traceback=data["traceback"]).inc() WORKER_EXCEPTION.labels(type=data["type"], exception=data["exception"], task=data["task"], traceback=data["traceback"]).inc()
await asyncio.sleep(1) await asyncio.sleep(1)

View File

@@ -12,7 +12,8 @@ from sqlalchemy.orm import Session
from loguru import logger from loguru import logger
from core.logging import logging_middleware, log_error from core.logging import logging_middleware, log_error
from worker.main import create_archive_task, create_sheet_task, celery, insert_result_into_db from shared.task_messaging import get_celery
from worker.main import insert_result_into_db
from db import crud, models, schemas from db import crud, models, schemas
from web.security import get_user_auth, token_api_key_auth, get_token_or_user_auth from web.security import get_user_auth, token_api_key_auth, get_token_or_user_auth
@@ -23,8 +24,13 @@ from shared.settings import get_settings
from auto_archiver import Metadata from auto_archiver import Metadata
from endpoints import default_router, url_router, sheet_router, task_router, interoperability_router from endpoints.default import default_router
from endpoints.url import url_router
from endpoints.sheet import sheet_router
from endpoints.task import task_router
from endpoints.interoperability import interoperability_router
celery = get_celery()
def app_factory(settings = get_settings()): def app_factory(settings = get_settings()):
app = FastAPI( app = FastAPI(
@@ -84,7 +90,8 @@ def app_factory(settings = get_settings()):
if type(url) != str or len(url) <= 5: if type(url) != str or len(url) <= 5:
raise HTTPException(status_code=422, detail=f"Invalid URL received: {url}") raise HTTPException(status_code=422, detail=f"Invalid URL received: {url}")
logger.info("creating task") logger.info("creating task")
task = create_archive_task.delay(archive.model_dump_json())
task = celery.signature("create_archive_task", args=[archive.model_dump_json()]).delay()
return JSONResponse({"id": task.id}) return JSONResponse({"id": task.id})
@@ -139,7 +146,7 @@ def app_factory(settings = get_settings()):
raise HTTPException(status_code=422, detail=f"sheet name or id is required") raise HTTPException(status_code=422, detail=f"sheet name or id is required")
if not crud.is_user_in_group(db, email, sheet.group_id): if not crud.is_user_in_group(db, email, sheet.group_id):
raise HTTPException(status_code=403, detail="User does not have access to this group.") raise HTTPException(status_code=403, detail="User does not have access to this group.")
task = create_sheet_task.delay(sheet.model_dump_json()) task = celery.signature("create_sheet_task", args=[sheet.model_dump_json()]).delay()
return JSONResponse({"id": task.id}) return JSONResponse({"id": task.id})
@@ -149,7 +156,8 @@ def app_factory(settings = get_settings()):
sheet.author_id = sheet.author_id or "api-endpoint" sheet.author_id = sheet.author_id or "api-endpoint"
if not sheet.sheet_name and not sheet.sheet_id: if not sheet.sheet_name and not sheet.sheet_id:
raise HTTPException(status_code=422, detail=f"sheet name or id is required") raise HTTPException(status_code=422, detail=f"sheet name or id is required")
task = create_sheet_task.delay(sheet.model_dump_json())
task = celery.signature("create_sheet_task", args=[sheet.model_dump_json()]).delay()
return JSONResponse({"id": task.id}) return JSONResponse({"id": task.id})
# ----- endpoint to submit data archived elsewhere # ----- endpoint to submit data archived elsewhere

View File

@@ -1,78 +1,56 @@
from functools import lru_cache
import traceback, yaml, datetime import traceback, yaml, datetime
from typing import List, Set from typing import List, Set
from celery import Celery from celery.signals import task_failure
from celery.signals import task_failure, worker_init
from auto_archiver import Config, ArchivingOrchestrator, Metadata from auto_archiver import Config, ArchivingOrchestrator, Metadata
from auto_archiver.core import Media from auto_archiver.core import Media
from loguru import logger from loguru import logger
from db import crud, schemas, models from db import crud, schemas, models
from db.database import get_db from db.database import get_db
from shared.task_messaging import get_celery, get_redis
from shared.settings import get_settings from shared.settings import get_settings
import json import json
import redis
from sqlalchemy import exc from sqlalchemy import exc
from core.logging import log_error from core.logging import log_error
settings = get_settings() settings = get_settings()
celery = get_celery("worker")
Redis = get_redis()
celery = Celery(__name__)
celery.conf.broker_url = settings.CELERY_BROKER_URL
celery.conf.result_backend = settings.CELERY_RESULT_BACKEND
USER_GROUPS_FILENAME = settings.USER_GROUPS_FILENAME USER_GROUPS_FILENAME = settings.USER_GROUPS_FILENAME
Rdis = redis.Redis.from_url(celery.conf.broker_url)
@celery.task(name="create_archive_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 0})
@celery.task(name="create_archive_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 3})
def create_archive_task(self, archive_json: str): def create_archive_task(self, archive_json: str):
logger.info(archive_json)
archive = schemas.ArchiveCreate.model_validate_json(archive_json) archive = schemas.ArchiveCreate.model_validate_json(archive_json)
logger.info(f"Archiving {archive.url=} {archive.tags=} {archive.public=} {archive.group_id=} {archive.author_id=}")
#TODO: move group checks out of here
invalid = is_group_invalid_for_user(archive.public, archive.group_id, archive.author_id)
if invalid:
raise Exception(invalid) # marks task FAILED, saves the Exception as result
url = archive.url # call auto-archiver
logger.info(f"{url=} {archive=}") orchestrator = load_orchestrator(archive.group_id)
result = orchestrator.feed_item(Metadata().set_url(archive.url))
# TODO: re-evaluate if this logic is to be used # prepare for DB
if not archive.rearchive: assert result, f"UNABLE TO archive: {archive.url}"
with get_db() as session: archive.id = self.request.id
archives = crud.search_archives_by_url(session, url, archive.author_id, absolute_search=True) archive.urls = get_all_urls(result)
if len(archives): archive.result = json.loads(result.to_json())
logger.info(f"Skipping {url=} as it was already archived")
return Metadata.choose_most_complete([a.result for a in archives])
orchestrator = choose_orchestrator(archive.group_id, archive.author_id) insert_result_into_db(archive)
logger.info(f"Using orchestrator {orchestrator=}") return archive.result.to_dict() # TODO: is return used?
result = orchestrator.feed_item(Metadata().set_url(url))
try:
insert_result_into_db(result, archive.tags, archive.public, archive.group_id, archive.author_id, self.request.id)
except Exception as e:
# Log it, then raise again to store the error as the task result
log_error(e)
redis_publish_exception(e, self.name, traceback.format_exc())
raise e
return result.to_dict()
#TODO: refactor how user-groups are loaded and orchestrators chosen @celery.task(name="create_sheet_task", bind=True)
@celery.task(name="create_sheet_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 0})
def create_sheet_task(self, sheet_json: str): def create_sheet_task(self, sheet_json: str):
sheet = schemas.SubmitSheet.model_validate_json(sheet_json) sheet = schemas.SubmitSheet.model_validate_json(sheet_json)
sheet.tags.add("gsheet") sheet.tags.add("gsheet")
logger.info(f"SHEET START {sheet=}") logger.info(f"SHEET START {sheet=}")
config = Config()
# TODO: use choose_orchestrator and overwrite the feeder
# TODO: drop sheet_name and use only sheet_id (new endpoints/models) # TODO: drop sheet_name and use only sheet_id (new endpoints/models)
config.parse(use_cli=False, yaml_config_filename=get_settings().SHEET_ORCHESTRATION_YAML, overwrite_configs={"configurations": {"gsheet_feeder": {"sheet": sheet.sheet_name, "sheet_id": sheet.sheet_id, "header": sheet.header}}}) orchestrator = load_orchestrator(sheet.group_id, {"configurations": {"gsheet_feeder": {"sheet": sheet.sheet_name, "sheet_id": sheet.sheet_id, "header": sheet.header}}})
orchestrator = ArchivingOrchestrator(config)
stats = {"archived": 0, "failed": 0, "errors": []} stats = {"archived": 0, "failed": 0, "errors": []}
for result in orchestrator.feed(): for result in orchestrator.feed():
@@ -80,8 +58,8 @@ def create_sheet_task(self, sheet_json: str):
logger.error("Got empty result from feeder, an internal error must have occurred.") logger.error("Got empty result from feeder, an internal error must have occurred.")
continue continue
try: try:
#TODO: remove public from sheet in new refactor # TODO: remove public from sheet in new refactor
#TODO: update the sheets table with the current date if any new archive was done #TODO: use new insert_result_into_db
insert_result_into_db(result, sheet.tags, sheet.public, sheet.group_id, sheet.author_id, models.generate_uuid(), sheet.sheet_id) insert_result_into_db(result, sheet.tags, sheet.public, sheet.group_id, sheet.author_id, models.generate_uuid(), sheet.sheet_id)
stats["archived"] += 1 stats["archived"] += 1
except exc.IntegrityError as e: except exc.IntegrityError as e:
@@ -97,26 +75,20 @@ def create_sheet_task(self, sheet_json: str):
crud.update_sheet_last_url_archived_at(session, sheet.sheet_id) crud.update_sheet_last_url_archived_at(session, sheet.sheet_id)
logger.info(f"SHEET DONE {sheet=}") logger.info(f"SHEET DONE {sheet=}")
# TODO: use data model
return {"success": True, "sheet": sheet.sheet_name, "sheet_id": sheet.sheet_id, "time": datetime.datetime.now().isoformat(), **stats} return {"success": True, "sheet": sheet.sheet_name, "sheet_id": sheet.sheet_id, "time": datetime.datetime.now().isoformat(), **stats}
@task_failure.connect(sender=create_sheet_task) @task_failure.connect(sender=create_sheet_task)
@task_failure.connect(sender=create_archive_task) @task_failure.connect(sender=create_archive_task)
def task_failure_notifier(sender, **kwargs): def task_failure_notifier(sender, **kwargs):
# automatically capture exceptions in the worker tasks
logger.warning(f"⚠️ worker task failed: {sender.name}")
traceback_msg = "\n".join(traceback.format_list(traceback.extract_tb(kwargs['traceback']))) traceback_msg = "\n".join(traceback.format_list(traceback.extract_tb(kwargs['traceback'])))
logger.warning("😅 From task_failure_notifier ==> Task failed successfully!")
log_error(kwargs['exception'], traceback_msg, f"task_failure: {sender.name}") log_error(kwargs['exception'], traceback_msg, f"task_failure: {sender.name}")
redis_publish_exception(kwargs['exception'], sender.name, traceback_msg) redis_publish_exception(kwargs['exception'], sender.name, traceback_msg)
def choose_orchestrator(group, email):
global ORCHESTRATORS
if group not in ORCHESTRATORS: group = get_user_first_group(email)
assert group in ORCHESTRATORS, f"{group=} not in configurations"
logger.info(f"CHOOSE Orchestrator for {group=}, {email=}")
return ArchivingOrchestrator(ORCHESTRATORS.get(group))
def read_user_groups(): def read_user_groups():
# read yaml safely # read yaml safely
with open(USER_GROUPS_FILENAME) as inf: with open(USER_GROUPS_FILENAME) as inf:
@@ -127,52 +99,28 @@ def read_user_groups():
raise e raise e
def get_user_first_group(email): def load_orchestrator(group_id: str, overwrite_configs: dict = {}) -> ArchivingOrchestrator:
user_groups_yaml = read_user_groups()
groups = user_groups_yaml.get("users", {}).get(email, [])
if groups != None and len(groups):
return groups[0]
return "default"
def load_orchestrators():
global ORCHESTRATORS
ORCHESTRATORS = {}
"""
reads the orchestrators key in the config file to load different orchestrators for different groups
"""
user_groups_yaml = read_user_groups()
orchestrators_config = user_groups_yaml.get("orchestrators", {})
assert len(orchestrators_config), f"No orchestrators key found in {USER_GROUPS_FILENAME}. please see the example file"
assert "default" in orchestrators_config, "please include a 'default' orchestrator to be used when the user has no group"
logger.debug(f"Found {len(orchestrators_config)} group orchestrators.")
for group, config_filename in orchestrators_config.items():
config = Config()
config.parse(use_cli=False, yaml_config_filename=config_filename)
ORCHESTRATORS[group] = config
return ORCHESTRATORS
def is_group_invalid_for_user(public: bool, group_id: str, author_id: str):
"""
ensures that, if a group is specified, the user belongs to it.
if public is true the requirement is not needed
returns an error message if invalid, or False if all is good.
"""
if public: return False
if not group_id or len(group_id) == 0: return False
# otherwise group must match
with get_db() as session: with get_db() as session:
if not crud.is_user_in_group(session, author_id, group_id): orchestrator_fn = crud.get_group(session, group_id).orchestrator
logger.error(em := f"User {author_id} is not part of {group_id}, no permission") assert orchestrator_fn, f"no orchestrator found for {group_id}"
return em
return False config = Config()
config.parse(use_cli=False, yaml_config_filename=orchestrator_fn, overwrite_configs=overwrite_configs)
return ArchivingOrchestrator(config)
def insert_result_into_db(result: Metadata, tags: Set[str], public: bool, group_id: str, author_id: str, task_id: str, sheet_id:str="") -> str: def insert_result_into_db(archive: schemas.ArchiveCreate) -> str:
with get_db() as session:
# create and load user, tags, if needed
crud.create_or_get_user(session, archive.author_id)
db_tags = [crud.create_tag(session, tag) for tag in archive.tags]
# insert everything
db_task = crud.create_task(session, task=archive, tags=db_tags, urls=archive.urls)
logger.debug(f"Added {db_task.id=} to database on {db_task.created_at} ({db_task.author_id})")
return db_task.id
def insert_result_into_db(result: Metadata, tags: Set[str], public: bool, group_id: str, author_id: str, task_id: str, sheet_id: str = "") -> str:
logger.info(f"INSERTING {public=} {group_id=} {author_id=} {tags=} into {task_id}") logger.info(f"INSERTING {public=} {group_id=} {author_id=} {tags=} into {task_id}")
assert result, f"UNABLE TO archive: {result.get_url() if result else result}" assert result, f"UNABLE TO archive: {result.get_url() if result else result}"
with get_db() as session: with get_db() as session:
@@ -186,7 +134,7 @@ def insert_result_into_db(result: Metadata, tags: Set[str], public: bool, group_
logger.debug(f"Added {db_task.id=} to database on {db_task.created_at} ({db_task.author_id})") logger.debug(f"Added {db_task.id=} to database on {db_task.created_at} ({db_task.author_id})")
return db_task.id return db_task.id
# TODO: this should live within the auto-archiver
def get_all_urls(result: Metadata) -> List[models.ArchiveUrl]: def get_all_urls(result: Metadata) -> List[models.ArchiveUrl]:
db_urls = [] db_urls = []
for m in result.media: for m in result.media:
@@ -202,6 +150,7 @@ def get_all_urls(result: Metadata) -> List[models.ArchiveUrl]:
return db_urls return db_urls
# TODO: this should live within the auto-archiver??
def convert_if_media(media): def convert_if_media(media):
if isinstance(media, Media): return media if isinstance(media, Media): return media
elif isinstance(media, dict): elif isinstance(media, dict):
@@ -214,24 +163,7 @@ def convert_if_media(media):
def redis_publish_exception(exception, task_name, traceback: str = ""): def redis_publish_exception(exception, task_name, traceback: str = ""):
REDIS_EXCEPTIONS_CHANNEL = settings.REDIS_EXCEPTIONS_CHANNEL REDIS_EXCEPTIONS_CHANNEL = settings.REDIS_EXCEPTIONS_CHANNEL
try: try:
Rdis.publish(REDIS_EXCEPTIONS_CHANNEL, json.dumps({"exception": exception, "task": task_name, "traceback": traceback}, default=str)) exception_data = {"task": task_name, "type": exception.__class__.__name__, "exception": exception, "traceback": traceback}
Redis.publish(REDIS_EXCEPTIONS_CHANNEL, json.dumps(exception_data, default=str))
except Exception as e: except Exception as e:
log_error(e, f"[CRITICAL] Could not publish to {REDIS_EXCEPTIONS_CHANNEL}") log_error(e, f"[CRITICAL] Could not publish to {REDIS_EXCEPTIONS_CHANNEL}")
@worker_init.connect
def at_start(sender, **kwargs):
global ORCHESTRATORS
ORCHESTRATORS = {}
load_orchestrators()
logger.info("Orchestrators loaded successfully.")
@lru_cache
def get_url_orchestrator(group_name):
with get_db() as db:
group = crud.get_group(db, group_name)
assert group, f"Group {group_name} not found"
# config = Config()
# config.parse(use_cli=False, yaml_config_filename=group.orchestrator_sheet)
# return ArchivingOrchestrator(config)