mirror of
https://github.com/bellingcat/auto-archiver-api.git
synced 2026-06-12 05:28:34 +03:00
feat: /sheet endpoint and new security protocol
This commit is contained in:
113
src/worker.py
113
src/worker.py
@@ -1,22 +1,21 @@
|
||||
|
||||
import os, re, traceback, yaml, datetime
|
||||
import os, traceback, yaml, datetime
|
||||
from typing import List, Set
|
||||
|
||||
from celery import Celery, states
|
||||
from celery.exceptions import Ignore
|
||||
from celery import Celery
|
||||
from celery.signals import task_failure
|
||||
from auto_archiver import Config, ArchivingOrchestrator, Metadata
|
||||
# from auto_archiver.enrichers import ScreenshotEnricher
|
||||
from loguru import logger
|
||||
|
||||
from db import crud, schemas, models
|
||||
from db.database import engine, SessionLocal
|
||||
from db.database import SessionLocal
|
||||
from contextlib import contextmanager
|
||||
import json
|
||||
|
||||
celery = Celery(__name__)
|
||||
celery.conf.broker_url = os.environ.get("CELERY_BROKER_URL", "redis://localhost:6379")
|
||||
celery.conf.result_backend = os.environ.get("CELERY_RESULT_BACKEND", "redis://localhost:6379")
|
||||
USER_GROUPS_FILENAME=os.environ.get("USER_GROUPS_FILENAME", "user-groups.yaml")
|
||||
USER_GROUPS_FILENAME = os.environ.get("USER_GROUPS_FILENAME", "user-groups.yaml")
|
||||
|
||||
|
||||
@contextmanager
|
||||
@@ -25,51 +24,56 @@ def get_db():
|
||||
try: yield session
|
||||
finally: session.close()
|
||||
|
||||
@celery.task(name="create_archive_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 5})
|
||||
|
||||
@celery.task(name="create_archive_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 5})
|
||||
def create_archive_task(self, archive_json: str):
|
||||
|
||||
archive = schemas.ArchiveCreate.parse_raw(archive_json)
|
||||
if not archive.public and archive.group_id and len(archive.group_id) > 0:
|
||||
# ensure group is valid for user
|
||||
with get_db() as session:
|
||||
db_group = crud.get_group_for_user(session, archive.group_id, archive.author_id)
|
||||
if not db_group:
|
||||
logger.error(em := f"User {archive.author_id} is not part of {archive.group_id}, no permission")
|
||||
return {"error": em}
|
||||
|
||||
if (em := is_group_invalid_for_user(archive.public, archive.group_id, archive.author_id)): return {"error": em}
|
||||
|
||||
url = archive.url
|
||||
logger.info(f"{url=}")
|
||||
logger.info(f"{archive=}")
|
||||
logger.info(f"{url=} {archive=}")
|
||||
orchestrator = choose_orchestrator(archive.group_id, archive.author_id)
|
||||
result = orchestrator.feed_item(Metadata().set_url(url))
|
||||
if not result:
|
||||
logger.error(f"UNABLE TO archive: {url}")
|
||||
return {"error": "unable to archive"}
|
||||
|
||||
result_json = result.to_json()
|
||||
with get_db() as session:
|
||||
# create DB URLs
|
||||
db_urls = [models.ArchiveUrl(url=url, key=m.get("id", f"media_{i}")) for i, m in enumerate(result.media) for url in m.urls]
|
||||
# create DB TAGs if needed
|
||||
db_tags = [crud.create_tag(session, tag) for tag in archive.tags]
|
||||
# insert archive
|
||||
db_task = crud.create_task(session, task=schemas.ArchiveCreate(id=self.request.id, url=url, result=json.loads(result_json), public=archive.public, author_id=archive.author_id, group_id=archive.group_id), tags=db_tags, urls=db_urls)
|
||||
logger.debug(f"Added {db_task.id=} to database on {db_task.created_at}")
|
||||
return result_json
|
||||
try:
|
||||
insert_result_into_db(result, archive.tags, archive.public, archive.group_id, archive.author_id, self.request.id)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error(traceback.format_exc())
|
||||
return {"error": e}
|
||||
return result.to_json()
|
||||
|
||||
|
||||
@celery.task(name="create_sheet_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 0})
|
||||
@celery.task(name="create_sheet_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 0})
|
||||
def create_sheet_task(self, sheet_json: str):
|
||||
logger.info(f"STARTING {sheet_json}")
|
||||
sheet = schemas.SubmitSheet.parse_raw(sheet_json)
|
||||
sheet.tags.add("gsheet")
|
||||
logger.info(f"SHEET START {sheet=}")
|
||||
|
||||
if (em := is_group_invalid_for_user(sheet.public, sheet.group_id, sheet.author_id)): return {"error": em}
|
||||
|
||||
config = Config()
|
||||
#TODO: use choose_orchestrator and overwrite the feeder
|
||||
config.parse(use_cli=False, yaml_config_filename="secrets/orchestration-sheet.yaml", overwrite_configs={"configurations": {"gsheet_feeder": {"sheet": sheet.sheet_name, "sheet_id": sheet.sheet_id, "header": sheet.header}}})
|
||||
orchestrator = ArchivingOrchestrator(config)
|
||||
# TODO: save into local DB
|
||||
orchestrator.feed()
|
||||
|
||||
return {"success": True, "sheet": sheet.sheet_name, "sheet_id": sheet.sheet_id, "time": datetime.datetime.now().isoformat()}
|
||||
stats = {"archived": 0, "failed": 0, "errors": []}
|
||||
for result in orchestrator.feed():
|
||||
try:
|
||||
insert_result_into_db(result, sheet.tags, sheet.public, sheet.group_id, sheet.author_id, models.generate_uuid())
|
||||
stats["archived"]+=1
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error(traceback.format_exc())
|
||||
stats["failed"]+=1
|
||||
stats["errors"].append(e)
|
||||
|
||||
logger.info(f"SHEET DONE {sheet=}")
|
||||
return {"success": True, "sheet": sheet.sheet_name, "sheet_id": sheet.sheet_id, "time": datetime.datetime.now().isoformat(), **stats}
|
||||
|
||||
|
||||
@task_failure.connect(sender=create_sheet_task)
|
||||
@task_failure.connect(sender=create_archive_task)
|
||||
def task_failure_notifier(sender=None, **kwargs):
|
||||
logger.warning("😅 From task_failure_notifier ==> Task failed successfully! ")
|
||||
@@ -77,6 +81,7 @@ def task_failure_notifier(sender=None, **kwargs):
|
||||
logger.error(kwargs['traceback'])
|
||||
logger.error("\n".join(traceback.format_list(traceback.extract_tb(kwargs['traceback']))))
|
||||
|
||||
|
||||
def choose_orchestrator(group, email):
|
||||
global ORCHESTRATORS
|
||||
if group not in ORCHESTRATORS: group = get_user_first_group(email)
|
||||
@@ -84,6 +89,7 @@ def choose_orchestrator(group, email):
|
||||
logger.info(f"CHOOSE Orchestrator for {group=}, {email=}")
|
||||
return ArchivingOrchestrator(ORCHESTRATORS.get(group))
|
||||
|
||||
|
||||
def read_user_groups():
|
||||
# read yaml safely
|
||||
with open(USER_GROUPS_FILENAME) as inf:
|
||||
@@ -93,6 +99,7 @@ def read_user_groups():
|
||||
logger.error(f"could not open user groups filename {USER_GROUPS_FILENAME}: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
def get_user_first_group(email):
|
||||
user_groups_yaml = read_user_groups()
|
||||
groups = user_groups_yaml.get("users", {}).get(email, [])
|
||||
@@ -107,12 +114,12 @@ def load_orchestrators():
|
||||
reads the orchestrators key in the config file to load different orchestrators for different groups
|
||||
"""
|
||||
user_groups_yaml = read_user_groups()
|
||||
|
||||
|
||||
orchestrators_config = user_groups_yaml.get("orchestrators", {})
|
||||
assert len(orchestrators_config), f"No orchestrators key found in {USER_GROUPS_FILENAME}. please see the example file"
|
||||
assert "default" in orchestrators_config, "please include a 'default' orchestrator to be used when the user has no group"
|
||||
logger.debug(f"Found {len(orchestrators_config)} group orchestrators.")
|
||||
|
||||
|
||||
for group, config_filename in orchestrators_config.items():
|
||||
config = Config()
|
||||
config.parse(use_cli=False, yaml_config_filename=config_filename)
|
||||
@@ -120,7 +127,35 @@ def load_orchestrators():
|
||||
return ORCHESTRATORS
|
||||
|
||||
|
||||
## INIT
|
||||
def is_group_invalid_for_user(public: bool, group_id: str, author_id: str):
|
||||
"""
|
||||
ensures that, if a group is specified, the user belongs to it.
|
||||
if public is true the requirement is not needed
|
||||
returns an error message if invalid, or False if all is good.
|
||||
"""
|
||||
if not public and group_id and len(group_id) > 0:
|
||||
# ensure group is valid for user
|
||||
with get_db() as session:
|
||||
db_group = crud.get_group_for_user(session, group_id, author_id)
|
||||
if not db_group:
|
||||
logger.error(em := f"User {author_id} is not part of {group_id}, no permission")
|
||||
return em
|
||||
return False
|
||||
|
||||
|
||||
def insert_result_into_db(result: Metadata, tags: Set[str], public: bool, group_id: str, author_id: str, task_id:str):
|
||||
logger.info(f"INSERTING {public=} {result} into {task_id}")
|
||||
assert result, "UNABLE TO archive: {url}"
|
||||
with get_db() as session:
|
||||
# create DB URLs
|
||||
db_urls = [models.ArchiveUrl(url=url, key=m.get("id", f"media_{i}")) for i, m in enumerate(result.media) for url in m.urls]
|
||||
# create DB TAGs if needed
|
||||
db_tags = [crud.create_tag(session, tag) for tag in tags]
|
||||
# insert archive
|
||||
db_task = crud.create_task(session, task=schemas.ArchiveCreate(id=task_id, url=result.get_url(), result=json.loads(result.to_json()), public=public, author_id=author_id, group_id=group_id), tags=db_tags, urls=db_urls)
|
||||
logger.debug(f"Added {db_task.id=} to database on {db_task.created_at}")
|
||||
|
||||
|
||||
# INIT
|
||||
ORCHESTRATORS = {}
|
||||
load_orchestrators()
|
||||
load_orchestrators()
|
||||
|
||||
Reference in New Issue
Block a user