Merge pull request #43 from bellingcat/dev

introduces CI tests and does some refactoring on the code and logic
This commit is contained in:
Miguel Sozinho Ramalho
2024-10-23 10:56:37 +01:00
committed by GitHub
49 changed files with 4186 additions and 1486 deletions

View File

@@ -1,3 +1 @@
FLOWER_USERNAME=TODO
FLOWER_PASSWORD=TODO
REDIS_PASSWORD=TODO REDIS_PASSWORD=TODO

45
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,45 @@
name: CI
on:
push:
branches:
- main
- dev
pull_request:
branches:
- main
- dev
jobs:
test:
runs-on: ubuntu-latest
services:
redis:
image: redis:6-alpine
ports:
- 6379:6379
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.10'
- name: Install pipenv
run: pip install pipenv
working-directory: src
- name: Install dependencies
run: pipenv install --dev
working-directory: src
- name: Run tests with coverage
run: PYTHONPATH=. PIPENV_DOTENV_LOCATION=.env.test pipenv run coverage run -m pytest -v --color=yes tests/
working-directory: src
- name: Report coverage
run: pipenv run coverage report
working-directory: src

15
.gitignore vendored
View File

@@ -7,10 +7,19 @@ secrets
__pycache __pycache
.pytest_cach .pytest_cach
.env .env
.env.dev
.env.prod
*.db *.db
redis/data/* redis/data/*
.ipynb_checkpoints* .ipynb_checkpoints*
#temp
tests
src/user-groups.yaml src/user-groups.yaml
wit* src/user-groups.dev.yaml
wit*
src/crawls
.coverage
.pytest_cache/*
htmlcov
local_archive
local_archive_test
*db-wal
*db-shm

View File

@@ -1,16 +1,18 @@
# Auto Archiver API # Auto Archiver API
An api that uses celery workers to process URL archive requests via [bellingcat/auto-archiver](https://github.com/bellingcat/auto-archiver), it allows authentication via Google OAuth Apps an d enables CORS, everything runs on docker but development can be done without docker (except for redis). An api that uses celery workers to process URL archive requests via [bellingcat/auto-archiver](https://github.com/bellingcat/auto-archiver), it allows authentication via Google OAuth Apps and enables CORS, everything runs on docker but development can be done without docker (except for redis).
## Development ## Development
http://localhost:8004 http://localhost:8004
TODO: update .env file instructions, should use .env.prod and .env.dev and only use .env for always overwriting dev/prod settings.
requires `src/.env` requires `src/.env`
cd /src cd /src
<!-- * `pipenv install --editable ../../auto-archiver` --> <!-- * `pipenv install --editable ../../auto-archiver` -->
* console 1 - `docker compose up redis` optionally add `dashboard` for flower dashboard and `web` if not running uvicorn locally * console 1 - `docker compose up redis` optionally add `web` if not running uvicorn locally
* console 2 - `pipenv shell` + `celery worker --app=worker.celery --loglevel=info --logfile=logs/celery_dev.log` * console 2 - `pipenv shell` + `celery worker --app=worker.celery --loglevel=info --logfile=logs/celery_dev.log`
* `celery --app=worker.celery worker --loglevel=info --logfile=logs/celery_dev.log` celery 5 * `celery --app=worker.celery worker --loglevel=info --logfile=logs/celery_dev.log` celery 5
* or with watchdog for dev auto-reload `watchmedo auto-restart -d ./ -- celery --app=worker.celery worker --loglevel=info --logfile=logs/celery_dev.log` * or with watchdog for dev auto-reload `watchmedo auto-restart -d ./ -- celery --app=worker.celery worker --loglevel=info --logfile=logs/celery_dev.log`
@@ -37,7 +39,7 @@ Auto-archiver orchestrator files configurations. For each archiving task an orch
orchestrators: orchestrators:
group1: secrets/orchestration-group1.yaml group1: secrets/orchestration-group1.yaml
group2: secrets/orchestration-group2.yaml group2: secrets/orchestration-group2.yaml
default: secrets/orchestration-default:.yaml default: secrets/orchestration-default:orchestration.yaml
``` ```
## Database migrations ## Database migrations
@@ -66,4 +68,21 @@ Run `pipenv update auto-archiver` inside `src` to update the auto-archiver versi
# CALL /sheet POST endpoint # CALL /sheet POST endpoint
curl -XPOST -H "Authorization: Bearer GOOGLE_OAUTH_TOKEN" -H "Content-type: application/json" -d '{"sheet_id": "SHEET_ID", "header": 1}' 'http://localhost:8004/sheet' curl -XPOST -H "Authorization: Bearer GOOGLE_OAUTH_TOKEN" -H "Content-type: application/json" -d '{"sheet_id": "SHEET_ID", "header": 1}' 'http://localhost:8004/sheet'
```
### Testing
```bash
# can be done from top level but let's do it from the src folder for consistency with CI etc
cd src
# run tests and generate coverage
PYTHONPATH=. PIPENV_DOTENV_LOCATION=.env.test pipenv run coverage run -m pytest -vv --disable-warnings --color=yes tests/ && pipenv run coverage html
# get coverage report in command line
pipenv run coverage report
# get coverage HTML
pipenv run coverage html
# > open/run server on htmlcov/index.html to navigate through line coverage
``` ```

View File

@@ -1,17 +1,19 @@
version: '3.8'
services: services:
web: web:
restart: "no" restart: "no"
env_file: src/.env.dev
environment: environment:
- SERVE_LOCAL_ARCHIVE=/app/local_archive # See orchestration.yaml local_storage.save_to - SERVE_LOCAL_ARCHIVE=/app/local_archive # See orchestration.yaml local_storage.save_to
- ALLOWED_ORIGINS=http://localhost:8004,chrome-extension://ojcimmjndnlmmlgnjaeojoebaceokpdp - ALLOWED_ORIGINS=http://localhost:8004,chrome-extension://ojcimmjndnlmmlgnjaeojoebaceokpdp
- API_BEARER_TOKEN=dev-api-bearer-token - USER_GROUPS_FILENAME=user-groups.dev.yaml
- DATABASE_PATH=sqlite:////app/auto-archiver.db
worker: worker:
restart: "no" restart: "no"
env_file: src/.env.dev
redis: redis:
restart: "no" restart: "no"
env_file: src/.env.dev
ports: ports:
- 6379:6379 - 6379:6379

View File

@@ -4,13 +4,11 @@ x-broker-url: &broker-url "redis://:${REDIS_PASSWORD}@redis:6379/0"
x-base-setup: &base-setup x-base-setup: &base-setup
build: ./src build: ./src
restart: always restart: always
env_file: src/.env env_file: src/.env.prod
environment: environment:
CELERY_BROKER_URL: *broker-url CELERY_BROKER_URL: *broker-url
CELERY_RESULT_BACKEND: *broker-url CELERY_RESULT_BACKEND: *broker-url
version: '3.8'
volumes: volumes:
crawls: crawls:
@@ -20,20 +18,26 @@ services:
<<: *base-setup <<: *base-setup
ports: ports:
- "127.0.0.1:8004:8000" - "127.0.0.1:8004:8000"
command: uvicorn main:app --host 0.0.0.0 --reload command: uvicorn web:app --host 0.0.0.0 --reload
volumes: volumes:
- ./src:/app - ./src:/app
depends_on: depends_on:
- redis - redis
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 30s
timeout: 10s
retries: 3
worker: worker:
<<: *base-setup <<: *base-setup
command: celery worker --app=worker.celery --loglevel=info --logfile=logs/celery.log command: celery --app=worker.main.celery worker --loglevel=info --logfile=logs/celery.log
volumes: volumes:
- ./src:/app - ./src:/app
- /var/run/docker.sock:/var/run/docker.sock - /var/run/docker.sock:/var/run/docker.sock
- crawls:/crawls # BROWSERTRIX_HOME_HOST:BROWSERTRIX_HOME_CONTAINER, do not change /crawls - crawls:/crawls # BROWSERTRIX_HOME_HOST:BROWSERTRIX_HOME_CONTAINER, do not change /crawls
environment: environment:
# celery broker-url needs to be duplicated here, do not remove
CELERY_BROKER_URL: *broker-url CELERY_BROKER_URL: *broker-url
CELERY_RESULT_BACKEND: *broker-url CELERY_RESULT_BACKEND: *broker-url
WACZ_ENABLE_DOCKER: 1 # Enable calling docker from this container WACZ_ENABLE_DOCKER: 1 # Enable calling docker from this container
@@ -42,6 +46,11 @@ services:
depends_on: depends_on:
- web - web
- redis - redis
healthcheck:
test: ["CMD", "pipenv", "run", "celery", "-A", "worker.celery", "status"]
interval: 30s
timeout: 10s
retries: 3
redis: redis:
image: redis:6-alpine image: redis:6-alpine
@@ -50,16 +59,8 @@ services:
volumes: volumes:
- "./redis/data:/data" - "./redis/data:/data"
- "./redis/config:/conf" - "./redis/config:/conf"
healthcheck:
# dashboard service will only launch the dashboard if "--profile flower" is passed to docker compose; or if explicitly called "docker compose up dashboard" test: ["CMD", "redis-cli", "-a", "${REDIS_PASSWORD}", "ping"]
dashboard: interval: 30s
<<: *base-setup timeout: 10s
profiles: retries: 3
- flower
command: ["flower", "--app=worker.celery", "--port=5555", "--broker", *broker-url, "--basic_auth=${FLOWER_USERNAME}:${FLOWER_PASSWORD}"]
ports:
- 5556:5555
depends_on:
- web
- redis
- worker

9
src/.env.test Normal file
View File

@@ -0,0 +1,9 @@
CHROME_APP_IDS='["test_app_id_1","test_app_id_2"]'
ALLOWED_ORIGINS='["chrome-extension://example1","chrome-extension://example2","http://localhost:8081"]'
BLOCKED_EMAILS='["blocked@example.com"]'
DATABASE_PATH="sqlite:///auto-archiver.test.db"
API_BEARER_TOKEN=this_is_the_test_api_token
USER_GROUPS_FILENAME=tests/user-groups.test.yaml
SHEET_ORCHESTRATION_YAML=tests/orchestration.test.yaml

View File

@@ -7,8 +7,8 @@ WORKDIR /app
RUN curl -fsSL https://get.docker.com -o get-docker.sh && \ RUN curl -fsSL https://get.docker.com -o get-docker.sh && \
sh get-docker.sh sh get-docker.sh
# set environment variables # set environment variables
ENV PYTHONUNBUFFERED 1 ENV PYTHONUNBUFFERED=1
ENV PYTHONDONTWRITEBYTECODE 1 ENV PYTHONDONTWRITEBYTECODE=1
# install dependencies # install dependencies
RUN pip install --upgrade pip && \ RUN pip install --upgrade pip && \

View File

@@ -5,9 +5,8 @@ name = "pypi"
[packages] [packages]
aiofiles = "==0.6.0" aiofiles = "==0.6.0"
celery = "==4.4.7" celery = ">=5.0"
fastapi = "*" fastapi = "*"
flower = "==0.9.7"
jinja2 = "*" jinja2 = "*"
redis = "==3.5.3" redis = "==3.5.3"
requests = ">=2.25.1" requests = ">=2.25.1"
@@ -20,10 +19,14 @@ alembic = "*"
fastapi-utils = "*" fastapi-utils = "*"
prometheus-fastapi-instrumentator = "*" prometheus-fastapi-instrumentator = "*"
auto-archiver = "*" auto-archiver = "*"
pydantic-settings = "*"
[dev-packages] [dev-packages]
watchdog = "*" watchdog = "*"
pytest = "==6.2.4" pytest = "*"
httpx = "*"
coverage = "*"
pytest-asyncio = "*"
[requires] [requires]
python_version = "3.10" python_version = "3.10"

3136
src/Pipfile.lock generated

File diff suppressed because it is too large Load Diff

0
src/core/__init__.py Normal file
View File

13
src/core/config.py Normal file
View File

@@ -0,0 +1,13 @@
VERSION = "0.7.0"
API_DESCRIPTION = """
#### API for the Auto-Archiver project, a tool to archive web pages and Google Sheets.
**Usage notes:**
- The API requires a Bearer token for most operations, which you can obtain by logging in with your Google account.
- You can use this API to archive single URLs or entire Google Sheets.
- Once you submit a URL or Sheet for archiving, the API will return a task_id that you can use to check the status of the archiving process. It works asynchronously.
"""
BREAKING_CHANGES = {"minVersion": "0.3.1", "message": "The latest update has breaking changes, please update the extension to the most recent version."}
# changing this will corrupt the database logic
ALLOW_ANY_EMAIL = "*"

45
src/core/events.py Normal file
View File

@@ -0,0 +1,45 @@
import asyncio
import logging
import alembic.config
from fastapi import FastAPI
from contextlib import asynccontextmanager
from fastapi_utils.tasks import repeat_every
from loguru import logger
from db import crud, models
from db.database import get_db, make_engine
from shared.settings import get_settings
from utils.metrics import measure_regular_metrics, redis_subscribe_worker_exceptions
@asynccontextmanager
async def lifespan(app: FastAPI):
# see https://fastapi.tiangolo.com/advanced/events/#lifespan
# STARTUP
engine = make_engine(get_settings().DATABASE_PATH)
models.Base.metadata.create_all(bind=engine)
alembic.config.main(argv=['--raiseerr', 'upgrade', 'head'])
# disabling uvicorn logger since we use loguru in logging_middleware
logging.getLogger("uvicorn.access").disabled = True
asyncio.create_task(redis_subscribe_worker_exceptions(get_settings().REDIS_EXCEPTIONS_CHANNEL, get_settings().CELERY_BROKER_URL))
asyncio.create_task(refresh_user_groups())
asyncio.create_task(repeat_measure_regular_metrics())
yield # separates startup from shutdown instructions
# SHUTDOWN
logger.info("shutting down")
# CRON JOBS
@repeat_every(seconds=60 * 60) # 1 hour
async def refresh_user_groups():
with get_db() as db:
crud.upsert_user_groups(db)
@repeat_every(seconds=get_settings().REPEAT_COUNT_METRICS_SECONDS)
async def repeat_measure_regular_metrics():
await measure_regular_metrics(get_settings().DATABASE_PATH, get_settings().REPEAT_COUNT_METRICS_SECONDS)

27
src/core/logging.py Normal file
View File

@@ -0,0 +1,27 @@
import traceback
from loguru import logger
from fastapi import Request
# logging configurations
logger.add("logs/api_logs.log", retention="30 days", rotation="3 days")
error_logger = logger.add("logs/error_logs.log", retention="30 days")
def log_error(e: Exception, traceback_str: str = None, extra:str = ""):
# EXCEPTION_COUNTER.labels(type(e).__name__).inc()
if not traceback_str: traceback_str = traceback.format_exc()
if extra: extra = f"{extra}\n"
logger.error(f"{extra}{e.__class__.__name__}: {e}")
error_logger.error(f"{extra}{e.__class__.__name__}: {e}\n{traceback_str}")
async def logging_middleware(request: Request, call_next):
try:
response = await call_next(request)
logger.info(f"{request.client.host}:{request.client.port} {request.method} {request.url._url} - HTTP {response.status_code}")
return response
except Exception as e:
from utils.metrics import EXCEPTION_COUNTER
EXCEPTION_COUNTER.labels(type(e).__name__).inc()
log_error(e)
raise e

View File

@@ -4,27 +4,28 @@ from sqlalchemy import Column, or_, func
from loguru import logger from loguru import logger
from datetime import datetime, timedelta from datetime import datetime, timedelta
from security import ALLOW_ANY_EMAIL from core.config import ALLOW_ANY_EMAIL
from shared.settings import get_settings
from . import models, schemas from . import models, schemas
import yaml, os import yaml
DOMAIN_GROUPS = {} DOMAIN_GROUPS = {}
DOMAIN_GROUPS_LOADED = False DOMAIN_GROUPS_LOADED = False
MAX_LIMIT = 100 DATABASE_QUERY_LIMIT = get_settings().DATABASE_QUERY_LIMIT
# --------------- TASK = Archive # --------------- TASK = Archive
def get_task(db: Session, task_id: str, email: str): def get_archive(db: Session, id: str, email: str):
email = email.lower() email = email.lower()
query = base_query(db).filter(models.Archive.id == task_id) query = base_query(db).filter(models.Archive.id == id)
if email != ALLOW_ANY_EMAIL: if email != ALLOW_ANY_EMAIL:
groups = get_user_groups(db, email) groups = get_user_groups(db, email)
query = query.filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups))) query = query.filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups)))
return query.first() return query.first()
def search_tasks_by_url(db: Session, url: str, email: str, skip: int = 0, limit: int = 100, archived_after: datetime = None, archived_before: datetime = None, absolute_search: bool = False): def search_archives_by_url(db: Session, url: str, email: str, skip: int = 0, limit: int = 100, archived_after: datetime = None, archived_before: datetime = None, absolute_search: bool = False):
# searches for partial URLs, if email is * no ownership filtering happens # searches for partial URLs, if email is * no ownership filtering happens
query = base_query(db) query = base_query(db)
if email != ALLOW_ANY_EMAIL: if email != ALLOW_ANY_EMAIL:
@@ -36,15 +37,15 @@ def search_tasks_by_url(db: Session, url: str, email: str, skip: int = 0, limit:
else: else:
query = query.filter(models.Archive.url.like(f'%{url}%')) query = query.filter(models.Archive.url.like(f'%{url}%'))
if archived_after: if archived_after:
query = query.filter(models.Archive.created_at >= archived_after) query = query.filter(models.Archive.created_at > archived_after)
if archived_before: if archived_before:
query = query.filter(models.Archive.created_at <= archived_before) query = query.filter(models.Archive.created_at < archived_before)
return query.order_by(models.Archive.created_at.desc()).offset(skip).limit(min(limit, MAX_LIMIT)).all() return query.order_by(models.Archive.created_at.desc()).offset(skip).limit(min(limit, DATABASE_QUERY_LIMIT)).all()
def search_tasks_by_email(db: Session, email: str, skip: int = 0, limit: int = 100): def search_archives_by_email(db: Session, email: str, skip: int = 0, limit: int = 100):
email = email.lower() email = email.lower()
return base_query(db).filter(models.Archive.author.has(email=email)).offset(skip).limit(min(limit, MAX_LIMIT)).all() return base_query(db).filter(models.Archive.author_id == email).order_by(models.Archive.created_at.desc()).offset(skip).limit(min(limit, DATABASE_QUERY_LIMIT)).all()
def create_task(db: Session, task: schemas.ArchiveCreate, tags: list[models.Tag], urls: list[models.ArchiveUrl]): def create_task(db: Session, task: schemas.ArchiveCreate, tags: list[models.Tag], urls: list[models.ArchiveUrl]):
@@ -65,21 +66,28 @@ def soft_delete_task(db: Session, task_id: str, email: str) -> bool:
db.commit() db.commit()
return db_task is not None return db_task is not None
def count_archives(db:Session):
def count_archives(db: Session):
return db.query(func.count(models.Archive.id)).scalar() return db.query(func.count(models.Archive.id)).scalar()
def count_archive_urls(db:Session):
def count_archive_urls(db: Session):
return db.query(func.count(models.ArchiveUrl.url)).scalar() return db.query(func.count(models.ArchiveUrl.url)).scalar()
def count_by_user_since(db:Session, seconds_delta: int = 15): def count_users(db: Session):
return db.query(func.count(models.User.email)).scalar()
def count_by_user_since(db: Session, seconds_delta: int = 15):
time_threshold = datetime.now() - timedelta(seconds=seconds_delta) time_threshold = datetime.now() - timedelta(seconds=seconds_delta)
return db.query(models.Archive.author_id,func.count().label('total'))\ return db.query(models.Archive.author_id, func.count().label('total'))\
.filter(models.Archive.created_at >= time_threshold)\ .filter(models.Archive.created_at >= time_threshold)\
.group_by(models.Archive.author_id)\ .group_by(models.Archive.author_id)\
.order_by(func.count().desc()).limit(5 * MAX_LIMIT).all() .order_by(func.count().desc())\
.limit(500).all()
def base_query(db: Session): def base_query(db: Session):
# allow only some fields to be returned, for example author should remain hidden # TODO: allow only some fields to be returned, for example author should remain hidden
return db.query(models.Archive)\ return db.query(models.Archive)\
.options(load_only(models.Archive.id, models.Archive.created_at, models.Archive.url, models.Archive.result))\ .options(load_only(models.Archive.id, models.Archive.created_at, models.Archive.url, models.Archive.result))\
.filter(models.Archive.deleted == False) .filter(models.Archive.deleted == False)
@@ -97,10 +105,6 @@ def create_tag(db: Session, tag: str):
return db_tag return db_tag
def search_tags(db: Session, tag: str, skip: int = 0, limit: int = 100):
return db.query(models.Tag).filter(models.Tag.url.like(f'%{tag}%')).offset(skip).limit(min(limit, MAX_LIMIT)).all()
def is_user_in_group(db: Session, group_name: str, email: str) -> models.Group: def is_user_in_group(db: Session, group_name: str, email: str) -> models.Group:
if email == ALLOW_ANY_EMAIL: return True if email == ALLOW_ANY_EMAIL: return True
return len(group_name) and len(email) and group_name in get_user_groups(db, email) return len(group_name) and len(email) and group_name in get_user_groups(db, email)
@@ -108,6 +112,7 @@ def is_user_in_group(db: Session, group_name: str, email: str) -> models.Group:
def get_user_groups(db: Session, email: str): def get_user_groups(db: Session, email: str):
email = email.lower() email = email.lower()
if "@" not in email: return []
global DOMAIN_GROUPS, DOMAIN_GROUPS_LOADED global DOMAIN_GROUPS, DOMAIN_GROUPS_LOADED
if not DOMAIN_GROUPS_LOADED: upsert_user_groups(db) if not DOMAIN_GROUPS_LOADED: upsert_user_groups(db)
# given an email retrieves the user groups from the DB and then the email-domain groups from a global variable # given an email retrieves the user groups from the DB and then the email-domain groups from a global variable
@@ -121,8 +126,8 @@ def get_user_groups(db: Session, email: str):
# --------------- INIT User-Groups # --------------- INIT User-Groups
def get_user(db: Session, author_id: str): def create_or_get_user(db: Session, author_id: str):
if type(author_id)==str: author_id = author_id.lower() if type(author_id) == str: author_id = author_id.lower()
db_user = db.query(models.User).filter(models.User.email == author_id).first() db_user = db.query(models.User).filter(models.User.email == author_id).first()
if not db_user: if not db_user:
db_user = models.User(email=author_id) db_user = models.User(email=author_id)
@@ -133,11 +138,13 @@ def get_user(db: Session, author_id: str):
@cache @cache
def get_group(db: Session, group_name: str) -> models.Group: def create_or_get_group(db: Session, group_name: str) -> models.Group:
db_group = db.query(models.Group).filter(models.Group.id == group_name).first() db_group = db.query(models.Group).filter(models.Group.id == group_name).first()
if db_group is None: if db_group is None:
db_group = models.Group(id=group_name) db_group = models.Group(id=group_name)
db.add(db_group) db.add(db_group)
db.commit()
db.refresh(db_group)
return db_group return db_group
@@ -148,15 +155,15 @@ def upsert_user_groups(db: Session):
along with new participation of users in groups along with new participation of users in groups
""" """
logger.debug("Updating user-groups configuration.") logger.debug("Updating user-groups configuration.")
filename = os.environ.get("USER_GROUPS_FILENAME", "user-groups.yaml") filename = get_settings().USER_GROUPS_FILENAME
# read yaml safely # read yaml safely
with open(filename) as inf: try:
try: with open(filename) as inf:
user_groups_yaml = yaml.safe_load(inf) user_groups_yaml = yaml.safe_load(inf)
except yaml.YAMLError as e: except Exception as e:
logger.error(f"could not open user groups filename {filename}: {e}") logger.error(f"could not open user groups filename {filename}: {e}")
raise e raise e
# updating domain->groups access # updating domain->groups access
DOMAIN_GROUPS = user_groups_yaml.get("domains", {}) DOMAIN_GROUPS = user_groups_yaml.get("domains", {})
@@ -175,7 +182,7 @@ def upsert_user_groups(db: Session):
db.add(db_user) db.add(db_user)
if not groups: continue # avoid hanging in for x in None: if not groups: continue # avoid hanging in for x in None:
for group in groups: for group in groups:
db_group = get_group(db, group) db_group = create_or_get_group(db, group)
db_group.users.append(db_user) db_group.users.append(db_user)
db.commit() db.commit()

View File

@@ -1,15 +1,34 @@
from sqlalchemy import create_engine from sqlalchemy import Engine, create_engine, event
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
import os from shared.settings import get_settings
from contextlib import contextmanager
SQLALCHEMY_DATABASE_URL = os.environ.get("DATABASE_PATH")#"sqlite:///./auto-archiver.db"
# SQLALCHEMY_DATABASE_URL = "postgresql://user:password@postgresserver/db"
engine = create_engine( def make_engine(database_url: str):
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} engine = create_engine(database_url, connect_args={"check_same_thread": False})
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base() @event.listens_for(engine, "connect")
def set_sqlite_pragma(conn, _) -> None:
cursor = conn.cursor()
cursor.execute("PRAGMA journal_mode=WAL")
cursor.close()
return engine
def make_session_local(engine: Engine):
session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
return session_local
@contextmanager
def get_db():
session = make_session_local(make_engine(get_settings().DATABASE_PATH))()
try: yield session
finally: session.close()
def get_db_dependency():
# to use with Depends and ensure proper session closing
with get_db() as db:
yield db

View File

@@ -1,8 +1,10 @@
from sqlalchemy import Column, String, JSON, DateTime, Boolean, Table, ForeignKey from sqlalchemy import Column, String, JSON, DateTime, Boolean, Table, ForeignKey
from sqlalchemy.sql import func from sqlalchemy.sql import func
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship, declarative_base
import uuid import uuid
from .database import Base
Base = declarative_base()
def generate_uuid(): def generate_uuid():
return str(uuid.uuid4()) return str(uuid.uuid4())
@@ -59,7 +61,6 @@ class Tag(Base):
archives = relationship("Archive", back_populates="tags", secondary=association_table_archive_tags) archives = relationship("Archive", back_populates="tags", secondary=association_table_archive_tags)
class User(Base): class User(Base):
__tablename__ = "users" __tablename__ = "users"

View File

@@ -2,6 +2,13 @@ from pydantic import BaseModel
from datetime import datetime from datetime import datetime
class Tag(BaseModel):
id: str
created_at: datetime
model_config = { "from_attributes": True }
__hash__ = object.__hash__
class ArchiveCreate(BaseModel): class ArchiveCreate(BaseModel):
id: str | None = None id: str | None = None
url: str url: str
@@ -9,7 +16,7 @@ class ArchiveCreate(BaseModel):
public: bool = True public: bool = True
author_id: str | None = None author_id: str | None = None
group_id: str | None = None group_id: str | None = None
tags: set = set() tags: set[Tag] | None = set()
rearchive: bool = True rearchive: bool = True
# urls: list = [] # urls: list = []
@@ -19,9 +26,7 @@ class Archive(ArchiveCreate):
updated_at: datetime | None updated_at: datetime | None
deleted: bool deleted: bool
class Config: model_config = { "from_attributes": True }
orm_mode = True
class SubmitSheet(BaseModel): class SubmitSheet(BaseModel):
sheet_name: str | None = None sheet_name: str | None = None
@@ -30,7 +35,7 @@ class SubmitSheet(BaseModel):
public: bool = False public: bool = False
author_id: str | None = None author_id: str | None = None
group_id: str | None = None group_id: str | None = None
tags: set | None = set() tags: set[str] | None = set()
columns: dict | None = {} # TODO: implement columns: dict | None = {} # TODO: implement
class SubmitManual(BaseModel): class SubmitManual(BaseModel):
@@ -38,4 +43,14 @@ class SubmitManual(BaseModel):
public: bool = False public: bool = False
author_id: str | None = None author_id: str | None = None
group_id: str | None = None group_id: str | None = None
tags: set | None = set() tags: set[str] | None = set()
class Task(BaseModel):
id: str
class TaskResult(Task):
status: str
result: str
class TaskDelete(Task):
deleted: bool

View File

@@ -0,0 +1,5 @@
from endpoints.default import default_router
from endpoints.url import url_router
from endpoints.task import task_router
from endpoints.interoperability import interoperability_router
from endpoints.sheet import sheet_router

40
src/endpoints/default.py Normal file
View File

@@ -0,0 +1,40 @@
from fastapi import APIRouter, Depends, Request, HTTPException
from fastapi.responses import FileResponse, JSONResponse
from sqlalchemy.orm import Session
from core.config import VERSION, BREAKING_CHANGES
from core.logging import log_error
from db import crud
from db.database import get_db_dependency, get_db
from web.security import get_user_auth, bearer_security
default_router = APIRouter()
@default_router.get("/")
async def home(request: Request):
# TODO: maybe split into 2 routes: one non authenticated and one authenticated for the groups info only
status = {"version": VERSION, "breakingChanges": BREAKING_CHANGES}
try:
email = await get_user_auth(await bearer_security(request))
with get_db() as db:
status["groups"] = crud.get_user_groups(db, email)
except HTTPException: pass # not authenticated is fine
except Exception as e: log_error(e)
return JSONResponse(status)
@default_router.get("/health")
async def health():
return JSONResponse({"status": "ok"})
@default_router.get("/groups", response_model=list[str])
def get_user_groups(db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
return crud.get_user_groups(db, email)
@default_router.get('/favicon.ico', include_in_schema=False)
async def favicon():
return FileResponse("static/favicon.ico")

View File

@@ -0,0 +1,27 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse
from auto_archiver import Metadata
from loguru import logger
import sqlalchemy
from web.security import token_api_key_auth
from db import models, schemas
from worker.main import insert_result_into_db
from core.logging import log_error
interoperability_router = APIRouter(prefix="/interop", tags=["Interoperability endpoints."])
# ----- endpoint to submit data archived elsewhere
@interoperability_router.post("/submit-archive", status_code=201, summary="Submit a manual archive entry, for data that was archived elsewhere.")
def submit_manual_archive(manual: schemas.SubmitManual, auth=Depends(token_api_key_auth)):
result = Metadata.from_json(manual.result)
logger.info(f"MANUAL SUBMIT {result.get_url()} {manual.author_id}")
manual.tags.add("manual")
try:
archive_id = insert_result_into_db(result, manual.tags, manual.public, manual.group_id, manual.author_id, models.generate_uuid())
except sqlalchemy.exc.IntegrityError as e:
log_error(e)
raise HTTPException(status_code=422, detail=f"Cannot insert into DB due to integrity error")
return JSONResponse({"id": archive_id}, status_code=201)

24
src/endpoints/sheet.py Normal file
View File

@@ -0,0 +1,24 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse
from loguru import logger
from core.config import ALLOW_ANY_EMAIL
from web.security import get_token_or_user_auth
from db import schemas
from worker.main import create_sheet_task
sheet_router = APIRouter(prefix="/sheet", tags=["Google Spreadsheet operations"])
@sheet_router.post("/archive", status_code=201, summary="Submit a Google Sheet archive request, starts a sheet archiving task.", response_model=schemas.Task, response_description="task_id for the archiving task.")
def archive_sheet(sheet:schemas.SubmitSheet, email = Depends(get_token_or_user_auth)):
logger.info(f"SHEET TASK for {sheet=}")
if email == ALLOW_ANY_EMAIL:
email = sheet.author_id or "api-endpoint"
sheet.author_id = email
if not sheet.sheet_name and not sheet.sheet_id:
raise HTTPException(status_code=422, detail=f"sheet name or id is required")
task = create_sheet_task.delay(sheet.model_dump_json())
return JSONResponse({"id": task.id}, status_code=201)

41
src/endpoints/task.py Normal file
View File

@@ -0,0 +1,41 @@
from celery.result import AsyncResult
from fastapi import APIRouter, Depends
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from loguru import logger
from web.security import get_token_or_user_auth
from db import schemas
from core.logging import log_error
from worker.main import celery
task_router = APIRouter(prefix="/task", tags=["Async task operations"])
@task_router.get("/{task_id}", response_model=schemas.TaskResult, summary="Check the status of an async task by its id, works for URLs and Sheet tasks.")
def get_status(task_id, email=Depends(get_token_or_user_auth)):
logger.info(f"status check for user {email} task {task_id}")
task = AsyncResult(task_id, app=celery)
try:
if task.status == "FAILURE":
# *FAILURE* The task raised an exception, or has exceeded the retry limit.
# The :attr:`result` attribute then contains the exception raised by the task.
# https://docs.celeryq.dev/en/stable/_modules/celery/result.html#AsyncResult
raise task.result
response = {
"id": task_id,
"status": task.status,
"result": task.result
}
return JSONResponse(jsonable_encoder(response, exclude_unset=True))
except Exception as e:
log_error(e)
return JSONResponse({
"id": task_id,
"status": "FAILURE",
"result": {"error": str(e)}
})

59
src/endpoints/url.py Normal file
View File

@@ -0,0 +1,59 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse
from datetime import datetime
from loguru import logger
from web.security import get_user_auth, get_token_or_user_auth
from sqlalchemy.orm import Session
from db import crud, schemas
from db.database import get_db_dependency
from worker.main import create_archive_task
url_router = APIRouter(prefix="/url", tags=["Single URL operations"])
@url_router.post("/archive", status_code=201, summary="Submit a single URL archive request, starts an archiving task.", response_model=schemas.Task, response_description="task_id for the archiving task, will match the archive id.")
def archive_url(archive: schemas.ArchiveCreate, email=Depends(get_token_or_user_auth)):
archive.author_id = email
url = archive.url
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {url}")
if type(url) != str or len(url) <= 5:
raise HTTPException(status_code=422, detail=f"Invalid URL received: {url}")
logger.info("creating task")
task = create_archive_task.delay(archive.model_dump_json())
task_response = schemas.Task(id=task.id)
return JSONResponse(task_response.model_dump(), status_code=201)
@url_router.get("/search", response_model=list[schemas.Archive], summary="Search for archive entries by URL.")
def search_by_url(
url: str, skip: int = 0, limit: int = 25,
archived_after: datetime = None, archived_before: datetime = None,
db: Session = Depends(get_db_dependency),
email=Depends(get_token_or_user_auth)):
return crud.search_archives_by_url(db, url.strip(), email, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before)
@url_router.get("/latest", response_model=list[schemas.Archive], summary="Fetch latest URL archives for the authenticated user.")
def latest(skip: int = 0, limit: int = 25, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
return crud.search_archives_by_email(db, email, skip=skip, limit=limit)
@url_router.get("/{id}", response_model=schemas.Archive, summary="Fetch a single URL archive by the associated id.")
def lookup(id, db: Session = Depends(get_db_dependency), email=Depends(get_token_or_user_auth)):
archive = crud.get_archive(db, id, email)
if archive is None:
raise HTTPException(status_code=404, detail="Archive not found")
return archive
@url_router.delete("/{id}", response_model=schemas.TaskDelete, summary="Delete a single URL archive by id.")
def delete_task(id, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
logger.info(f"deleting url archive task {id} request by {email}")
return JSONResponse({
"id": id,
"deleted": crud.soft_delete_task(db, id, email)
})

View File

@@ -1,257 +0,0 @@
from celery.result import AsyncResult
from fastapi import FastAPI, Depends, Request, HTTPException
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from fastapi_utils.tasks import repeat_every
import alembic.config
from dotenv import load_dotenv
import traceback, os, logging
from loguru import logger
from datetime import datetime
import sqlalchemy
from prometheus_fastapi_instrumentator import Instrumentator
from prometheus_client import Counter, Gauge
from contextlib import asynccontextmanager
import asyncio, json
import shutil
from worker import REDIS_EXCEPTIONS_CHANNEL, create_archive_task, create_sheet_task, celery, insert_result_into_db, Rdis
from db import crud, models, schemas
from db.database import engine, SessionLocal, SQLALCHEMY_DATABASE_URL
from sqlalchemy.orm import Session
from security import get_user_auth, token_api_key_auth, bearer_security, get_token_or_user_auth
from auto_archiver import Metadata
load_dotenv()
# Configuration
ALLOWED_ORIGINS = os.environ.get("ALLOWED_ORIGINS", "chrome-extension://ondkcheoicfckabcnkdgbepofpjmjcmb,chrome-extension://ojcimmjndnlmmlgnjaeojoebaceokpdp").split(",")
VERSION = "0.6.3"
# min-version refers to the version of auto-archiver-extension on the webstore
BREAKING_CHANGES = {"minVersion": "0.3.1", "message": "The latest update has breaking changes, please update the extension to the most recent version."}
@repeat_every(seconds=60 * 60) # 1 hour
async def refresh_user_groups():
db: Session = next(get_db())
crud.upsert_user_groups(db)
@asynccontextmanager
async def lifespan(app: FastAPI):
# see https://fastapi.tiangolo.com/advanced/events/#lifespan
# STARTUP
models.Base.metadata.create_all(bind=engine)
alembic.config.main(argv=['--raiseerr', 'upgrade', 'head'])
# disabling uvicorn logger since we use loguru in logging_middleware
logging.getLogger("uvicorn.access").disabled = True
asyncio.create_task(redis_subscribe_worker_exceptions())
asyncio.create_task(refresh_user_groups())
asyncio.create_task(measure_regular_metrics())
yield # separates startup from shutdown instructions
# SHUTDOWN
logger.info("shutting down")
app = FastAPI(title="Auto-Archiver API", version=VERSION, contact={"name":"Bellingcat", "url":"https://github.com/bellingcat/auto-archiver-api"}, lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
EXCEPTION_COUNTER = Counter(
"exceptions",
"Number of times a certain exception has occurred.",
labelnames=("types",)
)
# prometheus exposed in /metrics with authentication
Instrumentator(should_group_status_codes=False, excluded_handlers=["/metrics"]).instrument(app).expose(app, dependencies=[Depends(token_api_key_auth)])
app.mount("/static", StaticFiles(directory="static"), name="static")
SERVE_LOCAL_ARCHIVE = os.environ.get("SERVE_LOCAL_ARCHIVE", "")
if len(SERVE_LOCAL_ARCHIVE) > 1 and os.path.isdir(SERVE_LOCAL_ARCHIVE):
logger.info(f"mounting local archive {SERVE_LOCAL_ARCHIVE}")
app.mount(SERVE_LOCAL_ARCHIVE, StaticFiles(directory=SERVE_LOCAL_ARCHIVE), name=SERVE_LOCAL_ARCHIVE)
def get_db():
session = SessionLocal()
try: yield session
finally: session.close()
# logging configurations
logger.add("logs/api_logs.log", retention="30 days", rotation="3 days")
@app.middleware("http")
async def logging_middleware(request: Request, call_next):
try:
response = await call_next(request)
logger.info(f"{request.client.host}:{request.client.port} {request.method} {request.url._url} - HTTP {response.status_code}")
return response
except Exception as e:
EXCEPTION_COUNTER.labels(type(e).__name__).inc()
raise e
@app.get("/")
async def home(request: Request):
status = {"version": VERSION, "breakingChanges": BREAKING_CHANGES}
try:
# if authenticated will load available groups
email = await get_user_auth(await bearer_security(request))
db: Session = next(get_db())
status["groups"] = crud.get_user_groups(db, email)
except HTTPException: pass
except Exception as e: logger.error(e)
return JSONResponse(status)
#-----Submit URL and manipulate tasks. Bearer protected below
@app.get("/groups", response_model=list[str])
def get_user_groups(db: Session = Depends(get_db), email = Depends(get_user_auth)):
return crud.get_user_groups(db, email)
@app.get("/tasks/search-url", response_model=list[schemas.Archive])
def search_by_url(url:str, skip: int = 0, limit: int = 100, archived_after:datetime=None, archived_before:datetime=None, db: Session = Depends(get_db), email = Depends(get_token_or_user_auth)):
return crud.search_tasks_by_url(db, url.strip(), email, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before)
@app.get("/tasks/sync", response_model=list[schemas.Archive])
def search(skip: int = 0, limit: int = 100, db: Session = Depends(get_db), email = Depends(get_user_auth)):
return crud.search_tasks_by_email(db, email, skip=skip, limit=limit)
@app.post("/tasks", status_code=201)
def archive_tasks(archive:schemas.ArchiveCreate, email = Depends(get_token_or_user_auth)):
archive.author_id = email
url = archive.url
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {url}")
if type(url)!=str or len(url)<=5:
raise HTTPException(status_code=422, detail=f"Invalid URL received: {url}")
logger.info("creating task")
task = create_archive_task.delay(archive.json())
return JSONResponse({"id": task.id})
@app.get("/archive/{task_id}")
def lookup(task_id, db: Session = Depends(get_db), email = Depends(get_token_or_user_auth)):
return crud.get_task(db, task_id, email)
@app.get("/tasks/{task_id}")
def get_status(task_id, email = Depends(get_token_or_user_auth)):
logger.info(f"status check for user {email} task {task_id}")
task = AsyncResult(task_id, app=celery)
try:
if task.status == "FAILURE":
# *FAILURE* The task raised an exception, or has exceeded the retry limit.
# The :attr:`result` attribute then contains the exception raised by the task.
# https://docs.celeryq.dev/en/stable/_modules/celery/result.html#AsyncResult
raise task.result
response = {
"id": task_id,
"status": task.status,
"result": task.result
}
return JSONResponse(jsonable_encoder(response, exclude_unset=True))
except Exception as e:
logger.error(e)
logger.error(traceback.format_exc())
return JSONResponse({
"id": task_id,
"status": "FAILURE",
"result": {"error": str(e)}
})
@app.delete("/tasks/{task_id}")
def delete_task(task_id, db: Session = Depends(get_db), email = Depends(get_user_auth)):
logger.info(f"deleting task {task_id} request by {email}")
return JSONResponse({
"id": task_id,
"deleted": crud.soft_delete_task(db, task_id, email)
})
#----- Google Sheets Logic
@app.post("/sheet", status_code=201)
def archive_sheet(sheet:schemas.SubmitSheet, email = Depends(get_user_auth)):
logger.info(f"SHEET TASK for {sheet=}")
sheet.author_id = email
if not sheet.sheet_name and not sheet.sheet_id:
raise HTTPException(status_code=422, detail=f"sheet name or id is required")
task = create_sheet_task.delay(sheet.json())
return JSONResponse({"id": task.id})
@app.post("/sheet_service", status_code=201)
def archive_sheet_service(sheet:schemas.SubmitSheet, auth = Depends(token_api_key_auth)):
logger.info(f"SHEET TASK for {sheet=}")
sheet.author_id = sheet.author_id or "api-endpoint"
if not sheet.sheet_name and not sheet.sheet_id:
raise HTTPException(status_code=422, detail=f"sheet name or id is required")
task = create_sheet_task.delay(sheet.json())
return JSONResponse({"id": task.id})
#----- endpoint to submit data archived elsewhere
@app.post("/submit-archive", status_code=201)
def submit_manual_archive(manual:schemas.SubmitManual, auth = Depends(token_api_key_auth)):
result = Metadata.from_json(manual.result)
logger.info(f"MANUAL SUBMIT {result.get_url()} {manual.author_id}")
manual.tags.add("manual")
try:
archive_id = insert_result_into_db(result, manual.tags, manual.public, manual.group_id, manual.author_id, models.generate_uuid())
except sqlalchemy.exc.IntegrityError as e:
logger.error(e)
raise HTTPException(status_code=422, detail=f"Cannot insert into DB due to integrity error")
return JSONResponse({"id": archive_id})
# --------- Prometheus metrics
WORKER_EXCEPTION = Counter(
"worker_exceptions_total",
"Number of times a certain exception has occurred on the worker.",
labelnames=("exception", "task",)
)
async def redis_subscribe_worker_exceptions():
PubSubExceptions = Rdis.pubsub()
PubSubExceptions.subscribe(REDIS_EXCEPTIONS_CHANNEL)
while True:
message = PubSubExceptions.get_message()
if message and message["type"] == "message":
data = json.loads(message["data"].decode("utf-8"))
WORKER_EXCEPTION.labels(exception=data["exception"], task=data["task"]).inc()
await asyncio.sleep(1)
DISK_UTILIZATION = Gauge(
"disk_utilization",
"Disk utilization in GB",
labelnames=("type",)
)
DATABASE_METRICS = Gauge(
"database_metrics",
"Useful database metrics from queries",
labelnames=("query", "user")
)
REPEAT_COUNT_METRICS_SECONDS = 15
@repeat_every(seconds=REPEAT_COUNT_METRICS_SECONDS)
async def measure_regular_metrics():
_total, used, free = shutil.disk_usage("/")
DISK_UTILIZATION.labels(type="used").set(used / (2**30))
DISK_UTILIZATION.labels(type="free").set(free / (2**30))
try:
fs = os.stat(SQLALCHEMY_DATABASE_URL.replace("sqlite:///", ""))
DISK_UTILIZATION.labels(type="database").set(fs.st_size / (2**30))
except Exception as e: logger.error(e)
session: Session = next(get_db())
count_archives = crud.count_archives(session)
count_archive_urls = crud.count_archive_urls(session)
DATABASE_METRICS.labels(query="count_archives", user="-").set(count_archives)
DATABASE_METRICS.labels(query="count_archive_urls", user="-").set(count_archive_urls)
for user in crud.count_by_user_since(session, REPEAT_COUNT_METRICS_SECONDS):
DATABASE_METRICS.labels(query="count_by_user", user=user.author_id).set(user.total)

View File

@@ -4,14 +4,13 @@ from sqlalchemy import engine_from_config
from sqlalchemy import pool from sqlalchemy import pool
from alembic import context from alembic import context
from dotenv import load_dotenv
load_dotenv() from shared.settings import get_settings
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides
# access to the values within the .ini file in use. # access to the values within the .ini file in use.
config = context.config config = context.config
config.set_main_option('sqlalchemy.url', os.environ.get("DATABASE_PATH")) config.set_main_option('sqlalchemy.url', get_settings().DATABASE_PATH)
# Interpret the config file for Python logging. # Interpret the config file for Python logging.
# This line sets up loggers basically. # This line sets up loggers basically.
if config.config_file_name is not None: if config.config_file_name is not None:

37
src/shared/settings.py Normal file
View File

@@ -0,0 +1,37 @@
from functools import lru_cache
from pydantic_settings import BaseSettings
from pydantic import ConfigDict
from typing import Annotated, Set
from annotated_types import Len
class Settings(BaseSettings):
model_config = ConfigDict(extra='ignore', str_strip_whitespace=True)
# general
SERVE_LOCAL_ARCHIVE: str = ""
USER_GROUPS_FILENAME: str = "user-groups.yaml"
SHEET_ORCHESTRATION_YAML : str = "secrets/orchestration-sheet.yaml"
# database
DATABASE_PATH: str
DATABASE_QUERY_LIMIT: int = 100
# redis
CELERY_BROKER_URL: str = "redis://localhost:6379"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379"
REDIS_EXCEPTIONS_CHANNEL: str = "exceptions-channel"
# observability
REPEAT_COUNT_METRICS_SECONDS: int = 30
# security
API_BEARER_TOKEN: Annotated[str, Len(min_length=20)]
ALLOWED_ORIGINS: Annotated[set[str], Len(min_length=1)]
CHROME_APP_IDS: Annotated[set[Annotated[str, Len(min_length=10)]], Len(min_length=1)]
BLOCKED_EMAILS: Annotated[Set[str], Len(min_length=0)] = set()
@lru_cache
def get_settings():
return Settings()

BIN
src/static/favicon.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 93 KiB

94
src/tests/conftest.py Normal file
View File

@@ -0,0 +1,94 @@
import os
from fastapi.testclient import TestClient
import pytest
from unittest.mock import patch
from shared.settings import Settings
@pytest.fixture(autouse=True)
def mock_logger_add():
"""Fixture to mock loguru.logger.add for all tests."""
with patch('loguru.logger.add') as mock_add:
yield mock_add # This makes the mock available to tests
@pytest.fixture()
def get_settings():
return Settings(_env_file=".env.test")
@pytest.fixture(autouse=True)
def mock_settings():
with patch('shared.settings.Settings', return_value=Settings(_env_file=".env.test")) as mock_settings:
yield mock_settings
@pytest.fixture()
def test_db(get_settings: Settings):
from db.database import make_engine
from db import models
engine = make_engine(get_settings.DATABASE_PATH)
fs = get_settings.DATABASE_PATH.replace("sqlite:///", "")
if not os.path.exists(fs):
open(fs, 'w').close()
models.Base.metadata.create_all(engine)
connection = engine.connect()
yield connection
connection.close()
models.Base.metadata.drop_all(bind=engine)
for suffix in ["", "-wal", "-shm"]:
new_fs = fs + suffix
if os.path.exists(new_fs):
os.remove(new_fs)
@pytest.fixture()
def db_session(test_db):
from db.database import make_session_local
session_local = make_session_local(test_db)
with session_local() as session:
yield session
@pytest.fixture()
def app(db_session):
from web.main import app_factory
from db import crud
app = app_factory()
crud.upsert_user_groups(db_session)
return app
@pytest.fixture()
def client(app):
client = TestClient(app)
return client
@pytest.fixture()
def app_with_auth(app):
from web.security import get_token_or_user_auth, get_user_auth, token_api_key_auth
app.dependency_overrides[get_token_or_user_auth] = lambda: "rick@example.com"
app.dependency_overrides[get_user_auth] = lambda: "morty@example.com"
app.dependency_overrides[token_api_key_auth] = lambda: "jerry@example.com"
return app
@pytest.fixture()
def client_with_auth(app_with_auth):
client = TestClient(app_with_auth)
return client
@pytest.fixture()
def test_no_auth():
# reusable code to ensure a method/endpoint combination is unauthorized
def no_auth(http_method, endpoint):
response = http_method(endpoint)
assert response.status_code == 403
assert response.json() == {"detail": "Not authenticated"}
return no_auth

389
src/tests/db/test_crud.py Normal file
View File

@@ -0,0 +1,389 @@
from datetime import datetime
from unittest.mock import patch
import pytest
import yaml
from db import models
from shared.settings import Settings
authors = ["rick@example.com", "morty@example.com", "jerry@example.com"]
@pytest.fixture()
def test_data(db_session):
# creates 3 users
for email in authors:
db_session.add(models.User(email=email))
db_session.commit()
assert db_session.query(models.User).count() == 3
# creates 100 archives for 3 users over 2 months with repeating URLs
for i in range(100):
author = authors[i % 3]
archive = models.Archive(
id=f"archive-id-456-{i}",
url=f"https://example-{i%3}.com",
result={},
public=author == "jerry@example.com",
author_id=author,
group_id="spaceship" if author == "morty@example.com" and i % 2 == 0 else None,
created_at=datetime(2021, (i % 2) + 1, (i % 25) + 1)
)
if i % 5 == 0:
archive.tags.append(models.Tag(id=f"tag-{i}"))
if i % 10 == 0:
archive.tags.append(models.Tag(id=f"tag-second-{i}"))
if i % 4 == 0:
archive.tags.append(models.Tag(id=f"tag-third-{i}"))
for j in range(10):
archive.urls.append(models.ArchiveUrl(url=f"https://example-{i}.com/{j}", key=f"media_{j}"))
db_session.add(archive)
db_session.commit()
assert db_session.query(models.Archive).count() == 100
assert db_session.query(models.Tag).count() == 20 + 10 + 25
assert db_session.query(models.ArchiveUrl).count() == 1000
assert db_session.query(models.ArchiveUrl).filter(models.ArchiveUrl.archive_id == "archive-id-456-0").count() == 10
# setup groups
assert db_session.query(models.Group).count() == 0
from db import crud
crud.upsert_user_groups(db_session)
assert db_session.query(models.Group).count() == 3
assert db_session.query(models.User).count() == 4
def test_get_archive(test_data, db_session):
from db import crud
from core.config import ALLOW_ANY_EMAIL
print(db_session.query(models.Group).all())
# each author's archives work
assert (a0 := crud.get_archive(db_session, "archive-id-456-0", authors[0])) is not None
assert a0.id == "archive-id-456-0"
assert a0.url == "https://example-0.com"
assert a0.author_id == authors[0]
assert a0.public == False
assert crud.get_archive(db_session, "archive-id-456-1", authors[1]) is not None
assert crud.get_archive(db_session, "archive-id-456-2", authors[2]) is not None
# ALLOW_ANY_EMAIL
assert crud.get_archive(db_session, "archive-id-456-0", ALLOW_ANY_EMAIL) is not None
assert crud.get_archive(db_session, "archive-id-456-1", ALLOW_ANY_EMAIL) is not None
# not found
assert crud.get_archive(db_session, "archive-missing", authors[0]) is None
# public
assert (a_public := crud.get_archive(db_session, "archive-id-456-2", authors[0])) is not None
assert a_public.public == True
# not public - rick's
assert crud.get_archive(db_session, "archive-id-456-0", authors[1]) is None
def test_search_archives_by_url(test_data, db_session):
from db import crud
from core.config import ALLOW_ANY_EMAIL
# rick's archives are private
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "rick@example.com")) == 34
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", ALLOW_ANY_EMAIL)) == 34
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "morty@example.com")) == 0
# morty's archives are public but half are in spaceship group
assert len(crud.search_archives_by_url(db_session, "https://example-1.com", "rick@example.com")) == 16
# jerry's archives are public
assert len(crud.search_archives_by_url(db_session, "https://example-2.com", "jerry@example.com")) == 33
assert len(crud.search_archives_by_url(db_session, "https://example-2.com", "rick@example.com")) == 33
# fuzzy search
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL)) == 100
assert len(crud.search_archives_by_url(db_session, "https://EXAMPLE", ALLOW_ANY_EMAIL)) == 100
assert len(crud.search_archives_by_url(db_session, "2.com", ALLOW_ANY_EMAIL)) == 33
# absolute search
assert len(crud.search_archives_by_url(db_session, "example-2.com", ALLOW_ANY_EMAIL, absolute_search=True)) == 0
assert len(crud.search_archives_by_url(db_session, "https://example-2.com", ALLOW_ANY_EMAIL, absolute_search=True)) == 33
# archived_after
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, archived_after=datetime(2010, 1, 1))) == 100
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, archived_after=datetime(2021, 1, 15))) == 70
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, archived_after=datetime(2031, 1, 1))) == 0
# archived before
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, archived_before=datetime(2010, 1, 1))) == 0
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, archived_before=datetime(2021, 1, 15))) == 28
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, archived_before=datetime(2031, 1, 1))) == 100
# archived before and after
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, archived_after=datetime(2001, 1, 1), archived_before=datetime(2031, 1, 11))) == 100
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, archived_after=datetime(2021, 1, 14), archived_before=datetime(2021, 1, 16))) == 2
# limit
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, limit=10)) == 10
# skip
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, skip=10)) == 90
def test_search_archives_by_email(test_data, db_session):
from core.config import ALLOW_ANY_EMAIL
from db import crud
# lower/upper case
assert len(crud.search_archives_by_email(db_session, "rick@example.com")) == 34
assert len(crud.search_archives_by_email(db_session, "RICK@example.com")) == 34
# ALLOW_ANY_EMAIL is not a user
assert len(crud.search_archives_by_email(db_session, ALLOW_ANY_EMAIL)) == 0
# most recent first
a1 = crud.search_archives_by_email(db_session, "rick@example.com", limit=1)
assert len(a1) == 1
assert a1[0].created_at == datetime(2021, 2, 25)
# earliest is the last
a2 = crud.search_archives_by_email(db_session, "rick@example.com", skip=33)
assert len(a2) == 1
assert a2[0].created_at == datetime(2021, 1, 1)
@patch("db.crud.DATABASE_QUERY_LIMIT", new=25)
def test_max_query_limit(test_data, db_session):
from db import crud
from core.config import ALLOW_ANY_EMAIL
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL)) == 25
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, limit=1000)) == 25
assert len(crud.search_archives_by_email(db_session, "rick@example.com")) == 25
assert len(crud.search_archives_by_email(db_session, "rick@example.com", limit=1000)) == 25
def test_create_task(db_session):
from db import crud
from db import schemas
task = schemas.ArchiveCreate(
id="archive-id-456-101",
url="https://example-0.com",
result={},
public=False,
author_id="rick@example.com",
group_id="spaceship",
tags=[],
urls=[]
)
# with tags and urls
nt = crud.create_task(db_session, task, [models.Tag(id="tag-101")], [models.ArchiveUrl(url="https://example-0.com/0", key="media_0")])
assert nt is not None
assert nt.id == "archive-id-456-101"
assert nt.url == "https://example-0.com"
assert nt.author_id == "rick@example.com"
assert nt.public == False
assert nt.group_id == "spaceship"
assert len(nt.tags) == 1
assert nt.tags[0].id == "tag-101"
assert len(nt.urls) == 1
assert nt.urls[0].url == "https://example-0.com/0"
assert nt.urls[0].key == "media_0"
assert nt.created_at is not None
# without tags and urls
task.id = "archive-id-456-102"
nt = crud.create_task(db_session, task, [], [])
assert nt is not None
assert nt.id == "archive-id-456-102"
assert nt.url == "https://example-0.com"
assert nt.author_id == "rick@example.com"
assert nt.public == False
assert nt.group_id == "spaceship"
assert len(nt.tags) == 0
assert len(nt.urls) == 0
assert nt.created_at is not None
def test_soft_delete(test_data, db_session):
from db import crud
# none deleted yet
assert crud.get_archive(db_session, "archive-id-456-0", "rick@example.com") is not None
assert db_session.query(models.Archive).filter(models.Archive.deleted == True).count() == 0
# delete
assert crud.soft_delete_task(db_session, "archive-id-456-0", "rick@example.com") == True
# ensure soft delete
assert db_session.query(models.Archive).filter(models.Archive.deleted == True).count() == 1
assert crud.get_archive(db_session, "archive-id-456-0", "rick@example.com") is None
# already deleted
assert crud.soft_delete_task(db_session, "archive-id-456-0", "rick@example.com") == False
def test_count_archives(test_data, db_session):
from db import crud
assert crud.count_archives(db_session) == 100
db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").delete()
db_session.commit()
assert crud.count_archives(db_session) == 99
def test_count_archive_urls(test_data, db_session):
from db import crud
assert crud.count_archive_urls(db_session) == 1000
db_session.query(models.ArchiveUrl).filter(models.ArchiveUrl.url == "https://example-0.com/0").delete()
db_session.commit()
assert crud.count_archive_urls(db_session) == 999
db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").delete()
db_session.commit()
# no Cascade is enabled
assert crud.count_archives(db_session) == 99
assert crud.count_archive_urls(db_session) == 999
def test_count_users(test_data, db_session):
from db import crud
assert crud.count_users(db_session) == 4
db_session.query(models.User).filter(models.User.email == "rick@example.com").delete()
db_session.commit()
assert crud.count_users(db_session) == 3
def test_count_by_users_since(test_data, db_session):
from db import crud
# 100y window
assert len(cu := crud.count_by_user_since(db_session, 60 * 60 * 24 * 31 * 12 * 100)) == 3
assert cu[0].total == 34
assert cu[1].total == 33
assert cu[2].total == 33
def test_create_tag(db_session):
from db import crud
assert db_session.query(models.Tag).count() == 0
# create first
create_tag = crud.create_tag(db_session, "tag-101")
assert create_tag is not None
assert create_tag.id == "tag-101"
assert db_session.query(models.Tag).count() == 1
assert db_session.query(models.Tag).filter(models.Tag.id == "tag-101").first() == create_tag
# same id does not add new db entry
existing_tag = crud.create_tag(db_session, "tag-101")
assert existing_tag == create_tag
assert db_session.query(models.Tag).count() == 1
# create second
second_tag = crud.create_tag(db_session, "tag-102")
assert second_tag is not None
assert second_tag.id == "tag-102"
assert db_session.query(models.Tag).count() == 2
def test_is_user_in_group(test_data, db_session):
from db import crud
from core.config import ALLOW_ANY_EMAIL
# see user-groups.test.yaml
test_pairs = [
(ALLOW_ANY_EMAIL, "spaceship", True),
(ALLOW_ANY_EMAIL, "non-existant!@#!%!", True),
("rick@example.com", "spaceship", True),
("rick@example.com", "SPACESHIP", False),
("RICK@example.com", "interdimensional", True),
("rick@example.com", "the-jerrys-club", False),
("morty@example.com", "spaceship", True),
("morty@example.com", "interdimensional", False),
("morty@example.com", "the-jerrys-club", False),
("jerry@example.com", "spaceship", False),
("jerry@example.com", "interdimensional", False),
("jerry@example.com", "the-jerrys-club", True),
("rick@example.com", "animated-characters", True),
("morty@example.com", "animated-characters", True),
("jerry@example.com", "animated-characters", True),
("rick@example.com", "", False),
("", "spaceship", False),
("BADEMAILexample.com", "spaceship", False),
]
for email, group, expected in test_pairs:
assert crud.is_user_in_group(db_session, group, email) == expected
def test_create_or_get_user(test_data, db_session):
from db import crud
assert db_session.query(models.User).count() == 4
assert (u1 := crud.create_or_get_user(db_session, "rick@example.com")) is not None
assert u1.email == "rick@example.com"
assert u1.is_active == True
assert (u2 := crud.create_or_get_user(db_session, "beth@example.com")) is not None
assert u2.email == "beth@example.com"
assert u2.is_active == True
assert db_session.query(models.User).count() == 5
def test_get_group(test_data, db_session):
from db import crud
assert db_session.query(models.Group).count() == 3
assert (g1 := crud.create_or_get_group(db_session, "spaceship")) is not None
assert g1.id == "spaceship"
assert len(g1.users) == 2
assert [u.email for u in g1.users] == ["rick@example.com", "morty@example.com"]
assert (g2 := crud.create_or_get_group(db_session, "the-jerrys-club")) is not None
assert g2.id == "the-jerrys-club"
assert len(g2.users) == 1
assert g2.users[0].email == "jerry@example.com"
assert (g3 := crud.create_or_get_group(db_session, "this-is-a-new-group")) is not None
assert g3.id == "this-is-a-new-group"
assert len(g3.users) == 0
assert db_session.query(models.Group).count() == 4
def test_upsert_user_groups(db_session):
from db import crud
@patch('db.crud.get_settings', new = lambda: bad_setings)
def test_missing_yaml(db_session):
with pytest.raises(FileNotFoundError):
crud.upsert_user_groups(db_session)
@patch('db.crud.get_settings', new = lambda: bad_setings)
def test_broken_yaml(db_session):
with pytest.raises(yaml.YAMLError):
crud.upsert_user_groups(db_session)
bad_setings = Settings(_env_file=".env.test")
bad_setings.USER_GROUPS_FILENAME = "tests/user-groups.test.missing.yaml"
test_missing_yaml(db_session)
bad_setings.USER_GROUPS_FILENAME = "tests/user-groups.test.broken.yaml"
test_broken_yaml(db_session)

View File

@@ -0,0 +1,6 @@
def test_generate_uuid():
from db.models import generate_uuid
assert generate_uuid() != generate_uuid()
assert len(generate_uuid()) == 36
assert generate_uuid().count("-") == 4

View File

@@ -0,0 +1,120 @@
from unittest.mock import AsyncMock, patch
from fastapi.testclient import TestClient
import pytest
from core.config import VERSION
def test_endpoint_home(client_with_auth):
r = client_with_auth.get("/")
assert r.status_code == 200
j = r.json()
assert "version" in j and j["version"] == VERSION
assert "breakingChanges" in j
assert "groups" not in j
@patch("endpoints.default.bearer_security", new_callable=AsyncMock)
@patch("endpoints.default.get_user_auth", new_callable=AsyncMock, return_value="test@example.com")
@patch("endpoints.default.crud.get_user_groups", return_value=["group1", "group2"])
def test_endpoint_home_with_groups(m1, m2, m3, client_with_auth):
r = client_with_auth.get("/")
assert r.status_code == 200
j = r.json()
assert "version" in j and j["version"] == VERSION
assert "breakingChanges" in j
assert "groups" in j
assert j["groups"] == ["group1", "group2"]
@patch("endpoints.default.bearer_security", new_callable=AsyncMock)
@patch("endpoints.default.get_user_auth", new_callable=AsyncMock, return_value="test@example.com")
@patch("endpoints.default.crud.get_user_groups", side_effect=Exception('mocked error'))
def test_endpoint_home_with_groups_exception(m1, m2, m3, client_with_auth): # mocks call that triggers an internal error
r = client_with_auth.get("/")
assert r.status_code == 200
j = r.json()
assert "version" in j and j["version"] == VERSION
assert "breakingChanges" in j
assert "groups" not in j
def test_endpoint_health(client_with_auth):
r = client_with_auth.get("/health")
assert r.status_code == 200
assert r.json() == {"status": "ok"}
def test_endpoint_groups_no_auth(client, test_no_auth):
test_no_auth(client.get, "/groups")
def test_endpoint_groups_rick_and_morty(client_with_auth):
r = client_with_auth.get("/groups")
assert r.status_code == 200
assert len(j := r.json()) == 2
assert 'animated-characters' in j
assert 'spaceship' in j
@patch("endpoints.default.crud.get_user_groups", return_value=["group1", "group2"])
def test_endpoint_groups(m1, app):
from web.security import get_user_auth
app.dependency_overrides[get_user_auth] = lambda: True
client = TestClient(app)
r = client.get("/groups")
assert r.status_code == 200
assert r.json() == ["group1", "group2"]
def test_no_serve_local_archive_by_default(client_with_auth):
r = client_with_auth.get("/app/local_archive_test/temp.txt")
assert r.status_code == 404
def test_favicon(client_with_auth):
r = client_with_auth.get("/favicon.ico")
assert r.status_code == 200
assert r.headers["content-type"] == "image/vnd.microsoft.icon"
from tests.db.test_crud import test_data
@pytest.mark.asyncio
async def test_prometheus_metrics(test_data, client_with_auth, get_settings):
# before metrics calculation
r = client_with_auth.get("/metrics")
assert r.status_code == 200
assert r.headers["content-type"] == "text/plain; version=0.0.4; charset=utf-8"
assert "disk_utilization" in r.text
assert "database_metrics" in r.text
assert "exceptions" in r.text
assert "worker_exceptions_total" in r.text
assert 'disk_utilization{type="used"}' not in r.text
# after metrics calculation
from utils.metrics import measure_regular_metrics
await measure_regular_metrics(get_settings.DATABASE_PATH, 60 * 60 * 24 * 31 * 12 * 100)
r2 = client_with_auth.get("/metrics")
assert 'disk_utilization{type="used"}' in r2.text
assert 'disk_utilization{type="free"}' in r2.text
assert 'disk_utilization{type="database"}' in r2.text
assert 'database_metrics{query="count_archives"} 100.0' in r2.text
assert 'database_metrics{query="count_archive_urls"} 1000.0' in r2.text
assert 'database_metrics{query="count_users"} 4.0' in r2.text
assert 'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0' in r2.text
assert 'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0' in r2.text
assert 'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0' in r2.text
# 30s window, should not change the gauges nor the total in the counters
from utils.metrics import measure_regular_metrics
await measure_regular_metrics(get_settings.DATABASE_PATH, 30)
r3 = client_with_auth.get("/metrics")
assert 'database_metrics{query="count_archives"} 100.0' in r3.text
assert 'database_metrics{query="count_archive_urls"} 1000.0' in r3.text
assert 'database_metrics{query="count_users"} 4.0' in r3.text
assert 'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0' in r3.text
assert 'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0' in r3.text
assert 'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0' in r3.text

View File

@@ -0,0 +1,19 @@
import json
def test_submit_manual_archive_unauthenticated(client, test_no_auth):
test_no_auth(client.post, "/interop/submit-archive")
def test_submit_manual_archive(client_with_auth):
aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": []})
r = client_with_auth.post("/interop/submit-archive", json={"result": aa_metadata, "public": False, "author_id": "jerry@gmail.com", "group_id": None, "tags": ["test"]})
assert r.status_code == 201
assert "id" in r.json()
# cannot have the same URL twice
aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.com", "http://example.com"]}]})
r = client_with_auth.post("/interop/submit-archive", json={"result": aa_metadata, "public": False, "author_id": "jerry@gmail.com", "group_id": None, "tags": ["test"]})
assert r.status_code == 422
assert r.json() == {"detail": "Cannot insert into DB due to integrity error"}

View File

@@ -0,0 +1,46 @@
import json
from unittest.mock import patch
from db.schemas import TaskResult
def test_sheet_no_auth(client, test_no_auth):
test_no_auth(client.post, "/sheet/archive")
@patch("worker.main.create_sheet_task.delay", return_value=TaskResult(id="123-456-789", status="PENDING", result=""))
def test_sheet_rick(m1, client_with_auth):
response = client_with_auth.post("/sheet/archive", json={"sheet_id": "123-sheet-id"})
assert response.status_code == 201
assert response.json() == {'id': '123-456-789'}
m1.assert_called_once()
called_val = m1.call_args.args[0]
assert json.loads(called_val) == {"sheet_id": "123-sheet-id", "sheet_name": None, "public": False, "author_id": "rick@example.com", "group_id": None, "tags": [], "columns": {}, "header": 1}
def test_sheet_missing_sheet_data(client_with_auth):
r = client_with_auth.post("/sheet/archive", json={})
assert r.status_code == 422
assert r.json() == {"detail": "sheet name or id is required"}
@patch("worker.main.create_sheet_task.delay", return_value=TaskResult(id="123-API-789", status="PENDING", result=""))
def test_sheet_api(m1, client):
response = client.post("/sheet/archive", json={"sheet_name": "456-sheet_name-id"}, headers={"Authorization": "Bearer this_is_the_test_api_token"})
assert response.status_code == 201
assert response.json() == {'id': '123-API-789'}
m1.assert_called_once()
called_val = m1.call_args.args[0]
assert json.loads(called_val) == {"sheet_name": "456-sheet_name-id", "sheet_id": None, "public": False, "author_id": "api-endpoint", "group_id": None, "tags": [], "columns": {}, "header": 1}
response = client.post("/sheet/archive", json={"sheet_id": "456-sheet-id", "author_id": "custom-author"}, headers={"Authorization": "Bearer this_is_the_test_api_token"})
assert response.status_code == 201
assert response.json() == {'id': '123-API-789'}
assert m1.call_count == 2
called_val = m1.call_args.args[0]
assert json.loads(called_val) == {"sheet_id": "456-sheet-id", "sheet_name": None, "public": False, "author_id": "custom-author", "group_id": None, "tags": [], "columns": {}, "header": 1}

View File

@@ -0,0 +1,51 @@
from unittest.mock import patch
def test_endpoint_task_status_no_auth(client, test_no_auth):
test_no_auth(client.get, "/task/test-task-id")
@patch("endpoints.task.AsyncResult")
def test_get_status_success(mock_async_result, client_with_auth):
mock_async_result.return_value.status = "SUCCESS"
mock_async_result.return_value.result = {"data": "some result"}
response = client_with_auth.get("/task/test-task-id")
assert response.status_code == 200
assert response.json() == {
"id": "test-task-id",
"status": "SUCCESS",
"result": {"data": "some result"}
}
@patch("endpoints.task.AsyncResult")
def test_get_status_failure(mock_async_result, client_with_auth):
mock_async_result.return_value.status = "FAILURE"
mock_async_result.return_value.result = Exception("Some error")
response = client_with_auth.get("/task/test-task-id")
assert response.status_code == 200
assert response.json() == {
"id": "test-task-id",
"status": "FAILURE",
"result": {"error": "Some error"}
}
@patch("endpoints.task.AsyncResult")
def test_get_status_pending(mock_async_result, client_with_auth):
mock_async_result.return_value.status = "PENDING"
mock_async_result.return_value.result = None
response = client_with_auth.get("/task/test-task-id")
assert response.status_code == 200
assert response.json() == {
"id": "test-task-id",
"status": "PENDING",
"result": None
}

View File

@@ -0,0 +1,145 @@
import json
from unittest.mock import patch
from db.schemas import ArchiveCreate, TaskResult
def test_archive_url_unauthenticated(client, test_no_auth):
test_no_auth(client.post, "/url/archive")
test_no_auth(client.get, "/url/archive")
@patch("worker.main.create_archive_task.delay", return_value=TaskResult(id="123-456-789", status="PENDING", result=""))
def test_archive_url(m1, client_with_auth):
response = client_with_auth.post("/url/archive", json={"url": "bad"})
assert response.status_code == 422
assert response.json() == {'detail': 'Invalid URL received: bad'}
m1.assert_not_called()
response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
assert response.status_code == 201
assert response.json() == {'id': '123-456-789'}
m1.assert_called_once()
called_val = m1.call_args.args[0]
assert json.loads(called_val) == {"id": None, "url": "https://example.com", "result": None, "public": True, "author_id": "rick@example.com", "group_id": None, "tags": [], "rearchive": True}
def test_search_by_url_unauthenticated(client, test_no_auth):
test_no_auth(client.get, "/url/search")
def test_search_by_url(client_with_auth, db_session):
# tests the search endpoint, including through some db data for the endpoint params
response = client_with_auth.get("/url/search")
assert response.status_code == 422
assert response.json()["detail"][0]["msg"] == "Field required"
response = client_with_auth.get("/url/search?url=https://example.com")
assert response.status_code == 200
assert response.json() == []
from db import crud
for i in range(11):
crud.create_task(db_session, ArchiveCreate(id=f"url-456-{i}", url="https://example.com" if i < 10 else "https://something-else.com", result={}, public=True, author_id="rick@example.com", group_id=None), [], [])
#NB: this insertion is too fast for the ordering to be correct as they are within the same second
response = client_with_auth.get("/url/search?url=https://example.com")
assert response.status_code == 200
assert len(j := response.json()) == 10
assert "url-456-0" in [i["id"] for i in j]
assert "url-456-9" in [i["id"] for i in j]
assert "url-456-10" not in [i["id"] for i in j]
response = client_with_auth.get("/url/search?url=https://example.com&limit=5")
assert response.status_code == 200
assert len(response.json()) == 5
response = client_with_auth.get("/url/search?url=https://example.com&skip=5&limit=2")
assert response.status_code == 200
assert len(response.json()) == 2
response = client_with_auth.get("/url/search?url=https://example.com&archived_before=2010-01-01")
assert response.status_code == 200
assert len(response.json()) == 0
response = client_with_auth.get("/url/search?url=https://example.com&archived_after=2010-01-01")
assert response.status_code == 200
assert len(response.json()) == 10
def test_latest_unauthenticated(client, test_no_auth):
test_no_auth(client.get, "/url/latest")
def test_latest(client_with_auth, db_session):
response = client_with_auth.get("/url/latest")
assert response.status_code == 200
assert response.json() == []
from db import crud
for i in range(11):
crud.create_task(db_session, ArchiveCreate(id=f"latest-456-{i}", url="https://example.com", result={}, public=True, author_id="morty@example.com" if i < 10 else "rick@example.com", group_id=None), [], [])
#NB: this insertion is too fast for the ordering to be correct as they are within the same second
# user must exist for /latest to work
crud.create_or_get_user(db_session, "morty@example.com")
response = client_with_auth.get("/url/latest")
assert response.status_code == 200
assert len(j := response.json()) == 10
assert "latest-456-0" in [i["id"] for i in j]
assert "latest-456-9" in [i["id"] for i in j]
assert "latest-456-10" not in [i["id"] for i in j]
response = client_with_auth.get("/url/latest?limit=5")
assert response.status_code == 200
assert len(response.json()) == 5
response = client_with_auth.get("/url/latest?skip=5&limit=2")
assert response.status_code == 200
assert len(response.json()) == 2
def test_lookup_unauthenticated(client, test_no_auth):
test_no_auth(client.get, "/url/123-456-789")
def test_lookup(client_with_auth, db_session):
response = client_with_auth.get("/url/lookup-123-456-789")
assert response.status_code == 404
assert response.json() == {"detail": "Archive not found"}
from db import crud
crud.create_task(db_session, ArchiveCreate(id="lookup-123-456-789", url="https://example.com", result={}, public=True, author_id="rick@example.com", group_id=None), [], [])
response = client_with_auth.get("/url/lookup-123-456-789")
assert response.status_code == 200
j = response.json()
assert j["id"] == "lookup-123-456-789"
assert j["url"] == "https://example.com"
assert j["result"] == {}
assert j["public"] == True
assert j["author_id"] == "rick@example.com"
assert j["group_id"] == None
assert j["tags"] == []
assert j["updated_at"] == None
assert j["rearchive"] == True
def test_delete_task_unauthenticated(client, test_no_auth):
test_no_auth(client.delete, "/url/123-456-789")
def test_delete_task(client_with_auth, db_session):
response = client_with_auth.delete("/url/delete-123-456-789")
assert response.status_code == 200
assert response.json() == {"id": "delete-123-456-789", "deleted": False}
from db import crud
crud.create_task(db_session, ArchiveCreate(id="delete-123-456-789", url="https://example.com", result={}, public=True, author_id="morty@example.com", group_id=None), [], [])
response = client_with_auth.delete("/url/delete-123-456-789")
assert response.status_code == 200
assert response.json() == {"id": "delete-123-456-789", "deleted": True}

View File

@@ -0,0 +1,24 @@
steps:
feeder: cli_feeder
archivers: # order matters
- youtubedl_archiver
enrichers:
- hash_enricher
formatter: html_formatter # defaults to mute_formatter
storages:
- local_storage
databases:
- console_db
configurations:
cli_feeder:
urls:
- "url1"
hash_enricher:
algorithm: "SHA-256"
local_storage:
save_to: "./local_archive"
save_absolute: true
filename_generator: static
path_generator: flat

View File

@@ -0,0 +1,6 @@
broken: True
This is just an invalid yaml for testing
still broken: True
- one
- two

View File

@@ -0,0 +1,19 @@
# NOTE: all emails should be lower-cased
users:
rick@example.com:
- spaceship
- interdimensional
morty@example.com:
- spaceship
jerry@example.com:
- the-jerrys-club
birdman@example.com:
domains:
example.com:
- animated-characters
orchestrators:
spaceship: tests/orchestration.test.yaml
interdimensional: tests/orchestration.test.yaml
default: tests/orchestration.test.yaml

View File

@@ -0,0 +1,49 @@
import os
from unittest.mock import patch
from fastapi.testclient import TestClient
import shutil
import pytest
def test_lifespan(app):
with TestClient(app) as client:
r = client.get("/health")
assert r.status_code == 200
assert r.json() == {"status": "ok"}
def test_alembic(db_session):
import alembic.config
alembic.config.main(argv=['--raiseerr', 'upgrade', 'head'])
alembic.config.main(argv=['--raiseerr', 'downgrade', 'base'])
@patch("endpoints.default.crud.get_user_groups", side_effect=Exception('mocked error'))
def test_logging_middleware(m1, client_with_auth):
from utils.metrics import EXCEPTION_COUNTER
assert len(EXCEPTION_COUNTER.collect()[0].samples) == 0
with pytest.raises(Exception, match="mocked error"):
client_with_auth.get("/groups")
# creates one empty and one from above
assert len(EXCEPTION_COUNTER.collect()[0].samples) == 2
def test_serve_local_archive_logic(get_settings):
# create a test file first
os.makedirs("local_archive_test", exist_ok=True)
with open("local_archive_test/temp.txt", "w") as f:
f.write("test")
try:
# modify the settings
get_settings.SERVE_LOCAL_ARCHIVE = "/app/local_archive_test"
from web.main import app_factory
app = app_factory(get_settings)
# test
client = TestClient(app)
r = client.get("/app/local_archive_test/temp.txt")
assert r.status_code == 200
assert r.text == "test"
finally:
# cleanup
shutil.rmtree("local_archive_test")

View File

@@ -0,0 +1,108 @@
from unittest.mock import patch
from fastapi import HTTPException
from fastapi.security import HTTPAuthorizationCredentials
import pytest
from core.config import ALLOW_ANY_EMAIL
def test_secure_compare():
from web.security import secure_compare
assert secure_compare("test", "test")
assert not secure_compare("test", "test2")
@pytest.mark.asyncio
async def test_get_token_or_user_auth_with_api():
from web.security import get_token_or_user_auth
mock_api = HTTPAuthorizationCredentials(scheme="lorem", credentials="this_is_the_test_api_token")
assert await get_token_or_user_auth(mock_api) == ALLOW_ANY_EMAIL
@pytest.mark.asyncio
async def test_get_token_or_user_auth_with_user():
from web.security import get_token_or_user_auth
bad_user = HTTPAuthorizationCredentials(scheme="ipsum", credentials="invalid")
e: pytest.ExceptionInfo = None
with pytest.raises(HTTPException) as e:
await get_token_or_user_auth(bad_user)
assert e.value.status_code == 401
assert e.value.detail == "invalid access_token"
@patch("web.security.authenticate_user", return_value=(True, "summer@example.com"))
@pytest.mark.asyncio
async def test_get_user_auth(m1):
from web.security import get_user_auth
bad_user = HTTPAuthorizationCredentials(scheme="ipsum", credentials="valid-and-good")
assert await get_user_auth(bad_user) == "summer@example.com"
@patch("web.security.secure_compare", return_value=False)
@pytest.mark.asyncio
async def test_token_api_key_auth_exception(m1):
from web.security import token_api_key_auth
e: pytest.ExceptionInfo = None
with pytest.raises(HTTPException) as e:
await token_api_key_auth(HTTPAuthorizationCredentials(scheme="ipsum", credentials="does-not-matter"), auto_error=True)
assert e.value.status_code == 401
assert e.value.detail == "Wrong auth credentials"
@pytest.mark.asyncio
async def test_authenticate_user():
from web.security import authenticate_user
assert authenticate_user("test") == (False, "invalid access_token")
assert authenticate_user(123) == (False, "invalid access_token")
with patch("web.security.requests.get") as mock_get:
# bad response from oauth2
mock_get.return_value.status_code = 403
assert authenticate_user("this-will-call-requests") == (False, "invalid token")
assert mock_get.call_count == 1
# 200 but invalid json
mock_get.return_value.status_code = 200
assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID")
assert mock_get.call_count == 2
# 200 but invalid azp and aud
mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "not_an_app"}
assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID")
mock_get.return_value.json.return_value = {"email": "summer@example.com", "aud": "not_an_app"}
assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID")
mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "not_an_app", "aud": "not_an_app"}
assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID")
# blocked email
mock_get.return_value.json.return_value = {"email": "blocked@example.com", "azp": "test_app_id_1", "aud": "not_an_app"}
assert authenticate_user("this-will-call-requests") == (False, "email 'blocked@example.com' not allowed")
# not verified
mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "not_an_app", "aud": "test_app_id_1"}
assert authenticate_user("this-will-call-requests") == (False, "email 'summer@example.com' not verified")
# token expired
mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "test_app_id_2", "email_verified": "true"}
assert authenticate_user("this-will-call-requests") == (False, "Token expired")
# 200 and valid azp and aup and verified
mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "test_app_id_2", "email_verified": "true", "expires_in": 100}
assert authenticate_user("this-will-call-requests") == (True, "summer@example.com")
assert mock_get.call_count == 9
@pytest.mark.asyncio
async def test_authenticate_user_exception():
from web.security import authenticate_user
with patch("web.security.requests.get") as mock_get:
mock_get.return_value.status_code = 200
mock_get.return_value.json.side_effect = Exception("mocked error")
assert authenticate_user("this-will-call-requests") == (False, "exception occurred")

View File

@@ -0,0 +1,196 @@
from unittest import mock
from unittest.mock import MagicMock, patch
import pytest
from db import models, schemas
from auto_archiver import Metadata
from auto_archiver.core import Media
@pytest.fixture()
def worker_init():
from worker.main import at_start
at_start(None)
class Test_create_archive_task():
URL = "https://example-live.com"
archive = schemas.ArchiveCreate(url=URL, tags=[], public=True, group_id=None, author_id="rick@example.com")
@patch("worker.main.insert_result_into_db")
@patch("worker.main.is_group_invalid_for_user", return_value=None)
@patch("worker.main.choose_orchestrator")
@patch("celery.app.task.Task.request")
def test_success(self, m_req, m_choose, m_is_group, m_insert, worker_init, db_session):
from worker.main import create_archive_task
m_req.id = "this-just-in"
mock_orchestrator = self.mock_orchestrator_choice(m_choose)
task = create_archive_task(self.archive.model_dump_json())
m_choose.assert_called_once()
mock_orchestrator.feed_item.assert_called_once()
assert task["status"] == "success"
assert task["metadata"]["url"] == self.URL
assert len(task["media"]) == 0
@patch("worker.main.is_group_invalid_for_user", return_value=True)
def test_raise_invalid(self, m_is_group, worker_init):
from worker.main import create_archive_task
with pytest.raises(Exception):
create_archive_task(self.archive.model_dump_json())
@patch("worker.main.insert_result_into_db", side_effect=Exception)
@patch("worker.main.is_group_invalid_for_user", return_value=False)
@patch("worker.main.choose_orchestrator")
def test_raise_db_error(self, m_choose, m_is_group, m_insert, worker_init):
from worker.main import create_archive_task
mock_orchestrator = self.mock_orchestrator_choice(m_choose)
with pytest.raises(Exception):
create_archive_task(self.archive.model_dump_json())
mock_orchestrator.feed_item.assert_called_once()
def mock_orchestrator_choice(self, m_choose):
mock_orchestrator = mock.MagicMock()
mock_orchestrator.configure_mock(feed_item=mock.MagicMock(return_value=Metadata().set_url(self.URL).success()))
m_choose.return_value = mock_orchestrator
return mock_orchestrator
class Test_create_sheet_task():
URL = "https://example-live.com"
sheet = schemas.SubmitSheet(sheet_name="Sheet", sheet_id="123", author_id="rick@example.com", group_id=None)
# @patch("worker.main.insert_result_into_db")
@patch("worker.main.models.generate_uuid", return_value="constant-uuid")
@patch("worker.main.is_group_invalid_for_user", return_value=False)
@patch("worker.main.ArchivingOrchestrator")
def test_success(self, m_orch_generator, m_is_group, m_uuid, worker_init, db_session):
from worker.main import create_sheet_task
assert db_session.query(models.Archive).filter(models.Archive.url == self.URL).count() == 0
mock_metadata = Metadata().set_url(self.URL).success()
mock_metadata.add_media(Media("fn1.txt", urls=["outcome1.com"]))
m_orch = MagicMock()
m_orch.feed.return_value = iter([False, mock_metadata, mock_metadata])
m_orch_generator.return_value = m_orch
res = create_sheet_task(self.sheet.model_dump_json())
print(res)
assert res["archived"] == 1
assert res["failed"] == 0
assert len(res["errors"]) == 0
assert res["sheet"] == "Sheet"
assert res["sheet_id"] == "123"
assert res["success"] == True
assert len(res["time"]) > 0
# query created archive entry
inserted = db_session.query(models.Archive).filter(models.Archive.url == self.URL).one()
assert inserted is not None
assert inserted.url == self.URL
assert inserted.tags[0].id == "gsheet"
@patch("worker.main.insert_result_into_db", side_effect=Exception("some-error"))
@patch("worker.main.models.generate_uuid", return_value="constant-uuid")
@patch("worker.main.is_group_invalid_for_user", return_value=False)
@patch("worker.main.ArchivingOrchestrator")
def test_has_exception(self, m_orch_generator, m_is_group, m_uuid, worker_init, db_session):
from worker.main import create_sheet_task
assert db_session.query(models.Archive).filter(models.Archive.url == self.URL).count() == 0
mock_metadata = Metadata().set_url(self.URL).success()
mock_metadata.add_media(Media("fn1.txt", urls=["outcome1.com"]))
m_orch = MagicMock()
m_orch.feed.return_value = iter([mock_metadata])
m_orch_generator.return_value = m_orch
res = create_sheet_task(self.sheet.model_dump_json())
print(res)
assert res["archived"] == 0
assert res["failed"] == 1
assert res["errors"] == ["some-error"]
assert res["sheet_id"] == "123"
assert res["success"] == True
assert db_session.query(models.Archive).filter(models.Archive.url == self.URL).count() == 0
@patch("worker.main.is_group_invalid_for_user", return_value="Access denied")
def test_error_access(self, m_insert, worker_init, db_session):
from worker.main import create_sheet_task
res = create_sheet_task(self.sheet.model_dump_json())
assert "error" in res
assert res["error"] == "Access denied"
def test_choose_orchestrator(worker_init):
from worker.main import choose_orchestrator
assert choose_orchestrator(None, "rick@example.com").__class__.__name__ == "ArchivingOrchestrator"
@patch("worker.main.get_user_first_group", return_value="does-not-exist")
def test_choose_orchestrator_assertion(worker_init):
from worker.main import choose_orchestrator
with pytest.raises(Exception):
choose_orchestrator(None, "rick@example.com")
@patch("worker.main.read_user_groups")
def test_get_user_first_group(m_read_user_groups, worker_init):
from worker.main import get_user_first_group
m_read_user_groups.return_value = {"users": {}}
assert get_user_first_group("email1") == "default"
m_read_user_groups.return_value = {"users": {"email1": []}}
assert get_user_first_group("email1") == "default"
m_read_user_groups.return_value = {"users": {"email1": ["group1", "group2"]}}
assert get_user_first_group("email1") == "group1"
def test_is_group_invalid_for_user(worker_init, db_session):
from worker.main import is_group_invalid_for_user
from db.crud import upsert_user_groups
upsert_user_groups(db_session)
assert is_group_invalid_for_user(True, "", "") == False
assert is_group_invalid_for_user(False, "", "") == False
assert is_group_invalid_for_user(False, "default", "") == "User is not part of default, no permission"
assert is_group_invalid_for_user(False, "spaceship", "jerry@example.com") == "User jerry@example.com is not part of spaceship, no permission"
assert is_group_invalid_for_user(False, "spaceship", "rick@example.com") == False
def test_get_all_urls(worker_init, db_session):
from worker.main import get_all_urls
from auto_archiver import Metadata
meta = Metadata().set_url("https://example.com")
m1 = meta.add_media(Media("fn1.txt", urls=["outcome1.com"]))
m2 = meta.add_media(Media("fn2.txt", urls=["outcome2.com"]))
m3 = meta.add_media(Media("fn3.txt", urls=["outcome3.com"]))
m1.set("screenshot", Media("screenshot.png", urls=["screenshot.com"]))
m2.set("thumbnails", [Media("thumb1.png", urls=["thumb1.com"]), Media("thumb2.png", urls=["thumb2.com"])])
m3.set("ssl_data", Media("ssl_data.txt", urls=["ssl_data.com"]).to_dict())
m3.set("bad_data", {"bad": "dict is ignored"})
urls = [u.url for u in get_all_urls(meta)]
assert len(urls) == 7
assert "outcome1.com" in urls
assert "outcome2.com" in urls
assert "outcome3.com" in urls
assert "screenshot.com" in urls
assert "thumb1.com" in urls
assert "thumb2.com" in urls
assert "ssl_data.com" in urls

0
src/utils/__init__.py Normal file
View File

69
src/utils/metrics.py Normal file
View File

@@ -0,0 +1,69 @@
import asyncio
import json
import os
import shutil
from prometheus_client import Counter, Gauge
import redis
from db import crud
from db.database import get_db
from core.logging import log_error
# Custom metrics
EXCEPTION_COUNTER = Counter(
"exceptions",
"Number of times a certain exception has occurred.",
labelnames=["types"]
)
WORKER_EXCEPTION = Counter(
"worker_exceptions_total",
"Number of times a certain exception has occurred on the worker.",
labelnames=["types", "exception", "task", "traceback"]
)
DISK_UTILIZATION = Gauge(
"disk_utilization",
"Disk utilization in GB",
labelnames=["type"]
)
DATABASE_METRICS = Gauge(
"database_metrics",
"Database metric readings at a certain point in time",
labelnames=["query"]
)
DATABASE_METRICS_COUNTER = Counter(
"database_metrics_counter",
"Database metrics that increase over time",
labelnames=["query", "user"]
)
async def redis_subscribe_worker_exceptions(REDIS_EXCEPTIONS_CHANNEL, CELERY_BROKER_URL):
# Subscribe to Redis channel and increment the counter for each exception with info on the exception and task
Rdis = redis.Redis.from_url(CELERY_BROKER_URL)
PubSubExceptions = Rdis.pubsub()
PubSubExceptions.subscribe(REDIS_EXCEPTIONS_CHANNEL)
while True:
message = PubSubExceptions.get_message()
if message and message["type"] == "message":
data = json.loads(message["data"].decode("utf-8"))
WORKER_EXCEPTION.labels(types=type(data["exception"]).__name__, exception=data["exception"], task=data["task"], traceback=data["traceback"]).inc()
await asyncio.sleep(1)
async def measure_regular_metrics(sqlite_db_url: str, repeat_in_seconds: int):
_total, used, free = shutil.disk_usage("/")
DISK_UTILIZATION.labels(type="used").set(used / (2**30))
DISK_UTILIZATION.labels(type="free").set(free / (2**30))
try:
fs = os.stat(sqlite_db_url.replace("sqlite:///", ""))
DISK_UTILIZATION.labels(type="database").set(fs.st_size / (2**30))
except Exception as e: log_error(e)
with get_db() as db:
DATABASE_METRICS.labels(query="count_archives").set(crud.count_archives(db))
DATABASE_METRICS.labels(query="count_archive_urls").set(crud.count_archive_urls(db))
DATABASE_METRICS.labels(query="count_users").set(crud.count_users(db))
for user in crud.count_by_user_since(db, repeat_in_seconds):
DATABASE_METRICS_COUNTER.labels(query="count_by_user", user=user.author_id).inc(user.total)

4
src/web/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
from web.main import app_factory
app = app_factory

167
src/web/main.py Normal file
View File

@@ -0,0 +1,167 @@
import os
from celery.result import AsyncResult
from fastapi import FastAPI, Depends, HTTPException
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from prometheus_fastapi_instrumentator import Instrumentator
from datetime import datetime
import sqlalchemy
from sqlalchemy.orm import Session
from loguru import logger
from core.logging import logging_middleware, log_error
from worker.main import create_archive_task, create_sheet_task, celery, insert_result_into_db
from db import crud, models, schemas
from web.security import get_user_auth, token_api_key_auth, get_token_or_user_auth
from core.config import VERSION, API_DESCRIPTION
from db.database import get_db_dependency
from core.events import lifespan
from shared.settings import get_settings
from auto_archiver import Metadata
from endpoints import default_router, url_router, sheet_router, task_router, interoperability_router
def app_factory(settings = get_settings()):
app = FastAPI(
title="Auto-Archiver API",
description=API_DESCRIPTION,
version=VERSION,
contact={"name": "GitHub", "url": "https://github.com/bellingcat/auto-archiver-api"},
lifespan=lifespan
)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.middleware("http")(logging_middleware)
app.include_router(default_router)
app.include_router(url_router)
app.include_router(sheet_router)
app.include_router(task_router)
app.include_router(interoperability_router)
# prometheus exposed in /metrics with authentication
Instrumentator(should_group_status_codes=False, excluded_handlers=["/metrics", "/health"]).instrument(app).expose(app, dependencies=[Depends(token_api_key_auth)])
local_dir = settings.SERVE_LOCAL_ARCHIVE
if not os.path.isdir(local_dir) and os.path.isdir(local_dir.replace("/app", ".")):
local_dir = local_dir.replace("/app", ".")
if len(settings.SERVE_LOCAL_ARCHIVE) > 1 and os.path.isdir(local_dir):
logger.warning(f"MOUNTing local archive {settings.SERVE_LOCAL_ARCHIVE}")
app.mount(settings.SERVE_LOCAL_ARCHIVE, StaticFiles(directory=local_dir), name=settings.SERVE_LOCAL_ARCHIVE)
# -----Submit URL and manipulate tasks. Bearer protected below
@app.get("/tasks/search-url", response_model=list[schemas.Archive], deprecated=True) # DEPRECATED
def search_by_url(url: str, skip: int = 0, limit: int = 100, archived_after: datetime = None, archived_before: datetime = None, db: Session = Depends(get_db_dependency), email=Depends(get_token_or_user_auth)):
return crud.search_archives_by_url(db, url.strip(), email, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before)
@app.get("/tasks/sync", response_model=list[schemas.Archive], deprecated=True) # DEPRECATED
def search(skip: int = 0, limit: int = 100, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
return crud.search_archives_by_email(db, email, skip=skip, limit=limit)
@app.post("/tasks", status_code=201, deprecated=True) # DEPRECATED
def archive_tasks(archive: schemas.ArchiveCreate, email=Depends(get_token_or_user_auth)):
archive.author_id = email
url = archive.url
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {url}")
if type(url) != str or len(url) <= 5:
raise HTTPException(status_code=422, detail=f"Invalid URL received: {url}")
logger.info("creating task")
task = create_archive_task.delay(archive.model_dump_json())
return JSONResponse({"id": task.id})
@app.get("/archive/{task_id}", deprecated=True) # DEPRECATED
def lookup(task_id, db: Session = Depends(get_db_dependency), email=Depends(get_token_or_user_auth)):
return crud.get_archive(db, task_id, email)
@app.get("/tasks/{task_id}", deprecated=True) # DEPRECATED
def get_status(task_id, email=Depends(get_token_or_user_auth)):
logger.info(f"status check for user {email} task {task_id}")
task = AsyncResult(task_id, app=celery)
try:
if task.status == "FAILURE":
# *FAILURE* The task raised an exception, or has exceeded the retry limit.
# The :attr:`result` attribute then contains the exception raised by the task.
# https://docs.celeryq.dev/en/stable/_modules/celery/result.html#AsyncResult
raise task.result
response = {
"id": task_id,
"status": task.status,
"result": task.result
}
return JSONResponse(jsonable_encoder(response, exclude_unset=True))
except Exception as e:
log_error(e)
return JSONResponse({
"id": task_id,
"status": "FAILURE",
"result": {"error": str(e)}
})
@app.delete("/tasks/{task_id}", deprecated=True) # DEPRECATED
def delete_task(task_id, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
logger.info(f"deleting task {task_id} request by {email}")
return JSONResponse({
"id": task_id,
"deleted": crud.soft_delete_task(db, task_id, email)
})
# ----- Google Sheets Logic
@app.post("/sheet", status_code=201, deprecated=True) # DEPRECATED
def archive_sheet(sheet: schemas.SubmitSheet, email=Depends(get_user_auth)):
logger.info(f"SHEET TASK for {sheet=}")
sheet.author_id = email
if not sheet.sheet_name and not sheet.sheet_id:
raise HTTPException(status_code=422, detail=f"sheet name or id is required")
task = create_sheet_task.delay(sheet.model_dump_json())
return JSONResponse({"id": task.id})
@app.post("/sheet_service", status_code=201, deprecated=True) # DEPRECATED
def archive_sheet_service(sheet: schemas.SubmitSheet, auth=Depends(token_api_key_auth)):
logger.info(f"SHEET TASK for {sheet=}")
sheet.author_id = sheet.author_id or "api-endpoint"
if not sheet.sheet_name and not sheet.sheet_id:
raise HTTPException(status_code=422, detail=f"sheet name or id is required")
task = create_sheet_task.delay(sheet.model_dump_json())
return JSONResponse({"id": task.id})
# ----- endpoint to submit data archived elsewhere
@app.post("/submit-archive", status_code=201, deprecated=True) # DEPRECATED
def submit_manual_archive(manual: schemas.SubmitManual, auth=Depends(token_api_key_auth)):
result = Metadata.from_json(manual.result)
logger.info(f"MANUAL SUBMIT {result.get_url()} {manual.author_id}")
manual.tags.add("manual")
try:
archive_id = insert_result_into_db(result, manual.tags, manual.public, manual.group_id, manual.author_id, models.generate_uuid())
except sqlalchemy.exc.IntegrityError as e:
log_error(e)
raise HTTPException(status_code=422, detail=f"Cannot insert into DB due to integrity error")
return JSONResponse({"id": archive_id})
return app

View File

@@ -1,26 +1,18 @@
from loguru import logger from loguru import logger
import requests, os, secrets import requests, secrets
from fastapi import HTTPException, status, Depends from fastapi import HTTPException, status, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from core.config import ALLOW_ANY_EMAIL
from shared.settings import get_settings
settings = get_settings()
# Configuration
CHROME_APP_IDS = set([app_id.strip() for app_id in os.environ.get("CHROME_APP_IDS", "").split(",")])
assert len(CHROME_APP_IDS) > 0, "CHROME_APP_IDS env variable not properly set, it's a csv"
for app_id in CHROME_APP_IDS:
assert len(app_id) > 10, f"CHROME_APP_IDS got invalid id: {app_id} env variable not set"
logger.info(f"{CHROME_APP_IDS=}")
BLOCKED_EMAILS = set([e.strip().lower() for e in os.environ.get("BLOCKED_EMAILS", "").split(",")])
logger.info(f"{len(BLOCKED_EMAILS)=}")
bearer_security = HTTPBearer() bearer_security = HTTPBearer()
ALLOW_ANY_EMAIL = "*"
def secure_compare(token, api_key): def secure_compare(token, api_key):
return secrets.compare_digest(token.encode("utf8"), api_key.encode("utf8")) return secrets.compare_digest(token.encode("utf8"), api_key.encode("utf8"))
# Factory method to create an authentication dependency for a specific key # Factory method to create an authentication dependency for a specific key
def api_key_auth(api_key): def api_key_auth(api_key):
@@ -39,20 +31,22 @@ def api_key_auth(api_key):
return auth return auth
# --------------------- Token Auth for AA itself to query the API, AA setup tool and Prometheus # --------------------- Token Auth for AA itself to query the API, AA setup tool and Prometheus
API_BEARER_TOKEN = os.environ.get("API_BEARER_TOKEN", "") # min length is 20 chars token_api_key_auth = api_key_auth(settings.API_BEARER_TOKEN)
token_api_key_auth = api_key_auth(API_BEARER_TOKEN)
async def get_token_or_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)): async def get_token_or_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
# tries to use the static API_KEY and defaults to google JWT auth # tries to use the static API_KEY and defaults to google JWT auth
access_token = credentials.credentials if await token_api_key_auth(credentials, auto_error=False): return ALLOW_ANY_EMAIL
if token_api_key_auth(access_token, auto_error=False): return ALLOW_ANY_EMAIL
return await get_user_auth(credentials) return await get_user_auth(credentials)
async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)): async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
# validates the Bearer token in the case that it requires it # validates the Bearer token in the case that it requires it
valid_user, info = authenticate_user(credentials.credentials) valid_user, info = authenticate_user(credentials.credentials)
if valid_user: return info if valid_user:
return info
logger.debug(f"TOKEN FAILURE: {valid_user=} {info=}") logger.debug(f"TOKEN FAILURE: {valid_user=} {info=}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@@ -60,16 +54,17 @@ async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bear
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
def authenticate_user(access_token): def authenticate_user(access_token):
# https://cloud.google.com/docs/authentication/token-types#access # https://cloud.google.com/docs/authentication/token-types#access
if type(access_token) != str or len(access_token) < 10: return False, "invalid access_token" if type(access_token) != str or len(access_token) < 10: return False, "invalid access_token"
r = requests.get("https://oauth2.googleapis.com/tokeninfo", {"access_token": access_token}) r = requests.get("https://oauth2.googleapis.com/tokeninfo", {"access_token": access_token})
if r.status_code != 200: return False, "error occurred" if r.status_code != 200: return False, "invalid token"
try: try:
j = r.json() j = r.json()
if j.get("azp") not in CHROME_APP_IDS and j.get("aud") not in CHROME_APP_IDS: if j.get("azp") not in settings.CHROME_APP_IDS and j.get("aud") not in settings.CHROME_APP_IDS:
return False, f"token does not belong to valid APP_ID" return False, f"token does not belong to valid APP_ID"
if j.get("email") in BLOCKED_EMAILS: if j.get("email") in settings.BLOCKED_EMAILS:
return False, f"email '{j.get('email')}' not allowed" return False, f"email '{j.get('email')}' not allowed"
if j.get("email_verified") != "true": if j.get("email_verified") != "true":
return False, f"email '{j.get('email')}' not verified" return False, f"email '{j.get('email')}' not verified"
@@ -77,6 +72,5 @@ def authenticate_user(access_token):
return False, "Token expired" return False, "Token expired"
return True, j.get('email') return True, j.get('email')
except Exception as e: except Exception as e:
logger.warning(f"EXCEPTION occurred: {e}") logger.warning(f"AUTH EXCEPTION occurred: {e}")
return False, f"EXCEPTION occurred" return False, "exception occurred"

0
src/worker/__init__.py Normal file
View File

View File

@@ -1,82 +1,80 @@
import os, traceback, yaml, datetime import traceback, yaml, datetime
from typing import List, Set from typing import List, Set
from celery import Celery from celery import Celery
from celery.signals import task_failure from celery.signals import task_failure, worker_init
from auto_archiver import Config, ArchivingOrchestrator, Metadata from auto_archiver import Config, ArchivingOrchestrator, Metadata
from auto_archiver.core import Media from auto_archiver.core import Media
from loguru import logger from loguru import logger
from db import crud, schemas, models from db import crud, schemas, models
from db.database import SessionLocal from db.database import get_db
from contextlib import contextmanager from shared.settings import get_settings
import json import json
import redis import redis
from sqlalchemy import exc from sqlalchemy import exc
from core.logging import log_error
settings = get_settings()
celery = Celery(__name__) celery = Celery(__name__)
celery.conf.broker_url = os.environ.get("CELERY_BROKER_URL", "redis://localhost:6379") celery.conf.broker_url = settings.CELERY_BROKER_URL
celery.conf.result_backend = os.environ.get("CELERY_RESULT_BACKEND", "redis://localhost:6379") celery.conf.result_backend = settings.CELERY_RESULT_BACKEND
USER_GROUPS_FILENAME = os.environ.get("USER_GROUPS_FILENAME", "user-groups.yaml") USER_GROUPS_FILENAME = settings.USER_GROUPS_FILENAME
REDIS_EXCEPTIONS_CHANNEL = "exceptions-channel"
Rdis = redis.Redis.from_url(celery.conf.broker_url)
@contextmanager Rdis = redis.Redis.from_url(celery.conf.broker_url)
def get_db():
session = SessionLocal()
try: yield session
finally: session.close()
@celery.task(name="create_archive_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 3}) @celery.task(name="create_archive_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 3})
def create_archive_task(self, archive_json: str): def create_archive_task(self, archive_json: str):
archive = schemas.ArchiveCreate.parse_raw(archive_json) archive = schemas.ArchiveCreate.model_validate_json(archive_json)
logger.info(f"Archiving {archive.url=} {archive.tags=} {archive.public=} {archive.group_id=} {archive.author_id=}") logger.info(f"Archiving {archive.url=} {archive.tags=} {archive.public=} {archive.group_id=} {archive.author_id=}")
invalid = is_group_invalid_for_user(archive.public, archive.group_id, archive.author_id) invalid = is_group_invalid_for_user(archive.public, archive.group_id, archive.author_id)
if invalid: if invalid:
raise Exception(invalid) # marks task FAILED, saves the Exception as result raise Exception(invalid) # marks task FAILED, saves the Exception as result
url = archive.url url = archive.url
logger.info(f"{url=} {archive=}") logger.info(f"{url=} {archive=}")
# TODO: re-evaluate if this logic is to be used
if not archive.rearchive: if not archive.rearchive:
with get_db() as session: with get_db() as session:
archives = crud.search_tasks_by_url(session, url, archive.author_id, absolute_search=True) archives = crud.search_archives_by_url(session, url, archive.author_id, absolute_search=True)
if len(archives): if len(archives):
logger.info(f"Skipping {url=} as it was already archived") logger.info(f"Skipping {url=} as it was already archived")
return Metadata.choose_most_complete([a.result for a in archives]) return Metadata.choose_most_complete([a.result for a in archives])
orchestrator = choose_orchestrator(archive.group_id, archive.author_id) orchestrator = choose_orchestrator(archive.group_id, archive.author_id)
result = orchestrator.feed_item(Metadata().set_url(url)) result = orchestrator.feed_item(Metadata().set_url(url))
try: try:
insert_result_into_db(result, archive.tags, archive.public, archive.group_id, archive.author_id, self.request.id) insert_result_into_db(result, archive.tags, archive.public, archive.group_id, archive.author_id, self.request.id)
except Exception as e: except Exception as e:
# Log it, then raise again to store the error as the task result # Log it, then raise again to store the error as the task result
logger.error(e) log_error(e)
logger.error(traceback.format_exc()) redis_publish_exception(e, self.name, traceback.format_exc())
redis_publish_exception(e, self.name)
raise e raise e
return result.to_dict() return result.to_dict()
@celery.task(name="create_sheet_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 0}) @celery.task(name="create_sheet_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 0})
def create_sheet_task(self, sheet_json: str): def create_sheet_task(self, sheet_json: str):
sheet = schemas.SubmitSheet.parse_raw(sheet_json) sheet = schemas.SubmitSheet.model_validate_json(sheet_json)
sheet.tags.add("gsheet") sheet.tags.add("gsheet")
logger.info(f"SHEET START {sheet=}") logger.info(f"SHEET START {sheet=}")
if (em := is_group_invalid_for_user(sheet.public, sheet.group_id, sheet.author_id)): return {"error": em} if (em := is_group_invalid_for_user(sheet.public, sheet.group_id, sheet.author_id)):
return {"error": em}
config = Config() config = Config()
# TODO: use choose_orchestrator and overwrite the feeder # TODO: use choose_orchestrator and overwrite the feeder
config.parse(use_cli=False, yaml_config_filename="secrets/orchestration-sheet.yaml", overwrite_configs={"configurations": {"gsheet_feeder": {"sheet": sheet.sheet_name, "sheet_id": sheet.sheet_id, "header": sheet.header}}}) config.parse(use_cli=False, yaml_config_filename=get_settings().SHEET_ORCHESTRATION_YAML, overwrite_configs={"configurations": {"gsheet_feeder": {"sheet": sheet.sheet_name, "sheet_id": sheet.sheet_id, "header": sheet.header}}})
orchestrator = ArchivingOrchestrator(config) orchestrator = ArchivingOrchestrator(config)
stats = {"archived": 0, "failed": 0, "errors": []} stats = {"archived": 0, "failed": 0, "errors": []}
for result in orchestrator.feed(): for result in orchestrator.feed():
if not result: if not result:
logger.error("Got empty result from feeder, an internal error must have occurred.") logger.error("Got empty result from feeder, an internal error must have occurred.")
continue continue
try: try:
@@ -84,12 +82,9 @@ def create_sheet_task(self, sheet_json: str):
stats["archived"] += 1 stats["archived"] += 1
except exc.IntegrityError as e: except exc.IntegrityError as e:
logger.warning(f"cached result detected: {e}") logger.warning(f"cached result detected: {e}")
stats["archived"] += 1
except Exception as e: except Exception as e:
logger.error(type(e)) log_error(e, extra=f"{self.name}: {sheet_json}")
logger.error(e) redis_publish_exception(e, self.name, traceback.format_exc())
logger.error(traceback.format_exc())
redis_publish_exception(e, self.name)
stats["failed"] += 1 stats["failed"] += 1
stats["errors"].append(str(e)) stats["errors"].append(str(e))
@@ -100,11 +95,11 @@ def create_sheet_task(self, sheet_json: str):
@task_failure.connect(sender=create_sheet_task) @task_failure.connect(sender=create_sheet_task)
@task_failure.connect(sender=create_archive_task) @task_failure.connect(sender=create_archive_task)
def task_failure_notifier(sender, **kwargs): def task_failure_notifier(sender, **kwargs):
logger.warning("😅 From task_failure_notifier ==> Task failed successfully! ") traceback_msg = "\n".join(traceback.format_list(traceback.extract_tb(kwargs['traceback'])))
logger.error(kwargs['exception']) logger.warning("😅 From task_failure_notifier ==> Task failed successfully!")
logger.error(kwargs['traceback']) log_error(kwargs['exception'], traceback_msg, f"task_failure: {sender.name}")
logger.error("\n".join(traceback.format_list(traceback.extract_tb(kwargs['traceback'])))) redis_publish_exception(kwargs['exception'], sender.name, traceback_msg)
redis_publish_exception(kwargs['exception'], sender.name)
def choose_orchestrator(group, email): def choose_orchestrator(group, email):
global ORCHESTRATORS global ORCHESTRATORS
@@ -127,7 +122,8 @@ def read_user_groups():
def get_user_first_group(email): def get_user_first_group(email):
user_groups_yaml = read_user_groups() user_groups_yaml = read_user_groups()
groups = user_groups_yaml.get("users", {}).get(email, []) groups = user_groups_yaml.get("users", {}).get(email, [])
if groups != None and len(groups): return groups[0] if groups != None and len(groups):
return groups[0]
return "default" return "default"
@@ -157,12 +153,14 @@ def is_group_invalid_for_user(public: bool, group_id: str, author_id: str):
if public is true the requirement is not needed if public is true the requirement is not needed
returns an error message if invalid, or False if all is good. returns an error message if invalid, or False if all is good.
""" """
if not public and group_id and len(group_id) > 0: if public: return False
# ensure group is valid for user if not group_id or len(group_id) == 0: return False
with get_db() as session:
if not crud.is_user_in_group(session, group_id, author_id): # otherwise group must match
logger.error(em := f"User {author_id} is not part of {group_id}, no permission") with get_db() as session:
return em if not crud.is_user_in_group(session, group_id, author_id):
logger.error(em := f"User {author_id} is not part of {group_id}, no permission")
return em
return False return False
@@ -172,7 +170,7 @@ def insert_result_into_db(result: Metadata, tags: Set[str], public: bool, group_
with get_db() as session: with get_db() as session:
# urls are created by get_all_urls # urls are created by get_all_urls
# create author_id if needed # create author_id if needed
crud.get_user(session, author_id) crud.create_or_get_user(session, author_id)
# create DB TAGs if needed # create DB TAGs if needed
db_tags = [crud.create_tag(session, tag) for tag in tags] db_tags = [crud.create_tag(session, tag) for tag in tags]
# insert archive # insert archive
@@ -191,10 +189,11 @@ def get_all_urls(result: Metadata) -> List[models.ArchiveUrl]:
if isinstance(prop, list): if isinstance(prop, list):
for i, prop_media in enumerate(prop): for i, prop_media in enumerate(prop):
if prop_media := convert_if_media(prop_media): if prop_media := convert_if_media(prop_media):
for j, url in enumerate(prop_media.urls): for j, url in enumerate(prop_media.urls):
db_urls.append(models.ArchiveUrl(url=url, key=prop_media.get("id", f"{k}{prop_media.key}_{i}.{j}"))) db_urls.append(models.ArchiveUrl(url=url, key=prop_media.get("id", f"{k}{prop_media.key}_{i}.{j}")))
return db_urls return db_urls
def convert_if_media(media): def convert_if_media(media):
if isinstance(media, Media): return media if isinstance(media, Media): return media
elif isinstance(media, dict): elif isinstance(media, dict):
@@ -203,15 +202,18 @@ def convert_if_media(media):
logger.debug(f"error parsing {media} : {e}") logger.debug(f"error parsing {media} : {e}")
return False return False
def redis_publish_exception(exception, task_name):
def redis_publish_exception(exception, task_name, traceback: str = ""):
REDIS_EXCEPTIONS_CHANNEL = settings.REDIS_EXCEPTIONS_CHANNEL
try: try:
Rdis.publish(REDIS_EXCEPTIONS_CHANNEL, json.dumps({"exception": exception, "task": task_name}, default=str)) Rdis.publish(REDIS_EXCEPTIONS_CHANNEL, json.dumps({"exception": exception, "task": task_name, "traceback": traceback}, default=str))
except Exception as e: except Exception as e:
logger.error(e) log_error(e, f"[CRITICAL] Could not publish to {REDIS_EXCEPTIONS_CHANNEL}")
logger.error(traceback.format_exc())
logger.error(f"Could not publish to {REDIS_EXCEPTIONS_CHANNEL}")
# INIT @worker_init.connect
ORCHESTRATORS = {} def at_start(sender, **kwargs):
load_orchestrators() global ORCHESTRATORS
ORCHESTRATORS = {}
load_orchestrators()
logger.info("Orchestrators loaded successfully.")