mirror of
https://github.com/bellingcat/auto-archiver-api.git
synced 2026-06-08 03:28:35 +03:00
Merge pull request #43 from bellingcat/dev
introduces CI tests and does some refactoring on the code and logic
This commit is contained in:
@@ -1,3 +1 @@
|
||||
FLOWER_USERNAME=TODO
|
||||
FLOWER_PASSWORD=TODO
|
||||
REDIS_PASSWORD=TODO
|
||||
45
.github/workflows/ci.yml
vendored
Normal file
45
.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,45 @@
|
||||
name: CI
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
services:
|
||||
redis:
|
||||
image: redis:6-alpine
|
||||
ports:
|
||||
- 6379:6379
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install pipenv
|
||||
run: pip install pipenv
|
||||
working-directory: src
|
||||
|
||||
- name: Install dependencies
|
||||
run: pipenv install --dev
|
||||
working-directory: src
|
||||
|
||||
- name: Run tests with coverage
|
||||
run: PYTHONPATH=. PIPENV_DOTENV_LOCATION=.env.test pipenv run coverage run -m pytest -v --color=yes tests/
|
||||
working-directory: src
|
||||
|
||||
- name: Report coverage
|
||||
run: pipenv run coverage report
|
||||
working-directory: src
|
||||
15
.gitignore
vendored
15
.gitignore
vendored
@@ -7,10 +7,19 @@ secrets
|
||||
__pycache
|
||||
.pytest_cach
|
||||
.env
|
||||
.env.dev
|
||||
.env.prod
|
||||
*.db
|
||||
redis/data/*
|
||||
.ipynb_checkpoints*
|
||||
#temp
|
||||
tests
|
||||
src/user-groups.yaml
|
||||
wit*
|
||||
src/user-groups.dev.yaml
|
||||
wit*
|
||||
src/crawls
|
||||
.coverage
|
||||
.pytest_cache/*
|
||||
htmlcov
|
||||
local_archive
|
||||
local_archive_test
|
||||
*db-wal
|
||||
*db-shm
|
||||
25
README.md
25
README.md
@@ -1,16 +1,18 @@
|
||||
# Auto Archiver API
|
||||
|
||||
An api that uses celery workers to process URL archive requests via [bellingcat/auto-archiver](https://github.com/bellingcat/auto-archiver), it allows authentication via Google OAuth Apps an d enables CORS, everything runs on docker but development can be done without docker (except for redis).
|
||||
An api that uses celery workers to process URL archive requests via [bellingcat/auto-archiver](https://github.com/bellingcat/auto-archiver), it allows authentication via Google OAuth Apps and enables CORS, everything runs on docker but development can be done without docker (except for redis).
|
||||
|
||||
|
||||
## Development
|
||||
http://localhost:8004
|
||||
|
||||
TODO: update .env file instructions, should use .env.prod and .env.dev and only use .env for always overwriting dev/prod settings.
|
||||
|
||||
requires `src/.env`
|
||||
|
||||
cd /src
|
||||
<!-- * `pipenv install --editable ../../auto-archiver` -->
|
||||
* console 1 - `docker compose up redis` optionally add `dashboard` for flower dashboard and `web` if not running uvicorn locally
|
||||
* console 1 - `docker compose up redis` optionally add `web` if not running uvicorn locally
|
||||
* console 2 - `pipenv shell` + `celery worker --app=worker.celery --loglevel=info --logfile=logs/celery_dev.log`
|
||||
* `celery --app=worker.celery worker --loglevel=info --logfile=logs/celery_dev.log` celery 5
|
||||
* or with watchdog for dev auto-reload `watchmedo auto-restart -d ./ -- celery --app=worker.celery worker --loglevel=info --logfile=logs/celery_dev.log`
|
||||
@@ -37,7 +39,7 @@ Auto-archiver orchestrator files configurations. For each archiving task an orch
|
||||
orchestrators:
|
||||
group1: secrets/orchestration-group1.yaml
|
||||
group2: secrets/orchestration-group2.yaml
|
||||
default: secrets/orchestration-default:.yaml
|
||||
default: secrets/orchestration-default:orchestration.yaml
|
||||
```
|
||||
|
||||
## Database migrations
|
||||
@@ -66,4 +68,21 @@ Run `pipenv update auto-archiver` inside `src` to update the auto-archiver versi
|
||||
# CALL /sheet POST endpoint
|
||||
curl -XPOST -H "Authorization: Bearer GOOGLE_OAUTH_TOKEN" -H "Content-type: application/json" -d '{"sheet_id": "SHEET_ID", "header": 1}' 'http://localhost:8004/sheet'
|
||||
|
||||
```
|
||||
|
||||
|
||||
### Testing
|
||||
```bash
|
||||
# can be done from top level but let's do it from the src folder for consistency with CI etc
|
||||
cd src
|
||||
# run tests and generate coverage
|
||||
PYTHONPATH=. PIPENV_DOTENV_LOCATION=.env.test pipenv run coverage run -m pytest -vv --disable-warnings --color=yes tests/ && pipenv run coverage html
|
||||
|
||||
# get coverage report in command line
|
||||
pipenv run coverage report
|
||||
|
||||
# get coverage HTML
|
||||
pipenv run coverage html
|
||||
|
||||
# > open/run server on htmlcov/index.html to navigate through line coverage
|
||||
```
|
||||
@@ -1,17 +1,19 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
web:
|
||||
restart: "no"
|
||||
env_file: src/.env.dev
|
||||
environment:
|
||||
- SERVE_LOCAL_ARCHIVE=/app/local_archive # See orchestration.yaml local_storage.save_to
|
||||
- ALLOWED_ORIGINS=http://localhost:8004,chrome-extension://ojcimmjndnlmmlgnjaeojoebaceokpdp
|
||||
- API_BEARER_TOKEN=dev-api-bearer-token
|
||||
- USER_GROUPS_FILENAME=user-groups.dev.yaml
|
||||
- DATABASE_PATH=sqlite:////app/auto-archiver.db
|
||||
|
||||
worker:
|
||||
restart: "no"
|
||||
env_file: src/.env.dev
|
||||
|
||||
redis:
|
||||
restart: "no"
|
||||
env_file: src/.env.dev
|
||||
ports:
|
||||
- 6379:6379
|
||||
|
||||
@@ -4,13 +4,11 @@ x-broker-url: &broker-url "redis://:${REDIS_PASSWORD}@redis:6379/0"
|
||||
x-base-setup: &base-setup
|
||||
build: ./src
|
||||
restart: always
|
||||
env_file: src/.env
|
||||
env_file: src/.env.prod
|
||||
environment:
|
||||
CELERY_BROKER_URL: *broker-url
|
||||
CELERY_RESULT_BACKEND: *broker-url
|
||||
|
||||
version: '3.8'
|
||||
|
||||
volumes:
|
||||
crawls:
|
||||
|
||||
@@ -20,20 +18,26 @@ services:
|
||||
<<: *base-setup
|
||||
ports:
|
||||
- "127.0.0.1:8004:8000"
|
||||
command: uvicorn main:app --host 0.0.0.0 --reload
|
||||
command: uvicorn web:app --host 0.0.0.0 --reload
|
||||
volumes:
|
||||
- ./src:/app
|
||||
depends_on:
|
||||
- redis
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
|
||||
worker:
|
||||
<<: *base-setup
|
||||
command: celery worker --app=worker.celery --loglevel=info --logfile=logs/celery.log
|
||||
command: celery --app=worker.main.celery worker --loglevel=info --logfile=logs/celery.log
|
||||
volumes:
|
||||
- ./src:/app
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
- crawls:/crawls # BROWSERTRIX_HOME_HOST:BROWSERTRIX_HOME_CONTAINER, do not change /crawls
|
||||
environment:
|
||||
# celery broker-url needs to be duplicated here, do not remove
|
||||
CELERY_BROKER_URL: *broker-url
|
||||
CELERY_RESULT_BACKEND: *broker-url
|
||||
WACZ_ENABLE_DOCKER: 1 # Enable calling docker from this container
|
||||
@@ -42,6 +46,11 @@ services:
|
||||
depends_on:
|
||||
- web
|
||||
- redis
|
||||
healthcheck:
|
||||
test: ["CMD", "pipenv", "run", "celery", "-A", "worker.celery", "status"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
|
||||
redis:
|
||||
image: redis:6-alpine
|
||||
@@ -50,16 +59,8 @@ services:
|
||||
volumes:
|
||||
- "./redis/data:/data"
|
||||
- "./redis/config:/conf"
|
||||
|
||||
# dashboard service will only launch the dashboard if "--profile flower" is passed to docker compose; or if explicitly called "docker compose up dashboard"
|
||||
dashboard:
|
||||
<<: *base-setup
|
||||
profiles:
|
||||
- flower
|
||||
command: ["flower", "--app=worker.celery", "--port=5555", "--broker", *broker-url, "--basic_auth=${FLOWER_USERNAME}:${FLOWER_PASSWORD}"]
|
||||
ports:
|
||||
- 5556:5555
|
||||
depends_on:
|
||||
- web
|
||||
- redis
|
||||
- worker
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "-a", "${REDIS_PASSWORD}", "ping"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
9
src/.env.test
Normal file
9
src/.env.test
Normal file
@@ -0,0 +1,9 @@
|
||||
CHROME_APP_IDS='["test_app_id_1","test_app_id_2"]'
|
||||
ALLOWED_ORIGINS='["chrome-extension://example1","chrome-extension://example2","http://localhost:8081"]'
|
||||
BLOCKED_EMAILS='["blocked@example.com"]'
|
||||
|
||||
|
||||
DATABASE_PATH="sqlite:///auto-archiver.test.db"
|
||||
API_BEARER_TOKEN=this_is_the_test_api_token
|
||||
USER_GROUPS_FILENAME=tests/user-groups.test.yaml
|
||||
SHEET_ORCHESTRATION_YAML=tests/orchestration.test.yaml
|
||||
@@ -7,8 +7,8 @@ WORKDIR /app
|
||||
RUN curl -fsSL https://get.docker.com -o get-docker.sh && \
|
||||
sh get-docker.sh
|
||||
# set environment variables
|
||||
ENV PYTHONUNBUFFERED 1
|
||||
ENV PYTHONDONTWRITEBYTECODE 1
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
|
||||
# install dependencies
|
||||
RUN pip install --upgrade pip && \
|
||||
|
||||
@@ -5,9 +5,8 @@ name = "pypi"
|
||||
|
||||
[packages]
|
||||
aiofiles = "==0.6.0"
|
||||
celery = "==4.4.7"
|
||||
celery = ">=5.0"
|
||||
fastapi = "*"
|
||||
flower = "==0.9.7"
|
||||
jinja2 = "*"
|
||||
redis = "==3.5.3"
|
||||
requests = ">=2.25.1"
|
||||
@@ -20,10 +19,14 @@ alembic = "*"
|
||||
fastapi-utils = "*"
|
||||
prometheus-fastapi-instrumentator = "*"
|
||||
auto-archiver = "*"
|
||||
pydantic-settings = "*"
|
||||
|
||||
[dev-packages]
|
||||
watchdog = "*"
|
||||
pytest = "==6.2.4"
|
||||
pytest = "*"
|
||||
httpx = "*"
|
||||
coverage = "*"
|
||||
pytest-asyncio = "*"
|
||||
|
||||
[requires]
|
||||
python_version = "3.10"
|
||||
|
||||
3136
src/Pipfile.lock
generated
3136
src/Pipfile.lock
generated
File diff suppressed because it is too large
Load Diff
0
src/core/__init__.py
Normal file
0
src/core/__init__.py
Normal file
13
src/core/config.py
Normal file
13
src/core/config.py
Normal file
@@ -0,0 +1,13 @@
|
||||
VERSION = "0.7.0"
|
||||
API_DESCRIPTION = """
|
||||
#### API for the Auto-Archiver project, a tool to archive web pages and Google Sheets.
|
||||
|
||||
**Usage notes:**
|
||||
- The API requires a Bearer token for most operations, which you can obtain by logging in with your Google account.
|
||||
- You can use this API to archive single URLs or entire Google Sheets.
|
||||
- Once you submit a URL or Sheet for archiving, the API will return a task_id that you can use to check the status of the archiving process. It works asynchronously.
|
||||
"""
|
||||
BREAKING_CHANGES = {"minVersion": "0.3.1", "message": "The latest update has breaking changes, please update the extension to the most recent version."}
|
||||
|
||||
# changing this will corrupt the database logic
|
||||
ALLOW_ANY_EMAIL = "*"
|
||||
45
src/core/events.py
Normal file
45
src/core/events.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import alembic.config
|
||||
from fastapi import FastAPI
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi_utils.tasks import repeat_every
|
||||
from loguru import logger
|
||||
|
||||
from db import crud, models
|
||||
from db.database import get_db, make_engine
|
||||
from shared.settings import get_settings
|
||||
from utils.metrics import measure_regular_metrics, redis_subscribe_worker_exceptions
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# see https://fastapi.tiangolo.com/advanced/events/#lifespan
|
||||
|
||||
# STARTUP
|
||||
engine = make_engine(get_settings().DATABASE_PATH)
|
||||
models.Base.metadata.create_all(bind=engine)
|
||||
alembic.config.main(argv=['--raiseerr', 'upgrade', 'head'])
|
||||
# disabling uvicorn logger since we use loguru in logging_middleware
|
||||
logging.getLogger("uvicorn.access").disabled = True
|
||||
asyncio.create_task(redis_subscribe_worker_exceptions(get_settings().REDIS_EXCEPTIONS_CHANNEL, get_settings().CELERY_BROKER_URL))
|
||||
asyncio.create_task(refresh_user_groups())
|
||||
asyncio.create_task(repeat_measure_regular_metrics())
|
||||
|
||||
yield # separates startup from shutdown instructions
|
||||
|
||||
# SHUTDOWN
|
||||
logger.info("shutting down")
|
||||
|
||||
|
||||
# CRON JOBS
|
||||
|
||||
@repeat_every(seconds=60 * 60) # 1 hour
|
||||
async def refresh_user_groups():
|
||||
with get_db() as db:
|
||||
crud.upsert_user_groups(db)
|
||||
|
||||
|
||||
@repeat_every(seconds=get_settings().REPEAT_COUNT_METRICS_SECONDS)
|
||||
async def repeat_measure_regular_metrics():
|
||||
await measure_regular_metrics(get_settings().DATABASE_PATH, get_settings().REPEAT_COUNT_METRICS_SECONDS)
|
||||
27
src/core/logging.py
Normal file
27
src/core/logging.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import traceback
|
||||
from loguru import logger
|
||||
from fastapi import Request
|
||||
|
||||
|
||||
# logging configurations
|
||||
logger.add("logs/api_logs.log", retention="30 days", rotation="3 days")
|
||||
error_logger = logger.add("logs/error_logs.log", retention="30 days")
|
||||
|
||||
|
||||
def log_error(e: Exception, traceback_str: str = None, extra:str = ""):
|
||||
# EXCEPTION_COUNTER.labels(type(e).__name__).inc()
|
||||
if not traceback_str: traceback_str = traceback.format_exc()
|
||||
if extra: extra = f"{extra}\n"
|
||||
logger.error(f"{extra}{e.__class__.__name__}: {e}")
|
||||
error_logger.error(f"{extra}{e.__class__.__name__}: {e}\n{traceback_str}")
|
||||
|
||||
async def logging_middleware(request: Request, call_next):
|
||||
try:
|
||||
response = await call_next(request)
|
||||
logger.info(f"{request.client.host}:{request.client.port} {request.method} {request.url._url} - HTTP {response.status_code}")
|
||||
return response
|
||||
except Exception as e:
|
||||
from utils.metrics import EXCEPTION_COUNTER
|
||||
EXCEPTION_COUNTER.labels(type(e).__name__).inc()
|
||||
log_error(e)
|
||||
raise e
|
||||
@@ -4,27 +4,28 @@ from sqlalchemy import Column, or_, func
|
||||
from loguru import logger
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from security import ALLOW_ANY_EMAIL
|
||||
from core.config import ALLOW_ANY_EMAIL
|
||||
from shared.settings import get_settings
|
||||
from . import models, schemas
|
||||
import yaml, os
|
||||
import yaml
|
||||
|
||||
DOMAIN_GROUPS = {}
|
||||
DOMAIN_GROUPS_LOADED = False
|
||||
MAX_LIMIT = 100
|
||||
DATABASE_QUERY_LIMIT = get_settings().DATABASE_QUERY_LIMIT
|
||||
|
||||
# --------------- TASK = Archive
|
||||
|
||||
|
||||
def get_task(db: Session, task_id: str, email: str):
|
||||
def get_archive(db: Session, id: str, email: str):
|
||||
email = email.lower()
|
||||
query = base_query(db).filter(models.Archive.id == task_id)
|
||||
query = base_query(db).filter(models.Archive.id == id)
|
||||
if email != ALLOW_ANY_EMAIL:
|
||||
groups = get_user_groups(db, email)
|
||||
query = query.filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups)))
|
||||
return query.first()
|
||||
|
||||
|
||||
def search_tasks_by_url(db: Session, url: str, email: str, skip: int = 0, limit: int = 100, archived_after: datetime = None, archived_before: datetime = None, absolute_search: bool = False):
|
||||
def search_archives_by_url(db: Session, url: str, email: str, skip: int = 0, limit: int = 100, archived_after: datetime = None, archived_before: datetime = None, absolute_search: bool = False):
|
||||
# searches for partial URLs, if email is * no ownership filtering happens
|
||||
query = base_query(db)
|
||||
if email != ALLOW_ANY_EMAIL:
|
||||
@@ -36,15 +37,15 @@ def search_tasks_by_url(db: Session, url: str, email: str, skip: int = 0, limit:
|
||||
else:
|
||||
query = query.filter(models.Archive.url.like(f'%{url}%'))
|
||||
if archived_after:
|
||||
query = query.filter(models.Archive.created_at >= archived_after)
|
||||
query = query.filter(models.Archive.created_at > archived_after)
|
||||
if archived_before:
|
||||
query = query.filter(models.Archive.created_at <= archived_before)
|
||||
return query.order_by(models.Archive.created_at.desc()).offset(skip).limit(min(limit, MAX_LIMIT)).all()
|
||||
query = query.filter(models.Archive.created_at < archived_before)
|
||||
return query.order_by(models.Archive.created_at.desc()).offset(skip).limit(min(limit, DATABASE_QUERY_LIMIT)).all()
|
||||
|
||||
|
||||
def search_tasks_by_email(db: Session, email: str, skip: int = 0, limit: int = 100):
|
||||
def search_archives_by_email(db: Session, email: str, skip: int = 0, limit: int = 100):
|
||||
email = email.lower()
|
||||
return base_query(db).filter(models.Archive.author.has(email=email)).offset(skip).limit(min(limit, MAX_LIMIT)).all()
|
||||
return base_query(db).filter(models.Archive.author_id == email).order_by(models.Archive.created_at.desc()).offset(skip).limit(min(limit, DATABASE_QUERY_LIMIT)).all()
|
||||
|
||||
|
||||
def create_task(db: Session, task: schemas.ArchiveCreate, tags: list[models.Tag], urls: list[models.ArchiveUrl]):
|
||||
@@ -65,21 +66,28 @@ def soft_delete_task(db: Session, task_id: str, email: str) -> bool:
|
||||
db.commit()
|
||||
return db_task is not None
|
||||
|
||||
def count_archives(db:Session):
|
||||
|
||||
def count_archives(db: Session):
|
||||
return db.query(func.count(models.Archive.id)).scalar()
|
||||
|
||||
def count_archive_urls(db:Session):
|
||||
|
||||
def count_archive_urls(db: Session):
|
||||
return db.query(func.count(models.ArchiveUrl.url)).scalar()
|
||||
|
||||
def count_by_user_since(db:Session, seconds_delta: int = 15):
|
||||
def count_users(db: Session):
|
||||
return db.query(func.count(models.User.email)).scalar()
|
||||
|
||||
def count_by_user_since(db: Session, seconds_delta: int = 15):
|
||||
time_threshold = datetime.now() - timedelta(seconds=seconds_delta)
|
||||
return db.query(models.Archive.author_id,func.count().label('total'))\
|
||||
return db.query(models.Archive.author_id, func.count().label('total'))\
|
||||
.filter(models.Archive.created_at >= time_threshold)\
|
||||
.group_by(models.Archive.author_id)\
|
||||
.order_by(func.count().desc()).limit(5 * MAX_LIMIT).all()
|
||||
.order_by(func.count().desc())\
|
||||
.limit(500).all()
|
||||
|
||||
|
||||
def base_query(db: Session):
|
||||
# allow only some fields to be returned, for example author should remain hidden
|
||||
# TODO: allow only some fields to be returned, for example author should remain hidden
|
||||
return db.query(models.Archive)\
|
||||
.options(load_only(models.Archive.id, models.Archive.created_at, models.Archive.url, models.Archive.result))\
|
||||
.filter(models.Archive.deleted == False)
|
||||
@@ -97,10 +105,6 @@ def create_tag(db: Session, tag: str):
|
||||
return db_tag
|
||||
|
||||
|
||||
def search_tags(db: Session, tag: str, skip: int = 0, limit: int = 100):
|
||||
return db.query(models.Tag).filter(models.Tag.url.like(f'%{tag}%')).offset(skip).limit(min(limit, MAX_LIMIT)).all()
|
||||
|
||||
|
||||
def is_user_in_group(db: Session, group_name: str, email: str) -> models.Group:
|
||||
if email == ALLOW_ANY_EMAIL: return True
|
||||
return len(group_name) and len(email) and group_name in get_user_groups(db, email)
|
||||
@@ -108,6 +112,7 @@ def is_user_in_group(db: Session, group_name: str, email: str) -> models.Group:
|
||||
|
||||
def get_user_groups(db: Session, email: str):
|
||||
email = email.lower()
|
||||
if "@" not in email: return []
|
||||
global DOMAIN_GROUPS, DOMAIN_GROUPS_LOADED
|
||||
if not DOMAIN_GROUPS_LOADED: upsert_user_groups(db)
|
||||
# given an email retrieves the user groups from the DB and then the email-domain groups from a global variable
|
||||
@@ -121,8 +126,8 @@ def get_user_groups(db: Session, email: str):
|
||||
# --------------- INIT User-Groups
|
||||
|
||||
|
||||
def get_user(db: Session, author_id: str):
|
||||
if type(author_id)==str: author_id = author_id.lower()
|
||||
def create_or_get_user(db: Session, author_id: str):
|
||||
if type(author_id) == str: author_id = author_id.lower()
|
||||
db_user = db.query(models.User).filter(models.User.email == author_id).first()
|
||||
if not db_user:
|
||||
db_user = models.User(email=author_id)
|
||||
@@ -133,11 +138,13 @@ def get_user(db: Session, author_id: str):
|
||||
|
||||
|
||||
@cache
|
||||
def get_group(db: Session, group_name: str) -> models.Group:
|
||||
def create_or_get_group(db: Session, group_name: str) -> models.Group:
|
||||
db_group = db.query(models.Group).filter(models.Group.id == group_name).first()
|
||||
if db_group is None:
|
||||
db_group = models.Group(id=group_name)
|
||||
db.add(db_group)
|
||||
db.commit()
|
||||
db.refresh(db_group)
|
||||
return db_group
|
||||
|
||||
|
||||
@@ -148,15 +155,15 @@ def upsert_user_groups(db: Session):
|
||||
along with new participation of users in groups
|
||||
"""
|
||||
logger.debug("Updating user-groups configuration.")
|
||||
filename = os.environ.get("USER_GROUPS_FILENAME", "user-groups.yaml")
|
||||
filename = get_settings().USER_GROUPS_FILENAME
|
||||
|
||||
# read yaml safely
|
||||
with open(filename) as inf:
|
||||
try:
|
||||
try:
|
||||
with open(filename) as inf:
|
||||
user_groups_yaml = yaml.safe_load(inf)
|
||||
except yaml.YAMLError as e:
|
||||
logger.error(f"could not open user groups filename {filename}: {e}")
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"could not open user groups filename {filename}: {e}")
|
||||
raise e
|
||||
# updating domain->groups access
|
||||
DOMAIN_GROUPS = user_groups_yaml.get("domains", {})
|
||||
|
||||
@@ -175,7 +182,7 @@ def upsert_user_groups(db: Session):
|
||||
db.add(db_user)
|
||||
if not groups: continue # avoid hanging in for x in None:
|
||||
for group in groups:
|
||||
db_group = get_group(db, group)
|
||||
db_group = create_or_get_group(db, group)
|
||||
db_group.users.append(db_user)
|
||||
|
||||
db.commit()
|
||||
|
||||
@@ -1,15 +1,34 @@
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy import Engine, create_engine, event
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
import os
|
||||
from shared.settings import get_settings
|
||||
from contextlib import contextmanager
|
||||
|
||||
SQLALCHEMY_DATABASE_URL = os.environ.get("DATABASE_PATH")#"sqlite:///./auto-archiver.db"
|
||||
# SQLALCHEMY_DATABASE_URL = "postgresql://user:password@postgresserver/db"
|
||||
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
||||
)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
def make_engine(database_url: str):
|
||||
engine = create_engine(database_url, connect_args={"check_same_thread": False})
|
||||
|
||||
Base = declarative_base()
|
||||
@event.listens_for(engine, "connect")
|
||||
def set_sqlite_pragma(conn, _) -> None:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.close()
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
def make_session_local(engine: Engine):
|
||||
session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
return session_local
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db():
|
||||
session = make_session_local(make_engine(get_settings().DATABASE_PATH))()
|
||||
try: yield session
|
||||
finally: session.close()
|
||||
|
||||
|
||||
def get_db_dependency():
|
||||
# to use with Depends and ensure proper session closing
|
||||
with get_db() as db:
|
||||
yield db
|
||||
@@ -1,8 +1,10 @@
|
||||
from sqlalchemy import Column, String, JSON, DateTime, Boolean, Table, ForeignKey
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.orm import relationship, declarative_base
|
||||
import uuid
|
||||
from .database import Base
|
||||
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
def generate_uuid():
|
||||
return str(uuid.uuid4())
|
||||
@@ -59,7 +61,6 @@ class Tag(Base):
|
||||
|
||||
archives = relationship("Archive", back_populates="tags", secondary=association_table_archive_tags)
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
|
||||
@@ -2,6 +2,13 @@ from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class Tag(BaseModel):
|
||||
id: str
|
||||
created_at: datetime
|
||||
|
||||
model_config = { "from_attributes": True }
|
||||
__hash__ = object.__hash__
|
||||
|
||||
class ArchiveCreate(BaseModel):
|
||||
id: str | None = None
|
||||
url: str
|
||||
@@ -9,7 +16,7 @@ class ArchiveCreate(BaseModel):
|
||||
public: bool = True
|
||||
author_id: str | None = None
|
||||
group_id: str | None = None
|
||||
tags: set = set()
|
||||
tags: set[Tag] | None = set()
|
||||
rearchive: bool = True
|
||||
# urls: list = []
|
||||
|
||||
@@ -19,9 +26,7 @@ class Archive(ArchiveCreate):
|
||||
updated_at: datetime | None
|
||||
deleted: bool
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
model_config = { "from_attributes": True }
|
||||
|
||||
class SubmitSheet(BaseModel):
|
||||
sheet_name: str | None = None
|
||||
@@ -30,7 +35,7 @@ class SubmitSheet(BaseModel):
|
||||
public: bool = False
|
||||
author_id: str | None = None
|
||||
group_id: str | None = None
|
||||
tags: set | None = set()
|
||||
tags: set[str] | None = set()
|
||||
columns: dict | None = {} # TODO: implement
|
||||
|
||||
class SubmitManual(BaseModel):
|
||||
@@ -38,4 +43,14 @@ class SubmitManual(BaseModel):
|
||||
public: bool = False
|
||||
author_id: str | None = None
|
||||
group_id: str | None = None
|
||||
tags: set | None = set()
|
||||
tags: set[str] | None = set()
|
||||
|
||||
class Task(BaseModel):
|
||||
id: str
|
||||
|
||||
class TaskResult(Task):
|
||||
status: str
|
||||
result: str
|
||||
|
||||
class TaskDelete(Task):
|
||||
deleted: bool
|
||||
5
src/endpoints/__init__.py
Normal file
5
src/endpoints/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from endpoints.default import default_router
|
||||
from endpoints.url import url_router
|
||||
from endpoints.task import task_router
|
||||
from endpoints.interoperability import interoperability_router
|
||||
from endpoints.sheet import sheet_router
|
||||
40
src/endpoints/default.py
Normal file
40
src/endpoints/default.py
Normal file
@@ -0,0 +1,40 @@
|
||||
|
||||
from fastapi import APIRouter, Depends, Request, HTTPException
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.config import VERSION, BREAKING_CHANGES
|
||||
from core.logging import log_error
|
||||
from db import crud
|
||||
from db.database import get_db_dependency, get_db
|
||||
from web.security import get_user_auth, bearer_security
|
||||
|
||||
default_router = APIRouter()
|
||||
|
||||
|
||||
@default_router.get("/")
|
||||
async def home(request: Request):
|
||||
# TODO: maybe split into 2 routes: one non authenticated and one authenticated for the groups info only
|
||||
status = {"version": VERSION, "breakingChanges": BREAKING_CHANGES}
|
||||
try:
|
||||
email = await get_user_auth(await bearer_security(request))
|
||||
with get_db() as db:
|
||||
status["groups"] = crud.get_user_groups(db, email)
|
||||
except HTTPException: pass # not authenticated is fine
|
||||
except Exception as e: log_error(e)
|
||||
return JSONResponse(status)
|
||||
|
||||
|
||||
@default_router.get("/health")
|
||||
async def health():
|
||||
return JSONResponse({"status": "ok"})
|
||||
|
||||
|
||||
@default_router.get("/groups", response_model=list[str])
|
||||
def get_user_groups(db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
|
||||
return crud.get_user_groups(db, email)
|
||||
|
||||
|
||||
@default_router.get('/favicon.ico', include_in_schema=False)
|
||||
async def favicon():
|
||||
return FileResponse("static/favicon.ico")
|
||||
27
src/endpoints/interoperability.py
Normal file
27
src/endpoints/interoperability.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from auto_archiver import Metadata
|
||||
from loguru import logger
|
||||
import sqlalchemy
|
||||
|
||||
from web.security import token_api_key_auth
|
||||
from db import models, schemas
|
||||
from worker.main import insert_result_into_db
|
||||
from core.logging import log_error
|
||||
|
||||
|
||||
interoperability_router = APIRouter(prefix="/interop", tags=["Interoperability endpoints."])
|
||||
|
||||
|
||||
# ----- endpoint to submit data archived elsewhere
|
||||
@interoperability_router.post("/submit-archive", status_code=201, summary="Submit a manual archive entry, for data that was archived elsewhere.")
|
||||
def submit_manual_archive(manual: schemas.SubmitManual, auth=Depends(token_api_key_auth)):
|
||||
result = Metadata.from_json(manual.result)
|
||||
logger.info(f"MANUAL SUBMIT {result.get_url()} {manual.author_id}")
|
||||
manual.tags.add("manual")
|
||||
try:
|
||||
archive_id = insert_result_into_db(result, manual.tags, manual.public, manual.group_id, manual.author_id, models.generate_uuid())
|
||||
except sqlalchemy.exc.IntegrityError as e:
|
||||
log_error(e)
|
||||
raise HTTPException(status_code=422, detail=f"Cannot insert into DB due to integrity error")
|
||||
return JSONResponse({"id": archive_id}, status_code=201)
|
||||
24
src/endpoints/sheet.py
Normal file
24
src/endpoints/sheet.py
Normal file
@@ -0,0 +1,24 @@
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from core.config import ALLOW_ANY_EMAIL
|
||||
from web.security import get_token_or_user_auth
|
||||
from db import schemas
|
||||
from worker.main import create_sheet_task
|
||||
|
||||
sheet_router = APIRouter(prefix="/sheet", tags=["Google Spreadsheet operations"])
|
||||
|
||||
|
||||
@sheet_router.post("/archive", status_code=201, summary="Submit a Google Sheet archive request, starts a sheet archiving task.", response_model=schemas.Task, response_description="task_id for the archiving task.")
|
||||
def archive_sheet(sheet:schemas.SubmitSheet, email = Depends(get_token_or_user_auth)):
|
||||
logger.info(f"SHEET TASK for {sheet=}")
|
||||
if email == ALLOW_ANY_EMAIL:
|
||||
email = sheet.author_id or "api-endpoint"
|
||||
sheet.author_id = email
|
||||
if not sheet.sheet_name and not sheet.sheet_id:
|
||||
raise HTTPException(status_code=422, detail=f"sheet name or id is required")
|
||||
task = create_sheet_task.delay(sheet.model_dump_json())
|
||||
return JSONResponse({"id": task.id}, status_code=201)
|
||||
41
src/endpoints/task.py
Normal file
41
src/endpoints/task.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from celery.result import AsyncResult
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from loguru import logger
|
||||
from web.security import get_token_or_user_auth
|
||||
|
||||
from db import schemas
|
||||
from core.logging import log_error
|
||||
from worker.main import celery
|
||||
|
||||
|
||||
task_router = APIRouter(prefix="/task", tags=["Async task operations"])
|
||||
|
||||
|
||||
@task_router.get("/{task_id}", response_model=schemas.TaskResult, summary="Check the status of an async task by its id, works for URLs and Sheet tasks.")
|
||||
def get_status(task_id, email=Depends(get_token_or_user_auth)):
|
||||
logger.info(f"status check for user {email} task {task_id}")
|
||||
task = AsyncResult(task_id, app=celery)
|
||||
try:
|
||||
if task.status == "FAILURE":
|
||||
# *FAILURE* The task raised an exception, or has exceeded the retry limit.
|
||||
# The :attr:`result` attribute then contains the exception raised by the task.
|
||||
# https://docs.celeryq.dev/en/stable/_modules/celery/result.html#AsyncResult
|
||||
raise task.result
|
||||
|
||||
response = {
|
||||
"id": task_id,
|
||||
"status": task.status,
|
||||
"result": task.result
|
||||
}
|
||||
return JSONResponse(jsonable_encoder(response, exclude_unset=True))
|
||||
|
||||
except Exception as e:
|
||||
log_error(e)
|
||||
return JSONResponse({
|
||||
"id": task_id,
|
||||
"status": "FAILURE",
|
||||
"result": {"error": str(e)}
|
||||
})
|
||||
59
src/endpoints/url.py
Normal file
59
src/endpoints/url.py
Normal file
@@ -0,0 +1,59 @@
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from datetime import datetime
|
||||
|
||||
from loguru import logger
|
||||
from web.security import get_user_auth, get_token_or_user_auth
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from db import crud, schemas
|
||||
from db.database import get_db_dependency
|
||||
|
||||
from worker.main import create_archive_task
|
||||
|
||||
url_router = APIRouter(prefix="/url", tags=["Single URL operations"])
|
||||
|
||||
|
||||
@url_router.post("/archive", status_code=201, summary="Submit a single URL archive request, starts an archiving task.", response_model=schemas.Task, response_description="task_id for the archiving task, will match the archive id.")
|
||||
def archive_url(archive: schemas.ArchiveCreate, email=Depends(get_token_or_user_auth)):
|
||||
archive.author_id = email
|
||||
url = archive.url
|
||||
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {url}")
|
||||
if type(url) != str or len(url) <= 5:
|
||||
raise HTTPException(status_code=422, detail=f"Invalid URL received: {url}")
|
||||
logger.info("creating task")
|
||||
task = create_archive_task.delay(archive.model_dump_json())
|
||||
task_response = schemas.Task(id=task.id)
|
||||
return JSONResponse(task_response.model_dump(), status_code=201)
|
||||
|
||||
|
||||
@url_router.get("/search", response_model=list[schemas.Archive], summary="Search for archive entries by URL.")
|
||||
def search_by_url(
|
||||
url: str, skip: int = 0, limit: int = 25,
|
||||
archived_after: datetime = None, archived_before: datetime = None,
|
||||
db: Session = Depends(get_db_dependency),
|
||||
email=Depends(get_token_or_user_auth)):
|
||||
return crud.search_archives_by_url(db, url.strip(), email, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before)
|
||||
|
||||
|
||||
@url_router.get("/latest", response_model=list[schemas.Archive], summary="Fetch latest URL archives for the authenticated user.")
|
||||
def latest(skip: int = 0, limit: int = 25, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
|
||||
return crud.search_archives_by_email(db, email, skip=skip, limit=limit)
|
||||
|
||||
|
||||
@url_router.get("/{id}", response_model=schemas.Archive, summary="Fetch a single URL archive by the associated id.")
|
||||
def lookup(id, db: Session = Depends(get_db_dependency), email=Depends(get_token_or_user_auth)):
|
||||
archive = crud.get_archive(db, id, email)
|
||||
if archive is None:
|
||||
raise HTTPException(status_code=404, detail="Archive not found")
|
||||
return archive
|
||||
|
||||
|
||||
@url_router.delete("/{id}", response_model=schemas.TaskDelete, summary="Delete a single URL archive by id.")
|
||||
def delete_task(id, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
|
||||
logger.info(f"deleting url archive task {id} request by {email}")
|
||||
return JSONResponse({
|
||||
"id": id,
|
||||
"deleted": crud.soft_delete_task(db, id, email)
|
||||
})
|
||||
257
src/main.py
257
src/main.py
@@ -1,257 +0,0 @@
|
||||
from celery.result import AsyncResult
|
||||
from fastapi import FastAPI, Depends, Request, HTTPException
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi_utils.tasks import repeat_every
|
||||
import alembic.config
|
||||
from dotenv import load_dotenv
|
||||
import traceback, os, logging
|
||||
from loguru import logger
|
||||
from datetime import datetime
|
||||
import sqlalchemy
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from prometheus_client import Counter, Gauge
|
||||
from contextlib import asynccontextmanager
|
||||
import asyncio, json
|
||||
import shutil
|
||||
|
||||
from worker import REDIS_EXCEPTIONS_CHANNEL, create_archive_task, create_sheet_task, celery, insert_result_into_db, Rdis
|
||||
|
||||
from db import crud, models, schemas
|
||||
from db.database import engine, SessionLocal, SQLALCHEMY_DATABASE_URL
|
||||
from sqlalchemy.orm import Session
|
||||
from security import get_user_auth, token_api_key_auth, bearer_security, get_token_or_user_auth
|
||||
from auto_archiver import Metadata
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Configuration
|
||||
ALLOWED_ORIGINS = os.environ.get("ALLOWED_ORIGINS", "chrome-extension://ondkcheoicfckabcnkdgbepofpjmjcmb,chrome-extension://ojcimmjndnlmmlgnjaeojoebaceokpdp").split(",")
|
||||
VERSION = "0.6.3"
|
||||
# min-version refers to the version of auto-archiver-extension on the webstore
|
||||
BREAKING_CHANGES = {"minVersion": "0.3.1", "message": "The latest update has breaking changes, please update the extension to the most recent version."}
|
||||
|
||||
@repeat_every(seconds=60 * 60) # 1 hour
|
||||
async def refresh_user_groups():
|
||||
db: Session = next(get_db())
|
||||
crud.upsert_user_groups(db)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# see https://fastapi.tiangolo.com/advanced/events/#lifespan
|
||||
# STARTUP
|
||||
|
||||
models.Base.metadata.create_all(bind=engine)
|
||||
alembic.config.main(argv=['--raiseerr', 'upgrade', 'head'])
|
||||
# disabling uvicorn logger since we use loguru in logging_middleware
|
||||
logging.getLogger("uvicorn.access").disabled = True
|
||||
asyncio.create_task(redis_subscribe_worker_exceptions())
|
||||
asyncio.create_task(refresh_user_groups())
|
||||
asyncio.create_task(measure_regular_metrics())
|
||||
|
||||
yield # separates startup from shutdown instructions
|
||||
|
||||
# SHUTDOWN
|
||||
logger.info("shutting down")
|
||||
|
||||
|
||||
app = FastAPI(title="Auto-Archiver API", version=VERSION, contact={"name":"Bellingcat", "url":"https://github.com/bellingcat/auto-archiver-api"}, lifespan=lifespan)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
EXCEPTION_COUNTER = Counter(
|
||||
"exceptions",
|
||||
"Number of times a certain exception has occurred.",
|
||||
labelnames=("types",)
|
||||
)
|
||||
# prometheus exposed in /metrics with authentication
|
||||
Instrumentator(should_group_status_codes=False, excluded_handlers=["/metrics"]).instrument(app).expose(app, dependencies=[Depends(token_api_key_auth)])
|
||||
|
||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||
|
||||
SERVE_LOCAL_ARCHIVE = os.environ.get("SERVE_LOCAL_ARCHIVE", "")
|
||||
if len(SERVE_LOCAL_ARCHIVE) > 1 and os.path.isdir(SERVE_LOCAL_ARCHIVE):
|
||||
logger.info(f"mounting local archive {SERVE_LOCAL_ARCHIVE}")
|
||||
app.mount(SERVE_LOCAL_ARCHIVE, StaticFiles(directory=SERVE_LOCAL_ARCHIVE), name=SERVE_LOCAL_ARCHIVE)
|
||||
|
||||
def get_db():
|
||||
session = SessionLocal()
|
||||
try: yield session
|
||||
finally: session.close()
|
||||
|
||||
|
||||
# logging configurations
|
||||
logger.add("logs/api_logs.log", retention="30 days", rotation="3 days")
|
||||
@app.middleware("http")
|
||||
async def logging_middleware(request: Request, call_next):
|
||||
try:
|
||||
response = await call_next(request)
|
||||
logger.info(f"{request.client.host}:{request.client.port} {request.method} {request.url._url} - HTTP {response.status_code}")
|
||||
return response
|
||||
except Exception as e:
|
||||
EXCEPTION_COUNTER.labels(type(e).__name__).inc()
|
||||
raise e
|
||||
|
||||
@app.get("/")
|
||||
async def home(request: Request):
|
||||
status = {"version": VERSION, "breakingChanges": BREAKING_CHANGES}
|
||||
try:
|
||||
# if authenticated will load available groups
|
||||
email = await get_user_auth(await bearer_security(request))
|
||||
db: Session = next(get_db())
|
||||
status["groups"] = crud.get_user_groups(db, email)
|
||||
except HTTPException: pass
|
||||
except Exception as e: logger.error(e)
|
||||
return JSONResponse(status)
|
||||
|
||||
|
||||
#-----Submit URL and manipulate tasks. Bearer protected below
|
||||
|
||||
@app.get("/groups", response_model=list[str])
|
||||
def get_user_groups(db: Session = Depends(get_db), email = Depends(get_user_auth)):
|
||||
return crud.get_user_groups(db, email)
|
||||
|
||||
@app.get("/tasks/search-url", response_model=list[schemas.Archive])
|
||||
def search_by_url(url:str, skip: int = 0, limit: int = 100, archived_after:datetime=None, archived_before:datetime=None, db: Session = Depends(get_db), email = Depends(get_token_or_user_auth)):
|
||||
return crud.search_tasks_by_url(db, url.strip(), email, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before)
|
||||
|
||||
@app.get("/tasks/sync", response_model=list[schemas.Archive])
|
||||
def search(skip: int = 0, limit: int = 100, db: Session = Depends(get_db), email = Depends(get_user_auth)):
|
||||
return crud.search_tasks_by_email(db, email, skip=skip, limit=limit)
|
||||
|
||||
@app.post("/tasks", status_code=201)
|
||||
def archive_tasks(archive:schemas.ArchiveCreate, email = Depends(get_token_or_user_auth)):
|
||||
archive.author_id = email
|
||||
url = archive.url
|
||||
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {url}")
|
||||
if type(url)!=str or len(url)<=5:
|
||||
raise HTTPException(status_code=422, detail=f"Invalid URL received: {url}")
|
||||
logger.info("creating task")
|
||||
task = create_archive_task.delay(archive.json())
|
||||
return JSONResponse({"id": task.id})
|
||||
|
||||
@app.get("/archive/{task_id}")
|
||||
def lookup(task_id, db: Session = Depends(get_db), email = Depends(get_token_or_user_auth)):
|
||||
return crud.get_task(db, task_id, email)
|
||||
|
||||
@app.get("/tasks/{task_id}")
|
||||
def get_status(task_id, email = Depends(get_token_or_user_auth)):
|
||||
logger.info(f"status check for user {email} task {task_id}")
|
||||
task = AsyncResult(task_id, app=celery)
|
||||
try:
|
||||
if task.status == "FAILURE":
|
||||
# *FAILURE* The task raised an exception, or has exceeded the retry limit.
|
||||
# The :attr:`result` attribute then contains the exception raised by the task.
|
||||
# https://docs.celeryq.dev/en/stable/_modules/celery/result.html#AsyncResult
|
||||
raise task.result
|
||||
|
||||
response = {
|
||||
"id": task_id,
|
||||
"status": task.status,
|
||||
"result": task.result
|
||||
}
|
||||
return JSONResponse(jsonable_encoder(response, exclude_unset=True))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error(traceback.format_exc())
|
||||
return JSONResponse({
|
||||
"id": task_id,
|
||||
"status": "FAILURE",
|
||||
"result": {"error": str(e)}
|
||||
})
|
||||
|
||||
@app.delete("/tasks/{task_id}")
|
||||
def delete_task(task_id, db: Session = Depends(get_db), email = Depends(get_user_auth)):
|
||||
logger.info(f"deleting task {task_id} request by {email}")
|
||||
return JSONResponse({
|
||||
"id": task_id,
|
||||
"deleted": crud.soft_delete_task(db, task_id, email)
|
||||
})
|
||||
|
||||
#----- Google Sheets Logic
|
||||
@app.post("/sheet", status_code=201)
|
||||
def archive_sheet(sheet:schemas.SubmitSheet, email = Depends(get_user_auth)):
|
||||
logger.info(f"SHEET TASK for {sheet=}")
|
||||
sheet.author_id = email
|
||||
if not sheet.sheet_name and not sheet.sheet_id:
|
||||
raise HTTPException(status_code=422, detail=f"sheet name or id is required")
|
||||
task = create_sheet_task.delay(sheet.json())
|
||||
return JSONResponse({"id": task.id})
|
||||
|
||||
@app.post("/sheet_service", status_code=201)
|
||||
def archive_sheet_service(sheet:schemas.SubmitSheet, auth = Depends(token_api_key_auth)):
|
||||
logger.info(f"SHEET TASK for {sheet=}")
|
||||
sheet.author_id = sheet.author_id or "api-endpoint"
|
||||
if not sheet.sheet_name and not sheet.sheet_id:
|
||||
raise HTTPException(status_code=422, detail=f"sheet name or id is required")
|
||||
task = create_sheet_task.delay(sheet.json())
|
||||
return JSONResponse({"id": task.id})
|
||||
|
||||
#----- endpoint to submit data archived elsewhere
|
||||
@app.post("/submit-archive", status_code=201)
|
||||
def submit_manual_archive(manual:schemas.SubmitManual, auth = Depends(token_api_key_auth)):
|
||||
result = Metadata.from_json(manual.result)
|
||||
logger.info(f"MANUAL SUBMIT {result.get_url()} {manual.author_id}")
|
||||
manual.tags.add("manual")
|
||||
try:
|
||||
archive_id = insert_result_into_db(result, manual.tags, manual.public, manual.group_id, manual.author_id, models.generate_uuid())
|
||||
except sqlalchemy.exc.IntegrityError as e:
|
||||
logger.error(e)
|
||||
raise HTTPException(status_code=422, detail=f"Cannot insert into DB due to integrity error")
|
||||
return JSONResponse({"id": archive_id})
|
||||
|
||||
# --------- Prometheus metrics
|
||||
|
||||
WORKER_EXCEPTION = Counter(
|
||||
"worker_exceptions_total",
|
||||
"Number of times a certain exception has occurred on the worker.",
|
||||
labelnames=("exception", "task",)
|
||||
)
|
||||
async def redis_subscribe_worker_exceptions():
|
||||
PubSubExceptions = Rdis.pubsub()
|
||||
PubSubExceptions.subscribe(REDIS_EXCEPTIONS_CHANNEL)
|
||||
while True:
|
||||
message = PubSubExceptions.get_message()
|
||||
if message and message["type"] == "message":
|
||||
data = json.loads(message["data"].decode("utf-8"))
|
||||
WORKER_EXCEPTION.labels(exception=data["exception"], task=data["task"]).inc()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
DISK_UTILIZATION = Gauge(
|
||||
"disk_utilization",
|
||||
"Disk utilization in GB",
|
||||
labelnames=("type",)
|
||||
)
|
||||
DATABASE_METRICS = Gauge(
|
||||
"database_metrics",
|
||||
"Useful database metrics from queries",
|
||||
labelnames=("query", "user")
|
||||
)
|
||||
|
||||
REPEAT_COUNT_METRICS_SECONDS = 15
|
||||
@repeat_every(seconds=REPEAT_COUNT_METRICS_SECONDS)
|
||||
async def measure_regular_metrics():
|
||||
_total, used, free = shutil.disk_usage("/")
|
||||
DISK_UTILIZATION.labels(type="used").set(used / (2**30))
|
||||
DISK_UTILIZATION.labels(type="free").set(free / (2**30))
|
||||
try:
|
||||
fs = os.stat(SQLALCHEMY_DATABASE_URL.replace("sqlite:///", ""))
|
||||
DISK_UTILIZATION.labels(type="database").set(fs.st_size / (2**30))
|
||||
except Exception as e: logger.error(e)
|
||||
|
||||
session: Session = next(get_db())
|
||||
count_archives = crud.count_archives(session)
|
||||
count_archive_urls = crud.count_archive_urls(session)
|
||||
DATABASE_METRICS.labels(query="count_archives", user="-").set(count_archives)
|
||||
DATABASE_METRICS.labels(query="count_archive_urls", user="-").set(count_archive_urls)
|
||||
|
||||
for user in crud.count_by_user_since(session, REPEAT_COUNT_METRICS_SECONDS):
|
||||
DATABASE_METRICS.labels(query="count_by_user", user=user.author_id).set(user.total)
|
||||
@@ -4,14 +4,13 @@ from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
|
||||
from alembic import context
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
from shared.settings import get_settings
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
config.set_main_option('sqlalchemy.url', os.environ.get("DATABASE_PATH"))
|
||||
config.set_main_option('sqlalchemy.url', get_settings().DATABASE_PATH)
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
|
||||
37
src/shared/settings.py
Normal file
37
src/shared/settings.py
Normal file
@@ -0,0 +1,37 @@
|
||||
|
||||
from functools import lru_cache
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import ConfigDict
|
||||
from typing import Annotated, Set
|
||||
from annotated_types import Len
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = ConfigDict(extra='ignore', str_strip_whitespace=True)
|
||||
|
||||
# general
|
||||
SERVE_LOCAL_ARCHIVE: str = ""
|
||||
USER_GROUPS_FILENAME: str = "user-groups.yaml"
|
||||
SHEET_ORCHESTRATION_YAML : str = "secrets/orchestration-sheet.yaml"
|
||||
|
||||
# database
|
||||
DATABASE_PATH: str
|
||||
DATABASE_QUERY_LIMIT: int = 100
|
||||
|
||||
# redis
|
||||
CELERY_BROKER_URL: str = "redis://localhost:6379"
|
||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379"
|
||||
REDIS_EXCEPTIONS_CHANNEL: str = "exceptions-channel"
|
||||
|
||||
# observability
|
||||
REPEAT_COUNT_METRICS_SECONDS: int = 30
|
||||
|
||||
# security
|
||||
API_BEARER_TOKEN: Annotated[str, Len(min_length=20)]
|
||||
ALLOWED_ORIGINS: Annotated[set[str], Len(min_length=1)]
|
||||
CHROME_APP_IDS: Annotated[set[Annotated[str, Len(min_length=10)]], Len(min_length=1)]
|
||||
BLOCKED_EMAILS: Annotated[Set[str], Len(min_length=0)] = set()
|
||||
|
||||
@lru_cache
|
||||
def get_settings():
|
||||
return Settings()
|
||||
BIN
src/static/favicon.ico
Normal file
BIN
src/static/favicon.ico
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 93 KiB |
94
src/tests/conftest.py
Normal file
94
src/tests/conftest.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import os
|
||||
from fastapi.testclient import TestClient
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from shared.settings import Settings
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_logger_add():
|
||||
"""Fixture to mock loguru.logger.add for all tests."""
|
||||
with patch('loguru.logger.add') as mock_add:
|
||||
yield mock_add # This makes the mock available to tests
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def get_settings():
|
||||
return Settings(_env_file=".env.test")
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_settings():
|
||||
with patch('shared.settings.Settings', return_value=Settings(_env_file=".env.test")) as mock_settings:
|
||||
yield mock_settings
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_db(get_settings: Settings):
|
||||
from db.database import make_engine
|
||||
from db import models
|
||||
|
||||
engine = make_engine(get_settings.DATABASE_PATH)
|
||||
|
||||
fs = get_settings.DATABASE_PATH.replace("sqlite:///", "")
|
||||
if not os.path.exists(fs):
|
||||
open(fs, 'w').close()
|
||||
|
||||
models.Base.metadata.create_all(engine)
|
||||
|
||||
connection = engine.connect()
|
||||
yield connection
|
||||
connection.close()
|
||||
|
||||
models.Base.metadata.drop_all(bind=engine)
|
||||
for suffix in ["", "-wal", "-shm"]:
|
||||
new_fs = fs + suffix
|
||||
if os.path.exists(new_fs):
|
||||
os.remove(new_fs)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def db_session(test_db):
|
||||
from db.database import make_session_local
|
||||
session_local = make_session_local(test_db)
|
||||
with session_local() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def app(db_session):
|
||||
from web.main import app_factory
|
||||
from db import crud
|
||||
app = app_factory()
|
||||
crud.upsert_user_groups(db_session)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client(app):
|
||||
client = TestClient(app)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def app_with_auth(app):
|
||||
from web.security import get_token_or_user_auth, get_user_auth, token_api_key_auth
|
||||
app.dependency_overrides[get_token_or_user_auth] = lambda: "rick@example.com"
|
||||
app.dependency_overrides[get_user_auth] = lambda: "morty@example.com"
|
||||
app.dependency_overrides[token_api_key_auth] = lambda: "jerry@example.com"
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client_with_auth(app_with_auth):
|
||||
client = TestClient(app_with_auth)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_no_auth():
|
||||
# reusable code to ensure a method/endpoint combination is unauthorized
|
||||
def no_auth(http_method, endpoint):
|
||||
response = http_method(endpoint)
|
||||
assert response.status_code == 403
|
||||
assert response.json() == {"detail": "Not authenticated"}
|
||||
return no_auth
|
||||
389
src/tests/db/test_crud.py
Normal file
389
src/tests/db/test_crud.py
Normal file
@@ -0,0 +1,389 @@
|
||||
from datetime import datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from db import models
|
||||
from shared.settings import Settings
|
||||
|
||||
authors = ["rick@example.com", "morty@example.com", "jerry@example.com"]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_data(db_session):
|
||||
|
||||
# creates 3 users
|
||||
for email in authors:
|
||||
db_session.add(models.User(email=email))
|
||||
db_session.commit()
|
||||
assert db_session.query(models.User).count() == 3
|
||||
|
||||
# creates 100 archives for 3 users over 2 months with repeating URLs
|
||||
for i in range(100):
|
||||
author = authors[i % 3]
|
||||
archive = models.Archive(
|
||||
id=f"archive-id-456-{i}",
|
||||
url=f"https://example-{i%3}.com",
|
||||
result={},
|
||||
public=author == "jerry@example.com",
|
||||
author_id=author,
|
||||
group_id="spaceship" if author == "morty@example.com" and i % 2 == 0 else None,
|
||||
created_at=datetime(2021, (i % 2) + 1, (i % 25) + 1)
|
||||
)
|
||||
if i % 5 == 0:
|
||||
archive.tags.append(models.Tag(id=f"tag-{i}"))
|
||||
if i % 10 == 0:
|
||||
archive.tags.append(models.Tag(id=f"tag-second-{i}"))
|
||||
if i % 4 == 0:
|
||||
archive.tags.append(models.Tag(id=f"tag-third-{i}"))
|
||||
for j in range(10):
|
||||
archive.urls.append(models.ArchiveUrl(url=f"https://example-{i}.com/{j}", key=f"media_{j}"))
|
||||
db_session.add(archive)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
assert db_session.query(models.Archive).count() == 100
|
||||
assert db_session.query(models.Tag).count() == 20 + 10 + 25
|
||||
assert db_session.query(models.ArchiveUrl).count() == 1000
|
||||
assert db_session.query(models.ArchiveUrl).filter(models.ArchiveUrl.archive_id == "archive-id-456-0").count() == 10
|
||||
|
||||
# setup groups
|
||||
assert db_session.query(models.Group).count() == 0
|
||||
from db import crud
|
||||
crud.upsert_user_groups(db_session)
|
||||
assert db_session.query(models.Group).count() == 3
|
||||
assert db_session.query(models.User).count() == 4
|
||||
|
||||
|
||||
def test_get_archive(test_data, db_session):
|
||||
from db import crud
|
||||
from core.config import ALLOW_ANY_EMAIL
|
||||
|
||||
print(db_session.query(models.Group).all())
|
||||
|
||||
# each author's archives work
|
||||
assert (a0 := crud.get_archive(db_session, "archive-id-456-0", authors[0])) is not None
|
||||
assert a0.id == "archive-id-456-0"
|
||||
assert a0.url == "https://example-0.com"
|
||||
assert a0.author_id == authors[0]
|
||||
assert a0.public == False
|
||||
|
||||
assert crud.get_archive(db_session, "archive-id-456-1", authors[1]) is not None
|
||||
assert crud.get_archive(db_session, "archive-id-456-2", authors[2]) is not None
|
||||
|
||||
# ALLOW_ANY_EMAIL
|
||||
assert crud.get_archive(db_session, "archive-id-456-0", ALLOW_ANY_EMAIL) is not None
|
||||
assert crud.get_archive(db_session, "archive-id-456-1", ALLOW_ANY_EMAIL) is not None
|
||||
|
||||
# not found
|
||||
assert crud.get_archive(db_session, "archive-missing", authors[0]) is None
|
||||
|
||||
# public
|
||||
assert (a_public := crud.get_archive(db_session, "archive-id-456-2", authors[0])) is not None
|
||||
assert a_public.public == True
|
||||
|
||||
# not public - rick's
|
||||
assert crud.get_archive(db_session, "archive-id-456-0", authors[1]) is None
|
||||
|
||||
|
||||
def test_search_archives_by_url(test_data, db_session):
|
||||
from db import crud
|
||||
from core.config import ALLOW_ANY_EMAIL
|
||||
|
||||
# rick's archives are private
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "rick@example.com")) == 34
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", ALLOW_ANY_EMAIL)) == 34
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "morty@example.com")) == 0
|
||||
|
||||
# morty's archives are public but half are in spaceship group
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-1.com", "rick@example.com")) == 16
|
||||
|
||||
# jerry's archives are public
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-2.com", "jerry@example.com")) == 33
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-2.com", "rick@example.com")) == 33
|
||||
|
||||
# fuzzy search
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL)) == 100
|
||||
assert len(crud.search_archives_by_url(db_session, "https://EXAMPLE", ALLOW_ANY_EMAIL)) == 100
|
||||
assert len(crud.search_archives_by_url(db_session, "2.com", ALLOW_ANY_EMAIL)) == 33
|
||||
|
||||
# absolute search
|
||||
assert len(crud.search_archives_by_url(db_session, "example-2.com", ALLOW_ANY_EMAIL, absolute_search=True)) == 0
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-2.com", ALLOW_ANY_EMAIL, absolute_search=True)) == 33
|
||||
|
||||
# archived_after
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, archived_after=datetime(2010, 1, 1))) == 100
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, archived_after=datetime(2021, 1, 15))) == 70
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, archived_after=datetime(2031, 1, 1))) == 0
|
||||
|
||||
# archived before
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, archived_before=datetime(2010, 1, 1))) == 0
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, archived_before=datetime(2021, 1, 15))) == 28
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, archived_before=datetime(2031, 1, 1))) == 100
|
||||
|
||||
# archived before and after
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, archived_after=datetime(2001, 1, 1), archived_before=datetime(2031, 1, 11))) == 100
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, archived_after=datetime(2021, 1, 14), archived_before=datetime(2021, 1, 16))) == 2
|
||||
|
||||
# limit
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, limit=10)) == 10
|
||||
|
||||
# skip
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, skip=10)) == 90
|
||||
|
||||
|
||||
def test_search_archives_by_email(test_data, db_session):
|
||||
from core.config import ALLOW_ANY_EMAIL
|
||||
from db import crud
|
||||
|
||||
# lower/upper case
|
||||
assert len(crud.search_archives_by_email(db_session, "rick@example.com")) == 34
|
||||
assert len(crud.search_archives_by_email(db_session, "RICK@example.com")) == 34
|
||||
|
||||
# ALLOW_ANY_EMAIL is not a user
|
||||
assert len(crud.search_archives_by_email(db_session, ALLOW_ANY_EMAIL)) == 0
|
||||
|
||||
# most recent first
|
||||
a1 = crud.search_archives_by_email(db_session, "rick@example.com", limit=1)
|
||||
assert len(a1) == 1
|
||||
assert a1[0].created_at == datetime(2021, 2, 25)
|
||||
|
||||
# earliest is the last
|
||||
a2 = crud.search_archives_by_email(db_session, "rick@example.com", skip=33)
|
||||
assert len(a2) == 1
|
||||
assert a2[0].created_at == datetime(2021, 1, 1)
|
||||
|
||||
|
||||
@patch("db.crud.DATABASE_QUERY_LIMIT", new=25)
|
||||
def test_max_query_limit(test_data, db_session):
|
||||
from db import crud
|
||||
from core.config import ALLOW_ANY_EMAIL
|
||||
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL)) == 25
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, limit=1000)) == 25
|
||||
|
||||
assert len(crud.search_archives_by_email(db_session, "rick@example.com")) == 25
|
||||
assert len(crud.search_archives_by_email(db_session, "rick@example.com", limit=1000)) == 25
|
||||
|
||||
|
||||
def test_create_task(db_session):
|
||||
from db import crud
|
||||
from db import schemas
|
||||
|
||||
task = schemas.ArchiveCreate(
|
||||
id="archive-id-456-101",
|
||||
url="https://example-0.com",
|
||||
result={},
|
||||
public=False,
|
||||
author_id="rick@example.com",
|
||||
group_id="spaceship",
|
||||
tags=[],
|
||||
urls=[]
|
||||
)
|
||||
|
||||
# with tags and urls
|
||||
nt = crud.create_task(db_session, task, [models.Tag(id="tag-101")], [models.ArchiveUrl(url="https://example-0.com/0", key="media_0")])
|
||||
|
||||
assert nt is not None
|
||||
assert nt.id == "archive-id-456-101"
|
||||
assert nt.url == "https://example-0.com"
|
||||
assert nt.author_id == "rick@example.com"
|
||||
assert nt.public == False
|
||||
assert nt.group_id == "spaceship"
|
||||
assert len(nt.tags) == 1
|
||||
assert nt.tags[0].id == "tag-101"
|
||||
assert len(nt.urls) == 1
|
||||
assert nt.urls[0].url == "https://example-0.com/0"
|
||||
assert nt.urls[0].key == "media_0"
|
||||
assert nt.created_at is not None
|
||||
|
||||
# without tags and urls
|
||||
task.id = "archive-id-456-102"
|
||||
nt = crud.create_task(db_session, task, [], [])
|
||||
assert nt is not None
|
||||
assert nt.id == "archive-id-456-102"
|
||||
assert nt.url == "https://example-0.com"
|
||||
assert nt.author_id == "rick@example.com"
|
||||
assert nt.public == False
|
||||
assert nt.group_id == "spaceship"
|
||||
assert len(nt.tags) == 0
|
||||
assert len(nt.urls) == 0
|
||||
assert nt.created_at is not None
|
||||
|
||||
|
||||
def test_soft_delete(test_data, db_session):
|
||||
from db import crud
|
||||
|
||||
# none deleted yet
|
||||
assert crud.get_archive(db_session, "archive-id-456-0", "rick@example.com") is not None
|
||||
assert db_session.query(models.Archive).filter(models.Archive.deleted == True).count() == 0
|
||||
|
||||
# delete
|
||||
assert crud.soft_delete_task(db_session, "archive-id-456-0", "rick@example.com") == True
|
||||
|
||||
# ensure soft delete
|
||||
assert db_session.query(models.Archive).filter(models.Archive.deleted == True).count() == 1
|
||||
assert crud.get_archive(db_session, "archive-id-456-0", "rick@example.com") is None
|
||||
|
||||
# already deleted
|
||||
assert crud.soft_delete_task(db_session, "archive-id-456-0", "rick@example.com") == False
|
||||
|
||||
|
||||
def test_count_archives(test_data, db_session):
|
||||
from db import crud
|
||||
|
||||
assert crud.count_archives(db_session) == 100
|
||||
db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").delete()
|
||||
db_session.commit()
|
||||
assert crud.count_archives(db_session) == 99
|
||||
|
||||
|
||||
def test_count_archive_urls(test_data, db_session):
|
||||
from db import crud
|
||||
|
||||
assert crud.count_archive_urls(db_session) == 1000
|
||||
db_session.query(models.ArchiveUrl).filter(models.ArchiveUrl.url == "https://example-0.com/0").delete()
|
||||
db_session.commit()
|
||||
assert crud.count_archive_urls(db_session) == 999
|
||||
|
||||
db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").delete()
|
||||
db_session.commit()
|
||||
# no Cascade is enabled
|
||||
assert crud.count_archives(db_session) == 99
|
||||
assert crud.count_archive_urls(db_session) == 999
|
||||
|
||||
def test_count_users(test_data, db_session):
|
||||
from db import crud
|
||||
|
||||
assert crud.count_users(db_session) == 4
|
||||
db_session.query(models.User).filter(models.User.email == "rick@example.com").delete()
|
||||
db_session.commit()
|
||||
assert crud.count_users(db_session) == 3
|
||||
|
||||
def test_count_by_users_since(test_data, db_session):
|
||||
from db import crud
|
||||
|
||||
# 100y window
|
||||
assert len(cu := crud.count_by_user_since(db_session, 60 * 60 * 24 * 31 * 12 * 100)) == 3
|
||||
assert cu[0].total == 34
|
||||
assert cu[1].total == 33
|
||||
assert cu[2].total == 33
|
||||
|
||||
|
||||
def test_create_tag(db_session):
|
||||
from db import crud
|
||||
|
||||
assert db_session.query(models.Tag).count() == 0
|
||||
|
||||
# create first
|
||||
create_tag = crud.create_tag(db_session, "tag-101")
|
||||
assert create_tag is not None
|
||||
assert create_tag.id == "tag-101"
|
||||
assert db_session.query(models.Tag).count() == 1
|
||||
assert db_session.query(models.Tag).filter(models.Tag.id == "tag-101").first() == create_tag
|
||||
|
||||
# same id does not add new db entry
|
||||
existing_tag = crud.create_tag(db_session, "tag-101")
|
||||
assert existing_tag == create_tag
|
||||
assert db_session.query(models.Tag).count() == 1
|
||||
|
||||
# create second
|
||||
second_tag = crud.create_tag(db_session, "tag-102")
|
||||
assert second_tag is not None
|
||||
assert second_tag.id == "tag-102"
|
||||
assert db_session.query(models.Tag).count() == 2
|
||||
|
||||
|
||||
def test_is_user_in_group(test_data, db_session):
|
||||
from db import crud
|
||||
from core.config import ALLOW_ANY_EMAIL
|
||||
|
||||
# see user-groups.test.yaml
|
||||
test_pairs = [
|
||||
(ALLOW_ANY_EMAIL, "spaceship", True),
|
||||
(ALLOW_ANY_EMAIL, "non-existant!@#!%!", True),
|
||||
|
||||
("rick@example.com", "spaceship", True),
|
||||
("rick@example.com", "SPACESHIP", False),
|
||||
("RICK@example.com", "interdimensional", True),
|
||||
("rick@example.com", "the-jerrys-club", False),
|
||||
|
||||
("morty@example.com", "spaceship", True),
|
||||
("morty@example.com", "interdimensional", False),
|
||||
("morty@example.com", "the-jerrys-club", False),
|
||||
|
||||
("jerry@example.com", "spaceship", False),
|
||||
("jerry@example.com", "interdimensional", False),
|
||||
("jerry@example.com", "the-jerrys-club", True),
|
||||
|
||||
("rick@example.com", "animated-characters", True),
|
||||
("morty@example.com", "animated-characters", True),
|
||||
("jerry@example.com", "animated-characters", True),
|
||||
|
||||
("rick@example.com", "", False),
|
||||
("", "spaceship", False),
|
||||
("BADEMAILexample.com", "spaceship", False),
|
||||
]
|
||||
for email, group, expected in test_pairs:
|
||||
assert crud.is_user_in_group(db_session, group, email) == expected
|
||||
|
||||
|
||||
def test_create_or_get_user(test_data, db_session):
|
||||
from db import crud
|
||||
|
||||
assert db_session.query(models.User).count() == 4
|
||||
|
||||
assert (u1 := crud.create_or_get_user(db_session, "rick@example.com")) is not None
|
||||
assert u1.email == "rick@example.com"
|
||||
assert u1.is_active == True
|
||||
|
||||
assert (u2 := crud.create_or_get_user(db_session, "beth@example.com")) is not None
|
||||
assert u2.email == "beth@example.com"
|
||||
assert u2.is_active == True
|
||||
|
||||
assert db_session.query(models.User).count() == 5
|
||||
|
||||
|
||||
def test_get_group(test_data, db_session):
|
||||
from db import crud
|
||||
|
||||
assert db_session.query(models.Group).count() == 3
|
||||
|
||||
assert (g1 := crud.create_or_get_group(db_session, "spaceship")) is not None
|
||||
assert g1.id == "spaceship"
|
||||
assert len(g1.users) == 2
|
||||
assert [u.email for u in g1.users] == ["rick@example.com", "morty@example.com"]
|
||||
|
||||
assert (g2 := crud.create_or_get_group(db_session, "the-jerrys-club")) is not None
|
||||
assert g2.id == "the-jerrys-club"
|
||||
assert len(g2.users) == 1
|
||||
assert g2.users[0].email == "jerry@example.com"
|
||||
|
||||
assert (g3 := crud.create_or_get_group(db_session, "this-is-a-new-group")) is not None
|
||||
assert g3.id == "this-is-a-new-group"
|
||||
assert len(g3.users) == 0
|
||||
|
||||
assert db_session.query(models.Group).count() == 4
|
||||
|
||||
|
||||
def test_upsert_user_groups(db_session):
|
||||
from db import crud
|
||||
|
||||
@patch('db.crud.get_settings', new = lambda: bad_setings)
|
||||
def test_missing_yaml(db_session):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
crud.upsert_user_groups(db_session)
|
||||
|
||||
|
||||
@patch('db.crud.get_settings', new = lambda: bad_setings)
|
||||
def test_broken_yaml(db_session):
|
||||
with pytest.raises(yaml.YAMLError):
|
||||
crud.upsert_user_groups(db_session)
|
||||
|
||||
bad_setings = Settings(_env_file=".env.test")
|
||||
|
||||
bad_setings.USER_GROUPS_FILENAME = "tests/user-groups.test.missing.yaml"
|
||||
test_missing_yaml(db_session)
|
||||
|
||||
bad_setings.USER_GROUPS_FILENAME = "tests/user-groups.test.broken.yaml"
|
||||
test_broken_yaml(db_session)
|
||||
6
src/tests/db/test_models.py
Normal file
6
src/tests/db/test_models.py
Normal file
@@ -0,0 +1,6 @@
|
||||
def test_generate_uuid():
|
||||
from db.models import generate_uuid
|
||||
|
||||
assert generate_uuid() != generate_uuid()
|
||||
assert len(generate_uuid()) == 36
|
||||
assert generate_uuid().count("-") == 4
|
||||
120
src/tests/endpoints/test_default.py
Normal file
120
src/tests/endpoints/test_default.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from fastapi.testclient import TestClient
|
||||
import pytest
|
||||
from core.config import VERSION
|
||||
|
||||
|
||||
def test_endpoint_home(client_with_auth):
|
||||
r = client_with_auth.get("/")
|
||||
assert r.status_code == 200
|
||||
j = r.json()
|
||||
assert "version" in j and j["version"] == VERSION
|
||||
assert "breakingChanges" in j
|
||||
assert "groups" not in j
|
||||
|
||||
|
||||
@patch("endpoints.default.bearer_security", new_callable=AsyncMock)
|
||||
@patch("endpoints.default.get_user_auth", new_callable=AsyncMock, return_value="test@example.com")
|
||||
@patch("endpoints.default.crud.get_user_groups", return_value=["group1", "group2"])
|
||||
def test_endpoint_home_with_groups(m1, m2, m3, client_with_auth):
|
||||
r = client_with_auth.get("/")
|
||||
assert r.status_code == 200
|
||||
j = r.json()
|
||||
assert "version" in j and j["version"] == VERSION
|
||||
assert "breakingChanges" in j
|
||||
assert "groups" in j
|
||||
assert j["groups"] == ["group1", "group2"]
|
||||
|
||||
|
||||
@patch("endpoints.default.bearer_security", new_callable=AsyncMock)
|
||||
@patch("endpoints.default.get_user_auth", new_callable=AsyncMock, return_value="test@example.com")
|
||||
@patch("endpoints.default.crud.get_user_groups", side_effect=Exception('mocked error'))
|
||||
def test_endpoint_home_with_groups_exception(m1, m2, m3, client_with_auth): # mocks call that triggers an internal error
|
||||
r = client_with_auth.get("/")
|
||||
assert r.status_code == 200
|
||||
j = r.json()
|
||||
assert "version" in j and j["version"] == VERSION
|
||||
assert "breakingChanges" in j
|
||||
assert "groups" not in j
|
||||
|
||||
|
||||
def test_endpoint_health(client_with_auth):
|
||||
r = client_with_auth.get("/health")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"status": "ok"}
|
||||
|
||||
|
||||
def test_endpoint_groups_no_auth(client, test_no_auth):
|
||||
test_no_auth(client.get, "/groups")
|
||||
|
||||
|
||||
def test_endpoint_groups_rick_and_morty(client_with_auth):
|
||||
r = client_with_auth.get("/groups")
|
||||
assert r.status_code == 200
|
||||
assert len(j := r.json()) == 2
|
||||
assert 'animated-characters' in j
|
||||
assert 'spaceship' in j
|
||||
|
||||
|
||||
@patch("endpoints.default.crud.get_user_groups", return_value=["group1", "group2"])
|
||||
def test_endpoint_groups(m1, app):
|
||||
from web.security import get_user_auth
|
||||
app.dependency_overrides[get_user_auth] = lambda: True
|
||||
client = TestClient(app)
|
||||
|
||||
r = client.get("/groups")
|
||||
|
||||
assert r.status_code == 200
|
||||
assert r.json() == ["group1", "group2"]
|
||||
|
||||
|
||||
def test_no_serve_local_archive_by_default(client_with_auth):
|
||||
r = client_with_auth.get("/app/local_archive_test/temp.txt")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_favicon(client_with_auth):
|
||||
r = client_with_auth.get("/favicon.ico")
|
||||
assert r.status_code == 200
|
||||
assert r.headers["content-type"] == "image/vnd.microsoft.icon"
|
||||
|
||||
|
||||
from tests.db.test_crud import test_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prometheus_metrics(test_data, client_with_auth, get_settings):
|
||||
# before metrics calculation
|
||||
r = client_with_auth.get("/metrics")
|
||||
assert r.status_code == 200
|
||||
assert r.headers["content-type"] == "text/plain; version=0.0.4; charset=utf-8"
|
||||
assert "disk_utilization" in r.text
|
||||
assert "database_metrics" in r.text
|
||||
assert "exceptions" in r.text
|
||||
assert "worker_exceptions_total" in r.text
|
||||
assert 'disk_utilization{type="used"}' not in r.text
|
||||
|
||||
# after metrics calculation
|
||||
from utils.metrics import measure_regular_metrics
|
||||
await measure_regular_metrics(get_settings.DATABASE_PATH, 60 * 60 * 24 * 31 * 12 * 100)
|
||||
r2 = client_with_auth.get("/metrics")
|
||||
assert 'disk_utilization{type="used"}' in r2.text
|
||||
assert 'disk_utilization{type="free"}' in r2.text
|
||||
assert 'disk_utilization{type="database"}' in r2.text
|
||||
assert 'database_metrics{query="count_archives"} 100.0' in r2.text
|
||||
assert 'database_metrics{query="count_archive_urls"} 1000.0' in r2.text
|
||||
assert 'database_metrics{query="count_users"} 4.0' in r2.text
|
||||
assert 'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0' in r2.text
|
||||
assert 'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0' in r2.text
|
||||
assert 'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0' in r2.text
|
||||
|
||||
# 30s window, should not change the gauges nor the total in the counters
|
||||
from utils.metrics import measure_regular_metrics
|
||||
await measure_regular_metrics(get_settings.DATABASE_PATH, 30)
|
||||
r3 = client_with_auth.get("/metrics")
|
||||
assert 'database_metrics{query="count_archives"} 100.0' in r3.text
|
||||
assert 'database_metrics{query="count_archive_urls"} 1000.0' in r3.text
|
||||
assert 'database_metrics{query="count_users"} 4.0' in r3.text
|
||||
assert 'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0' in r3.text
|
||||
assert 'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0' in r3.text
|
||||
assert 'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0' in r3.text
|
||||
19
src/tests/endpoints/test_interopreability.py
Normal file
19
src/tests/endpoints/test_interopreability.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import json
|
||||
|
||||
|
||||
def test_submit_manual_archive_unauthenticated(client, test_no_auth):
|
||||
test_no_auth(client.post, "/interop/submit-archive")
|
||||
|
||||
|
||||
def test_submit_manual_archive(client_with_auth):
|
||||
aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": []})
|
||||
|
||||
r = client_with_auth.post("/interop/submit-archive", json={"result": aa_metadata, "public": False, "author_id": "jerry@gmail.com", "group_id": None, "tags": ["test"]})
|
||||
assert r.status_code == 201
|
||||
assert "id" in r.json()
|
||||
|
||||
# cannot have the same URL twice
|
||||
aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.com", "http://example.com"]}]})
|
||||
r = client_with_auth.post("/interop/submit-archive", json={"result": aa_metadata, "public": False, "author_id": "jerry@gmail.com", "group_id": None, "tags": ["test"]})
|
||||
assert r.status_code == 422
|
||||
assert r.json() == {"detail": "Cannot insert into DB due to integrity error"}
|
||||
46
src/tests/endpoints/test_sheet.py
Normal file
46
src/tests/endpoints/test_sheet.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
from db.schemas import TaskResult
|
||||
|
||||
|
||||
def test_sheet_no_auth(client, test_no_auth):
|
||||
test_no_auth(client.post, "/sheet/archive")
|
||||
|
||||
|
||||
@patch("worker.main.create_sheet_task.delay", return_value=TaskResult(id="123-456-789", status="PENDING", result=""))
|
||||
def test_sheet_rick(m1, client_with_auth):
|
||||
|
||||
response = client_with_auth.post("/sheet/archive", json={"sheet_id": "123-sheet-id"})
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {'id': '123-456-789'}
|
||||
|
||||
m1.assert_called_once()
|
||||
called_val = m1.call_args.args[0]
|
||||
assert json.loads(called_val) == {"sheet_id": "123-sheet-id", "sheet_name": None, "public": False, "author_id": "rick@example.com", "group_id": None, "tags": [], "columns": {}, "header": 1}
|
||||
|
||||
|
||||
def test_sheet_missing_sheet_data(client_with_auth):
|
||||
r = client_with_auth.post("/sheet/archive", json={})
|
||||
assert r.status_code == 422
|
||||
assert r.json() == {"detail": "sheet name or id is required"}
|
||||
|
||||
|
||||
@patch("worker.main.create_sheet_task.delay", return_value=TaskResult(id="123-API-789", status="PENDING", result=""))
|
||||
def test_sheet_api(m1, client):
|
||||
|
||||
response = client.post("/sheet/archive", json={"sheet_name": "456-sheet_name-id"}, headers={"Authorization": "Bearer this_is_the_test_api_token"})
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {'id': '123-API-789'}
|
||||
|
||||
m1.assert_called_once()
|
||||
called_val = m1.call_args.args[0]
|
||||
assert json.loads(called_val) == {"sheet_name": "456-sheet_name-id", "sheet_id": None, "public": False, "author_id": "api-endpoint", "group_id": None, "tags": [], "columns": {}, "header": 1}
|
||||
|
||||
response = client.post("/sheet/archive", json={"sheet_id": "456-sheet-id", "author_id": "custom-author"}, headers={"Authorization": "Bearer this_is_the_test_api_token"})
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {'id': '123-API-789'}
|
||||
|
||||
assert m1.call_count == 2
|
||||
called_val = m1.call_args.args[0]
|
||||
assert json.loads(called_val) == {"sheet_id": "456-sheet-id", "sheet_name": None, "public": False, "author_id": "custom-author", "group_id": None, "tags": [], "columns": {}, "header": 1}
|
||||
51
src/tests/endpoints/test_task.py
Normal file
51
src/tests/endpoints/test_task.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
def test_endpoint_task_status_no_auth(client, test_no_auth):
|
||||
test_no_auth(client.get, "/task/test-task-id")
|
||||
|
||||
|
||||
@patch("endpoints.task.AsyncResult")
|
||||
def test_get_status_success(mock_async_result, client_with_auth):
|
||||
mock_async_result.return_value.status = "SUCCESS"
|
||||
mock_async_result.return_value.result = {"data": "some result"}
|
||||
|
||||
response = client_with_auth.get("/task/test-task-id")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"id": "test-task-id",
|
||||
"status": "SUCCESS",
|
||||
"result": {"data": "some result"}
|
||||
}
|
||||
|
||||
|
||||
@patch("endpoints.task.AsyncResult")
|
||||
def test_get_status_failure(mock_async_result, client_with_auth):
|
||||
|
||||
mock_async_result.return_value.status = "FAILURE"
|
||||
mock_async_result.return_value.result = Exception("Some error")
|
||||
|
||||
response = client_with_auth.get("/task/test-task-id")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"id": "test-task-id",
|
||||
"status": "FAILURE",
|
||||
"result": {"error": "Some error"}
|
||||
}
|
||||
|
||||
|
||||
@patch("endpoints.task.AsyncResult")
|
||||
def test_get_status_pending(mock_async_result, client_with_auth):
|
||||
mock_async_result.return_value.status = "PENDING"
|
||||
mock_async_result.return_value.result = None
|
||||
|
||||
response = client_with_auth.get("/task/test-task-id")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"id": "test-task-id",
|
||||
"status": "PENDING",
|
||||
"result": None
|
||||
}
|
||||
145
src/tests/endpoints/test_url.py
Normal file
145
src/tests/endpoints/test_url.py
Normal file
@@ -0,0 +1,145 @@
|
||||
|
||||
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
from db.schemas import ArchiveCreate, TaskResult
|
||||
|
||||
def test_archive_url_unauthenticated(client, test_no_auth):
|
||||
test_no_auth(client.post, "/url/archive")
|
||||
test_no_auth(client.get, "/url/archive")
|
||||
|
||||
|
||||
@patch("worker.main.create_archive_task.delay", return_value=TaskResult(id="123-456-789", status="PENDING", result=""))
|
||||
def test_archive_url(m1, client_with_auth):
|
||||
response = client_with_auth.post("/url/archive", json={"url": "bad"})
|
||||
assert response.status_code == 422
|
||||
assert response.json() == {'detail': 'Invalid URL received: bad'}
|
||||
m1.assert_not_called()
|
||||
|
||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {'id': '123-456-789'}
|
||||
|
||||
m1.assert_called_once()
|
||||
called_val = m1.call_args.args[0]
|
||||
assert json.loads(called_val) == {"id": None, "url": "https://example.com", "result": None, "public": True, "author_id": "rick@example.com", "group_id": None, "tags": [], "rearchive": True}
|
||||
|
||||
|
||||
def test_search_by_url_unauthenticated(client, test_no_auth):
|
||||
test_no_auth(client.get, "/url/search")
|
||||
|
||||
|
||||
def test_search_by_url(client_with_auth, db_session):
|
||||
# tests the search endpoint, including through some db data for the endpoint params
|
||||
response = client_with_auth.get("/url/search")
|
||||
assert response.status_code == 422
|
||||
assert response.json()["detail"][0]["msg"] == "Field required"
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
from db import crud
|
||||
for i in range(11):
|
||||
crud.create_task(db_session, ArchiveCreate(id=f"url-456-{i}", url="https://example.com" if i < 10 else "https://something-else.com", result={}, public=True, author_id="rick@example.com", group_id=None), [], [])
|
||||
#NB: this insertion is too fast for the ordering to be correct as they are within the same second
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com")
|
||||
assert response.status_code == 200
|
||||
assert len(j := response.json()) == 10
|
||||
assert "url-456-0" in [i["id"] for i in j]
|
||||
assert "url-456-9" in [i["id"] for i in j]
|
||||
assert "url-456-10" not in [i["id"] for i in j]
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com&limit=5")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 5
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com&skip=5&limit=2")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 2
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com&archived_before=2010-01-01")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 0
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com&archived_after=2010-01-01")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 10
|
||||
|
||||
|
||||
def test_latest_unauthenticated(client, test_no_auth):
|
||||
test_no_auth(client.get, "/url/latest")
|
||||
|
||||
|
||||
def test_latest(client_with_auth, db_session):
|
||||
response = client_with_auth.get("/url/latest")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
from db import crud
|
||||
for i in range(11):
|
||||
crud.create_task(db_session, ArchiveCreate(id=f"latest-456-{i}", url="https://example.com", result={}, public=True, author_id="morty@example.com" if i < 10 else "rick@example.com", group_id=None), [], [])
|
||||
#NB: this insertion is too fast for the ordering to be correct as they are within the same second
|
||||
|
||||
# user must exist for /latest to work
|
||||
crud.create_or_get_user(db_session, "morty@example.com")
|
||||
|
||||
response = client_with_auth.get("/url/latest")
|
||||
assert response.status_code == 200
|
||||
assert len(j := response.json()) == 10
|
||||
assert "latest-456-0" in [i["id"] for i in j]
|
||||
assert "latest-456-9" in [i["id"] for i in j]
|
||||
assert "latest-456-10" not in [i["id"] for i in j]
|
||||
|
||||
response = client_with_auth.get("/url/latest?limit=5")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 5
|
||||
|
||||
response = client_with_auth.get("/url/latest?skip=5&limit=2")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 2
|
||||
|
||||
|
||||
def test_lookup_unauthenticated(client, test_no_auth):
|
||||
test_no_auth(client.get, "/url/123-456-789")
|
||||
|
||||
|
||||
def test_lookup(client_with_auth, db_session):
|
||||
response = client_with_auth.get("/url/lookup-123-456-789")
|
||||
assert response.status_code == 404
|
||||
assert response.json() == {"detail": "Archive not found"}
|
||||
|
||||
from db import crud
|
||||
crud.create_task(db_session, ArchiveCreate(id="lookup-123-456-789", url="https://example.com", result={}, public=True, author_id="rick@example.com", group_id=None), [], [])
|
||||
|
||||
response = client_with_auth.get("/url/lookup-123-456-789")
|
||||
assert response.status_code == 200
|
||||
j = response.json()
|
||||
assert j["id"] == "lookup-123-456-789"
|
||||
assert j["url"] == "https://example.com"
|
||||
assert j["result"] == {}
|
||||
assert j["public"] == True
|
||||
assert j["author_id"] == "rick@example.com"
|
||||
assert j["group_id"] == None
|
||||
assert j["tags"] == []
|
||||
assert j["updated_at"] == None
|
||||
assert j["rearchive"] == True
|
||||
|
||||
|
||||
def test_delete_task_unauthenticated(client, test_no_auth):
|
||||
test_no_auth(client.delete, "/url/123-456-789")
|
||||
|
||||
|
||||
def test_delete_task(client_with_auth, db_session):
|
||||
response = client_with_auth.delete("/url/delete-123-456-789")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"id": "delete-123-456-789", "deleted": False}
|
||||
|
||||
from db import crud
|
||||
crud.create_task(db_session, ArchiveCreate(id="delete-123-456-789", url="https://example.com", result={}, public=True, author_id="morty@example.com", group_id=None), [], [])
|
||||
|
||||
response = client_with_auth.delete("/url/delete-123-456-789")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"id": "delete-123-456-789", "deleted": True}
|
||||
24
src/tests/orchestration.test.yaml
Normal file
24
src/tests/orchestration.test.yaml
Normal file
@@ -0,0 +1,24 @@
|
||||
steps:
|
||||
feeder: cli_feeder
|
||||
archivers: # order matters
|
||||
- youtubedl_archiver
|
||||
enrichers:
|
||||
- hash_enricher
|
||||
|
||||
formatter: html_formatter # defaults to mute_formatter
|
||||
storages:
|
||||
- local_storage
|
||||
databases:
|
||||
- console_db
|
||||
|
||||
configurations:
|
||||
cli_feeder:
|
||||
urls:
|
||||
- "url1"
|
||||
hash_enricher:
|
||||
algorithm: "SHA-256"
|
||||
local_storage:
|
||||
save_to: "./local_archive"
|
||||
save_absolute: true
|
||||
filename_generator: static
|
||||
path_generator: flat
|
||||
6
src/tests/user-groups.test.broken.yaml
Normal file
6
src/tests/user-groups.test.broken.yaml
Normal file
@@ -0,0 +1,6 @@
|
||||
broken: True
|
||||
This is just an invalid yaml for testing
|
||||
|
||||
still broken: True
|
||||
- one
|
||||
- two
|
||||
19
src/tests/user-groups.test.yaml
Normal file
19
src/tests/user-groups.test.yaml
Normal file
@@ -0,0 +1,19 @@
|
||||
# NOTE: all emails should be lower-cased
|
||||
users:
|
||||
rick@example.com:
|
||||
- spaceship
|
||||
- interdimensional
|
||||
morty@example.com:
|
||||
- spaceship
|
||||
jerry@example.com:
|
||||
- the-jerrys-club
|
||||
birdman@example.com:
|
||||
|
||||
domains:
|
||||
example.com:
|
||||
- animated-characters
|
||||
|
||||
orchestrators:
|
||||
spaceship: tests/orchestration.test.yaml
|
||||
interdimensional: tests/orchestration.test.yaml
|
||||
default: tests/orchestration.test.yaml
|
||||
49
src/tests/web/test_main.py
Normal file
49
src/tests/web/test_main.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
|
||||
def test_lifespan(app):
|
||||
with TestClient(app) as client:
|
||||
r = client.get("/health")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"status": "ok"}
|
||||
|
||||
def test_alembic(db_session):
|
||||
import alembic.config
|
||||
alembic.config.main(argv=['--raiseerr', 'upgrade', 'head'])
|
||||
alembic.config.main(argv=['--raiseerr', 'downgrade', 'base'])
|
||||
|
||||
@patch("endpoints.default.crud.get_user_groups", side_effect=Exception('mocked error'))
|
||||
def test_logging_middleware(m1, client_with_auth):
|
||||
from utils.metrics import EXCEPTION_COUNTER
|
||||
assert len(EXCEPTION_COUNTER.collect()[0].samples) == 0
|
||||
with pytest.raises(Exception, match="mocked error"):
|
||||
client_with_auth.get("/groups")
|
||||
# creates one empty and one from above
|
||||
assert len(EXCEPTION_COUNTER.collect()[0].samples) == 2
|
||||
|
||||
|
||||
def test_serve_local_archive_logic(get_settings):
|
||||
# create a test file first
|
||||
os.makedirs("local_archive_test", exist_ok=True)
|
||||
with open("local_archive_test/temp.txt", "w") as f:
|
||||
f.write("test")
|
||||
|
||||
try:
|
||||
# modify the settings
|
||||
get_settings.SERVE_LOCAL_ARCHIVE = "/app/local_archive_test"
|
||||
from web.main import app_factory
|
||||
app = app_factory(get_settings)
|
||||
|
||||
# test
|
||||
client = TestClient(app)
|
||||
r = client.get("/app/local_archive_test/temp.txt")
|
||||
assert r.status_code == 200
|
||||
assert r.text == "test"
|
||||
finally:
|
||||
# cleanup
|
||||
shutil.rmtree("local_archive_test")
|
||||
108
src/tests/web/test_security.py
Normal file
108
src/tests/web/test_security.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
import pytest
|
||||
|
||||
from core.config import ALLOW_ANY_EMAIL
|
||||
|
||||
|
||||
def test_secure_compare():
|
||||
from web.security import secure_compare
|
||||
|
||||
assert secure_compare("test", "test")
|
||||
assert not secure_compare("test", "test2")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_token_or_user_auth_with_api():
|
||||
from web.security import get_token_or_user_auth
|
||||
mock_api = HTTPAuthorizationCredentials(scheme="lorem", credentials="this_is_the_test_api_token")
|
||||
assert await get_token_or_user_auth(mock_api) == ALLOW_ANY_EMAIL
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_token_or_user_auth_with_user():
|
||||
from web.security import get_token_or_user_auth
|
||||
bad_user = HTTPAuthorizationCredentials(scheme="ipsum", credentials="invalid")
|
||||
e: pytest.ExceptionInfo = None
|
||||
with pytest.raises(HTTPException) as e:
|
||||
await get_token_or_user_auth(bad_user)
|
||||
assert e.value.status_code == 401
|
||||
assert e.value.detail == "invalid access_token"
|
||||
|
||||
|
||||
@patch("web.security.authenticate_user", return_value=(True, "summer@example.com"))
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_auth(m1):
|
||||
from web.security import get_user_auth
|
||||
bad_user = HTTPAuthorizationCredentials(scheme="ipsum", credentials="valid-and-good")
|
||||
assert await get_user_auth(bad_user) == "summer@example.com"
|
||||
|
||||
|
||||
@patch("web.security.secure_compare", return_value=False)
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_api_key_auth_exception(m1):
|
||||
from web.security import token_api_key_auth
|
||||
|
||||
e: pytest.ExceptionInfo = None
|
||||
with pytest.raises(HTTPException) as e:
|
||||
await token_api_key_auth(HTTPAuthorizationCredentials(scheme="ipsum", credentials="does-not-matter"), auto_error=True)
|
||||
assert e.value.status_code == 401
|
||||
assert e.value.detail == "Wrong auth credentials"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_user():
|
||||
from web.security import authenticate_user
|
||||
|
||||
assert authenticate_user("test") == (False, "invalid access_token")
|
||||
assert authenticate_user(123) == (False, "invalid access_token")
|
||||
|
||||
with patch("web.security.requests.get") as mock_get:
|
||||
# bad response from oauth2
|
||||
mock_get.return_value.status_code = 403
|
||||
assert authenticate_user("this-will-call-requests") == (False, "invalid token")
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
# 200 but invalid json
|
||||
mock_get.return_value.status_code = 200
|
||||
assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID")
|
||||
assert mock_get.call_count == 2
|
||||
|
||||
# 200 but invalid azp and aud
|
||||
mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "not_an_app"}
|
||||
assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID")
|
||||
|
||||
mock_get.return_value.json.return_value = {"email": "summer@example.com", "aud": "not_an_app"}
|
||||
assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID")
|
||||
|
||||
mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "not_an_app", "aud": "not_an_app"}
|
||||
assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID")
|
||||
|
||||
# blocked email
|
||||
mock_get.return_value.json.return_value = {"email": "blocked@example.com", "azp": "test_app_id_1", "aud": "not_an_app"}
|
||||
assert authenticate_user("this-will-call-requests") == (False, "email 'blocked@example.com' not allowed")
|
||||
|
||||
# not verified
|
||||
mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "not_an_app", "aud": "test_app_id_1"}
|
||||
assert authenticate_user("this-will-call-requests") == (False, "email 'summer@example.com' not verified")
|
||||
|
||||
# token expired
|
||||
mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "test_app_id_2", "email_verified": "true"}
|
||||
assert authenticate_user("this-will-call-requests") == (False, "Token expired")
|
||||
|
||||
# 200 and valid azp and aup and verified
|
||||
mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "test_app_id_2", "email_verified": "true", "expires_in": 100}
|
||||
assert authenticate_user("this-will-call-requests") == (True, "summer@example.com")
|
||||
assert mock_get.call_count == 9
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_user_exception():
|
||||
from web.security import authenticate_user
|
||||
|
||||
with patch("web.security.requests.get") as mock_get:
|
||||
mock_get.return_value.status_code = 200
|
||||
mock_get.return_value.json.side_effect = Exception("mocked error")
|
||||
assert authenticate_user("this-will-call-requests") == (False, "exception occurred")
|
||||
196
src/tests/worker/test_worker_main.py
Normal file
196
src/tests/worker/test_worker_main.py
Normal file
@@ -0,0 +1,196 @@
|
||||
from unittest import mock
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from db import models, schemas
|
||||
from auto_archiver import Metadata
|
||||
from auto_archiver.core import Media
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def worker_init():
|
||||
from worker.main import at_start
|
||||
at_start(None)
|
||||
|
||||
|
||||
class Test_create_archive_task():
|
||||
URL = "https://example-live.com"
|
||||
archive = schemas.ArchiveCreate(url=URL, tags=[], public=True, group_id=None, author_id="rick@example.com")
|
||||
|
||||
@patch("worker.main.insert_result_into_db")
|
||||
@patch("worker.main.is_group_invalid_for_user", return_value=None)
|
||||
@patch("worker.main.choose_orchestrator")
|
||||
@patch("celery.app.task.Task.request")
|
||||
def test_success(self, m_req, m_choose, m_is_group, m_insert, worker_init, db_session):
|
||||
from worker.main import create_archive_task
|
||||
|
||||
m_req.id = "this-just-in"
|
||||
mock_orchestrator = self.mock_orchestrator_choice(m_choose)
|
||||
|
||||
task = create_archive_task(self.archive.model_dump_json())
|
||||
|
||||
m_choose.assert_called_once()
|
||||
mock_orchestrator.feed_item.assert_called_once()
|
||||
|
||||
assert task["status"] == "success"
|
||||
assert task["metadata"]["url"] == self.URL
|
||||
assert len(task["media"]) == 0
|
||||
|
||||
@patch("worker.main.is_group_invalid_for_user", return_value=True)
|
||||
def test_raise_invalid(self, m_is_group, worker_init):
|
||||
from worker.main import create_archive_task
|
||||
with pytest.raises(Exception):
|
||||
create_archive_task(self.archive.model_dump_json())
|
||||
|
||||
@patch("worker.main.insert_result_into_db", side_effect=Exception)
|
||||
@patch("worker.main.is_group_invalid_for_user", return_value=False)
|
||||
@patch("worker.main.choose_orchestrator")
|
||||
def test_raise_db_error(self, m_choose, m_is_group, m_insert, worker_init):
|
||||
from worker.main import create_archive_task
|
||||
mock_orchestrator = self.mock_orchestrator_choice(m_choose)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
create_archive_task(self.archive.model_dump_json())
|
||||
mock_orchestrator.feed_item.assert_called_once()
|
||||
|
||||
def mock_orchestrator_choice(self, m_choose):
|
||||
mock_orchestrator = mock.MagicMock()
|
||||
mock_orchestrator.configure_mock(feed_item=mock.MagicMock(return_value=Metadata().set_url(self.URL).success()))
|
||||
m_choose.return_value = mock_orchestrator
|
||||
return mock_orchestrator
|
||||
|
||||
|
||||
class Test_create_sheet_task():
|
||||
URL = "https://example-live.com"
|
||||
sheet = schemas.SubmitSheet(sheet_name="Sheet", sheet_id="123", author_id="rick@example.com", group_id=None)
|
||||
|
||||
# @patch("worker.main.insert_result_into_db")
|
||||
@patch("worker.main.models.generate_uuid", return_value="constant-uuid")
|
||||
@patch("worker.main.is_group_invalid_for_user", return_value=False)
|
||||
@patch("worker.main.ArchivingOrchestrator")
|
||||
def test_success(self, m_orch_generator, m_is_group, m_uuid, worker_init, db_session):
|
||||
from worker.main import create_sheet_task
|
||||
|
||||
assert db_session.query(models.Archive).filter(models.Archive.url == self.URL).count() == 0
|
||||
|
||||
mock_metadata = Metadata().set_url(self.URL).success()
|
||||
mock_metadata.add_media(Media("fn1.txt", urls=["outcome1.com"]))
|
||||
m_orch = MagicMock()
|
||||
m_orch.feed.return_value = iter([False, mock_metadata, mock_metadata])
|
||||
m_orch_generator.return_value = m_orch
|
||||
|
||||
res = create_sheet_task(self.sheet.model_dump_json())
|
||||
print(res)
|
||||
assert res["archived"] == 1
|
||||
assert res["failed"] == 0
|
||||
assert len(res["errors"]) == 0
|
||||
assert res["sheet"] == "Sheet"
|
||||
assert res["sheet_id"] == "123"
|
||||
assert res["success"] == True
|
||||
assert len(res["time"]) > 0
|
||||
|
||||
# query created archive entry
|
||||
inserted = db_session.query(models.Archive).filter(models.Archive.url == self.URL).one()
|
||||
assert inserted is not None
|
||||
assert inserted.url == self.URL
|
||||
assert inserted.tags[0].id == "gsheet"
|
||||
|
||||
@patch("worker.main.insert_result_into_db", side_effect=Exception("some-error"))
|
||||
@patch("worker.main.models.generate_uuid", return_value="constant-uuid")
|
||||
@patch("worker.main.is_group_invalid_for_user", return_value=False)
|
||||
@patch("worker.main.ArchivingOrchestrator")
|
||||
def test_has_exception(self, m_orch_generator, m_is_group, m_uuid, worker_init, db_session):
|
||||
from worker.main import create_sheet_task
|
||||
|
||||
assert db_session.query(models.Archive).filter(models.Archive.url == self.URL).count() == 0
|
||||
|
||||
mock_metadata = Metadata().set_url(self.URL).success()
|
||||
mock_metadata.add_media(Media("fn1.txt", urls=["outcome1.com"]))
|
||||
m_orch = MagicMock()
|
||||
m_orch.feed.return_value = iter([mock_metadata])
|
||||
m_orch_generator.return_value = m_orch
|
||||
|
||||
res = create_sheet_task(self.sheet.model_dump_json())
|
||||
print(res)
|
||||
assert res["archived"] == 0
|
||||
assert res["failed"] == 1
|
||||
assert res["errors"] == ["some-error"]
|
||||
assert res["sheet_id"] == "123"
|
||||
assert res["success"] == True
|
||||
|
||||
assert db_session.query(models.Archive).filter(models.Archive.url == self.URL).count() == 0
|
||||
|
||||
@patch("worker.main.is_group_invalid_for_user", return_value="Access denied")
|
||||
def test_error_access(self, m_insert, worker_init, db_session):
|
||||
from worker.main import create_sheet_task
|
||||
|
||||
res = create_sheet_task(self.sheet.model_dump_json())
|
||||
assert "error" in res
|
||||
assert res["error"] == "Access denied"
|
||||
|
||||
|
||||
def test_choose_orchestrator(worker_init):
|
||||
from worker.main import choose_orchestrator
|
||||
|
||||
assert choose_orchestrator(None, "rick@example.com").__class__.__name__ == "ArchivingOrchestrator"
|
||||
|
||||
|
||||
@patch("worker.main.get_user_first_group", return_value="does-not-exist")
|
||||
def test_choose_orchestrator_assertion(worker_init):
|
||||
from worker.main import choose_orchestrator
|
||||
|
||||
with pytest.raises(Exception):
|
||||
choose_orchestrator(None, "rick@example.com")
|
||||
|
||||
|
||||
@patch("worker.main.read_user_groups")
|
||||
def test_get_user_first_group(m_read_user_groups, worker_init):
|
||||
from worker.main import get_user_first_group
|
||||
|
||||
m_read_user_groups.return_value = {"users": {}}
|
||||
assert get_user_first_group("email1") == "default"
|
||||
m_read_user_groups.return_value = {"users": {"email1": []}}
|
||||
assert get_user_first_group("email1") == "default"
|
||||
m_read_user_groups.return_value = {"users": {"email1": ["group1", "group2"]}}
|
||||
assert get_user_first_group("email1") == "group1"
|
||||
|
||||
|
||||
def test_is_group_invalid_for_user(worker_init, db_session):
|
||||
from worker.main import is_group_invalid_for_user
|
||||
from db.crud import upsert_user_groups
|
||||
|
||||
upsert_user_groups(db_session)
|
||||
|
||||
assert is_group_invalid_for_user(True, "", "") == False
|
||||
assert is_group_invalid_for_user(False, "", "") == False
|
||||
|
||||
assert is_group_invalid_for_user(False, "default", "") == "User is not part of default, no permission"
|
||||
assert is_group_invalid_for_user(False, "spaceship", "jerry@example.com") == "User jerry@example.com is not part of spaceship, no permission"
|
||||
|
||||
assert is_group_invalid_for_user(False, "spaceship", "rick@example.com") == False
|
||||
|
||||
|
||||
def test_get_all_urls(worker_init, db_session):
|
||||
from worker.main import get_all_urls
|
||||
from auto_archiver import Metadata
|
||||
|
||||
meta = Metadata().set_url("https://example.com")
|
||||
m1 = meta.add_media(Media("fn1.txt", urls=["outcome1.com"]))
|
||||
m2 = meta.add_media(Media("fn2.txt", urls=["outcome2.com"]))
|
||||
m3 = meta.add_media(Media("fn3.txt", urls=["outcome3.com"]))
|
||||
m1.set("screenshot", Media("screenshot.png", urls=["screenshot.com"]))
|
||||
m2.set("thumbnails", [Media("thumb1.png", urls=["thumb1.com"]), Media("thumb2.png", urls=["thumb2.com"])])
|
||||
m3.set("ssl_data", Media("ssl_data.txt", urls=["ssl_data.com"]).to_dict())
|
||||
m3.set("bad_data", {"bad": "dict is ignored"})
|
||||
|
||||
urls = [u.url for u in get_all_urls(meta)]
|
||||
assert len(urls) == 7
|
||||
assert "outcome1.com" in urls
|
||||
assert "outcome2.com" in urls
|
||||
assert "outcome3.com" in urls
|
||||
assert "screenshot.com" in urls
|
||||
assert "thumb1.com" in urls
|
||||
assert "thumb2.com" in urls
|
||||
assert "ssl_data.com" in urls
|
||||
0
src/utils/__init__.py
Normal file
0
src/utils/__init__.py
Normal file
69
src/utils/metrics.py
Normal file
69
src/utils/metrics.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from prometheus_client import Counter, Gauge
|
||||
import redis
|
||||
|
||||
from db import crud
|
||||
from db.database import get_db
|
||||
from core.logging import log_error
|
||||
|
||||
|
||||
# Custom metrics
|
||||
EXCEPTION_COUNTER = Counter(
|
||||
"exceptions",
|
||||
"Number of times a certain exception has occurred.",
|
||||
labelnames=["types"]
|
||||
)
|
||||
WORKER_EXCEPTION = Counter(
|
||||
"worker_exceptions_total",
|
||||
"Number of times a certain exception has occurred on the worker.",
|
||||
labelnames=["types", "exception", "task", "traceback"]
|
||||
)
|
||||
DISK_UTILIZATION = Gauge(
|
||||
"disk_utilization",
|
||||
"Disk utilization in GB",
|
||||
labelnames=["type"]
|
||||
)
|
||||
DATABASE_METRICS = Gauge(
|
||||
"database_metrics",
|
||||
"Database metric readings at a certain point in time",
|
||||
labelnames=["query"]
|
||||
)
|
||||
DATABASE_METRICS_COUNTER = Counter(
|
||||
"database_metrics_counter",
|
||||
"Database metrics that increase over time",
|
||||
labelnames=["query", "user"]
|
||||
)
|
||||
|
||||
|
||||
async def redis_subscribe_worker_exceptions(REDIS_EXCEPTIONS_CHANNEL, CELERY_BROKER_URL):
|
||||
# Subscribe to Redis channel and increment the counter for each exception with info on the exception and task
|
||||
Rdis = redis.Redis.from_url(CELERY_BROKER_URL)
|
||||
PubSubExceptions = Rdis.pubsub()
|
||||
PubSubExceptions.subscribe(REDIS_EXCEPTIONS_CHANNEL)
|
||||
while True:
|
||||
message = PubSubExceptions.get_message()
|
||||
if message and message["type"] == "message":
|
||||
data = json.loads(message["data"].decode("utf-8"))
|
||||
WORKER_EXCEPTION.labels(types=type(data["exception"]).__name__, exception=data["exception"], task=data["task"], traceback=data["traceback"]).inc()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
async def measure_regular_metrics(sqlite_db_url: str, repeat_in_seconds: int):
|
||||
_total, used, free = shutil.disk_usage("/")
|
||||
DISK_UTILIZATION.labels(type="used").set(used / (2**30))
|
||||
DISK_UTILIZATION.labels(type="free").set(free / (2**30))
|
||||
try:
|
||||
fs = os.stat(sqlite_db_url.replace("sqlite:///", ""))
|
||||
DISK_UTILIZATION.labels(type="database").set(fs.st_size / (2**30))
|
||||
except Exception as e: log_error(e)
|
||||
|
||||
with get_db() as db:
|
||||
DATABASE_METRICS.labels(query="count_archives").set(crud.count_archives(db))
|
||||
DATABASE_METRICS.labels(query="count_archive_urls").set(crud.count_archive_urls(db))
|
||||
DATABASE_METRICS.labels(query="count_users").set(crud.count_users(db))
|
||||
|
||||
for user in crud.count_by_user_since(db, repeat_in_seconds):
|
||||
DATABASE_METRICS_COUNTER.labels(query="count_by_user", user=user.author_id).inc(user.total)
|
||||
4
src/web/__init__.py
Normal file
4
src/web/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from web.main import app_factory
|
||||
|
||||
|
||||
app = app_factory
|
||||
167
src/web/main.py
Normal file
167
src/web/main.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import os
|
||||
from celery.result import AsyncResult
|
||||
from fastapi import FastAPI, Depends, HTTPException
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from datetime import datetime
|
||||
import sqlalchemy
|
||||
from sqlalchemy.orm import Session
|
||||
from loguru import logger
|
||||
|
||||
from core.logging import logging_middleware, log_error
|
||||
from worker.main import create_archive_task, create_sheet_task, celery, insert_result_into_db
|
||||
|
||||
from db import crud, models, schemas
|
||||
from web.security import get_user_auth, token_api_key_auth, get_token_or_user_auth
|
||||
from core.config import VERSION, API_DESCRIPTION
|
||||
from db.database import get_db_dependency
|
||||
from core.events import lifespan
|
||||
from shared.settings import get_settings
|
||||
|
||||
from auto_archiver import Metadata
|
||||
|
||||
from endpoints import default_router, url_router, sheet_router, task_router, interoperability_router
|
||||
|
||||
|
||||
def app_factory(settings = get_settings()):
|
||||
app = FastAPI(
|
||||
title="Auto-Archiver API",
|
||||
description=API_DESCRIPTION,
|
||||
version=VERSION,
|
||||
contact={"name": "GitHub", "url": "https://github.com/bellingcat/auto-archiver-api"},
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
app.middleware("http")(logging_middleware)
|
||||
|
||||
app.include_router(default_router)
|
||||
app.include_router(url_router)
|
||||
app.include_router(sheet_router)
|
||||
app.include_router(task_router)
|
||||
app.include_router(interoperability_router)
|
||||
|
||||
# prometheus exposed in /metrics with authentication
|
||||
Instrumentator(should_group_status_codes=False, excluded_handlers=["/metrics", "/health"]).instrument(app).expose(app, dependencies=[Depends(token_api_key_auth)])
|
||||
|
||||
local_dir = settings.SERVE_LOCAL_ARCHIVE
|
||||
if not os.path.isdir(local_dir) and os.path.isdir(local_dir.replace("/app", ".")):
|
||||
local_dir = local_dir.replace("/app", ".")
|
||||
if len(settings.SERVE_LOCAL_ARCHIVE) > 1 and os.path.isdir(local_dir):
|
||||
logger.warning(f"MOUNTing local archive {settings.SERVE_LOCAL_ARCHIVE}")
|
||||
app.mount(settings.SERVE_LOCAL_ARCHIVE, StaticFiles(directory=local_dir), name=settings.SERVE_LOCAL_ARCHIVE)
|
||||
|
||||
|
||||
|
||||
# -----Submit URL and manipulate tasks. Bearer protected below
|
||||
|
||||
|
||||
@app.get("/tasks/search-url", response_model=list[schemas.Archive], deprecated=True) # DEPRECATED
|
||||
def search_by_url(url: str, skip: int = 0, limit: int = 100, archived_after: datetime = None, archived_before: datetime = None, db: Session = Depends(get_db_dependency), email=Depends(get_token_or_user_auth)):
|
||||
return crud.search_archives_by_url(db, url.strip(), email, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before)
|
||||
|
||||
|
||||
@app.get("/tasks/sync", response_model=list[schemas.Archive], deprecated=True) # DEPRECATED
|
||||
def search(skip: int = 0, limit: int = 100, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
|
||||
return crud.search_archives_by_email(db, email, skip=skip, limit=limit)
|
||||
|
||||
|
||||
@app.post("/tasks", status_code=201, deprecated=True) # DEPRECATED
|
||||
def archive_tasks(archive: schemas.ArchiveCreate, email=Depends(get_token_or_user_auth)):
|
||||
archive.author_id = email
|
||||
url = archive.url
|
||||
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {url}")
|
||||
if type(url) != str or len(url) <= 5:
|
||||
raise HTTPException(status_code=422, detail=f"Invalid URL received: {url}")
|
||||
logger.info("creating task")
|
||||
task = create_archive_task.delay(archive.model_dump_json())
|
||||
return JSONResponse({"id": task.id})
|
||||
|
||||
|
||||
@app.get("/archive/{task_id}", deprecated=True) # DEPRECATED
|
||||
def lookup(task_id, db: Session = Depends(get_db_dependency), email=Depends(get_token_or_user_auth)):
|
||||
return crud.get_archive(db, task_id, email)
|
||||
|
||||
|
||||
@app.get("/tasks/{task_id}", deprecated=True) # DEPRECATED
|
||||
def get_status(task_id, email=Depends(get_token_or_user_auth)):
|
||||
logger.info(f"status check for user {email} task {task_id}")
|
||||
task = AsyncResult(task_id, app=celery)
|
||||
try:
|
||||
if task.status == "FAILURE":
|
||||
# *FAILURE* The task raised an exception, or has exceeded the retry limit.
|
||||
# The :attr:`result` attribute then contains the exception raised by the task.
|
||||
# https://docs.celeryq.dev/en/stable/_modules/celery/result.html#AsyncResult
|
||||
raise task.result
|
||||
|
||||
response = {
|
||||
"id": task_id,
|
||||
"status": task.status,
|
||||
"result": task.result
|
||||
}
|
||||
return JSONResponse(jsonable_encoder(response, exclude_unset=True))
|
||||
|
||||
except Exception as e:
|
||||
log_error(e)
|
||||
return JSONResponse({
|
||||
"id": task_id,
|
||||
"status": "FAILURE",
|
||||
"result": {"error": str(e)}
|
||||
})
|
||||
|
||||
|
||||
@app.delete("/tasks/{task_id}", deprecated=True) # DEPRECATED
|
||||
def delete_task(task_id, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
|
||||
logger.info(f"deleting task {task_id} request by {email}")
|
||||
return JSONResponse({
|
||||
"id": task_id,
|
||||
"deleted": crud.soft_delete_task(db, task_id, email)
|
||||
})
|
||||
|
||||
# ----- Google Sheets Logic
|
||||
|
||||
|
||||
@app.post("/sheet", status_code=201, deprecated=True) # DEPRECATED
|
||||
def archive_sheet(sheet: schemas.SubmitSheet, email=Depends(get_user_auth)):
|
||||
logger.info(f"SHEET TASK for {sheet=}")
|
||||
sheet.author_id = email
|
||||
if not sheet.sheet_name and not sheet.sheet_id:
|
||||
raise HTTPException(status_code=422, detail=f"sheet name or id is required")
|
||||
task = create_sheet_task.delay(sheet.model_dump_json())
|
||||
return JSONResponse({"id": task.id})
|
||||
|
||||
|
||||
@app.post("/sheet_service", status_code=201, deprecated=True) # DEPRECATED
|
||||
def archive_sheet_service(sheet: schemas.SubmitSheet, auth=Depends(token_api_key_auth)):
|
||||
logger.info(f"SHEET TASK for {sheet=}")
|
||||
sheet.author_id = sheet.author_id or "api-endpoint"
|
||||
if not sheet.sheet_name and not sheet.sheet_id:
|
||||
raise HTTPException(status_code=422, detail=f"sheet name or id is required")
|
||||
task = create_sheet_task.delay(sheet.model_dump_json())
|
||||
return JSONResponse({"id": task.id})
|
||||
|
||||
# ----- endpoint to submit data archived elsewhere
|
||||
|
||||
|
||||
@app.post("/submit-archive", status_code=201, deprecated=True) # DEPRECATED
|
||||
def submit_manual_archive(manual: schemas.SubmitManual, auth=Depends(token_api_key_auth)):
|
||||
result = Metadata.from_json(manual.result)
|
||||
logger.info(f"MANUAL SUBMIT {result.get_url()} {manual.author_id}")
|
||||
manual.tags.add("manual")
|
||||
try:
|
||||
archive_id = insert_result_into_db(result, manual.tags, manual.public, manual.group_id, manual.author_id, models.generate_uuid())
|
||||
except sqlalchemy.exc.IntegrityError as e:
|
||||
log_error(e)
|
||||
raise HTTPException(status_code=422, detail=f"Cannot insert into DB due to integrity error")
|
||||
return JSONResponse({"id": archive_id})
|
||||
|
||||
return app
|
||||
@@ -1,26 +1,18 @@
|
||||
from loguru import logger
|
||||
import requests, os, secrets
|
||||
import requests, secrets
|
||||
from fastapi import HTTPException, status, Depends
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from core.config import ALLOW_ANY_EMAIL
|
||||
from shared.settings import get_settings
|
||||
|
||||
|
||||
# Configuration
|
||||
CHROME_APP_IDS = set([app_id.strip() for app_id in os.environ.get("CHROME_APP_IDS", "").split(",")])
|
||||
assert len(CHROME_APP_IDS) > 0, "CHROME_APP_IDS env variable not properly set, it's a csv"
|
||||
for app_id in CHROME_APP_IDS:
|
||||
assert len(app_id) > 10, f"CHROME_APP_IDS got invalid id: {app_id} env variable not set"
|
||||
logger.info(f"{CHROME_APP_IDS=}")
|
||||
|
||||
BLOCKED_EMAILS = set([e.strip().lower() for e in os.environ.get("BLOCKED_EMAILS", "").split(",")])
|
||||
logger.info(f"{len(BLOCKED_EMAILS)=}")
|
||||
|
||||
settings = get_settings()
|
||||
bearer_security = HTTPBearer()
|
||||
|
||||
ALLOW_ANY_EMAIL = "*"
|
||||
|
||||
def secure_compare(token, api_key):
|
||||
return secrets.compare_digest(token.encode("utf8"), api_key.encode("utf8"))
|
||||
|
||||
|
||||
# Factory method to create an authentication dependency for a specific key
|
||||
def api_key_auth(api_key):
|
||||
|
||||
@@ -39,20 +31,22 @@ def api_key_auth(api_key):
|
||||
|
||||
return auth
|
||||
|
||||
|
||||
# --------------------- Token Auth for AA itself to query the API, AA setup tool and Prometheus
|
||||
API_BEARER_TOKEN = os.environ.get("API_BEARER_TOKEN", "") # min length is 20 chars
|
||||
token_api_key_auth = api_key_auth(API_BEARER_TOKEN)
|
||||
token_api_key_auth = api_key_auth(settings.API_BEARER_TOKEN)
|
||||
|
||||
|
||||
async def get_token_or_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
|
||||
# tries to use the static API_KEY and defaults to google JWT auth
|
||||
access_token = credentials.credentials
|
||||
if token_api_key_auth(access_token, auto_error=False): return ALLOW_ANY_EMAIL
|
||||
if await token_api_key_auth(credentials, auto_error=False): return ALLOW_ANY_EMAIL
|
||||
return await get_user_auth(credentials)
|
||||
|
||||
|
||||
async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
|
||||
# validates the Bearer token in the case that it requires it
|
||||
valid_user, info = authenticate_user(credentials.credentials)
|
||||
if valid_user: return info
|
||||
if valid_user:
|
||||
return info
|
||||
logger.debug(f"TOKEN FAILURE: {valid_user=} {info=}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@@ -60,16 +54,17 @@ async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bear
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
def authenticate_user(access_token):
|
||||
# https://cloud.google.com/docs/authentication/token-types#access
|
||||
if type(access_token) != str or len(access_token) < 10: return False, "invalid access_token"
|
||||
r = requests.get("https://oauth2.googleapis.com/tokeninfo", {"access_token": access_token})
|
||||
if r.status_code != 200: return False, "error occurred"
|
||||
if r.status_code != 200: return False, "invalid token"
|
||||
try:
|
||||
j = r.json()
|
||||
if j.get("azp") not in CHROME_APP_IDS and j.get("aud") not in CHROME_APP_IDS:
|
||||
if j.get("azp") not in settings.CHROME_APP_IDS and j.get("aud") not in settings.CHROME_APP_IDS:
|
||||
return False, f"token does not belong to valid APP_ID"
|
||||
if j.get("email") in BLOCKED_EMAILS:
|
||||
if j.get("email") in settings.BLOCKED_EMAILS:
|
||||
return False, f"email '{j.get('email')}' not allowed"
|
||||
if j.get("email_verified") != "true":
|
||||
return False, f"email '{j.get('email')}' not verified"
|
||||
@@ -77,6 +72,5 @@ def authenticate_user(access_token):
|
||||
return False, "Token expired"
|
||||
return True, j.get('email')
|
||||
except Exception as e:
|
||||
logger.warning(f"EXCEPTION occurred: {e}")
|
||||
return False, f"EXCEPTION occurred"
|
||||
|
||||
logger.warning(f"AUTH EXCEPTION occurred: {e}")
|
||||
return False, "exception occurred"
|
||||
0
src/worker/__init__.py
Normal file
0
src/worker/__init__.py
Normal file
@@ -1,82 +1,80 @@
|
||||
|
||||
import os, traceback, yaml, datetime
|
||||
import traceback, yaml, datetime
|
||||
from typing import List, Set
|
||||
|
||||
from celery import Celery
|
||||
from celery.signals import task_failure
|
||||
from celery.signals import task_failure, worker_init
|
||||
from auto_archiver import Config, ArchivingOrchestrator, Metadata
|
||||
from auto_archiver.core import Media
|
||||
from loguru import logger
|
||||
|
||||
from db import crud, schemas, models
|
||||
from db.database import SessionLocal
|
||||
from contextlib import contextmanager
|
||||
from db.database import get_db
|
||||
from shared.settings import get_settings
|
||||
import json
|
||||
import redis
|
||||
from sqlalchemy import exc
|
||||
from core.logging import log_error
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
celery = Celery(__name__)
|
||||
celery.conf.broker_url = os.environ.get("CELERY_BROKER_URL", "redis://localhost:6379")
|
||||
celery.conf.result_backend = os.environ.get("CELERY_RESULT_BACKEND", "redis://localhost:6379")
|
||||
USER_GROUPS_FILENAME = os.environ.get("USER_GROUPS_FILENAME", "user-groups.yaml")
|
||||
REDIS_EXCEPTIONS_CHANNEL = "exceptions-channel"
|
||||
Rdis = redis.Redis.from_url(celery.conf.broker_url)
|
||||
celery.conf.broker_url = settings.CELERY_BROKER_URL
|
||||
celery.conf.result_backend = settings.CELERY_RESULT_BACKEND
|
||||
USER_GROUPS_FILENAME = settings.USER_GROUPS_FILENAME
|
||||
|
||||
@contextmanager
|
||||
def get_db():
|
||||
session = SessionLocal()
|
||||
try: yield session
|
||||
finally: session.close()
|
||||
Rdis = redis.Redis.from_url(celery.conf.broker_url)
|
||||
|
||||
|
||||
@celery.task(name="create_archive_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 3})
|
||||
def create_archive_task(self, archive_json: str):
|
||||
archive = schemas.ArchiveCreate.parse_raw(archive_json)
|
||||
archive = schemas.ArchiveCreate.model_validate_json(archive_json)
|
||||
logger.info(f"Archiving {archive.url=} {archive.tags=} {archive.public=} {archive.group_id=} {archive.author_id=}")
|
||||
invalid = is_group_invalid_for_user(archive.public, archive.group_id, archive.author_id)
|
||||
if invalid:
|
||||
raise Exception(invalid) # marks task FAILED, saves the Exception as result
|
||||
raise Exception(invalid) # marks task FAILED, saves the Exception as result
|
||||
|
||||
url = archive.url
|
||||
logger.info(f"{url=} {archive=}")
|
||||
|
||||
# TODO: re-evaluate if this logic is to be used
|
||||
if not archive.rearchive:
|
||||
with get_db() as session:
|
||||
archives = crud.search_tasks_by_url(session, url, archive.author_id, absolute_search=True)
|
||||
archives = crud.search_archives_by_url(session, url, archive.author_id, absolute_search=True)
|
||||
if len(archives):
|
||||
logger.info(f"Skipping {url=} as it was already archived")
|
||||
return Metadata.choose_most_complete([a.result for a in archives])
|
||||
|
||||
orchestrator = choose_orchestrator(archive.group_id, archive.author_id)
|
||||
result = orchestrator.feed_item(Metadata().set_url(url))
|
||||
|
||||
|
||||
try:
|
||||
insert_result_into_db(result, archive.tags, archive.public, archive.group_id, archive.author_id, self.request.id)
|
||||
except Exception as e:
|
||||
# Log it, then raise again to store the error as the task result
|
||||
logger.error(e)
|
||||
logger.error(traceback.format_exc())
|
||||
redis_publish_exception(e, self.name)
|
||||
log_error(e)
|
||||
redis_publish_exception(e, self.name, traceback.format_exc())
|
||||
raise e
|
||||
return result.to_dict()
|
||||
|
||||
|
||||
@celery.task(name="create_sheet_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 0})
|
||||
def create_sheet_task(self, sheet_json: str):
|
||||
sheet = schemas.SubmitSheet.parse_raw(sheet_json)
|
||||
sheet = schemas.SubmitSheet.model_validate_json(sheet_json)
|
||||
sheet.tags.add("gsheet")
|
||||
logger.info(f"SHEET START {sheet=}")
|
||||
|
||||
if (em := is_group_invalid_for_user(sheet.public, sheet.group_id, sheet.author_id)): return {"error": em}
|
||||
if (em := is_group_invalid_for_user(sheet.public, sheet.group_id, sheet.author_id)):
|
||||
return {"error": em}
|
||||
|
||||
config = Config()
|
||||
# TODO: use choose_orchestrator and overwrite the feeder
|
||||
config.parse(use_cli=False, yaml_config_filename="secrets/orchestration-sheet.yaml", overwrite_configs={"configurations": {"gsheet_feeder": {"sheet": sheet.sheet_name, "sheet_id": sheet.sheet_id, "header": sheet.header}}})
|
||||
config.parse(use_cli=False, yaml_config_filename=get_settings().SHEET_ORCHESTRATION_YAML, overwrite_configs={"configurations": {"gsheet_feeder": {"sheet": sheet.sheet_name, "sheet_id": sheet.sheet_id, "header": sheet.header}}})
|
||||
orchestrator = ArchivingOrchestrator(config)
|
||||
|
||||
stats = {"archived": 0, "failed": 0, "errors": []}
|
||||
for result in orchestrator.feed():
|
||||
if not result:
|
||||
if not result:
|
||||
logger.error("Got empty result from feeder, an internal error must have occurred.")
|
||||
continue
|
||||
try:
|
||||
@@ -84,12 +82,9 @@ def create_sheet_task(self, sheet_json: str):
|
||||
stats["archived"] += 1
|
||||
except exc.IntegrityError as e:
|
||||
logger.warning(f"cached result detected: {e}")
|
||||
stats["archived"] += 1
|
||||
except Exception as e:
|
||||
logger.error(type(e))
|
||||
logger.error(e)
|
||||
logger.error(traceback.format_exc())
|
||||
redis_publish_exception(e, self.name)
|
||||
log_error(e, extra=f"{self.name}: {sheet_json}")
|
||||
redis_publish_exception(e, self.name, traceback.format_exc())
|
||||
stats["failed"] += 1
|
||||
stats["errors"].append(str(e))
|
||||
|
||||
@@ -100,11 +95,11 @@ def create_sheet_task(self, sheet_json: str):
|
||||
@task_failure.connect(sender=create_sheet_task)
|
||||
@task_failure.connect(sender=create_archive_task)
|
||||
def task_failure_notifier(sender, **kwargs):
|
||||
logger.warning("😅 From task_failure_notifier ==> Task failed successfully! ")
|
||||
logger.error(kwargs['exception'])
|
||||
logger.error(kwargs['traceback'])
|
||||
logger.error("\n".join(traceback.format_list(traceback.extract_tb(kwargs['traceback']))))
|
||||
redis_publish_exception(kwargs['exception'], sender.name)
|
||||
traceback_msg = "\n".join(traceback.format_list(traceback.extract_tb(kwargs['traceback'])))
|
||||
logger.warning("😅 From task_failure_notifier ==> Task failed successfully!")
|
||||
log_error(kwargs['exception'], traceback_msg, f"task_failure: {sender.name}")
|
||||
redis_publish_exception(kwargs['exception'], sender.name, traceback_msg)
|
||||
|
||||
|
||||
def choose_orchestrator(group, email):
|
||||
global ORCHESTRATORS
|
||||
@@ -127,7 +122,8 @@ def read_user_groups():
|
||||
def get_user_first_group(email):
|
||||
user_groups_yaml = read_user_groups()
|
||||
groups = user_groups_yaml.get("users", {}).get(email, [])
|
||||
if groups != None and len(groups): return groups[0]
|
||||
if groups != None and len(groups):
|
||||
return groups[0]
|
||||
return "default"
|
||||
|
||||
|
||||
@@ -157,12 +153,14 @@ def is_group_invalid_for_user(public: bool, group_id: str, author_id: str):
|
||||
if public is true the requirement is not needed
|
||||
returns an error message if invalid, or False if all is good.
|
||||
"""
|
||||
if not public and group_id and len(group_id) > 0:
|
||||
# ensure group is valid for user
|
||||
with get_db() as session:
|
||||
if not crud.is_user_in_group(session, group_id, author_id):
|
||||
logger.error(em := f"User {author_id} is not part of {group_id}, no permission")
|
||||
return em
|
||||
if public: return False
|
||||
if not group_id or len(group_id) == 0: return False
|
||||
|
||||
# otherwise group must match
|
||||
with get_db() as session:
|
||||
if not crud.is_user_in_group(session, group_id, author_id):
|
||||
logger.error(em := f"User {author_id} is not part of {group_id}, no permission")
|
||||
return em
|
||||
return False
|
||||
|
||||
|
||||
@@ -172,7 +170,7 @@ def insert_result_into_db(result: Metadata, tags: Set[str], public: bool, group_
|
||||
with get_db() as session:
|
||||
# urls are created by get_all_urls
|
||||
# create author_id if needed
|
||||
crud.get_user(session, author_id)
|
||||
crud.create_or_get_user(session, author_id)
|
||||
# create DB TAGs if needed
|
||||
db_tags = [crud.create_tag(session, tag) for tag in tags]
|
||||
# insert archive
|
||||
@@ -191,10 +189,11 @@ def get_all_urls(result: Metadata) -> List[models.ArchiveUrl]:
|
||||
if isinstance(prop, list):
|
||||
for i, prop_media in enumerate(prop):
|
||||
if prop_media := convert_if_media(prop_media):
|
||||
for j, url in enumerate(prop_media.urls):
|
||||
for j, url in enumerate(prop_media.urls):
|
||||
db_urls.append(models.ArchiveUrl(url=url, key=prop_media.get("id", f"{k}{prop_media.key}_{i}.{j}")))
|
||||
return db_urls
|
||||
|
||||
|
||||
def convert_if_media(media):
|
||||
if isinstance(media, Media): return media
|
||||
elif isinstance(media, dict):
|
||||
@@ -203,15 +202,18 @@ def convert_if_media(media):
|
||||
logger.debug(f"error parsing {media} : {e}")
|
||||
return False
|
||||
|
||||
def redis_publish_exception(exception, task_name):
|
||||
|
||||
def redis_publish_exception(exception, task_name, traceback: str = ""):
|
||||
REDIS_EXCEPTIONS_CHANNEL = settings.REDIS_EXCEPTIONS_CHANNEL
|
||||
try:
|
||||
Rdis.publish(REDIS_EXCEPTIONS_CHANNEL, json.dumps({"exception": exception, "task": task_name}, default=str))
|
||||
Rdis.publish(REDIS_EXCEPTIONS_CHANNEL, json.dumps({"exception": exception, "task": task_name, "traceback": traceback}, default=str))
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"Could not publish to {REDIS_EXCEPTIONS_CHANNEL}")
|
||||
log_error(e, f"[CRITICAL] Could not publish to {REDIS_EXCEPTIONS_CHANNEL}")
|
||||
|
||||
|
||||
# INIT
|
||||
ORCHESTRATORS = {}
|
||||
load_orchestrators()
|
||||
@worker_init.connect
|
||||
def at_start(sender, **kwargs):
|
||||
global ORCHESTRATORS
|
||||
ORCHESTRATORS = {}
|
||||
load_orchestrators()
|
||||
logger.info("Orchestrators loaded successfully.")
|
||||
Reference in New Issue
Block a user