Merge pull request #52 from bellingcat/dev

major refactor completed.
This commit is contained in:
Miguel Sozinho Ramalho
2025-02-19 18:38:12 +00:00
committed by GitHub
116 changed files with 8672 additions and 5919 deletions

3
.coveragerc Normal file
View File

@@ -0,0 +1,3 @@
[run]
omit =
app/migrations/*

View File

@@ -1,5 +1,5 @@
CHROME_APP_IDS='["1234567890"]'
ALLOWED_ORIGINS='["allowed"]'
BLOCKED_EMAILS='[]'
DATABASE_PATH="sqlite:///./auto-archiver.db"
DATABASE_PATH="sqlite:///./database/auto-archiver.db"
API_BEARER_TOKEN=THIS_API_TOKEN_SHOULD_NEVER_BE_USED

38
.env.example Normal file
View File

@@ -0,0 +1,38 @@
# main settings
USER_GROUPS_FILENAME=app/user-groups.yaml
# database
DATABASE_PATH="sqlite:///./database/auto-archiver.db"
DATABASE_QUERY_LIMIT=100
# security settings
API_BEARER_TOKEN=TODO-MODIFY-THIS-API-TOKEN
ALLOWED_ORIGINS='["http://localhost:8000","http://localhost:8004","http://localhost:8081","https://auto-archiver.bellingcat.com"]'
CHROME_APP_IDS='[PROJECT_ID.apps.googleusercontent.com"]'
BLOCKED_EMAILS='[]'
# redis configuration
REDIS_PASSWORD=TODO-MODIFY-THIS-REDIS-PASSWORD
REDIS_HOSTNAME="localhost"
# cronjobs management, enable as needed
CRON_ARCHIVE_SHEETS=true
CRON_DELETE_STALE_SHEETS=true
DELETE_STALE_SHEETS_DAYS=7
CRON_DELETE_SCHEDULED_ARCHIVES=false
DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS=14
# observability for prometheus
REPEAT_COUNT_METRICS_SECONDS=30
# mail service settings, if you want to email users
MAIL_FROM="noreply@auto-archiver.com"
MAIL_FROM_NAME="My Auto Archiver deployment"
MAIL_USERNAME="USERNAME"
MAIL_PASSWORD="PASSWORD"
MAIL_SERVER="mail.server.com"
MAIL_PORT=587
MAIL_STARTTLS=False
MAIL_SSL_TLS=True
# celery workers config
CONCURRENCY=2

View File

@@ -5,5 +5,4 @@ BLOCKED_EMAILS='["blocked@example.com"]'
DATABASE_PATH="sqlite:///auto-archiver.test.db"
API_BEARER_TOKEN=this_is_the_test_api_token
USER_GROUPS_FILENAME=tests/user-groups.test.yaml
SHEET_ORCHESTRATION_YAML=tests/orchestration.test.yaml
USER_GROUPS_FILENAME=app/tests/user-groups.test.yaml

View File

@@ -1 +0,0 @@
REDIS_PASSWORD=TODO

View File

@@ -21,25 +21,24 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v2
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Install pipenv
run: pip install pipenv
working-directory: src
- name: Install Poetry
run: pipx install poetry
- name: Install dependencies
run: pipenv install --dev
working-directory: src
run: poetry install --no-interaction --with dev
- name: Set dev environment variable
run: echo "ENVIRONMENT_FILE=.env.test" >> $GITHUB_ENV
- name: Run tests with coverage
run: PYTHONPATH=. PIPENV_DOTENV_LOCATION=.env.test pipenv run coverage run -m pytest -v --color=yes tests/
working-directory: src
run: poetry run coverage run -m pytest -v -ra --color=yes app/tests/
- name: Report coverage
run: pipenv run coverage report
working-directory: src
run: poetry run coverage report

22
.gitignore vendored
View File

@@ -1,26 +1,32 @@
user-groups.dev.yaml
user-groups.yaml
orchestration.yaml
my-archives
*.pyc
.DS_Store
secrets
secrets/*
*.log
__pycache
.pytest_cach
__pycache__
.pytest_cache
.env
.env.dev
.env.prod
*.db
redis/data/*
.ipynb_checkpoints*
src/user-groups.yaml
src/user-groups.dev.yaml
app/user-groups.yaml
app/user-groups.dev.yaml
wit*
src/crawls
app/crawls
.coverage
.pytest_cache/*
.pytest_cache/
htmlcov
local_archive
local_archive_test
*db-wal
*db-shm
copy-files.sh
copy-files.sh
temp/
.python-version
orchestration2.yaml
database

View File

@@ -1,6 +1,6 @@
MIT License
Copyright (c) 2023 Stichting Bellingcat
Copyright (c) 2025 Stichting Bellingcat
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

View File

@@ -3,15 +3,20 @@ clean-dev:
docker compose -f docker-compose.yml -f docker-compose.dev.yml down --volumes --remove-orphans
dev:
docker compose -f docker-compose.yml -f docker-compose.dev.yml build
docker compose -f docker-compose.yml -f docker-compose.dev.yml up --remove-orphans
docker compose --env-file .env.dev -f docker-compose.yml -f docker-compose.dev.yml build
docker compose --env-file .env.dev -f docker-compose.yml -f docker-compose.dev.yml up --remove-orphans
dev-redis-only:
docker compose --env-file .env.dev -f docker-compose.yml -f docker-compose.dev.yml build redis
docker compose --env-file .env.dev -f docker-compose.yml -f docker-compose.dev.yml up --remove-orphans redis
stop-dev:
docker compose -f docker-compose.yml -f docker-compose.dev.yml down --volumes
prod:
docker compose build
docker compose up -d --remove-orphans
docker compose --env-file .env.prod build
docker compose --env-file .env.prod up -d --remove-orphans
docker buildx prune --keep-storage 20gb -f
docker image prune -f
docker system df

163
README.md
View File

@@ -1,89 +1,120 @@
# Auto Archiver API
An api that uses celery workers to process URL archive requests via [bellingcat/auto-archiver](https://github.com/bellingcat/auto-archiver), it allows authentication via Google OAuth Apps and enables CORS, everything runs on docker but development can be done without docker (except for redis).
[![CI](https://github.com/bellingcat/auto-archiver-api/workflows/CI/badge.svg)](https://github.com/bellingcat/auto-archiver-api/actions/workflows/ci.yaml)
A web API that uses celery workers to process URL archive requests via [bellingcat/auto-archiver](https://github.com/bellingcat/auto-archiver), it allows authentication via Google OAuth Apps and enables CORS, everything runs on docker but development can be done without docker (except for redis).
![image](https://github.com/user-attachments/assets/905d697d-b83e-437b-87d1-cc86d3c8d8bf)
## setup
To properly set up the API you need to install `docker` and to edit 3 files:
1. a `.env.prod` and `.env.dev` to configure the API, stays at the root level
2. a `user-groups.yaml` to manage user permissions
1. note that all local files referenced in `user-groups.yaml` and any orchestration.yaml files should be relative to the home directory so if your service account is in `secrets/orchestration.yaml` use that path and not just `orchestration.yaml`.
Do not commit those files, they are .gitignored by default.
We also advise you to keep any sensitive files in the `secrets/` folder which is pinned and gitignored.
## Development
http://localhost:8004
We have examples for both of those files (`.env.example` and `user-groups.example.yaml`), and here's how to set them up whether you're in development or production:
TODO: update .env file instructions, should use .env.prod and .env.dev and only use .env for always overwriting dev/prod settings.
### setup for DEVELOPMENT
```bash
# copy and modify the .env.dev file according to your needs
cp .env.example .env.dev
# copy the user-groups.example.yaml and modify it accordingly
cp user-groups.example.yaml user-groups.dev.yaml
# run the APP, make sure VPNs are off
make dev
# check it's running by calling the health endpoint
curl 'http://localhost:8004/health'
# > {"status":"ok"}
```
now go to http://localhost:8004/docs#/ and you should see the API documentation
requires `src/.env`
### setup for PRODUCTION
```bash
# copy and modify the .env.prod file according to your needs
cp .env.example .env.prod
# copy the user-groups.example.yaml and modify it accordingly
cp user-groups.example.yaml user-groups.yaml
# deploy the app
make prod
# check it's running by calling the health endpoint
curl 'http://localhost:8004/health'
# > {"status":"ok"}
```
now go to http://localhost:8004/docs#/ and you should see the API documentation
## User, Domains, Groups, and permissions management
there are 2 ways to access the API
1. via an API token which has full control/privileges to archive/search
2. via a Google Auth token which goes through the user access model
#### User access model
The permissions are defined solely via the `user-groups.yaml` file
- users belong to groups which determine their access level/quotas/orchestration setup
- users are assigned to groups explicitly (via email)
- users are assigned to groups implicitly (via email domains) as domains can be associated to groups
- users that are not explicitly or implicitly in the system belong to the `default` group, restrict their permissions if you do not wish them to be able to search/archive
- if a user is assigned to one group which is not explicitly defined, a warning will be thrown, it may be necessary to do that if you discontinue a given group but the database still has entries for it and so
- groups determine
- which orchestrator to use for single URL archives and for spreadsheet archives see [GroupPermissions](app/shared/user_groups.py)
- a set of permissions
- `read` can be [`all`], [] or a comma separated list of group names, meaning people in this group can access either all, none, or those belonging to explicitly listed groups.
- the group itself must be included in the list, otherwise the user cannot search archives of that group
- `read_public` a boolean that enables the user to search public archives
- `archive_url` a boolean that enables the user to archive links in this group
- `archive_sheet` a boolean that enables the user to archive spreadsheets
- `manually_trigger_sheet` a boolean that enables the user to manually trigger a sheet archive for sheets in this group
- `sheet_frequency` a list of options for the sheet archiving frequency, currently max permissions is `["hourly", "daily"]`
- `max_sheets` defines the maximum amount of spreadsheets someone can have in total (`-1` means no limit)
- `max_archive_lifespan_months` defines the lifespan of an archive before being deleted from S3, users will be notified 1 month in advance with instructions to download TODO
- `max_monthly_urls` how many total URLs someone can archive per month (`-1` means no limit)
- `max_monthly_mbs` how many MBs of data someone can archive per month (`-1` means no limit)
- `priority` one of `high` or `low`, this will be used to give archiving priority
- group names are all lower-case
## development of web/worker without docker
cd /src
<!-- * `pipenv install --editable ../../auto-archiver` -->
* console 1 - `docker compose up redis` optionally add `web` if not running uvicorn locally
* console 2 - `pipenv shell` + `celery worker --app=worker.celery --loglevel=info --logfile=logs/celery_dev.log`
* `celery --app=worker.celery worker --loglevel=info --logfile=logs/celery_dev.log` celery 5
* or with watchdog for dev auto-reload `watchmedo auto-restart -d ./ -- celery --app=worker.celery worker --loglevel=info --logfile=logs/celery_dev.log`
* console 3 - `pipenv shell` + `uvicorn main:app --host 0.0.0.0 --reload`
orchestration must be from the console(?)
* turn off VPNs if connection to docker is not working
We advise you to use `make prod` but you can also spin up redis and run the API (uvicorn) and worker (celery) individually like so:
* console 1 - `make dev-redis-only` to spin up redis, turn off any VPNs
* console 2 - `export ENVIRONMENT_FILE=.env.dev` then `poetry run celery --app=app.worker.main.celery worker --loglevel=debug --logfile=/aa-api/logs/celery.log -Q high_priority,low_priority --concurrency=1`
* or with watchdog for dev auto-reload `watchmedo auto-restart --patterns="*.py" --recursive --ignore-directories -- celery -- --app=app.worker.main.celery worker --loglevel=debug --logfile=/aa-api/logs/celery.log -Q high_priority,low_priority --concurrency=1`
* console 3 - `export ENVIRONMENT_FILE=.env.dev` then `poetry run uvicorn main:app --host 0.0.0.0 --reload`
## User management
Copy [example.user-groups.yaml](src/example.user-groups.yaml) into a new file and set the environment variable `USER_GROUPS_FILENAME` to that filename (defaults to `user-groups.yaml`).
This file contains 2 parts user-groups specifications. Each user can archive URLs publicly, privately, or privately for a group so long as they are declared as part of that group. In the example bellow `email1` has 2 groups while `email3` has none.
```yaml
users:
email1@example.com:
- group1
- group2
email2@example.com:
- group2
email3@example-no-group.com:
```
Auto-archiver orchestrator files configurations. For each archiving task an orchestrator is chosen, either from a specified group (if group-level visibility) or the first group the user is assigned to in the above file or the `default` orchestration file which is a required config.
```yaml
orchestrators:
group1: secrets/orchestration-group1.yaml
group2: secrets/orchestration-group2.yaml
default: secrets/orchestration-default:orchestration.yaml
```
## Database migrations
check https://alembic.sqlalchemy.org/en/latest/tutorial.html#the-migration-environment
* create migrations with `alembic revision -m "create account table"`
* if running in the normal pipenv environment use `PIPENV_DOTENV_LOCATION=.env.alembic pipenv run` followed by:
* migrate to most recent with `alembic upgrade head`
* downgrade with `alembic downgrade -1`
```bash
# set the env variables
export ENVIRONMENT_FILE=.env.alembic
# create a new migration with description in app/migrations
poetry run alembic revision -m "create account table"
# perform all migrations
poetry run alembic upgrade head
# downgrade by one migration
poetry run alembic downgrade -1
```
## Release
Update `main.py:VERSION`.
Update the version in [config.py](app/web/config.py)
Copy `.env` and `src/.env` to deployment, along with the contents of `secrets/` including `secrets/orchestration.yaml`.
Make sure environment and user-groups files are up to date.
Then `make prod`.
#### updating packages/app/access
If pipenv packages are updated: `make prod` to build images with new packages.
New users should be added to the `src/.env` file `ALLOWED_EMAILS` prop.
Run `pipenv update auto-archiver` inside `src` to update the auto-archiver version being used, then test with `make dev`.
```bash
# CALL /sheet POST endpoint
curl -XPOST -H "Authorization: Bearer GOOGLE_OAUTH_TOKEN" -H "Content-type: application/json" -d '{"sheet_id": "SHEET_ID", "header": 1}' 'http://localhost:8004/sheet'
```
### Testing
```bash
# can be done from top level but let's do it from the src folder for consistency with CI etc
cd src
# set the testing environment variables
export ENVIRONMENT_FILE=.env.test
# run tests and generate coverage
PYTHONPATH=. PIPENV_DOTENV_LOCATION=.env.test pipenv run coverage run -m pytest -vv --disable-warnings --color=yes tests/ && pipenv run coverage html
poetry run coverage run -m pytest -vv --disable-warnings --color=yes app/tests/
# get coverage report in command line
pipenv run coverage report
# get coverage HTML
pipenv run coverage html
# > open/run server on htmlcov/index.html to navigate through line coverage
```
poetry run coverage report
# get coverage report in HTML format
poetry run coverage html
```

View File

@@ -2,7 +2,7 @@
[alembic]
# path to migration scripts
script_location = migrations
script_location = app/migrations
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time

View File

@@ -1,11 +1,10 @@
from logging.config import fileConfig
import os
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
from shared.settings import get_settings
from app.shared.settings import get_settings
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.

View File

@@ -0,0 +1,34 @@
"""create archives.store_until column
Revision ID: 02b2f6d17ed0
Revises: 1636724ec4b1
Create Date: 2025-02-08 15:22:20.392522
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '02b2f6d17ed0'
down_revision = '1636724ec4b1'
branch_labels = None
depends_on = None
STORE_UNTIL_COL = "store_until"
def upgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns('archives')]
if STORE_UNTIL_COL not in columns:
op.add_column('archives', sa.Column(STORE_UNTIL_COL, sa.DateTime(), nullable=True, default=None))
def downgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns('archives')]
if STORE_UNTIL_COL in columns:
op.drop_column('archives', STORE_UNTIL_COL)

View File

@@ -0,0 +1,32 @@
"""rename sheets last_archived col
Revision ID: 1636724ec4b1
Revises: a23aaf3ae930
Create Date: 2025-02-05 19:19:01.984396
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '1636724ec4b1'
down_revision = 'a23aaf3ae930'
branch_labels = None
depends_on = None
def upgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns('sheets')]
if 'last_archived_at' in columns:
op.alter_column('sheets', 'last_archived_at', new_column_name='last_url_archived_at')
def downgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns('sheets')]
if 'last_url_archived_at' in columns:
op.alter_column('sheets', 'last_url_archived_at', new_column_name='last_archived_at')

View File

@@ -0,0 +1,36 @@
"""add new service_account_email column to groups table
Revision ID: 63ac79df4ad0
Revises: 02b2f6d17ed0
Create Date: 2025-02-11 21:53:23.293274
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '63ac79df4ad0'
down_revision = '02b2f6d17ed0'
branch_labels = None
depends_on = None
NEW_COL = "service_account_email"
TABLE = "groups"
def upgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns(TABLE)]
if NEW_COL not in columns:
op.add_column(TABLE, sa.Column(NEW_COL, sa.String, nullable=True, default=None))
def downgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns(TABLE)]
if NEW_COL in columns:
op.drop_column(TABLE, NEW_COL)

View File

@@ -0,0 +1,42 @@
"""add sheet_id to archive table
Revision ID: 89121d2c96d8
Revises: fa012ec405b8
Create Date: 2024-11-04 11:12:30.237299
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.engine.reflection import Inspector
# revision identifiers, used by Alembic.
revision = '89121d2c96d8'
down_revision = 'fa012ec405b8'
branch_labels = None
depends_on = None
def upgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns('archives')]
if 'sheet_id' not in columns:
with op.batch_alter_table('archives') as batch_op:
batch_op.add_column(sa.Column('sheet_id', sa.String(), nullable=True, default=None))
batch_op.create_foreign_key('fk_sheet_id', 'sheets', ['sheet_id'], ['id'])
def downgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
foreign_keys = [fk['name'] for fk in inspector.get_foreign_keys('archives')]
columns = [col['name'] for col in inspector.get_columns('archives')]
with op.batch_alter_table('archives') as batch_op:
if 'fk_sheet_id' in foreign_keys:
batch_op.drop_constraint('fk_sheet_id', type_='foreignkey')
if 'sheet_id' in columns:
batch_op.drop_column('sheet_id')

View File

@@ -0,0 +1,34 @@
"""drop active column
Revision ID: a23aaf3ae930
Revises: 89121d2c96d8
Create Date: 2025-02-04 12:19:20.753570
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'a23aaf3ae930'
down_revision = '89121d2c96d8'
branch_labels = None
depends_on = None
def upgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns('users')]
if 'is_active' in columns:
op.drop_column('users', 'is_active')
def downgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns('users')]
if 'is_active' not in columns:
op.add_column('users', sa.Column('is_active', sa.Boolean(), nullable=False, server_default=sa.false()))

View File

@@ -19,7 +19,7 @@ depends_on = None
def upgrade() -> None:
conn = op.get_bind()
inspector = Inspector.from_engine(conn)
inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns('groups')]
if 'description' not in columns:
@@ -35,8 +35,11 @@ def upgrade() -> None:
def downgrade() -> None:
op.drop_column('groups', 'description')
op.drop_column('groups', 'orchestrator')
op.drop_column('groups', 'orchestrator_sheet')
op.drop_column('groups', 'permissions')
op.drop_column('groups', 'domains')
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns('groups')]
column_names = ['description', 'orchestrator', 'orchestrator_sheet', 'permissions', 'domains']
for column_name in column_names:
if column_name in columns:
op.drop_column('groups', column_name)

32
app/shared/aa_utils.py Normal file
View File

@@ -0,0 +1,32 @@
# TODO: code in this file should eventually be moved to the auto-archiver code base
from typing import List
from loguru import logger
from auto_archiver.core import Media, Metadata
from app.shared.db import models
def get_all_urls(result: Metadata) -> List[models.ArchiveUrl]:
db_urls = []
for m in result.media:
for i, url in enumerate(m.urls): db_urls.append(models.ArchiveUrl(url=url, key=m.get("id", f"media_{i}")))
for k, prop in m.properties.items():
if prop_converted := convert_if_media(prop):
for i, url in enumerate(prop_converted.urls): db_urls.append(models.ArchiveUrl(url=url, key=prop_converted.get("id", f"{k}_{i}")))
if isinstance(prop, list):
for i, prop_media in enumerate(prop):
if prop_media := convert_if_media(prop_media):
for j, url in enumerate(prop_media.urls):
db_urls.append(models.ArchiveUrl(url=url, key=prop_media.get("id", f"{k}{prop_media.key}_{i}.{j}")))
return db_urls
def convert_if_media(media):
if isinstance(media, Media): return media
elif isinstance(media, dict):
try: return Media.from_dict(media)
except Exception as e:
logger.debug(f"error parsing {media} : {e}")
return False

View File

@@ -0,0 +1,16 @@
# TODO: temporary file for this code, maybe other code belongs here, maybe not. do decide
import datetime
from sqlalchemy.orm import Session
from app.shared.db import worker_crud
def get_store_archive_until(db: Session, group_id: str) -> datetime.datetime:
group = worker_crud.get_group(db, group_id)
assert group, f"Group {group_id} not found."
max_lifespan = group.permissions.get("max_archive_lifespan_months", -1)
if max_lifespan == -1: return None
return datetime.datetime.now() + datetime.timedelta(days=30 * max_lifespan)

74
app/shared/db/database.py Normal file
View File

@@ -0,0 +1,74 @@
from functools import lru_cache
from sqlalchemy import Engine, create_engine, event, text
from sqlalchemy.orm import sessionmaker
from contextlib import asynccontextmanager, contextmanager
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, AsyncEngine, async_sessionmaker
from app.shared.settings import get_settings
@lru_cache
def make_engine(database_url: str):
engine = create_engine(
database_url,
connect_args={"check_same_thread": False},
pool_size=15, # Increase pool size
max_overflow=20, # Allow more temporary connections
pool_recycle=1800 # Recycle connections every 30 minutes
)
@event.listens_for(engine, "connect")
def set_sqlite_pragma(conn, _) -> None:
cursor = conn.cursor()
cursor.execute("PRAGMA journal_mode=WAL")
cursor.close()
return engine
def make_session_local(engine: Engine):
session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
return session_local
@contextmanager
def get_db():
session = make_session_local(make_engine(get_settings().DATABASE_PATH))()
try: yield session
finally: session.close()
def get_db_dependency():
# to use with Depends and ensure proper session closing
with get_db() as db:
yield db
def wal_checkpoint():
# WAL checkpointing, make sure the .sqlite file receives the latest changes
# to be called at startup as it halts writes
with get_db() as db:
db.execute(text("PRAGMA wal_checkpoint(TRUNCATE)"))
# ASYNC connections
async def make_async_engine(database_url: str) -> AsyncEngine:
engine = create_async_engine(database_url, connect_args={"check_same_thread": False})
async with engine.begin() as conn:
await conn.run_sync(lambda sync_conn: sync_conn.execute(text("PRAGMA journal_mode=WAL;")))
return engine
async def make_async_session_local(engine: AsyncEngine) -> AsyncSession:
return async_sessionmaker(engine, expire_on_commit=False, autoflush=False, autocommit=False)
@asynccontextmanager
async def get_db_async():
engine = await make_async_engine(get_settings().ASYNC_DATABASE_PATH)
async_session = await make_async_session_local(engine)
async with async_session() as session:
try: yield session
finally: await engine.dispose()

View File

@@ -6,9 +6,11 @@ import uuid
Base = declarative_base()
def generate_uuid():
return str(uuid.uuid4())
# many to many association tables
association_table_archive_tags = Table(
"mtm_archives_tags",
@@ -23,6 +25,7 @@ association_table_user_groups = Table(
Column("group_id", ForeignKey("groups.id")),
)
# data model tables
class Archive(Base):
__tablename__ = "archives"
@@ -30,18 +33,22 @@ class Archive(Base):
id = Column(String, primary_key=True, index=True)
url = Column(String, index=True)
result = Column(JSON, default=None)
public = Column(Boolean, default=True) # if public=false, access to group and author
public = Column(Boolean, default=True) # if public=false, access by group and author
deleted = Column(Boolean, default=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
store_until = Column(DateTime(timezone=True), default=None)
group_id = Column(String, ForeignKey("groups.id"), default=None)
author_id = Column(String, ForeignKey("users.email"))
sheet_id = Column(String, ForeignKey("sheets.id"), default=None)
tags = relationship("Tag", back_populates="archives", secondary=association_table_archive_tags)
group = relationship("Group", back_populates="archives")
author = relationship("User", back_populates="archives")
urls = relationship("ArchiveUrl", back_populates="archive")
sheet = relationship("Sheet", back_populates="archives")
class ArchiveUrl(Base):
__tablename__ = "archive_urls"
@@ -61,15 +68,17 @@ class Tag(Base):
archives = relationship("Archive", back_populates="tags", secondary=association_table_archive_tags)
class User(Base):
__tablename__ = "users"
email = Column(String, primary_key=True, index=True)
is_active = Column(Boolean, default=False)
archives = relationship("Archive", back_populates="author")
sheets = relationship("Sheet", back_populates="author")
groups = relationship("Group", back_populates="users", secondary=association_table_user_groups)
class Group(Base):
__tablename__ = "groups"
@@ -77,8 +86,29 @@ class Group(Base):
description = Column(String, default=None)
orchestrator = Column(String, default=None)
orchestrator_sheet = Column(String, default=None)
permissions = Column(JSON, default=None)
permissions = Column(JSON, default={})
service_account_email = Column(String, default=None)
domains = Column(JSON, default=[])
archives = relationship("Archive", back_populates="group")
users = relationship("User", back_populates="groups", secondary=association_table_user_groups)
sheets = relationship("Sheet", back_populates="group")
users = relationship("User", back_populates="groups", secondary=association_table_user_groups)
class Sheet(Base):
__tablename__ = "sheets"
id = Column(String, primary_key=True, index=True, doc="Google Sheet ID")
name = Column(String, default=None)
author_id = Column(String, ForeignKey("users.email"))
group_id = Column(String, ForeignKey("groups.id"), doc="Group ID, user must be in a group to create a sheet.")
frequency = Column(String, default="daily", doc="Frequency of archiving: hourly, daily, weekly.")
# TODO: stats is not being used, consider removing
stats = Column(JSON, default={}, doc="Sheet statistics like total links, total rows, ...")
last_url_archived_at = Column(DateTime(timezone=True), server_default=func.now(), doc="Last time a new link was archived.")
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
group = relationship("Group", back_populates="sheets")
author = relationship("User", back_populates="sheets")
archives = relationship("Archive", back_populates="sheet")

View File

@@ -0,0 +1,60 @@
from sqlalchemy.orm import Session
from datetime import datetime
from app.shared.db import models
from app.shared import schemas
# TODO: isolate database operations away from worker and into WEB
# ONLY WORKER
def update_sheet_last_url_archived_at(db: Session, sheet_id: str):
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id).first()
if db_sheet:
db_sheet.last_url_archived_at = datetime.now()
db.commit()
return True
return False
# ONLY WORKER and INTEROP
def get_group(db: Session, group_name: str) -> models.Group:
return db.query(models.Group).filter(models.Group.id == group_name).first()
def create_or_get_user(db: Session, author_id: str) -> models.User:
if type(author_id) == str: author_id = author_id.lower()
db_user = db.query(models.User).filter(models.User.email == author_id).first()
if not db_user:
db_user = models.User(email=author_id)
db.add(db_user)
db.commit()
db.refresh(db_user)
return db_user
def create_tag(db: Session, tag: str) -> models.Tag:
db_tag = db.query(models.Tag).filter(models.Tag.id == tag).first()
if not db_tag:
db_tag = models.Tag(id=tag)
db.add(db_tag)
db.commit()
db.refresh(db_tag)
return db_tag
def create_archive(db: Session, archive: schemas.ArchiveCreate, tags: list[models.Tag], urls: list[models.ArchiveUrl]) -> models.Archive:
db_archive = models.Archive(id=archive.id, url=archive.url, result=archive.result, public=archive.public, author_id=archive.author_id, group_id=archive.group_id, sheet_id=archive.sheet_id, store_until=archive.store_until)
db_archive.tags = tags
db_archive.urls = urls
db.add(db_archive)
db.commit()
db.refresh(db_archive)
return db_archive
def store_archived_url(db: Session, archive: schemas.ArchiveCreate) -> models.Archive:
# create and load user, tags, if needed
create_or_get_user(db, archive.author_id)
db_tags = [create_tag(db, tag) for tag in (archive.tags or [])]
# insert everything
db_archive = create_archive(db, archive=archive, tags=db_tags, urls=archive.urls)
return db_archive

13
app/shared/log.py Normal file
View File

@@ -0,0 +1,13 @@
import traceback
from loguru import logger
# logging configurations
logger.add("logs/api_logs.log", retention="30 days")
logger.add("logs/error_logs.log", retention="30 days", level="ERROR")
def log_error(e: Exception, traceback_str: str = None, extra:str = ""):
if not traceback_str: traceback_str = traceback.format_exc()
if extra: extra = f"{extra}\n"
logger.error(f"{extra}{e.__class__.__name__}: {e}\n{traceback_str}")

100
app/shared/schemas.py Normal file
View File

@@ -0,0 +1,100 @@
from typing import Annotated
from annotated_types import Len
from pydantic import BaseModel
from datetime import datetime
class SubmitSheet(BaseModel):
sheet_id: str | None
author_id: str | None = None
group_id: str = "default"
tags: set[str] | None = set()
class ArchiveUrl(BaseModel):
url: str
public: bool = False
author_id: str | None
group_id: str | None
tags: set[str] | None = set()
class ArchiveResult(BaseModel):
id: str
url: str
result: dict
created_at: datetime
store_until: datetime | None
class Task(BaseModel):
id: str
class TaskResult(Task):
status: str
result: str
class DeleteResponse(Task):
deleted: bool
class ActiveUser(BaseModel):
active: bool
class SheetAdd(BaseModel):
id: str
name: str
group_id: str
frequency: str
class SheetResponse(SheetAdd):
author_id: str
created_at: datetime
last_url_archived_at: datetime | None
class ArchiveTrigger(BaseModel):
author_id: str | None = None
url: Annotated[str, Len(min_length=5)]
public: bool = False
group_id: Annotated[str, Len(min_length=1)] = "default"
tags: set[str] | None = None
class ArchiveCreate(ArchiveTrigger):
id: str | None = None
result: dict | None = None
sheet_id: str | None = None
urls: list | None = None
store_until: datetime | None = None
class Archive(ArchiveCreate):
created_at: datetime
updated_at: datetime | None
deleted: bool
model_config = {"from_attributes": True}
class Usage(BaseModel):
monthly_urls: int = 0
monthly_mbs: int = 0
total_sheets: int = 0
class UsageResponse(Usage):
groups: dict[str, Usage]
class CelerySheetTask(BaseModel):
success: bool
sheet_id: str
time: datetime
stats: dict
class SubmitManualArchive(ArchiveTrigger):
result: str # should be a Metadata.to_json()

76
app/shared/settings.py Normal file
View File

@@ -0,0 +1,76 @@
from functools import lru_cache
import os
from fastapi_mail import ConnectionConfig
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import Annotated, Set
from annotated_types import Len
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=os.environ.get("ENVIRONMENT_FILE") , env_file_encoding='utf-8', extra='ignore', str_strip_whitespace=True)
# general
SERVE_LOCAL_ARCHIVE: str | None = None
USER_GROUPS_FILENAME: str = "app/user-groups.yaml"
# database
DATABASE_PATH: str
DATABASE_QUERY_LIMIT: int = 100
@property
def ASYNC_DATABASE_PATH(self) -> str:
return self.DATABASE_PATH.replace("sqlite://", "sqlite+aiosqlite://")
# security
API_BEARER_TOKEN: Annotated[str, Len(min_length=20)]
ALLOWED_ORIGINS: Annotated[Set[str], Len(min_length=1)]
CHROME_APP_IDS: Annotated[Set[Annotated[str, Len(min_length=10)]], Len(min_length=1)]
BLOCKED_EMAILS: Annotated[Set[str], Len(min_length=0)] = set()
# redis
REDIS_PASSWORD: str = ""
REDIS_HOSTNAME: str = "localhost"
REDIS_EXCEPTIONS_CHANNEL: str = "exceptions-channel"
@property
def CELERY_BROKER_URL(self)-> str:
if self.REDIS_PASSWORD:
return f"redis://:{self.REDIS_PASSWORD}@{self.REDIS_HOSTNAME}:6379"
return f"redis://{self.REDIS_HOSTNAME}:6379"
# cronjobs
CRON_ARCHIVE_SHEETS: bool = False
CRON_DELETE_STALE_SHEETS: bool = False
DELETE_STALE_SHEETS_DAYS: int = 14
CRON_DELETE_SCHEDULED_ARCHIVES: bool = False
DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS: int = 7
# observability
REPEAT_COUNT_METRICS_SECONDS: int = 30
# email configuration, if needed
MAIL_FROM: str = "noreply@bellingcat.com"
MAIL_FROM_NAME: str = "Bellingcat's Auto Archiver"
MAIL_USERNAME: str = ""
MAIL_PASSWORD: str = ""
MAIL_SERVER: str = ""
MAIL_PORT: int = 587
MAIL_STARTTLS: bool = False
MAIL_SSL_TLS: bool = True
@property
def MAIL_CONFIG(self) -> str:
return ConnectionConfig(
MAIL_FROM=self.MAIL_FROM,
MAIL_FROM_NAME=self.MAIL_FROM_NAME,
MAIL_USERNAME=self.MAIL_USERNAME,
MAIL_PASSWORD=self.MAIL_PASSWORD,
MAIL_SERVER=self.MAIL_SERVER,
MAIL_PORT=self.MAIL_PORT,
MAIL_STARTTLS=self.MAIL_STARTTLS,
MAIL_SSL_TLS=self.MAIL_SSL_TLS,
)
@lru_cache
def get_settings():
return Settings()

View File

@@ -0,0 +1,23 @@
from functools import lru_cache
from celery import Celery
import redis
from app.shared.settings import get_settings
@lru_cache
def get_celery(name: str = "") -> Celery:
return Celery(
name,
broker_url=get_settings().CELERY_BROKER_URL,
result_backend=get_settings().CELERY_BROKER_URL,
broker_connection_retry_on_startup=False,
broker_transport_options={
'queue_order_strategy': 'priority',
}
)
def get_redis() -> redis.Redis:
return redis.Redis.from_url(get_settings().CELERY_BROKER_URL)

169
app/shared/user_groups.py Normal file
View File

@@ -0,0 +1,169 @@
import json
import os
import yaml
from loguru import logger
from pydantic import BaseModel, computed_field, field_validator, Field, model_validator
from typing import Dict, List, Set
from typing_extensions import Self
class UserGroups:
def __init__(self, filename):
user_groups = self.read_yaml(filename)
self.validate_and_load(user_groups)
def read_yaml(self, user_groups_filename):
# read yaml safely
with open(user_groups_filename) as inf:
try:
return yaml.safe_load(inf)
except yaml.YAMLError as e:
logger.error(f"could not open user groups filename {user_groups_filename}: {e}")
raise e
def validate_and_load(self, user_groups):
try:
configs = UserGroupModel(**user_groups)
self.users = configs.users
self.domains = configs.domains
self.groups = configs.groups
except Exception as e:
logger.error(f"Validation error: {e}")
raise e
class GroupPermissions(BaseModel):
read: Set[str] | bool = Field(default_factory=list)
read_public: bool = False
archive_url: bool = False
archive_sheet: bool = False
manually_trigger_sheet: bool = False
sheet_frequency: Set[str] = Field(default_factory=list)
max_sheets: int = 0
max_archive_lifespan_months: int = 12
max_monthly_urls: int = 0
max_monthly_mbs: int = 0
priority: str = "low"
@field_validator('max_sheets', 'max_archive_lifespan_months', 'max_monthly_urls', 'max_monthly_mbs', mode='before')
def validate_max_values(cls, v):
if v < -1:
raise ValueError("max_* values should be positive integers or -1 (for no limit).")
return v
@field_validator('sheet_frequency', mode='before')
def validate_sheet_frequency(cls, v):
if not v: return []
allowed = ["daily", "hourly"]
for k in v:
if k not in allowed:
raise ValueError(f"Invalid sheet frequency: '{k}', expected one of {allowed}")
return v
@field_validator('priority', mode='before')
def validate_priority(cls, v):
v = v.lower()
if v not in ["low", "high"]:
raise ValueError("priority must be either 'low' or 'high'.")
return v
class GroupModel(BaseModel):
description: str
orchestrator: str
orchestrator_sheet: str
permissions: GroupPermissions
@field_validator('orchestrator', 'orchestrator_sheet', mode='before')
def validate_orchestrator(cls, v):
if not os.path.exists(v):
raise ValueError(f"Orchestrator file not found with this path: {v}")
return v
@computed_field
@property
def service_account_email(self) -> str:
if hasattr(self, "_service_account_email"):
return self._service_account_email
orch = yaml.safe_load(open(self.orchestrator_sheet))
def find_service_account_email(d):
for k, v in d.items():
if k == "service_account":
return v
if isinstance(v, dict):
if result := find_service_account_email(v):
return result
return False
service_account_json = find_service_account_email(orch)
if not service_account_json:
raise ValueError(f"service_account key not found in orchestrator sheet file: {self.orchestrator_sheet}.")
with open(service_account_json) as f:
self._service_account_email = json.load(f).get("client_email")
if not self._service_account_email:
raise ValueError(f"Service account email not found in {service_account_json}.")
return self._service_account_email
class UserGroupModel(BaseModel):
users: Dict[str, List[str]] = Field(default_factory=dict)
domains: Dict[str, List[str]] = Field(default_factory=dict)
groups: Dict[str, GroupModel] = Field(default_factory=dict)
@field_validator('users', mode='before')
@classmethod
def validate_emails(cls, v):
for email in v.keys():
if '@' not in email:
raise ValueError(f"Invalid user, it should be an address: {email}")
if not v[email]:
raise ValueError(f"User {email} has no explicitly listed groups, only include them here if they should be in a group.")
# all users belong to the default group
return {k.lower().strip(): list(set(["default"] + [g.lower().strip() for g in v])) for k, v in v.items()}
@field_validator('domains', mode='before')
@classmethod
def validate_domains(cls, v):
for domain, members in v.items():
if '.' not in domain:
raise ValueError(f"Invalid domain, it should contain a dot: {domain}")
if not members:
raise ValueError(f"Domain {domain} should have at least one member.")
return {k.lower().strip(): list(set([g.lower().strip() for g in v])) for k, v in v.items()}
@field_validator('groups', mode='before')
@classmethod
def validate_groups(cls, v):
if "default" not in v.keys():
raise ValueError("Please include a 'default' group.")
if "all" in v.keys():
raise ValueError("'all' is a reserved group name.")
for group in v.keys():
if not group == group.lower():
raise ValueError(f"Group names should be lowercase: {group}")
return v
@model_validator(mode='after')
def check_groups_consistency(self) -> Self:
groups_in_domains = set([g for domain in self.domains for g in self.domains[domain]])
groups_in_users = set([g for user in self.users for g in self.users[user]])
configured_groups = set(self.groups.keys())
# groups mentioned in domains and users should be defined, but this is not a ValueError since historical DB data may require it
if groups_in_domains - configured_groups:
logger.warning(f"These groups are associated to DOMAINS but not defined in the GROUPS section, the domains settings may not work as expected: {groups_in_domains - configured_groups}")
if groups_in_users - configured_groups:
logger.warning(f"These groups are associated to USERS but not defined in the GROUPS section, the users settings may not work as expected: {groups_in_users - configured_groups}")
return self
# for the API return values
class GroupInfo(GroupPermissions):
description: str = ""
service_account_email: str = ""

10
app/shared/utils/misc.py Normal file
View File

@@ -0,0 +1,10 @@
def fnv1a_hash_mod(s: str, modulo:int) -> int:
# receives a string and returns a number in [0:modulo-1], ensures an even distribution over the modulo range
hash = 0x811c9dc5 # FNV offset basis
fnv_prime = 0x01000193 # FNV prime
for char in s:
hash ^= ord(char)
hash *= fnv_prime
hash &= 0xFFFFFFFF # Keep it 32-bit
return (hash if hash < 0x80000000 else hash - 0x100000000) % modulo

333
app/test.ipynb Normal file

File diff suppressed because one or more lines are too long

160
app/tests/conftest.py Normal file
View File

@@ -0,0 +1,160 @@
import os
from typing import AsyncGenerator
from fastapi.testclient import TestClient
import pytest
from unittest.mock import patch
import pytest_asyncio
from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine
from app.web.config import ALLOW_ANY_EMAIL
from app.shared.settings import Settings
from app.web.db.user_state import UserState
@pytest.fixture(autouse=True)
def mock_logger_add():
"""Fixture to mock loguru.logger.add for all tests."""
with patch('loguru.logger.add') as mock_add:
yield mock_add # This makes the mock available to tests
@pytest.fixture()
def get_settings():
return Settings(_env_file=".env.test")
@pytest.fixture(autouse=True)
def mock_settings():
with patch('app.shared.settings.Settings', return_value=Settings(_env_file=".env.test")) as mock_settings:
yield mock_settings
@pytest.fixture()
def test_db(get_settings: Settings):
from app.shared.db import models
from app.shared.db.database import make_engine
from app.web.db.crud import get_user_group_names
get_user_group_names.cache_clear()
make_engine.cache_clear()
engine = make_engine(get_settings.DATABASE_PATH)
fs = get_settings.DATABASE_PATH.replace("sqlite:///", "")
if not os.path.exists(fs):
open(fs, 'w').close()
models.Base.metadata.create_all(engine)
connection = engine.connect()
yield connection
connection.close()
models.Base.metadata.drop_all(bind=engine)
for suffix in ["", "-wal", "-shm"]:
new_fs = fs + suffix
if os.path.exists(new_fs):
os.remove(new_fs)
@pytest.fixture()
def db_session(test_db):
from app.shared.db.database import make_session_local
session_local = make_session_local(test_db)
with session_local() as session:
yield session
@pytest_asyncio.fixture()
async def async_test_db(get_settings: Settings):
from app.shared.db import models
from app.shared.db.database import make_async_engine
from app.web.db.crud import get_user_group_names
import asyncio
get_user_group_names.cache_clear()
engine = await make_async_engine(get_settings.ASYNC_DATABASE_PATH)
fs = get_settings.ASYNC_DATABASE_PATH.replace("sqlite+aiosqlite:///", "")
if not os.path.exists(fs):
open(fs, 'w').close()
async def create_all():
async with engine.begin() as conn:
await conn.run_sync(models.Base.metadata.create_all)
await create_all()
yield engine
async def drop_all():
async with engine.begin() as conn:
await conn.run_sync(models.Base.metadata.drop_all)
await drop_all()
engine.dispose()
for suffix in ["", "-wal", "-shm"]:
new_fs = fs + suffix
if os.path.exists(new_fs):
os.remove(new_fs)
@pytest_asyncio.fixture()
async def async_db_session(async_test_db: AsyncEngine) -> AsyncGenerator[AsyncSession, None]:
from app.shared.db.database import make_async_session_local
session_local = await make_async_session_local(async_test_db)
async with session_local() as session:
yield session
@pytest.fixture()
def app(db_session):
from app.web.main import app_factory
from app.web.db import crud
app = app_factory()
crud.upsert_user_groups(db_session)
return app
@pytest.fixture()
def client(app):
client = TestClient(app)
return client
@pytest.fixture()
def app_with_auth(app, db_session):
from app.web.security import get_token_or_user_auth, get_user_auth, get_user_state
app.dependency_overrides[get_token_or_user_auth] = lambda: "rick@example.com"
app.dependency_overrides[get_user_auth] = lambda: "morty@example.com"
app.dependency_overrides[get_user_state] = lambda: UserState(db_session, "MORTY@example.com")
return app
@pytest.fixture()
def client_with_auth(app_with_auth):
client = TestClient(app_with_auth)
return client
@pytest.fixture()
def app_with_token(app):
from app.web.security import token_api_key_auth, get_token_or_user_auth
app.dependency_overrides[token_api_key_auth] = lambda: ALLOW_ANY_EMAIL
app.dependency_overrides[get_token_or_user_auth] = lambda: ALLOW_ANY_EMAIL
return app
@pytest.fixture()
def client_with_token(app_with_token):
client = TestClient(app_with_token)
return client
@pytest.fixture()
def test_no_auth():
# reusable code to ensure a method/endpoint combination is unauthorized
def no_auth(http_method, endpoint):
response = http_method(endpoint)
assert response.status_code == 403
assert response.json() == {"detail": "Not authenticated"}
return no_auth

View File

@@ -0,0 +1,3 @@
{
"client_email": "fake_service_account@fake_service_account.iam.gserviceaccount.com"
}

View File

@@ -12,6 +12,8 @@ steps:
- console_db
configurations:
gsheet_feeder:
service_account: "app/tests/fake_service_account.json"
cli_feeder:
urls:
- "url1"

View File

@@ -1,5 +1,5 @@
def test_generate_uuid():
from db.models import generate_uuid
from app.shared.db.models import generate_uuid
assert generate_uuid() != generate_uuid()
assert len(generate_uuid()) == 36

View File

@@ -0,0 +1,117 @@
from app.shared.db import models
from app.shared.db import worker_crud, models
from datetime import datetime
from app.tests.web.db.test_crud import test_data
def test_update_sheet_last_url_archived_at(db_session):
# Create test sheet
test_sheet = models.Sheet(id="sheet-123")
db_session.add(test_sheet)
db_session.commit()
# Test updating existing sheet
assert isinstance(test_sheet.last_url_archived_at, datetime)
before = test_sheet.last_url_archived_at
assert worker_crud.update_sheet_last_url_archived_at(db_session, "sheet-123") is True
db_session.refresh(test_sheet)
assert isinstance(test_sheet.last_url_archived_at, datetime)
assert test_sheet.last_url_archived_at > before
# Test non-existent sheet
assert worker_crud.update_sheet_last_url_archived_at(db_session, "non-existent-sheet") is False
def test_get_group(test_data, db_session):
from app.shared.db import worker_crud
assert worker_crud.get_group(db_session, "spaceship") is not None
assert worker_crud.get_group(db_session, "interdimensional") is not None
assert worker_crud.get_group(db_session, "animated-characters") is not None
assert worker_crud.get_group(db_session, "non-existent!@#!%!") is None
def test_create_or_get_user(test_data, db_session):
from app.shared.db import worker_crud
assert db_session.query(models.User).count() == 3
# already exists
assert (u1 := worker_crud.create_or_get_user(db_session, "rick@example.com")) is not None
assert u1.email == "rick@example.com"
# new user
assert (u2 := worker_crud.create_or_get_user(db_session, "beth@example.com")) is not None
assert u2.email == "beth@example.com"
assert db_session.query(models.User).count() == 4
def test_create_tag(db_session):
from app.shared.db import worker_crud
assert db_session.query(models.Tag).count() == 0
# create first
create_tag = worker_crud.create_tag(db_session, "tag-101")
assert create_tag is not None
assert create_tag.id == "tag-101"
assert db_session.query(models.Tag).count() == 1
assert db_session.query(models.Tag).filter(models.Tag.id == "tag-101").first() == create_tag
# same id does not add new db entry
existing_tag = worker_crud.create_tag(db_session, "tag-101")
assert existing_tag == create_tag
assert db_session.query(models.Tag).count() == 1
# create second
second_tag = worker_crud.create_tag(db_session, "tag-102")
assert second_tag is not None
assert second_tag.id == "tag-102"
assert db_session.query(models.Tag).count() == 2
def test_create_task(db_session):
from app.shared.db import worker_crud
from app.shared import schemas
task = schemas.ArchiveCreate(
id="archive-id-456-101",
url="https://example-0.com",
result={},
public=False,
author_id="rick@example.com",
group_id="spaceship",
tags=[],
urls=[]
)
# with tags and urls
nt = worker_crud.create_archive(db_session, task, [models.Tag(id="tag-101")], [models.ArchiveUrl(url="https://example-0.com/0", key="media_0")])
assert nt is not None
assert nt.id == "archive-id-456-101"
assert nt.url == "https://example-0.com"
assert nt.author_id == "rick@example.com"
assert nt.public == False
assert nt.group_id == "spaceship"
assert len(nt.tags) == 1
assert nt.tags[0].id == "tag-101"
assert len(nt.urls) == 1
assert nt.urls[0].url == "https://example-0.com/0"
assert nt.urls[0].key == "media_0"
assert nt.created_at is not None
# without tags and urls
task.id = "archive-id-456-102"
nt = worker_crud.create_archive(db_session, task, [], [])
assert nt is not None
assert nt.id == "archive-id-456-102"
assert nt.url == "https://example-0.com"
assert nt.author_id == "rick@example.com"
assert nt.public == False
assert nt.group_id == "spaceship"
assert len(nt.tags) == 0
assert len(nt.urls) == 0
assert nt.created_at is not None

View File

@@ -0,0 +1,36 @@
from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch
import pytest
from app.shared.business_logic import get_store_archive_until
class Test_get_store_archive_until:
GROUP_ID = "test-group"
def test_group_not_found(self, db_session):
with pytest.raises(AssertionError) as exc:
get_store_archive_until(db_session, self.GROUP_ID)
assert str(exc.value) == f"Group {self.GROUP_ID} not found."
@patch("app.shared.db.worker_crud.get_group")
def test_no_max_lifespan(self, mock_get_group, db_session):
group = MagicMock()
group.permissions = {"max_archive_lifespan_months": -1}
mock_get_group.return_value = group
result = get_store_archive_until(db_session, self.GROUP_ID)
assert result is None
mock_get_group.assert_called_once_with(db_session, self.GROUP_ID)
@patch("app.shared.db.worker_crud.get_group")
def test_with_max_lifespan(self, mock_get_group, db_session):
group = MagicMock()
group.permissions = {"max_archive_lifespan_months": 6}
mock_get_group.return_value = group
result = get_store_archive_until(db_session, self.GROUP_ID)
expected = datetime.now() + timedelta(days=180) # 6 months
assert isinstance(result, datetime)
# Allow 1 second difference due to execution time
assert abs(result - expected) < timedelta(seconds=1)
mock_get_group.assert_called_once_with(db_session, self.GROUP_ID)

View File

@@ -0,0 +1,31 @@
from app.shared.utils.misc import fnv1a_hash_mod
def test_fnv1a_hash_mod():
# Test basic string hashing
assert fnv1a_hash_mod("test", 10) == fnv1a_hash_mod("test", 10)
assert 0 <= fnv1a_hash_mod("test", 10) < 10
# Test different strings give different hashes
assert fnv1a_hash_mod("test1", 100) != fnv1a_hash_mod("test2", 100)
# Test different modulos
hash1 = fnv1a_hash_mod("test", 5)
hash2 = fnv1a_hash_mod("test", 10)
assert 0 <= hash1 < 5
assert 0 <= hash2 < 10
# Test empty string
assert isinstance(fnv1a_hash_mod("", 10), int)
assert 0 <= fnv1a_hash_mod("", 10) < 10
# Test long string
long_str = "a" * 1000
assert 0 <= fnv1a_hash_mod(long_str, 20) < 20
# Test unicode string
assert isinstance(fnv1a_hash_mod("测试", 10), int)
assert 0 <= fnv1a_hash_mod("测试", 10) < 10
# Test modulo = 1 edge case
assert fnv1a_hash_mod("test", 1) == 0

View File

@@ -0,0 +1,87 @@
# NOTE: all emails should be lower-cased
users:
rick@example.com:
- spaceship
- interdimensional
morty@example.com:
- spaceship
jerry@example.com:
- the-jerrys-club
# summer@herself.com:
# badyemail.com:
domains:
example.com:
- animated-characters
birdy.com:
- animated-characters
- this-does-not-exist
orchestrators:
spaceship: app/tests/orchestration.test.yaml
interdimensional: app/tests/orchestration.test.yaml
default: app/tests/orchestration.test.yaml
default_orchestrator: app/tests/orchestration.test.yaml
groups:
spaceship:
description: "The spaceship crew"
orchestrator: app/tests/orchestration.test.yaml
orchestrator_sheet: app/tests/orchestration.test.yaml
permissions:
read: ["all"]
archive_url: true
archive_sheet: true
manually_trigger_sheet: true
sheet_frequency: ["hourly", "daily"]
max_sheets: -1
max_archive_lifespan_months: -1
max_monthly_urls: -1
max_monthly_mbs: -1
priority: "high"
interdimensional:
description: "Interdimensional travelers"
orchestrator: app/tests/orchestration.test.yaml
orchestrator_sheet: app/tests/orchestration.test.yaml
permissions:
read: ["interdimensional", "animated-characters"]
archive_url: true
archive_sheet: true
manually_trigger_sheet: true
sheet_frequency: ["hourly", "daily"]
max_sheets: 5
max_archive_lifespan_months: 12
max_monthly_urls: 1000
max_monthly_mbs: 1000
priority: "high"
animated-characters:
description: "Animated characters"
orchestrator: app/tests/orchestration.test.yaml
orchestrator_sheet: app/tests/orchestration.test.yaml
permissions:
read: ["animated-characters"]
archive_url: true
archive_sheet: true
sheet_frequency: ["daily"]
max_sheets: 1
max_archive_lifespan_months: 12
max_monthly_urls: 2
max_monthly_mbs: 10
priority: "low"
default:
description: "Public access"
orchestrator: app/tests/orchestration.test.yaml
orchestrator_sheet: app/tests/orchestration.test.yaml
permissions:
# read: []
archive_url: true
# manually_trigger_sheet: false
# archive_sheet: false
# sheet_frequency: []
# max_sheets: 0
# max_archive_lifespan_months: 12
max_monthly_urls: 1
# max_monthly_mbs: 50
priority: "low"

View File

@@ -0,0 +1,438 @@
from datetime import datetime, timedelta
from unittest.mock import patch
import pytest
import yaml
from app.shared.db import models
from app.shared.settings import Settings
from app.web.db import crud
authors = ["rick@example.com", "morty@example.com", "jerry@example.com"]
@pytest.fixture()
def test_data(db_session):
# creates 3 users
for email in authors:
db_session.add(models.User(email=email))
db_session.commit()
assert db_session.query(models.User).count() == 3
# creates 100 archives for 3 users over 2 months with repeating URLs
for i in range(100):
author = authors[i % 3]
archive = models.Archive(
id=f"archive-id-456-{i}",
url=f"https://example-{i%3}.com",
result={},
public=author == "jerry@example.com",
author_id=author,
group_id="spaceship" if author == "morty@example.com" and i % 2 == 0 else None,
created_at=datetime(2021, (i % 2) + 1, (i % 25) + 1)
)
if i % 5 == 0:
archive.tags.append(models.Tag(id=f"tag-{i}"))
if i % 10 == 0:
archive.tags.append(models.Tag(id=f"tag-second-{i}"))
if i % 4 == 0:
archive.tags.append(models.Tag(id=f"tag-third-{i}"))
for j in range(10):
archive.urls.append(models.ArchiveUrl(url=f"https://example-{i}.com/{j}", key=f"media_{j}"))
db_session.add(archive)
# creates a sheet for each user
for i, email in enumerate(authors):
db_session.add(models.Sheet(id=f"sheet-{i}", name=f"sheet-{i}", author_id=email, group_id=None, frequency="daily"))
if email == "rick@example.com":
db_session.add(models.Sheet(id=f"sheet-{i}-2", name=f"sheet-{i}-2", author_id=email, group_id="spaceship", frequency="hourly"))
db_session.commit()
assert db_session.query(models.Archive).count() == 100
assert db_session.query(models.Tag).count() == 20 + 10 + 25
assert db_session.query(models.ArchiveUrl).count() == 1000
assert db_session.query(models.ArchiveUrl).filter(models.ArchiveUrl.archive_id == "archive-id-456-0").count() == 10
# setup groups
assert db_session.query(models.Group).count() == 0
from app.web.db import crud
crud.upsert_user_groups(db_session)
assert db_session.query(models.Group).count() == 4
assert db_session.query(models.User).count() == 3
def test_search_archives_by_url(test_data, db_session):
from app.web.config import ALLOW_ANY_EMAIL
# rick's archives are private
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "rick@example.com", True, False)) == 34
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "rick@example.com", [], False)) == 34
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "rick@example.com", [], True)) == 34
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", ALLOW_ANY_EMAIL, [], False)) == 34
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", ALLOW_ANY_EMAIL, True, False)) == 34
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "morty@example.com", [], False)) == 0
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "morty@example.com", [], True)) == 0
# morty's archives are public but half are in spaceship group
assert len(crud.search_archives_by_url(db_session, "https://example-1.com", "rick@example.com", ["spaceship"], False)) == 16
assert len(crud.search_archives_by_url(db_session, "https://example-1.com", "rick@example.com", True, False)) == 16
assert len(crud.search_archives_by_url(db_session, "https://example-1.com", "jerry@example.com", True, True)) == 16
# jerry's archives are public
assert len(crud.search_archives_by_url(db_session, "https://example-2.com", "jerry@example.com", [], True)) == 33
assert len(crud.search_archives_by_url(db_session, "https://example-2.com", "rick@example.com", [], True)) == 33
# fuzzy search
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False)) == 100
assert len(crud.search_archives_by_url(db_session, "https://EXAMPLE", ALLOW_ANY_EMAIL, False, False)) == 100
assert len(crud.search_archives_by_url(db_session, "2.com", ALLOW_ANY_EMAIL, False, False)) == 33
# absolute search
assert len(crud.search_archives_by_url(db_session, "example-2.com", ALLOW_ANY_EMAIL, [], False, absolute_search=True)) == 0
assert len(crud.search_archives_by_url(db_session, "https://example-2.com", ALLOW_ANY_EMAIL, [], False, absolute_search=True)) == 33
# archived_after
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, True, True, archived_after=datetime(2010, 1, 1))) == 100
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_after=datetime(2021, 1, 15))) == 70
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_after=datetime(2031, 1, 1))) == 0
# archived before
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_before=datetime(2010, 1, 1))) == 0
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_before=datetime(2021, 1, 15))) == 28
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_before=datetime(2031, 1, 1))) == 100
# archived before and after
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_after=datetime(2001, 1, 1), archived_before=datetime(2031, 1, 11))) == 100
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_after=datetime(2021, 1, 14), archived_before=datetime(2021, 1, 16))) == 2
# limit
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, limit=10)) == 10
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, limit=-1)) == 1
# skip
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, skip=10)) == 90
def test_search_archives_by_email(test_data, db_session):
from app.web.config import ALLOW_ANY_EMAIL
# lower/upper case
assert len(crud.search_archives_by_email(db_session, "rick@example.com")) == 34
# ALLOW_ANY_EMAIL is not a user
assert len(crud.search_archives_by_email(db_session, ALLOW_ANY_EMAIL)) == 0
# most recent first
a1 = crud.search_archives_by_email(db_session, "rick@example.com", limit=1)
assert len(a1) == 1
assert a1[0].created_at == datetime(2021, 2, 25)
# earliest is the last
a2 = crud.search_archives_by_email(db_session, "rick@example.com", skip=33)
assert len(a2) == 1
assert a2[0].created_at == datetime(2021, 1, 1)
@patch("app.web.db.crud.DATABASE_QUERY_LIMIT", new=25)
def test_max_query_limit(test_data, db_session):
from app.web.config import ALLOW_ANY_EMAIL
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, [], False)) == 25
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, True, True, limit=1000)) == 25
assert len(crud.search_archives_by_email(db_session, "rick@example.com")) == 25
assert len(crud.search_archives_by_email(db_session, "rick@example.com", limit=1000)) == 25
def test_soft_delete(test_data, db_session):
# none deleted yet
db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").first() is not None
assert db_session.query(models.Archive).filter(models.Archive.deleted == True).count() == 0
# delete
assert crud.soft_delete_archive(db_session, "archive-id-456-0", "rick@example.com") == True
# ensure soft delete
assert db_session.query(models.Archive).filter(models.Archive.deleted == True).count() == 1
db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").first() is None
# already deleted
assert crud.soft_delete_archive(db_session, "archive-id-456-0", "rick@example.com") == False
def test_count_archives(test_data, db_session):
assert crud.count_archives(db_session) == 100
db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").delete()
db_session.commit()
assert crud.count_archives(db_session) == 99
def test_count_archive_urls(test_data, db_session):
assert crud.count_archive_urls(db_session) == 1000
db_session.query(models.ArchiveUrl).filter(models.ArchiveUrl.url == "https://example-0.com/0").delete()
db_session.commit()
assert crud.count_archive_urls(db_session) == 999
db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").delete()
db_session.commit()
# no Cascade is enabled
assert crud.count_archives(db_session) == 99
assert crud.count_archive_urls(db_session) == 999
def test_count_users(test_data, db_session):
assert crud.count_users(db_session) == 3
db_session.query(models.User).filter(models.User.email == "rick@example.com").delete()
db_session.commit()
assert crud.count_users(db_session) == 2
def test_count_by_users_since(test_data, db_session):
from app.web.db import crud
# 100y window
assert len(cu := crud.count_by_user_since(db_session, 60 * 60 * 24 * 31 * 12 * 100)) == 3
assert cu[0].total == 34
assert cu[1].total == 33
assert cu[2].total == 33
def test_upsert_group(test_data, db_session):
assert db_session.query(models.Group).count() == 4
repeatable_params = ["desc 1", "orch.yaml", "sheet.yaml", "service_account_email@example.com", {"read": ["all"]}, ["example.com"]]
assert (g1 := crud.upsert_group(db_session, "spaceship", *repeatable_params)) is not None
assert g1.id == "spaceship"
assert g1.description == "desc 1"
assert g1.orchestrator == "orch.yaml"
assert g1.orchestrator_sheet == "sheet.yaml"
assert g1.service_account_email == "service_account_email@example.com"
assert g1.permissions == {"read": ["all"]}
assert g1.domains == ["example.com"]
assert len(g1.users) == 2
assert [u.email for u in g1.users] == ["rick@example.com", "morty@example.com"]
assert (g2 := crud.upsert_group(db_session, "interdimensional", *repeatable_params)) is not None
assert g2.id == "interdimensional"
assert len(g2.users) == 1
assert [u.email for u in g2.users] == ["rick@example.com"]
assert (g3 := crud.upsert_group(db_session, "this-is-a-new-group", *repeatable_params)) is not None
assert g3.id == "this-is-a-new-group"
assert len(g3.users) == 0
assert db_session.query(models.Group).count() == 5
def test_upsert_user_groups(db_session):
@patch('app.web.db.crud.get_settings', new=lambda: bad_setings)
def test_missing_yaml(db_session):
with pytest.raises(FileNotFoundError):
crud.upsert_user_groups(db_session)
@patch('app.web.db.crud.get_settings', new=lambda: bad_setings)
def test_broken_yaml(db_session):
with pytest.raises(yaml.YAMLError):
crud.upsert_user_groups(db_session)
bad_setings = Settings(_env_file=".env.test")
bad_setings.USER_GROUPS_FILENAME = "app/tests/user-groups.test.missing.yaml"
test_missing_yaml(db_session)
bad_setings.USER_GROUPS_FILENAME = "app/tests/user-groups.test.broken.yaml"
test_broken_yaml(db_session)
def test_create_sheet(db_session):
assert db_session.query(models.Sheet).count() == 0
s = crud.create_sheet(db_session, "sheet-id-123", "sheet name", "email@example.com", "group-id", "hourly")
assert s is not None
assert s.id == "sheet-id-123"
assert s.name == "sheet name"
assert s.author_id == "email@example.com"
assert s.group_id == "group-id"
assert s.frequency == "hourly"
assert db_session.query(models.Sheet).count() == 1
# duplicate id
import sqlalchemy
with pytest.raises(sqlalchemy.exc.IntegrityError):
crud.create_sheet(db_session, "sheet-id-123", "I thought this was another sheet", "email", "group-id", "hourly")
def test_get_user_sheet(test_data, db_session):
assert crud.get_user_sheet(db_session, "", "sheet-0") is None
assert crud.get_user_sheet(db_session, "morty@example.com", "sheet-0") is None
assert crud.get_user_sheet(db_session, "rick@example.com", "sheet-0") is not None
assert crud.get_user_sheet(db_session, "rick@example.com", "sheet-0-2") is not None
assert crud.get_user_sheet(db_session, "morty@example.com", "sheet-1") is not None
def test_get_user_sheets(test_data, db_session):
assert len(crud.get_user_sheets(db_session, "")) == 0
rick_sheets = crud.get_user_sheets(db_session, "rick@example.com")
assert len(rick_sheets) == 2
assert [s.id for s in rick_sheets] == ["sheet-0", "sheet-0-2"]
assert len(crud.get_user_sheets(db_session, "morty@example.com")) == 1
def test_delete_sheet(test_data, db_session):
assert crud.delete_sheet(db_session, "sheet-0", "") == False
assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") == True
assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") == False
@pytest.mark.asyncio
async def test_find_by_store_until(async_db_session):
# Add archives with different store_until dates
now = datetime.now()
archive1 = models.Archive(
id="archive-expired-1",
url="https://example-expired-1.com",
result={},
author_id="rick@example.com",
store_until=now - timedelta(days=1)
)
archive2 = models.Archive(
id="archive-expired-2",
url="https://example-expired-2.com",
result={},
author_id="rick@example.com",
store_until=now - timedelta(hours=1)
)
archive3 = models.Archive(
id="archive-active",
url="https://example-active.com",
result={},
author_id="rick@example.com",
store_until=now + timedelta(days=1)
)
async_db_session.add_all([archive1, archive2, archive3])
await async_db_session.commit()
# Should find 2 expired archives
expired = await crud.find_by_store_until(async_db_session, now)
assert len(list(expired)) == 2
# Should find 1 archive expired before 2 hours ago
expired = await crud.find_by_store_until(async_db_session, now - timedelta(hours=2))
assert len(list(expired)) == 1
# Should find no archives expired before 2 days ago
expired = await crud.find_by_store_until(async_db_session, now - timedelta(days=2))
assert len(list(expired)) == 0
# Should not find deleted archives
archive1.deleted = True
await async_db_session.commit()
expired = await crud.find_by_store_until(async_db_session, now)
assert len(list(expired)) == 1
@pytest.mark.asyncio
async def test_get_sheets_by_id_hash(async_db_session):
# Add test data
authors = ["rick@example.com", "morty@example.com", "jerry@example.com"]
sheets = [
models.Sheet(id="sheet-0", name="sheet-0", author_id=authors[0], group_id=None, frequency="daily"),
models.Sheet(id="sheet-0-2", name="sheet-0-2", author_id=authors[0], group_id="spaceship", frequency="hourly"),
models.Sheet(id="sheet-1", name="sheet-1", author_id=authors[1], group_id=None, frequency="daily"),
models.Sheet(id="sheet-2", name="sheet-2", author_id=authors[2], group_id=None, frequency="daily")
]
async_db_session.add_all(sheets)
await async_db_session.commit()
with patch("app.web.db.crud.fnv1a_hash_mod", return_value=1):
# Test retrieving hourly sheets
hourly_sheets = await crud.get_sheets_by_id_hash(async_db_session, "hourly", 4, 1)
assert len(hourly_sheets) == 1
assert hourly_sheets[0].id == "sheet-0-2"
assert hourly_sheets[0].frequency == "hourly"
# Test retrieving daily sheets
daily_sheets = await crud.get_sheets_by_id_hash(async_db_session, "daily", 4, 1)
assert len(daily_sheets) == 3
assert all(sheet.frequency == "daily" for sheet in daily_sheets)
assert {sheet.id for sheet in daily_sheets} == {"sheet-0", "sheet-1", "sheet-2"}
# Test with non-matching hash
no_sheets = await crud.get_sheets_by_id_hash(async_db_session, "daily", 4, 3)
assert len(no_sheets) == 0
# Test with non-existent frequency
weekly_sheets = await crud.get_sheets_by_id_hash(async_db_session, "weekly", 4, 1)
assert len(weekly_sheets) == 0
@pytest.mark.asyncio
async def test_delete_stale_sheets(async_db_session):
from datetime import datetime, timedelta
from sqlalchemy.sql import select
now = datetime.now()
active_date = now - timedelta(days=5)
stale_date = now - timedelta(days=15)
# Create test sheets with different last_url_archived_at dates
sheets = [
models.Sheet(
id="sheet-active-1",
name="Active Sheet 1",
author_id="rick@example.com",
frequency="daily",
last_url_archived_at=active_date
),
models.Sheet(
id="sheet-active-2",
name="Active Sheet 2",
author_id="morty@example.com",
frequency="hourly",
last_url_archived_at=active_date
),
models.Sheet(
id="sheet-stale-1",
name="Stale Sheet 1",
author_id="rick@example.com",
frequency="daily",
last_url_archived_at=stale_date
),
models.Sheet(
id="sheet-stale-2",
name="Stale Sheet 2",
author_id="morty@example.com",
frequency="daily",
last_url_archived_at=stale_date
)
]
async_db_session.add_all(sheets)
await async_db_session.commit()
# Should not delete sheets with 20 days inactivity threshold
deleted = await crud.delete_stale_sheets(async_db_session, 20)
assert len(deleted) == 0 # No sheets should be deleted
result = await async_db_session.execute(select(models.Sheet))
assert len(list(result.scalars())) == 4 # All sheets should remain
# Should delete sheets with 7 days inactivity threshold
deleted = await crud.delete_stale_sheets(async_db_session, 7)
assert len(deleted) == 2 # Two authors affected
assert len(deleted["rick@example.com"]) == 1 # One sheet deleted for Rick
assert len(deleted["morty@example.com"]) == 1 # One sheet deleted for Morty
assert deleted["rick@example.com"][0].id == "sheet-stale-1"
assert deleted["morty@example.com"][0].id == "sheet-stale-2"
# Verify only active sheets remain
result = await async_db_session.execute(select(models.Sheet))
remaining = list(result.scalars())
assert len(remaining) == 2
assert {s.id for s in remaining} == {"sheet-active-1", "sheet-active-2"}
# Running again should not delete anything
deleted = await crud.delete_stale_sheets(async_db_session, 7)
assert len(deleted) == 0

View File

@@ -0,0 +1,433 @@
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from app.shared.db import models
from app.shared.user_groups import GroupInfo, GroupPermissions
from app.web.db.user_state import UserState
def fresh_user_state():
return UserState(None, email="test@example.com")
@pytest.fixture
def user_state():
return fresh_user_state()
@pytest.fixture
def user_state_with_groups(user_state):
user_groups = [
models.Group(id="no-permissions", permissions={}),
models.Group(id="group1", description="this is g1", service_account_email="sa1@example.com", permissions={"read": ["group1", "no-permissions"], "read_public": True, "archive_url": True, "archive_sheet": True, "max_archive_lifespan_months": 24, "max_monthly_urls": 100, "max_monthly_mbs": 1000, "priority": "high"}),
models.Group(id="group2", description="this is g2", service_account_email="sa2@example.com", permissions={"read": ["all"], "read_public": True, "archive_url": False, "archive_sheet": False, "max_archive_lifespan_months": -1, "max_monthly_urls": -1, "max_monthly_mbs": -1, "priority": "low", "sheet_frequency": {"daily"}}),
]
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=user_groups):
yield user_state
def test_permissions(user_state_with_groups):
permissions = user_state_with_groups.permissions
assert permissions["all"].read == True
assert permissions["all"].read_public == True
assert permissions["all"].archive_url == True
assert permissions["all"].archive_sheet == True
assert permissions["all"].max_archive_lifespan_months == -1
assert permissions["all"].max_monthly_urls == -1
assert permissions["all"].max_monthly_mbs == -1
assert permissions["all"].priority == "high"
assert permissions["group1"].read == set(["group1", "no-permissions"])
assert permissions["group1"].read_public == True
assert permissions["group1"].archive_url == True
assert permissions["group1"].archive_sheet == True
assert permissions["group1"].max_archive_lifespan_months == 24
assert permissions["group1"].max_monthly_urls == 100
assert permissions["group1"].max_monthly_mbs == 1000
assert permissions["group1"].priority == "high"
assert permissions["group2"].read == set(["all"])
assert permissions["group2"].read_public == True
assert permissions["group2"].archive_url == False
assert permissions["group2"].archive_sheet == False
assert permissions["group2"].max_archive_lifespan_months == -1
assert permissions["group2"].max_monthly_urls == -1
assert permissions["group2"].max_monthly_mbs == -1
assert permissions["group2"].priority == "low"
assert len(permissions) == 3
def test_user_groups_names(user_state):
with patch('app.web.db.crud.get_user_group_names', return_value=["group1", "group2"]) as mock:
assert user_state.user_groups_names == ["group1", "group2", "default"]
mock.assert_called_once_with(None, "test@example.com")
def test_user_groups(user_state):
with patch('app.web.db.crud.get_user_groups_by_name', return_value=[MagicMock(), MagicMock()]) as mock:
user_state._user_groups_names = ["group1", "group2"]
assert len(user_state.user_groups) == 2
mock.assert_called_once_with(None, ["group1", "group2"])
def test_read():
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
assert not hasattr(us, "_read")
assert us.read == set()
assert us._read == set()
mock.assert_called_once()
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"read": ["group1", "no-permissions"]})]):
assert us.read == set(["group1", "no-permissions"])
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"read": ["all"]})]):
assert us.read == True
def test_read_public():
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
assert not hasattr(us, "_read_public")
assert us.read_public == False
assert us._read_public == False
mock.assert_called_once()
# no new calls
assert us.read_public == False
mock.assert_called_once()
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"read_public": True})]):
assert us.read_public == True
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"read_public": False})]):
assert us.read_public == False
def test_archive_url():
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
assert not hasattr(us, "_archive_url")
assert us.archive_url == False
assert us._archive_url == False
mock.assert_called_once()
# no new calls
assert us.archive_url == False
mock.assert_called_once()
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"archive_url": False})]):
assert us.archive_url == False
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"archive_url": True})]):
assert us.archive_url == True
def test_archive_sheet():
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
assert not hasattr(us, "_archive_sheet")
assert us.archive_sheet == False
assert us._archive_sheet == False
mock.assert_called_once()
# no new calls
assert us.archive_sheet == False
mock.assert_called_once()
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"archive_sheet": False})]):
assert us.archive_sheet == False
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"archive_sheet": True})]):
assert us.archive_sheet == True
def test_sheet_frequency():
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
assert not hasattr(us, "_sheet_frequency")
assert us.sheet_frequency == set()
assert us._sheet_frequency == set()
mock.assert_called_once()
# no new calls
assert us.sheet_frequency == set()
mock.assert_called_once()
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"sheet_frequency": ["daily", "hourly"]})]):
assert us.sheet_frequency == {"daily", "hourly"}
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"sheet_frequency": []})]):
assert us.sheet_frequency == set()
def test_max_archive_lifespan_months():
us = fresh_user_state()
default = GroupPermissions.model_fields["max_archive_lifespan_months"].default
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
assert not hasattr(us, "_max_archive_lifespan_months")
assert us.max_archive_lifespan_months == default
assert us._max_archive_lifespan_months == default
mock.assert_called_once()
# no new calls
assert us.max_archive_lifespan_months == default
mock.assert_called_once()
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_archive_lifespan_months": 24})]):
assert us.max_archive_lifespan_months == 24
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_archive_lifespan_months": 150}), models.Group(id="group2", permissions={"max_archive_lifespan_months": -1})]):
assert us.max_archive_lifespan_months == -1
def test_max_monthly_urls():
us = fresh_user_state()
default = GroupPermissions.model_fields["max_monthly_urls"].default
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
assert not hasattr(us, "_max_monthly_urls")
assert us.max_monthly_urls == default
assert us._max_monthly_urls == default
mock.assert_called_once()
# no new calls
assert us.max_monthly_urls == default
mock.assert_called_once()
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_monthly_urls": 100})]):
assert us.max_monthly_urls == 100
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_monthly_urls": 150}), models.Group(id="group2", permissions={"max_monthly_urls": -1})]):
assert us.max_monthly_urls == -1
def test_max_monthly_mbs():
us = fresh_user_state()
default = GroupPermissions.model_fields["max_monthly_mbs"].default
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
assert not hasattr(us, "_max_monthly_mbs")
assert us.max_monthly_mbs == default
assert us._max_monthly_mbs == default
mock.assert_called_once()
# no new calls
assert us.max_monthly_mbs == default
mock.assert_called_once()
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_monthly_mbs": 1000})]):
assert us.max_monthly_mbs == 1000
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_monthly_mbs": 1500}), models.Group(id="group2", permissions={"max_monthly_mbs": -1})]):
assert us.max_monthly_mbs == -1
def test_priority(user_state):
default = GroupPermissions.model_fields["priority"].default
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
assert not hasattr(user_state, "_priority")
assert user_state.priority == default
assert user_state._priority == default
mock.assert_called_once()
# no new calls
assert user_state.priority == default
mock.assert_called_once()
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"priority": "high"})]):
assert us.priority == "high"
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"priority": "low"}), models.Group(id="group2", permissions={"priority": "medium"})]):
assert us.priority == "low"
def test_active():
for read, read_public, archive_url, archive_sheet, is_active in [
(False, False, False, False, False),
(True, False, False, False, True),
(False, True, False, False, True),
(False, False, True, False, True),
(False, False, False, True, True)
]:
us = fresh_user_state()
with patch.object(UserState, 'read', new_callable=PropertyMock, return_value=read), \
patch.object(UserState, 'read_public', new_callable=PropertyMock, return_value=read_public), \
patch.object(UserState, 'archive_url', new_callable=PropertyMock, return_value=archive_url), \
patch.object(UserState, 'archive_sheet', new_callable=PropertyMock, return_value=archive_sheet):
assert us.active == is_active
def test_in_group(user_state):
with patch.object(UserState, 'user_groups_names', new_callable=PropertyMock, return_value=["group1", "group2"]):
assert user_state.in_group("group1") == True
assert user_state.in_group("group2") == True
assert user_state.in_group("group3") == False
def test_usage(db_session):
user_state = UserState(db_session, email="test@example.com")
user_sheets = [
MagicMock(group_id="group1", sheet_count=5),
MagicMock(group_id="group2", sheet_count=10),
MagicMock(group_id="group3", sheet_count=100),
]
bytes = [1000000, 2000000, 3000000]
urls_by_group = [
MagicMock(group_id="group1", url_count=50, total_bytes=bytes[0]),
MagicMock(group_id="group2", url_count=100, total_bytes=bytes[1]),
MagicMock(group_id="group4", url_count=5, total_bytes=bytes[2]),
]
megabytes = int(sum(bytes) / 1024 / 1024)
with patch.object(db_session, 'query', side_effect=[
MagicMock(filter=MagicMock(return_value=MagicMock(group_by=MagicMock(return_value=MagicMock(all=MagicMock(return_value=user_sheets)))))),
MagicMock(filter=MagicMock(return_value=MagicMock(group_by=MagicMock(return_value=MagicMock(all=MagicMock(return_value=urls_by_group))))))
]):
usage_response = user_state.usage()
assert usage_response.monthly_urls == 155
assert usage_response.monthly_mbs == megabytes
assert usage_response.total_sheets == 115
assert usage_response.groups["group1"].monthly_urls == 50
assert usage_response.groups["group1"].monthly_mbs == int(bytes[0] / 1024 / 1024)
assert usage_response.groups["group1"].total_sheets == 5
assert usage_response.groups["group2"].monthly_urls == 100
assert usage_response.groups["group2"].monthly_mbs == int(bytes[1] / 1024 / 1024)
assert usage_response.groups["group2"].total_sheets == 10
assert usage_response.groups["group3"].monthly_urls == 0
assert usage_response.groups["group3"].monthly_mbs == 0
assert usage_response.groups["group3"].total_sheets == 100
assert usage_response.groups["group4"].monthly_urls == 5
assert usage_response.groups["group4"].monthly_mbs == int(bytes[2] / 1024 / 1024)
assert usage_response.groups["group4"].total_sheets == 0
def test_has_quota_monthly_sheets(db_session):
us = UserState(db_session, email="test@example.com")
test_cases = [
({"unkonwn": GroupInfo(max_sheets=5)}, 1, False),
({"group1": GroupInfo(max_sheets=-1)}, 1000, True),
({"group1": GroupInfo(max_sheets=5)}, 3, True),
({"group1": GroupInfo(max_sheets=5)}, 5, False),
({"group1": GroupInfo(max_sheets=5)}, 6, False),
]
for permissions, count, expected in test_cases:
with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions):
with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(count=MagicMock(return_value=count))))):
assert us.has_quota_monthly_sheets("group1") == expected
def test_has_quota_max_monthly_urls(db_session):
us = UserState(db_session, email="test@example.com")
test_cases = [
({"group1": GroupInfo(max_monthly_urls=-1)}, 1000, True),
({"group1": GroupInfo(max_monthly_urls=100)}, 50, True),
({"group1": GroupInfo(max_monthly_urls=100)}, 100, False),
({"group1": GroupInfo(max_monthly_urls=100)}, 150, False),
]
for permissions, count, expected in test_cases:
with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions):
with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(count=MagicMock(return_value=count))))):
assert us.has_quota_max_monthly_urls("group1") == expected
test_cases = [
(-1, 1000, True),
(100, 50, True),
(100, 100, False),
(100, 150, False),
]
for max_urls, count, expected in test_cases:
with patch.object(UserState, 'max_monthly_urls', new_callable=PropertyMock, return_value=max_urls):
with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(count=MagicMock(return_value=count))))):
assert us.has_quota_max_monthly_urls("") == expected
def test_has_quota_max_monthly_mbs(db_session):
us = UserState(db_session, email="test@example.com")
test_cases = [
({"group1": GroupInfo(max_monthly_mbs=-1)}, 1000, True),
({"group1": GroupInfo(max_monthly_mbs=100)}, 50, True),
({"group1": GroupInfo(max_monthly_mbs=100)}, 100, False),
({"group1": GroupInfo(max_monthly_mbs=100)}, 150, False),
]
for permissions, mbs, expected in test_cases:
with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions):
with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(with_entities=MagicMock(return_value=MagicMock(scalar=MagicMock(return_value=mbs * 1024 * 1024))))))):
assert us.has_quota_max_monthly_mbs("group1") == expected
test_cases = [
(-1, 1000, True),
(100, 50, True),
(100, 100, False),
(100, 150, False),
]
for max_mbs, mbs, expected in test_cases:
with patch.object(UserState, 'max_monthly_mbs', new_callable=PropertyMock, return_value=max_mbs):
with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(with_entities=MagicMock(return_value=MagicMock(scalar=MagicMock(return_value=mbs * 1024 * 1024))))))):
assert us.has_quota_max_monthly_mbs("") == expected
def test_can_manually_trigger(user_state):
permissions = {
"group1": GroupInfo(manually_trigger_sheet=True),
"group2": GroupInfo(manually_trigger_sheet=False),
}
with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions):
assert user_state.can_manually_trigger("group1") == True
assert user_state.can_manually_trigger("group2") == False
assert user_state.can_manually_trigger("group3") == False
def test_is_sheet_frequency_allowed(user_state):
permissions = {
"group1": GroupInfo(sheet_frequency={"daily", "hourly"}),
"group2": GroupInfo(sheet_frequency={"daily"}),
}
with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions):
assert user_state.is_sheet_frequency_allowed("group1", "daily") == True
assert user_state.is_sheet_frequency_allowed("group1", "hourly") == True
assert user_state.is_sheet_frequency_allowed("group1", "weekly") == False
assert user_state.is_sheet_frequency_allowed("group2", "hourly") == False
assert user_state.is_sheet_frequency_allowed("group2", "daily") == True
assert user_state.is_sheet_frequency_allowed("group3", "daily") == False
def test_priority_group(user_state):
from app.web.utils.misc import convert_priority_to_queue_dict
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[
models.Group(id="group1", permissions={"priority": "high"}),
models.Group(id="group2", permissions={"priority": "medium"}),
models.Group(id="group3", permissions={"priority": "low"}),
]):
assert user_state.priority_group("group1") == convert_priority_to_queue_dict("high")
assert user_state.priority_group("group2") == convert_priority_to_queue_dict("medium")
assert user_state.priority_group("group3") == convert_priority_to_queue_dict("low")
assert user_state.priority_group("group4") == convert_priority_to_queue_dict("low")

View File

@@ -0,0 +1,175 @@
from unittest.mock import MagicMock
from fastapi.testclient import TestClient
import pytest
from app.shared.schemas import Usage, UsageResponse
from app.shared.user_groups import GroupInfo
from app.web.config import VERSION
from app.tests.web.db.test_crud import test_data
def test_endpoint_home(client_with_auth):
r = client_with_auth.get("/")
assert r.status_code == 200
j = r.json()
assert "version" in j and j["version"] == VERSION
assert "breakingChanges" in j
assert "groups" not in j
def test_endpoint_health(client_with_auth):
r = client_with_auth.get("/health")
assert r.status_code == 200
assert r.json() == {"status": "ok"}
def test_endpoint_active_no_auth(client, test_no_auth):
test_no_auth(client.get, "/user/active")
def test_endpoint_active(app):
m_user_state = MagicMock()
from app.web.security import get_user_state
app.dependency_overrides[get_user_state] = lambda: m_user_state
# inactive user
m_user_state.active = False
client = TestClient(app)
r = client.get("/user/active")
assert r.status_code == 200
assert r.json() == {"active": False}
# active user
m_user_state.active = True
client = TestClient(app)
r = client.get("/user/active")
assert r.status_code == 200
assert r.json() == {"active": True}
def test_no_serve_local_archive_by_default(client_with_auth):
r = client_with_auth.get("/app/local_archive_test/temp.txt")
assert r.status_code == 404
def test_favicon(client_with_auth):
r = client_with_auth.get("/favicon.ico")
assert r.status_code == 200
assert r.headers["content-type"] == "image/vnd.microsoft.icon"
def test_endpoint_test_prometheus_no_auth(client, test_no_auth):
test_no_auth(client.get, "/metrics")
def test_endpoint_test_prometheus_no_user_auth(client_with_auth, test_no_auth):
test_no_auth(client_with_auth.get, "/metrics")
@pytest.mark.asyncio
async def test_prometheus_metrics(test_data, client_with_token, get_settings):
# before metrics calculation
r = client_with_token.get("/metrics")
assert r.status_code == 200
assert r.headers["content-type"] == "text/plain; version=0.0.4; charset=utf-8"
assert "disk_utilization" in r.text
assert "database_metrics" in r.text
assert "exceptions" in r.text
assert "worker_exceptions_total" in r.text
assert 'disk_utilization{type="used"}' not in r.text
# after metrics calculation
from app.web.utils.metrics import measure_regular_metrics
await measure_regular_metrics(get_settings.DATABASE_PATH, 60 * 60 * 24 * 31 * 12 * 100)
r2 = client_with_token.get("/metrics")
assert 'disk_utilization{type="used"}' in r2.text
assert 'disk_utilization{type="free"}' in r2.text
assert 'disk_utilization{type="database"}' in r2.text
assert 'database_metrics{query="count_archives"} 100.0' in r2.text
assert 'database_metrics{query="count_archive_urls"} 1000.0' in r2.text
assert 'database_metrics{query="count_users"} 3.0' in r2.text
assert 'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0' in r2.text
assert 'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0' in r2.text
assert 'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0' in r2.text
# 30s window, should not change the gauges nor the total in the counters
from app.web.utils.metrics import measure_regular_metrics
await measure_regular_metrics(get_settings.DATABASE_PATH, 30)
r3 = client_with_token.get("/metrics")
assert 'database_metrics{query="count_archives"} 100.0' in r3.text
assert 'database_metrics{query="count_archive_urls"} 1000.0' in r3.text
assert 'database_metrics{query="count_users"} 3.0' in r3.text
assert 'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0' in r3.text
assert 'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0' in r3.text
assert 'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0' in r3.text
def test_endpoint_get_user_permissions_no_user_auth(client, test_no_auth):
test_no_auth(client.get, "/user/permissions")
def test_endpoint_get_user_permissions(app):
from app.web.security import get_user_state
m_user_state = MagicMock()
rv = {
"all": GroupInfo(read=True),
"group1": GroupInfo(archive_url=True),
}
from loguru import logger
logger.info(rv)
m_user_state.permissions = rv
app.dependency_overrides[get_user_state] = lambda: m_user_state
client = TestClient(app)
r = client.get("/user/permissions")
assert r.status_code == 200
response = r.json()
assert response.keys() == {"all", "group1"}
assert response["all"]["read"]
assert response["group1"]["read"] == []
assert response["group1"]["archive_url"]
assert response["all"]["archive_url"] == False
def test_endpoint_get_user_usage_no_user_auth(client, test_no_auth):
test_no_auth(client.get, "/user/usage")
def test_endpoint_get_user_usage_inactive(app):
from app.web.security import get_user_state
m_user_state = MagicMock()
m_user_state.active = False
app.dependency_overrides[get_user_state] = lambda: m_user_state
client = TestClient(app)
r = client.get("/user/usage")
assert r.status_code == 403
assert r.json() == {"detail": "User is not active."}
def test_endpoint_get_user_usage_active(app):
from app.web.security import get_user_state
m_user_state = MagicMock()
m_user_state.active = True
mock_usage = UsageResponse(
monthly_urls=1,
monthly_mbs=2,
total_sheets=3,
groups={
"group1": Usage(monthly_urls=4, monthly_mbs=5, total_sheets=6),
"group2": Usage(monthly_urls=7, monthly_mbs=8, total_sheets=9)
}
)
m_user_state.usage.return_value = mock_usage
app.dependency_overrides[get_user_state] = lambda: m_user_state
client = TestClient(app)
r = client.get("/user/usage")
assert r.status_code == 200
assert UsageResponse(**r.json()) == mock_usage

View File

@@ -0,0 +1,56 @@
from datetime import datetime
import json
from unittest.mock import MagicMock, patch
from app.shared.db import models
from app.web.config import ALLOW_ANY_EMAIL
from app.web.db import crud
def test_submit_manual_archive_unauthenticated(client, test_no_auth):
test_no_auth(client.post, "/interop/submit-archive")
def test_submit_manual_archive_not_user_auth(client_with_auth, test_no_auth):
test_no_auth(client_with_auth.post, "/interop/submit-archive")
@patch("app.web.endpoints.interoperability.business_logic", return_value=MagicMock(get_store_archive_until=MagicMock(return_value=datetime)))
def test_submit_manual_archive(m1, client_with_token, db_session):
# normal workflow
aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.s3.com"]}]})
r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": True, "author_id": "jerry@gmail.com", "group_id": "spaceship", "tags": ["test"], "url": "http://example.com"})
assert r.status_code == 201
assert "id" in r.json()
inserted = db_session.query(models.Archive).filter(models.Archive.id == r.json()["id"]).first()
assert inserted.url == "http://example.com"
assert inserted.group_id == "spaceship"
assert inserted.author_id == "jerry@gmail.com"
assert sorted([t.id for t in inserted.tags]) == sorted(["test", "manual"])
assert inserted.public
assert type(inserted.result) == dict
assert [u.url for u in inserted.urls] == ["http://example.s3.com"]
assert type(inserted.store_until) == datetime
# cannot have the same URL twice
aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.com", "http://example.com"]}]})
r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": False, "author_id": "jerry@gmail.com", "tags": ["test"], "url": "http://example.com"})
assert r.status_code == 422
assert r.json() == {"detail": "Cannot insert into DB due to integrity error, likely duplicate urls."}
# test with invalid JSON
def test_submit_manual_archive_invalid_json(client_with_token):
r = client_with_token.post("/interop/submit-archive", json={"result": "invalid json", "public": False, "author_id": "jer", "tags": ["test"], "url": "http://example.com"})
assert r.status_code == 422
assert r.json() == {"detail": "Invalid JSON in result field."}
@patch("app.web.endpoints.interoperability.business_logic")
def test_submit_manual_archive_no_store_until(m_b, client_with_token, db_session):
m_b.get_store_archive_until.side_effect = AssertionError("AssertionError")
aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.s3.com"]}]})
r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": True, "author_id": "jerry@gmail.com", "group_id": "spaceship", "tags": ["test"], "url": "http://example.com"})
assert r.status_code == 422
assert r.json() == {"detail": "AssertionError"}

View File

@@ -0,0 +1,193 @@
from datetime import datetime
import json
from unittest.mock import MagicMock, patch
from fastapi.testclient import TestClient
from app.shared.schemas import TaskResult
def test_endpoints_no_auth(client, test_no_auth):
test_no_auth(client.post, "/sheet/create")
test_no_auth(client.get, "/sheet/mine")
test_no_auth(client.delete, "/sheet/123-sheet-id")
test_no_auth(client.post, "/sheet/123-sheet-id/archive")
def test_create_sheet_endpoint(app_with_auth, db_session):
client_with_auth = TestClient(app_with_auth)
good_data = {
"id": "123-sheet-id",
"name": "Test Sheet",
"group_id": "spaceship",
"frequency": "daily"
}
# with good data
response = client_with_auth.post("/sheet/create", json=good_data)
assert response.status_code == 201
j = response.json()
assert datetime.fromisoformat(j.pop("created_at"))
assert datetime.fromisoformat(j.pop("last_url_archived_at"))
assert j.pop("author_id") == 'morty@example.com'
assert j == good_data
# already exists
response = client_with_auth.post("/sheet/create", json=good_data)
assert response.status_code == 400
assert response.json() == {"detail": "Sheet with this ID is already being archived."}
# bad group
bad_data = good_data.copy()
bad_data["group_id"] = "not a group"
response = client_with_auth.post("/sheet/create", json=bad_data)
assert response.status_code == 403
assert response.json() == {"detail": "User does not have access to this group."}
# switch to jerry who's got less quota/permissions
from app.web.security import get_user_state
from app.web.db.user_state import UserState
app_with_auth.dependency_overrides[get_user_state] = lambda: UserState(db_session, "jerry@example.com")
client_jerry = TestClient(app_with_auth)
# frequency not allowed
jerry_data = good_data.copy()
jerry_data["group_id"] = "animated-characters"
jerry_data["frequency"] = "hourly"
jerry_data["id"] = "jerry-sheet-id"
response = client_jerry.post("/sheet/create", json=jerry_data)
assert response.status_code == 422
assert response.json() == {"detail": "Invalid frequency selected for this group."}
jerry_data["frequency"] = "daily"
# success for the first sheet, bad quota on second
response = client_jerry.post("/sheet/create", json=jerry_data)
assert response.status_code == 201
response = client_jerry.post("/sheet/create", json=jerry_data)
assert response.status_code == 429
assert response.json() == {"detail": "User has reached their sheet quota for this group."}
def test_get_user_sheets_endpoint(client_with_auth, db_session):
# no data
response = client_with_auth.get("/sheet/mine")
assert response.status_code == 200
assert response.json() == []
# with data
from app.shared.db import models
db_session.add(
models.Sheet(id="123", name="Test Sheet 1", author_id="morty@example.com", group_id="spaceship", frequency="hourly")
)
db_session.commit()
db_session.add_all([
models.Sheet(id="456", name="Test Sheet 2", author_id="morty@example.com", group_id="interdimensional", frequency="daily"),
models.Sheet(id="789", name="Test Sheet 3", author_id="rick@example.com", group_id="interdimensional", frequency="hourly"),
])
db_session.commit()
response = client_with_auth.get("/sheet/mine")
assert response.status_code == 200
r = response.json()
assert isinstance(r, list)
assert len(r) == 2
assert datetime.fromisoformat(r[0].pop("created_at"))
assert datetime.fromisoformat(r[0].pop("last_url_archived_at"))
assert datetime.fromisoformat(r[1].pop("created_at"))
assert datetime.fromisoformat(r[1].pop("last_url_archived_at"))
assert r[0] == {
'id': '123',
'author_id': 'morty@example.com',
'frequency': 'hourly',
'group_id': 'spaceship',
'name': 'Test Sheet 1',
}
assert r[1] == {
'id': '456',
'author_id': 'morty@example.com',
'frequency': 'daily',
'group_id': 'interdimensional',
'name': 'Test Sheet 2',
}
def test_delete_sheet_endpoint(client_with_auth, db_session):
# missing sheet
response = client_with_auth.delete("/sheet/123-sheet-id")
assert response.status_code == 200
assert response.json() == {
"id": "123-sheet-id",
"deleted": False
}
# add sheets for deletion
from app.shared.db import models
db_session.add_all([
models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="interdimensional", frequency="daily"),
models.Sheet(id="456-sheet-id", name="Test Sheet 2", author_id="rick@example.com", group_id="spaceship", frequency="hourly"),
])
db_session.commit()
# morty can delete his
response = client_with_auth.delete("/sheet/123-sheet-id")
assert response.status_code == 200
assert response.json() == {"id": "123-sheet-id", "deleted": True}
# but only once
response = client_with_auth.delete("/sheet/123-sheet-id")
assert response.status_code == 200
assert response.json() == {"id": "123-sheet-id", "deleted": False}
# and not rick's
response = client_with_auth.delete("/sheet/456-sheet-id")
assert response.status_code == 200
assert response.json() == {"id": "456-sheet-id", "deleted": False}
class TestArchiveUserSheetEndpoint:
@patch("app.web.endpoints.sheet.celery", return_value=MagicMock())
def test_normal_flow(self, m_celery, client_with_auth, db_session):
from app.shared.db import models
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="spaceship", frequency="hourly"))
db_session.commit()
m_signature = MagicMock()
m_signature.apply_async.return_value = TaskResult(id="123-taskid", status="PENDING", result="")
m_celery.signature.return_value = m_signature
r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == 201
assert r.json() == {"id": "123-taskid"}
m_celery.signature.assert_called_once()
m_signature.apply_async.assert_called_once()
def test_token_auth(self, client_with_token, test_no_auth):
test_no_auth(client_with_token.post, "/sheet/123-sheet-id/archive")
def test_missing_data(self, client_with_auth):
r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == 403
assert r.json() == {"detail": "No access to this sheet."}
def test_no_access(self, client_with_auth, db_session):
from app.shared.db import models
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="rick@example.com", group_id="spaceship", frequency="hourly"))
db_session.commit()
r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == 403
assert r.json() == {"detail": "No access to this sheet."}
def test_user_not_in_group(self, client_with_auth, db_session):
from app.shared.db import models
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="interdimensional", frequency="hourly"))
db_session.commit()
r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == 403
assert r.json() == {"detail": "User does not have access to this group."}
def test_user_cannot_manually_trigger(self, client_with_auth, db_session):
from app.shared.db import models
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="default", frequency="hourly"))
db_session.commit()
r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == 429
assert r.json() == {"detail": "User cannot manually trigger sheet archiving in this group."}

View File

@@ -5,7 +5,7 @@ def test_endpoint_task_status_no_auth(client, test_no_auth):
test_no_auth(client.get, "/task/test-task-id")
@patch("endpoints.task.AsyncResult")
@patch("app.web.endpoints.task.AsyncResult")
def test_get_status_success(mock_async_result, client_with_auth):
mock_async_result.return_value.status = "SUCCESS"
mock_async_result.return_value.result = {"data": "some result"}
@@ -20,7 +20,7 @@ def test_get_status_success(mock_async_result, client_with_auth):
}
@patch("endpoints.task.AsyncResult")
@patch("app.web.endpoints.task.AsyncResult")
def test_get_status_failure(mock_async_result, client_with_auth):
mock_async_result.return_value.status = "FAILURE"
@@ -36,7 +36,7 @@ def test_get_status_failure(mock_async_result, client_with_auth):
}
@patch("endpoints.task.AsyncResult")
@patch("app.web.endpoints.task.AsyncResult")
def test_get_status_pending(mock_async_result, client_with_auth):
mock_async_result.return_value.status = "PENDING"
mock_async_result.return_value.result = None

View File

@@ -0,0 +1,193 @@
import json
from unittest.mock import MagicMock, patch
from app.shared.schemas import ArchiveCreate, TaskResult
def test_archive_url_unauthenticated(client, test_no_auth):
test_no_auth(client.post, "/url/archive")
@patch("app.web.endpoints.url.UserState")
@patch("app.web.endpoints.url.celery", return_value=MagicMock())
def test_archive_url(m_celery, m2, client_with_auth):
m_signature = MagicMock()
m_signature.apply_async.return_value = TaskResult(id="123-456-789", status="PENDING", result="")
m_celery.signature.return_value = m_signature
m_user_state = MagicMock()
m2.return_value = m_user_state
# url is too short
response = client_with_auth.post("/url/archive", json={"url": "bad"})
assert response.status_code == 422
assert response.json()["detail"][0]["msg"] == 'String should have at least 5 characters'
m_celery.signature.assert_not_called()
# url is invalid
response = client_with_auth.post("/url/archive", json={"url": "example.com"})
assert response.status_code == 400
assert response.json()["detail"] == "Invalid URL received."
# valid request
m_user_state.has_quota_max_monthly_urls.return_value = True
m_user_state.has_quota_max_monthly_mbs.return_value = True
response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
assert response.status_code == 201
assert response.json() == {'id': '123-456-789'}
m_celery.signature.assert_called_once()
m_signature.apply_async.assert_called_once()
called_val = m_celery.signature.call_args
assert called_val[0][0] == "create_archive_task"
assert json.loads(called_val[1]['args'][0]) == {"id": None, "url": "https://example.com", "result": None, "public": False, "author_id": "rick@example.com", "group_id": "default", "tags": None, "sheet_id": None, "store_until": None, "urls": None}
m_user_state.has_quota_max_monthly_urls.assert_called_once()
m_user_state.has_quota_max_monthly_mbs.assert_called_once()
m_user_state.in_group.assert_called_once_with("default")
# user is not in group
m_user_state.in_group.return_value = False
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "new-group"})
assert response.status_code == 403
assert response.json()["detail"] == "User does not have access to this group."
m_user_state.in_group.assert_called_with("new-group")
# user is in group
m_user_state.in_group.return_value = True
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spaceship"})
assert response.status_code == 201
assert response.json() == {'id': '123-456-789'}
assert m_celery.signature.call_count == 2
assert m_signature.apply_async.call_count == 2
called_val = m_celery.signature.call_args
assert json.loads(called_val[1]['args'][0])["group_id"] == "spaceship"
m_user_state.in_group.assert_called_with("spaceship")
# user is over monthly URL quota
m_user_state.has_quota_max_monthly_urls.return_value = False
m_user_state.has_quota_max_monthly_mbs.return_value = True
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spaceship"})
assert response.status_code == 429
assert response.json()["detail"] == "User has reached their monthly URL quota."
m_user_state.has_quota_max_monthly_urls.assert_called_with("spaceship")
# user is over monthly MB quota
m_user_state.has_quota_max_monthly_urls.return_value = True
m_user_state.has_quota_max_monthly_mbs.return_value = False
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spacesuit"})
assert response.status_code == 429
assert response.json()["detail"] == "User has reached their monthly MB quota."
m_user_state.has_quota_max_monthly_mbs.assert_called_with("spacesuit")
assert m_celery.signature.call_count == 2
assert m_signature.apply_async.call_count == 2
@patch("app.web.endpoints.url.UserState")
def test_archive_url_quotas(m1, client_with_auth):
m_user_state = MagicMock()
m1.return_value = m_user_state
# misses on monthly URLs quota
m_user_state.has_quota_max_monthly_urls.return_value = False
response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
assert response.status_code == 429
assert response.json()["detail"] == "User has reached their monthly URL quota."
m_user_state.has_quota_max_monthly_urls.assert_called_once()
# misses on monthly MBs quota
m_user_state.has_quota_max_monthly_urls.return_value = True
m_user_state.has_quota_max_monthly_mbs.return_value = False
response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
assert response.status_code == 429
assert response.json()["detail"] == "User has reached their monthly MB quota."
m_user_state.has_quota_max_monthly_mbs.assert_called_once()
@patch("app.web.endpoints.url.celery", return_value=MagicMock())
def test_archive_url_with_api_token(m_celery, client_with_token):
m_signature = MagicMock()
m_signature.apply_async.return_value = TaskResult(id="123-456-789", status="PENDING", result="")
m_celery.signature.return_value = m_signature
response = client_with_token.post("/url/archive", json={"url": "https://example.com"})
assert response.status_code == 201
assert response.json() == {'id': '123-456-789'}
m_celery.signature.assert_called_once()
m_signature.apply_async.assert_called_once()
called_val = m_celery.signature.call_args
assert called_val[0][0] == "create_archive_task"
def test_search_by_url_unauthenticated(client, test_no_auth):
test_no_auth(client.get, "/url/search")
def test_search_by_url(client_with_auth, client_with_token, db_session):
# tests the search endpoint, including through some db data for the endpoint params
response = client_with_auth.get("/url/search")
assert response.status_code == 422
assert response.json()["detail"][0]["msg"] == "Field required"
response = client_with_auth.get("/url/search?url=https://example.com")
assert response.status_code == 200
assert response.json() == []
from app.shared import schemas
from app.shared.db import worker_crud
for i in range(11):
worker_crud.create_archive(db_session, ArchiveCreate(id=f"url-456-{i}", url="https://example.com" if i < 10 else "https://something-else.com", result={}, public=True, author_id="rick@example.com"), [], [])
# NB: this insertion is too fast for the ordering to be correct as they are within the same second
response = client_with_auth.get("/url/search?url=https://example.com")
assert response.status_code == 200
assert len(j := response.json()) == 10
assert "url-456-0" in [i["id"] for i in j]
assert "url-456-9" in [i["id"] for i in j]
assert "url-456-10" not in [i["id"] for i in j]
assert j[0].keys() == schemas.ArchiveResult.model_fields.keys()
response = client_with_auth.get("/url/search?url=https://example.com&limit=5")
assert response.status_code == 200
assert len(response.json()) == 5
response = client_with_auth.get("/url/search?url=https://example.com&skip=5&limit=2")
assert response.status_code == 200
assert len(response.json()) == 2
response = client_with_auth.get("/url/search?url=https://example.com&archived_before=2010-01-01")
assert response.status_code == 200
assert len(response.json()) == 0
response = client_with_auth.get("/url/search?url=https://example.com&archived_after=2010-01-01")
assert response.status_code == 200
assert len(response.json()) == 10
# API token will also work
response = client_with_token.get("/url/search?url=https://example.com&archived_after=2010-01-01")
assert response.status_code == 200
assert len(response.json()) == 10
@patch("app.web.endpoints.url.UserState")
def test_search_no_read_access(mock_user_state, client_with_auth):
mock_user_state.return_value.read = False
mock_user_state.return_value.read_public = False
response = client_with_auth.get("/url/search?url=https://example.com")
assert response.status_code == 403
assert response.json() == {"detail": "User does not have read access."}
def test_delete_task_unauthenticated(client, test_no_auth):
test_no_auth(client.delete, "/url/123-456-789")
def test_delete_task(client_with_auth, db_session):
response = client_with_auth.delete("/url/delete-123-456-789")
assert response.status_code == 200
assert response.json() == {"id": "delete-123-456-789", "deleted": False}
from app.shared.db import worker_crud
worker_crud.create_archive(db_session, ArchiveCreate(id="delete-123-456-789", url="https://example.com", result={}, public=True, author_id="morty@example.com"), [], [])
response = client_with_auth.delete("/url/delete-123-456-789")
assert response.status_code == 200
assert response.json() == {"id": "delete-123-456-789", "deleted": True}

View File

@@ -17,12 +17,12 @@ def test_alembic(db_session):
alembic.config.main(argv=['--raiseerr', 'upgrade', 'head'])
alembic.config.main(argv=['--raiseerr', 'downgrade', 'base'])
@patch("endpoints.default.crud.get_user_groups", side_effect=Exception('mocked error'))
@patch("app.web.endpoints.url.crud.soft_delete_archive", side_effect=Exception('mocked error'))
def test_logging_middleware(m1, client_with_auth):
from utils.metrics import EXCEPTION_COUNTER
from app.web.utils.metrics import EXCEPTION_COUNTER
assert len(EXCEPTION_COUNTER.collect()[0].samples) == 0
with pytest.raises(Exception, match="mocked error"):
client_with_auth.get("/groups")
client_with_auth.delete("/url/123")
# creates one empty and one from above
assert len(EXCEPTION_COUNTER.collect()[0].samples) == 2
@@ -36,7 +36,7 @@ def test_serve_local_archive_logic(get_settings):
try:
# modify the settings
get_settings.SERVE_LOCAL_ARCHIVE = "/app/local_archive_test"
from web.main import app_factory
from app.web.main import app_factory
app = app_factory(get_settings)
# test

View File

@@ -1,14 +1,14 @@
from unittest.mock import patch
from unittest.mock import Mock, patch
from fastapi import HTTPException
from fastapi.security import HTTPAuthorizationCredentials
import pytest
from core.config import ALLOW_ANY_EMAIL
from app.web.config import ALLOW_ANY_EMAIL
def test_secure_compare():
from web.security import secure_compare
from app.web.security import secure_compare
assert secure_compare("test", "test")
assert not secure_compare("test", "test2")
@@ -16,14 +16,14 @@ def test_secure_compare():
@pytest.mark.asyncio
async def test_get_token_or_user_auth_with_api():
from web.security import get_token_or_user_auth
from app.web.security import get_token_or_user_auth
mock_api = HTTPAuthorizationCredentials(scheme="lorem", credentials="this_is_the_test_api_token")
assert await get_token_or_user_auth(mock_api) == ALLOW_ANY_EMAIL
@pytest.mark.asyncio
async def test_get_token_or_user_auth_with_user():
from web.security import get_token_or_user_auth
from app.web.security import get_token_or_user_auth
bad_user = HTTPAuthorizationCredentials(scheme="ipsum", credentials="invalid")
e: pytest.ExceptionInfo = None
with pytest.raises(HTTPException) as e:
@@ -32,18 +32,18 @@ async def test_get_token_or_user_auth_with_user():
assert e.value.detail == "invalid access_token"
@patch("web.security.authenticate_user", return_value=(True, "summer@example.com"))
@patch("app.web.security.authenticate_user", return_value=(True, "summer@example.com"))
@pytest.mark.asyncio
async def test_get_user_auth(m1):
from web.security import get_user_auth
bad_user = HTTPAuthorizationCredentials(scheme="ipsum", credentials="valid-and-good")
assert await get_user_auth(bad_user) == "summer@example.com"
from app.web.security import get_user_auth
good_user = HTTPAuthorizationCredentials(scheme="ipsum", credentials="valid-and-good")
assert await get_user_auth(good_user) == "summer@example.com"
@patch("web.security.secure_compare", return_value=False)
@patch("app.web.security.secure_compare", return_value=False)
@pytest.mark.asyncio
async def test_token_api_key_auth_exception(m1):
from web.security import token_api_key_auth
from app.web.security import token_api_key_auth
e: pytest.ExceptionInfo = None
with pytest.raises(HTTPException) as e:
@@ -54,12 +54,12 @@ async def test_token_api_key_auth_exception(m1):
@pytest.mark.asyncio
async def test_authenticate_user():
from web.security import authenticate_user
from app.web.security import authenticate_user
assert authenticate_user("test") == (False, "invalid access_token")
assert authenticate_user(123) == (False, "invalid access_token")
with patch("web.security.requests.get") as mock_get:
with patch("app.web.security.requests.get") as mock_get:
# bad response from oauth2
mock_get.return_value.status_code = 403
assert authenticate_user("this-will-call-requests") == (False, "invalid token")
@@ -100,9 +100,22 @@ async def test_authenticate_user():
@pytest.mark.asyncio
async def test_authenticate_user_exception():
from web.security import authenticate_user
with patch("web.security.requests.get") as mock_get:
from app.web.security import authenticate_user
with patch("app.web.security.requests.get") as mock_get:
mock_get.return_value.status_code = 200
mock_get.return_value.json.side_effect = Exception("mocked error")
assert authenticate_user("this-will-call-requests") == (False, "exception occurred")
def test_get_user_state():
from app.web.security import get_user_state
from app.web.db.user_state import UserState
mock_session = Mock()
test_email = "test@example.com"
state = get_user_state(test_email, mock_session)
assert isinstance(state, UserState)
assert state.email == test_email
assert state.db == mock_session

View File

@@ -0,0 +1,137 @@
from datetime import datetime
from unittest.mock import patch
import pytest
from app.shared.db import models
from app.shared import schemas
from auto_archiver.core import Media, Metadata
class Test_create_archive_task():
URL = "https://example-live.com"
archive = schemas.ArchiveCreate(url=URL, tags=["tag-celery"], public=True, author_id="rick@example.com", group_id="interstellar")
@patch("app.worker.main.ArchivingOrchestrator")
@patch("app.worker.main.get_all_urls", return_value=[])
@patch("app.worker.main.insert_result_into_db")
@patch("app.worker.main.get_store_until", return_value=datetime.now())
@patch("app.worker.main.get_orchestrator_args", return_value=["arg1", "arg2"])
@patch("celery.app.task.Task.request")
def test_success(self, m_req, m_args, m_store, m_insert, m_urls, m_orchestrator, db_session):
from app.worker.main import create_archive_task
m_req.id = "this-just-in"
m_orchestrator.return_value.feed.return_value = iter([Metadata().set_url(self.URL).success()])
task = create_archive_task(self.archive.model_dump_json())
m_args.assert_called_once()
m_store.assert_called_once_with("interstellar")
m_insert.assert_called_once()
m_urls.assert_called_once()
m_orchestrator.return_value.feed.assert_called_once()
m_orchestrator.return_value.setup.assert_called_once()
assert task["status"] == "success"
assert task["metadata"]["url"] == self.URL
assert len(task["media"]) == 0
def test_raise_invalid(self):
from app.worker.main import create_archive_task
with pytest.raises(Exception):
create_archive_task(self.archive.model_dump_json())
@patch("app.worker.main.ArchivingOrchestrator")
@patch("app.worker.main.get_orchestrator_args")
def test_raise_db_error(self, m_args, m_orchestrator):
from app.worker.main import create_archive_task
m_orchestrator.return_value.feed.side_effect = Exception("Orchestrator failed")
with pytest.raises(Exception) as e:
create_archive_task(self.archive.model_dump_json())
assert str(e.value) == "Orchestrator failed"
m_args.assert_called_once()
m_orchestrator.return_value.feed.assert_called_once()
@patch("app.worker.main.ArchivingOrchestrator")
@patch("app.worker.main.insert_result_into_db", return_value=None)
@patch("app.worker.main.get_orchestrator_args")
def test_raise_empty_result(self, m_args, m_insert, m_orchestrator):
from app.worker.main import create_archive_task
m_orchestrator.return_value.feed.return_value = iter([None])
with pytest.raises(Exception) as e:
create_archive_task(self.archive.model_dump_json())
assert str(e.value) == "UNABLE TO archive: https://example-live.com"
m_orchestrator.return_value.feed.assert_called_once()
class Test_create_sheet_task():
URL = "https://example-live.com"
sheet = schemas.SubmitSheet(sheet_id="123", author_id="rick@example.com", group_id="interstellar", tags=["spaceship"])
@patch("app.worker.main.get_all_urls", return_value=[])
@patch("app.worker.main.ArchivingOrchestrator")
@patch("app.worker.main.models.generate_uuid", return_value="constant-uuid")
@patch("app.worker.main.get_store_until", return_value=datetime.now())
@patch("app.worker.main.get_orchestrator_args")
def test_success(self, m_args, m_store, m_uuid, m_orchestrator, m_urls, db_session):
from app.worker.main import create_sheet_task
assert db_session.query(models.Archive).filter(models.Archive.url == self.URL).count() == 0
mock_metadata = Metadata().set_url(self.URL).success()
mock_metadata.add_media(Media("fn1.txt", urls=["outcome1.com"]))
m_orchestrator.return_value.feed.return_value = iter([False, mock_metadata, mock_metadata])
res = create_sheet_task(self.sheet.model_dump_json())
m_args.assert_called_once_with("interstellar", True, ["--gsheet_feeder.sheet_id", "123"])
m_orchestrator.return_value.setup.assert_called_once()
m_orchestrator.return_value.feed.assert_called_once()
m_store.assert_called_with("interstellar")
m_store.call_count == 2
m_uuid.call_count == 2
assert type(res) == dict
assert res["stats"]["archived"] == 1
assert res["stats"]["failed"] == 1
assert len(res["stats"]["errors"]) == 1
assert res["sheet_id"] == "123"
assert res["success"]
assert type(res["time"]) == datetime
# query created archive entry
inserted = db_session.query(models.Archive).filter(models.Archive.url == self.URL).one()
assert inserted is not None
assert inserted.url == self.URL
assert len(inserted.tags) == 1
assert inserted.tags[0].id == "spaceship"
assert inserted.group_id == "interstellar"
assert inserted.author_id == "rick@example.com"
assert inserted.public == False
def test_get_all_urls(db_session):
from app.worker.main import get_all_urls
meta = Metadata().set_url("https://example.com")
m1 = meta.add_media(Media("fn1.txt", urls=["outcome1.com"]))
m2 = meta.add_media(Media("fn2.txt", urls=["outcome2.com"]))
m3 = meta.add_media(Media("fn3.txt", urls=["outcome3.com"]))
m1.set("screenshot", Media("screenshot.png", urls=["screenshot.com"]))
m2.set("thumbnails", [Media("thumb1.png", urls=["thumb1.com"]), Media("thumb2.png", urls=["thumb2.com"])])
m3.set("ssl_data", Media("ssl_data.txt", urls=["ssl_data.com"]).to_dict())
m3.set("bad_data", {"bad": "dict is ignored"})
urls = [u.url for u in get_all_urls(meta)]
assert len(urls) == 7
assert "outcome1.com" in urls
assert "outcome2.com" in urls
assert "outcome3.com" in urls
assert "screenshot.com" in urls
assert "thumb1.com" in urls
assert "thumb2.com" in urls
assert "ssl_data.com" in urls

3
app/web/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from app.web.main import app_factory
app = app_factory

View File

@@ -1,4 +1,5 @@
VERSION = "0.8.0"
VERSION = "0.9.0"
API_DESCRIPTION = """
#### API for the Auto-Archiver project, a tool to archive web pages and Google Sheets.
@@ -7,7 +8,7 @@ API_DESCRIPTION = """
- You can use this API to archive single URLs or entire Google Sheets.
- Once you submit a URL or Sheet for archiving, the API will return a task_id that you can use to check the status of the archiving process. It works asynchronously.
"""
BREAKING_CHANGES = {"minVersion": "0.3.1", "message": "The latest update has breaking changes, please update the extension to the most recent version."}
BREAKING_CHANGES = {"minVersion": "0.4.0", "message": "The latest update has breaking changes, please update the extension to the most recent version."}
# changing this will corrupt the database logic
ALLOW_ANY_EMAIL = "*"

273
app/web/db/crud.py Normal file
View File

@@ -0,0 +1,273 @@
from collections import defaultdict
from functools import lru_cache
from sqlalchemy.orm import Session, load_only
from sqlalchemy import Column, or_, func, select
from loguru import logger
from datetime import datetime, timedelta
from sqlalchemy.ext.asyncio import AsyncSession
from cachetools import LRUCache, cached
from cachetools.keys import hashkey
from app.web.config import ALLOW_ANY_EMAIL
from app.shared.db import models
from app.shared.settings import get_settings
from app.shared.user_groups import UserGroups
from app.shared.utils.misc import fnv1a_hash_mod
from app.web.utils.misc import convert_priority_to_queue_dict
DATABASE_QUERY_LIMIT = get_settings().DATABASE_QUERY_LIMIT
def get_limit(user_limit: int):
return max(1, min(user_limit, DATABASE_QUERY_LIMIT))
# --------------- TASK = Archive
def base_query(db: Session):
# NOTE: load_only is for optimization and not obfuscation, use .with_entities() if needed
return db.query(models.Archive)\
.filter(models.Archive.deleted == False)\
.options(load_only(models.Archive.id, models.Archive.created_at, models.Archive.url, models.Archive.result, models.Archive.store_until))
def search_archives_by_url(db: Session, url: str, email: str, read_groups: bool | set[str], read_public: bool, skip: int = 0, limit: int = 100, archived_after: datetime = None, archived_before: datetime = None, absolute_search: bool = False) -> list[models.Archive]:
# searches for partial URLs, if email is * no ownership (or read/read_public) filtering happens
query = base_query(db)
if email != ALLOW_ANY_EMAIL:
or_filters = [models.Archive.author_id == email]
if read_public:
or_filters.append(models.Archive.public == True)
if read_groups == True:
or_filters.append(models.Archive.group_id.isnot(None))
else:
or_filters.append(models.Archive.group_id.in_(read_groups))
query = query.filter(or_(*or_filters))
if absolute_search:
query = query.filter(models.Archive.url == url)
else:
query = query.filter(models.Archive.url.like(f'%{url}%'))
if archived_after:
query = query.filter(models.Archive.created_at > archived_after)
if archived_before:
query = query.filter(models.Archive.created_at < archived_before)
return query.order_by(models.Archive.created_at.desc()).offset(skip).limit(get_limit(limit)).all()
def search_archives_by_email(db: Session, email: str, skip: int = 0, limit: int = 100):
return base_query(db).filter(models.Archive.author_id == email).order_by(models.Archive.created_at.desc()).offset(skip).limit(get_limit(limit)).all()
def soft_delete_archive(db: Session, id: str, email: str) -> bool:
# TODO: implement hard-delete with cronjob that deletes from S3
db_archive = db.query(models.Archive).filter(models.Archive.id == id, models.Archive.author_id == email, models.Archive.deleted == False).first()
if db_archive:
db_archive.deleted = True
db.commit()
return db_archive is not None
def count_archives(db: Session):
return db.query(func.count(models.Archive.id)).scalar()
def count_archive_urls(db: Session):
return db.query(func.count(models.ArchiveUrl.url)).scalar()
def count_users(db: Session):
return db.query(func.count(models.User.email)).scalar()
def count_by_user_since(db: Session, seconds_delta: int = 15):
time_threshold = datetime.now() - timedelta(seconds=seconds_delta)
return db.query(models.Archive.author_id, func.count().label('total'))\
.filter(models.Archive.created_at >= time_threshold)\
.group_by(models.Archive.author_id)\
.order_by(func.count().desc())\
.limit(500).all()
async def find_by_store_until(db: AsyncSession, store_until_is_before: datetime) -> list[models.Archive]:
res = await db.execute(
select(models.Archive)
.filter(models.Archive.deleted == False, models.Archive.store_until < store_until_is_before)
)
return res.scalars()
async def soft_delete_expired_archives(db: AsyncSession) -> dict:
to_delete = await find_by_store_until(db, datetime.now())
counter = 0
for archive in to_delete:
archive.deleted = True
counter += 1
await db.commit()
return counter
# --------------- TAG
async def get_group_priority_async(db: AsyncSession, group_id: str) -> dict:
db_group = await db.get(models.Group, group_id)
priority = db_group.permissions.get("priority", "low") if db_group else "low"
return convert_priority_to_queue_dict(priority)
@cached(cache=LRUCache(maxsize=128), key=lambda db, email: hashkey(email))
def get_user_group_names(db: Session, email: str) -> list[str]:
"""
given an email retrieves the user groups from the DB and then the email-domain groups from a global variable, the email does not need to belong to an existing user.
"""
# TODO: the read: [group1, group2] permissions don't currently work
if not email or not len(email) or "@" not in email: return []
# get user groups
user_groups = db.query(models.association_table_user_groups).filter_by(user_id=email).with_entities(Column("group_id")).all()
user_level_groups_names = [g[0] for g in user_groups]
# get domain groups
domain = email.split('@')[1]
domain_level_groups = db.query(models.Group.id).filter(models.Group.domains.contains(domain)).with_entities(Column("id")).all()
domain_level_groups_names = [g[0] for g in domain_level_groups]
return list(set(user_level_groups_names + domain_level_groups_names))
def get_user_groups_by_name(db: Session, groups: list[str]) -> list[models.Group]:
return db.query(models.Group).filter(
models.Group.id.in_(groups)
).all()
# --------------- INIT User-Groups
def upsert_group(db: Session, group_name: str, description: str, orchestrator: str, orchestrator_sheet: str, service_account_email: str, permissions: dict, domains: list) -> models.Group:
db_group = db.query(models.Group).filter(models.Group.id == group_name).first()
if db_group is None:
db_group = models.Group(id=group_name, description=description, orchestrator=orchestrator, orchestrator_sheet=orchestrator_sheet, service_account_email=service_account_email, permissions=permissions, domains=domains)
db.add(db_group)
else:
db_group.description = description
db_group.orchestrator = orchestrator
db_group.orchestrator_sheet = orchestrator_sheet
db_group.service_account_email = service_account_email
db_group.permissions = permissions
db_group.domains = domains
db.commit()
db.refresh(db_group)
return db_group
def upsert_user(db: Session, email: str):
email = email.lower()
db_user = db.query(models.User).filter(models.User.email == email).first()
if db_user is None:
db_user = models.User(email=email)
db.add(db_user)
db.commit()
return db_user
def upsert_user_groups(db: Session):
def display_email_pii(email: str):
return f"'{email[0:3]}...@{email.split('@')[1]}'"
"""
reads the user_groups yaml file and inserts any new users, groups,
along with new participation of users in groups
"""
filename = get_settings().USER_GROUPS_FILENAME
logger.debug(f"Updating user-groups configuration with file {filename}.")
ug = UserGroups(filename)
# delete all user-groups relationships
db.query(models.association_table_user_groups).delete()
# create a map of group_id -> domains and another of domain -> groups
group_domains = defaultdict(set)
domain_groups = defaultdict(list)
for domain, explicit_groups in ug.domains.items():
domain_groups[domain] = list(set(explicit_groups))
for group in explicit_groups:
group_domains[group].add(domain)
import json
# upsert groups and save a map of groupid -> dbobject
for group_id, g in ug.groups.items():
upsert_group(db, group_id, g.description, g.orchestrator, g.orchestrator_sheet, g.service_account_email, json.loads(g.permissions.model_dump_json()), list(group_domains.get(group_id, [])))
db_groups: dict[str, models.Group] = {g.id: g for g in db.query(models.Group).all()}
# integrity checks
for group_in_domains in group_domains:
if group_in_domains not in db_groups:
logger.warning(f"[CONFIG] Group '{group_in_domains}' does not exist in the database: domains setting will not work.")
# reinsert users in their EXPLICITLY DEFINED groups
# domain groups are check live, as there may be new users that are not explicitly registered but belong to a domain
for email, explicit_groups in ug.users.items():
explicit_groups = explicit_groups or []
logger.info(f"EXPLICIT {display_email_pii(email)} => {explicit_groups}")
db_user = upsert_user(db, email)
# connect users to groups
for group_id in explicit_groups:
if group_id not in db_groups:
logger.warning(f"[CONFIG] Group {group_id} does not exist in config file, skipping for email={display_email_pii(email)}.")
continue
db_groups[group_id].users.append(db_user)
db.commit()
count_user_groups = db.query(models.association_table_user_groups).count()
count_groups = db.query(func.count(models.Group.id)).scalar()
logger.success(f"[CONFIG] DONE: [users={count_users(db)}, groups={count_groups}, explicit user groups={count_user_groups}].")
# --------------- SHEET
def create_sheet(db: Session, sheet_id: str, name: str, email: str, group_id: str, frequency: str):
db_sheet = models.Sheet(id=sheet_id, name=name, author_id=email, group_id=group_id, frequency=frequency)
db.add(db_sheet)
db.commit()
db.refresh(db_sheet)
return db_sheet
def get_user_sheet(db: Session, email: str, sheet_id: str) -> models.Sheet:
return db.query(models.Sheet).filter(models.Sheet.author_id == email, models.Sheet.id == sheet_id).first()
def get_user_sheets(db: Session, email: str) -> list[models.Sheet]:
return db.query(models.Sheet).filter(models.Sheet.author_id == email).order_by(models.Sheet.last_url_archived_at.desc()).all()
async def get_sheets_by_id_hash(db: AsyncSession, frequency: str, modulo: str, id_hash: int) -> list[models.Sheet]:
result = await db.execute(
select(models.Sheet).filter(models.Sheet.frequency == frequency)
)
filtered = []
for sheet in result.scalars():
if fnv1a_hash_mod(sheet.id, modulo) == id_hash:
filtered.append(sheet)
return filtered
async def delete_stale_sheets(db: AsyncSession, inactivity_days: int) -> dict:
time_threshold = datetime.now() - timedelta(days=inactivity_days)
result = await db.execute(
select(models.Sheet).filter(models.Sheet.last_url_archived_at < time_threshold)
)
deleted = defaultdict(list)
for sheet in result.scalars():
await db.delete(sheet)
deleted[sheet.author_id].append(sheet)
await db.commit()
return dict(deleted)
def delete_sheet(db: Session, sheet_id: str, email: str) -> bool:
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id, models.Sheet.author_id == email).first()
if db_sheet:
db.delete(db_sheet)
db.commit()
return db_sheet is not None

341
app/web/db/user_state.py Normal file
View File

@@ -0,0 +1,341 @@
from typing import Dict, Set
import sqlalchemy
from sqlalchemy.orm import Session
from sqlalchemy import func
from datetime import datetime
from app.shared.db import models
from app.shared.user_groups import GroupInfo, GroupPermissions
from app.shared.schemas import Usage, UsageResponse
from app.web.db import crud
from app.web.utils.misc import convert_priority_to_queue_dict
class UserState:
"""
Manage a user's state and permissions
"""
def __init__(self, db: Session, email: str):
self.db = db
self.email = email.lower()
@property
def permissions(self) -> Dict[str, GroupInfo]:
"""
Returns a dict of all group permissions and a special {"all": read/archive_url/archive_sheet} key
"""
if not hasattr(self, '_permissions'):
self._permissions = {}
self._permissions["all"] = GroupInfo(
read=self.read,
read_public=self.read_public,
archive_url=self.archive_url,
archive_sheet=self.archive_sheet,
# below are relevant only for /url endpoints
max_archive_lifespan_months=self.max_archive_lifespan_months,
max_monthly_urls=self.max_monthly_urls,
max_monthly_mbs=self.max_monthly_mbs,
priority=self.priority
)
for group in self.user_groups:
if not group.permissions: continue
self._permissions[group.id] = GroupInfo(**group.permissions, description=group.description, service_account_email=group.service_account_email)
return self._permissions
@property
def user_groups_names(self):
if not hasattr(self, '_user_groups_names'):
self._user_groups_names = crud.get_user_group_names(self.db, self.email) + ["default"]
return self._user_groups_names
@property
def user_groups(self):
if not hasattr(self, '_user_groups'):
self._user_groups = crud.get_user_groups_by_name(self.db, self.user_groups_names)
return self._user_groups
@property
def read(self) -> Set[str] | bool:
"""
Read can be a list of group names or True, if all can be read.
"""
if not hasattr(self, '_read'):
self._read = set()
for group in self.user_groups:
if not group.permissions: continue
group_read_permissions = group.permissions.get("read", [])
if "all" in group_read_permissions:
self._read = True
return self._read
else:
self._read.update(group.permissions.get("read", []))
return self._read
@property
def read_public(self) -> bool:
"""
Read public permission
"""
if not hasattr(self, '_read_public'):
self._read_public = False
for group in self.user_groups:
if not group.permissions: continue
if group.permissions.get("read_public", False):
self._read_public = True
return self._read_public
return self._read_public
@property
def archive_url(self) -> bool:
"""
Archive URL permission
"""
if not hasattr(self, '_archive_url'):
self._archive_url = False
for group in self.user_groups:
if not group.permissions: continue
if group.permissions.get("archive_url", False):
self._archive_url = True
return self._archive_url
return self._archive_url
@property
def archive_sheet(self) -> bool:
"""
Archive sheet permission
"""
if not hasattr(self, '_archive_sheet'):
self._archive_sheet = False
for group in self.user_groups:
if not group.permissions: continue
if group.permissions.get("archive_sheet", False):
self._archive_sheet = True
return self._archive_sheet
return self._archive_sheet
@property
def sheet_frequency(self):
if not hasattr(self, '_sheet_frequency'):
self._sheet_frequency = set()
for group in self.user_groups:
if not group.permissions: continue
self._sheet_frequency.update(group.permissions.get("sheet_frequency", None))
return self._sheet_frequency
@property
def max_archive_lifespan_months(self) -> int:
if not hasattr(self, '_max_archive_lifespan_months'):
self._max_archive_lifespan_months = self._helper_for_grouping_max_numerical_permissions("max_archive_lifespan_months")
return self._max_archive_lifespan_months
@property
def max_monthly_urls(self) -> int:
if not hasattr(self, '_max_monthly_urls'):
self._max_monthly_urls = self._helper_for_grouping_max_numerical_permissions("max_monthly_urls")
return self._max_monthly_urls
@property
def max_monthly_mbs(self) -> int:
if not hasattr(self, '_max_monthly_mbs'):
self._max_monthly_mbs = self._helper_for_grouping_max_numerical_permissions("max_monthly_mbs")
return self._max_monthly_mbs
@property
def priority(self) -> str:
if not hasattr(self, '_priority'):
self._priority = "low"
for group in self.user_groups:
if not group.permissions: continue
if group.permissions.get("priority", self._priority) == "high":
self._priority = "high"
break
return self._priority
@property
def active(self) -> bool:
"""
A user is active if they can read/archive anything
"""
if not hasattr(self, '_active'):
self._active = bool(self.read or self.read_public or self.archive_url or self.archive_sheet)
return self._active
def _helper_for_grouping_max_numerical_permissions(self, permission_name: str) -> int:
"""
Iterates one of the numerical permissions where -1 means no restrictions and returns either -1 or the maximum value, defaults according to GroupPermissions
"""
default = GroupPermissions.model_fields[permission_name].default
max_value = default
for group in self.user_groups:
if not group.permissions: continue
group_value = group.permissions.get(permission_name, default)
if group_value == -1:
max_value = -1
return max_value
max_value = max(max_value, group_value)
return max_value
def in_group(self, group_id: str) -> bool:
return group_id in self.user_groups_names
def usage(self) -> Dict:
"""
returns the monthly quotas for the URLs/MBs and the totals for Sheets
"""
current_month = datetime.now().month
current_year = datetime.now().year
# find and sum all user sheets over this month
user_sheets = self.db.query(
models.Sheet.group_id,
func.count(models.Sheet.id).label('sheet_count')
).filter(models.Sheet.author_id == self.email).group_by(models.Sheet.group_id).all()
sheets_by_group = {sheet.group_id: sheet.sheet_count for sheet in user_sheets}
# find and sum all user urls over this month
urls_by_group = self.db.query(
models.Archive.group_id,
func.count(models.Archive.id).label('url_count'),
func.coalesce(func.sum(
func.coalesce(
func.cast(
func.json_extract(models.Archive.result, '$.metadata.total_bytes'),
sqlalchemy.Integer
), 0
)
), 0).label('total_bytes')
).filter(
models.Archive.author_id == self.email,
func.extract('month', models.Archive.created_at) == current_month,
func.extract('year', models.Archive.created_at) == current_year
).group_by(models.Archive.group_id).all()
# merge the two queries
usage_by_group: Dict[str, Usage] = {
(url.group_id or ""):
Usage(monthly_urls=url.url_count, monthly_mbs=int(url.total_bytes / 1024 / 1024))
for url in urls_by_group
}
for group_id, sheet_count in sheets_by_group.items():
group_id = group_id or ""
if group_id in usage_by_group:
usage_by_group[group_id].total_sheets = sheet_count
else:
usage_by_group[group_id] = Usage(total_sheets=sheet_count)
# calculate totals
total_sheets = sum([sheet.sheet_count for sheet in user_sheets])
total_bytes = sum([url.total_bytes for url in urls_by_group])
total_urls = sum([url.url_count for url in urls_by_group])
return UsageResponse(
monthly_urls=total_urls,
monthly_mbs=int(total_bytes / 1024 / 1024),
total_sheets=total_sheets,
groups=usage_by_group
)
def has_quota_monthly_sheets(self, group_id: str) -> bool:
"""
checks if a user has reached their sheet quota for a given group
"""
if group_id not in self.permissions:
return False
user_sheets = self.db.query(models.Sheet).filter(models.Sheet.author_id == self.email, models.Sheet.group_id == group_id).count()
sheet_quota = self.permissions[group_id].max_sheets
if sheet_quota == -1:
return True
return user_sheets < sheet_quota
def has_quota_max_monthly_urls(self, group_id: str) -> bool:
"""
checks if a user has reached their monthly url quota for a group, if global then group should be empty string
"""
quota = 0
if not group_id:
quota = self.max_monthly_urls
else:
if group_id not in self.permissions: return False
quota = self.permissions[group_id].max_monthly_urls
if quota == -1:
return True
current_month = datetime.now().month
current_year = datetime.now().year
user_urls = self.db.query(models.Archive).filter(
models.Archive.author_id == self.email,
models.Archive.group_id == group_id,
func.extract('month', models.Archive.created_at) == current_month,
func.extract('year', models.Archive.created_at) == current_year
).count()
return user_urls < quota
def has_quota_max_monthly_mbs(self, group_id: str) -> bool:
"""
checks if a user has reached their monthly MBs quota for a group, if global then group should be empty string
"""
quota = 0
if not group_id:
quota = self.max_monthly_mbs
else:
if group_id not in self.permissions: return False
quota = self.permissions[group_id].max_monthly_mbs
if quota == -1:
return True
current_month = datetime.now().month
current_year = datetime.now().year
# find and sum all user bytes over this month
user_bytes = self.db.query(models.Archive).filter(
models.Archive.author_id == self.email,
models.Archive.group_id == group_id,
func.extract('month', models.Archive.created_at) == current_month,
func.extract('year', models.Archive.created_at) == current_year
).with_entities(func.coalesce(func.sum(
func.coalesce(
func.cast(
func.json_extract(models.Archive.result, '$.metadata.total_bytes'),
sqlalchemy.Integer
), 0
)
), 0).label('total')).scalar()
# convert bytes to mb
user_mbs = int(user_bytes / 1024 / 1024)
return user_mbs < quota
def can_manually_trigger(self, group_id: str) -> bool:
"""
checks if a user is allowed to manually trigger a sheet
"""
if group_id not in self.permissions:
return False
return self.permissions[group_id].manually_trigger_sheet
def is_sheet_frequency_allowed(self, group_id: str, frequency: str) -> bool:
"""
checks if a user is allowed to create a sheet with this frequency for this group
"""
if group_id not in self.permissions:
return False
return frequency in self.permissions[group_id].sheet_frequency
def priority_group(self, group_id: str) -> str:
priority = "low"
for group in self.user_groups:
if group.id != group_id: continue
if not group.permissions: continue
priority = group.permissions.get("priority", priority)
break
return convert_priority_to_queue_dict(priority)

View File

@@ -0,0 +1,50 @@
from typing import Dict
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import FileResponse, JSONResponse
from app.web.config import VERSION, BREAKING_CHANGES
from app.shared.schemas import ActiveUser, UsageResponse
from app.web.db.user_state import UserState
from app.web.security import get_user_state
from app.shared.user_groups import GroupInfo
default_router = APIRouter()
@default_router.get("/")
async def home():
return JSONResponse({"version": VERSION, "breakingChanges": BREAKING_CHANGES})
@default_router.get("/health")
async def health():
return JSONResponse({"status": "ok"})
@default_router.get("/user/active", summary="Check if the user is active and can use the tool.")
async def active(
user: UserState = Depends(get_user_state),
) -> ActiveUser:
return {"active": user.active}
@default_router.get("/user/permissions", summary="Get the user's global 'all' permissions and the permissions for each group they belong to.")
def get_user_permissions(
user: UserState = Depends(get_user_state),
) -> Dict[str, GroupInfo]:
return user.permissions
@default_router.get("/user/usage", summary="Get the user's monthly URLs/MBs usage along with the total active sheets, breakdown by group.")
def get_user_usage(
user: UserState = Depends(get_user_state),
) -> UsageResponse:
if not user.active:
raise HTTPException(status_code=403, detail="User is not active.")
return user.usage()
@default_router.get('/favicon.ico', include_in_schema=False)
async def favicon() -> FileResponse:
return FileResponse("app/web/static/favicon.ico")

View File

@@ -0,0 +1,61 @@
import json
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse
from loguru import logger
import sqlalchemy
from auto_archiver.core import Metadata
from sqlalchemy.orm import Session
from app.shared.aa_utils import get_all_urls
from app.web.config import ALLOW_ANY_EMAIL
from app.shared import business_logic, schemas
from app.shared.db import worker_crud
from app.shared.db.database import get_db_dependency
from app.web.security import token_api_key_auth
from app.shared.db import models
from app.shared.log import log_error
interoperability_router = APIRouter(prefix="/interop", tags=["Interoperability endpoints."])
# ----- endpoint to submit data archived elsewhere
@interoperability_router.post("/submit-archive", status_code=201, summary="Submit a manual archive entry, for data that was archived elsewhere.")
def submit_manual_archive(
manual: schemas.SubmitManualArchive,
auth=Depends(token_api_key_auth),
db: Session = Depends(get_db_dependency)
):
try:
result: Metadata = Metadata.from_json(manual.result)
except json.JSONDecodeError as e:
log_error(e)
raise HTTPException(status_code=422, detail="Invalid JSON in result field.")
manual.author_id = manual.author_id or ALLOW_ANY_EMAIL
manual.tags.add("manual")
try:
store_until=business_logic.get_store_archive_until(db, manual.group_id)
except AssertionError as e:
log_error(e)
raise HTTPException(status_code=422, detail=str(e))
try:
archive = schemas.ArchiveCreate(
author_id=manual.author_id,
url=result.get_url(),
public=manual.public,
group_id=manual.group_id,
tags=manual.tags,
id=models.generate_uuid(),
result=json.loads(result.to_json()),
urls=get_all_urls(result),
store_until=store_until,
)
db_archive = worker_crud.store_archived_url(db, archive)
logger.debug(f"[MANUAL ARCHIVE STORED] {db_archive.author_id} {db_archive.url}")
return JSONResponse({"id": db_archive.id}, status_code=201)
except sqlalchemy.exc.IntegrityError as e:
log_error(e)
raise HTTPException(status_code=422, detail=f"Cannot insert into DB due to integrity error, likely duplicate urls.")

View File

@@ -0,0 +1,81 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse
from sqlalchemy import exc
from sqlalchemy.orm import Session
from app.web.db.user_state import UserState
from app.shared import schemas
from app.shared.task_messaging import get_celery
from app.web.security import get_user_state
from app.web.db import crud
from app.shared.db.database import get_db_dependency
sheet_router = APIRouter(prefix="/sheet", tags=["Google Spreadsheet operations"])
celery = get_celery()
@sheet_router.post("/create", status_code=201, summary="Store a new Google Sheet for regular archiving.")
def create_sheet(
sheet: schemas.SheetAdd,
user: UserState = Depends(get_user_state),
db: Session = Depends(get_db_dependency),
) -> schemas.SheetResponse:
if not user.in_group(sheet.group_id):
raise HTTPException(status_code=403, detail="User does not have access to this group.")
if not user.has_quota_monthly_sheets(sheet.group_id):
raise HTTPException(status_code=429, detail="User has reached their sheet quota for this group.")
if not user.is_sheet_frequency_allowed(sheet.group_id, sheet.frequency):
raise HTTPException(status_code=422, detail="Invalid frequency selected for this group.")
try:
return crud.create_sheet(db, sheet.id, sheet.name, user.email, sheet.group_id, sheet.frequency)
except exc.IntegrityError as e:
raise HTTPException(status_code=400, detail="Sheet with this ID is already being archived.") from e
@sheet_router.get("/mine", status_code=200, summary="Get the authenticated user's Google Sheets.")
def get_user_sheets(
user: UserState = Depends(get_user_state),
db: Session = Depends(get_db_dependency)
) -> list[schemas.SheetResponse]:
return crud.get_user_sheets(db, user.email)
@sheet_router.delete("/{id}", summary="Delete a Google Sheet by ID.")
def delete_sheet(
id: str,
user: UserState = Depends(get_user_state),
db: Session = Depends(get_db_dependency),
) -> schemas.DeleteResponse:
return JSONResponse({
"id": id,
"deleted": crud.delete_sheet(db, id, user.email)
})
@sheet_router.post("/{id}/archive", status_code=201, summary="Trigger an archiving task for a GSheet you own.", response_description="task_id for the archiving task.")
def archive_user_sheet(
id: str,
user: UserState = Depends(get_user_state),
db: Session = Depends(get_db_dependency),
) -> schemas.Task:
sheet = crud.get_user_sheet(db, user.email, sheet_id=id)
if not sheet:
raise HTTPException(status_code=403, detail="No access to this sheet.")
if not user.in_group(sheet.group_id):
raise HTTPException(status_code=403, detail="User does not have access to this group.")
if not user.can_manually_trigger(sheet.group_id):
raise HTTPException(status_code=429, detail="User cannot manually trigger sheet archiving in this group.")
group_queue = user.priority_group(sheet.group_id)
task = celery.signature("create_sheet_task", args=[schemas.SubmitSheet(sheet_id=id, author_id=user.email, group_id=sheet.group_id).model_dump_json()]).apply_async(**group_queue)
return JSONResponse({"id": task.id}, status_code=201)

View File

@@ -3,21 +3,19 @@ from fastapi import APIRouter, Depends
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from loguru import logger
from web.security import get_token_or_user_auth
from db import schemas
from core.logging import log_error
from worker.main import celery
from utils.mics import custom_jsonable_encoder
from app.shared.task_messaging import get_celery
from app.web.security import get_token_or_user_auth
from app.shared import schemas
from app.shared.log import log_error
from app.web.utils.misc import custom_jsonable_encoder
task_router = APIRouter(prefix="/task", tags=["Async task operations"])
celery = get_celery()
@task_router.get("/{task_id}", summary="Check the status of an async task by its id, works for URLs and Sheet tasks.")
def get_status(task_id, email=Depends(get_token_or_user_auth)) -> schemas.TaskResult:
logger.info(f"status check for user {email} task {task_id}")
task = AsyncResult(task_id, app=celery)
try:
if task.status == "FAILURE":

84
app/web/endpoints/url.py Normal file
View File

@@ -0,0 +1,84 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse
from datetime import datetime
from loguru import logger
from sqlalchemy.orm import Session
from app.web.config import ALLOW_ANY_EMAIL
from app.shared import schemas
from app.shared.task_messaging import get_celery
from app.web.security import get_token_or_user_auth, get_user_state
from app.web.db import crud
from app.web.db.user_state import UserState
from app.shared.db.database import get_db_dependency
from urllib.parse import urlparse
from app.web.utils.misc import convert_priority_to_queue_dict
url_router = APIRouter(prefix="/url", tags=["Single URL operations"])
celery = get_celery()
@url_router.post("/archive", status_code=201, summary="Submit a single URL archive request, starts an archiving task.", response_description="task_id for the archiving task, will match the archive id.")
def archive_url(
archive: schemas.ArchiveTrigger,
email=Depends(get_token_or_user_auth),
db: Session = Depends(get_db_dependency)
) -> schemas.Task:
archive.author_id = email
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {archive.url}")
parsed_url = urlparse(archive.url)
if not all([parsed_url.scheme, parsed_url.netloc]):
raise HTTPException(status_code=400, detail="Invalid URL received.")
archive_create = schemas.ArchiveCreate(**archive.model_dump())
if email != ALLOW_ANY_EMAIL:
user = UserState(db, email)
if archive.group_id and not user.in_group(archive.group_id):
raise HTTPException(status_code=403, detail="User does not have access to this group.")
if not user.has_quota_max_monthly_urls(archive.group_id):
raise HTTPException(status_code=429, detail="User has reached their monthly URL quota.")
if not user.has_quota_max_monthly_mbs(archive.group_id):
raise HTTPException(status_code=429, detail="User has reached their monthly MB quota.")
group_queue = user.priority_group(archive_create.group_id)
else:
group_queue = convert_priority_to_queue_dict("high")
task = celery.signature("create_archive_task", args=[archive_create.model_dump_json()]).apply_async(**group_queue)
task_response = schemas.Task(id=task.id)
return JSONResponse(task_response.model_dump(), status_code=201)
@url_router.get("/search", summary="Search for archive entries by URL.")
def search_by_url(
url: str, skip: int = 0, limit: int = 25,
archived_after: datetime = None, archived_before: datetime = None,
db: Session = Depends(get_db_dependency),
email: str = Depends(get_token_or_user_auth)
) -> list[schemas.ArchiveResult]:
read_groups, read_public = False, False
if email != ALLOW_ANY_EMAIL:
user = UserState(db, email)
if not user.read and not user.read_public:
raise HTTPException(status_code=403, detail="User does not have read access.")
read_groups = user.read
read_public = user.read_public
return crud.search_archives_by_url(db, url.strip(), email, read_groups, read_public, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before)
@url_router.delete("/{id}", summary="Delete a single URL archive by id.")
def delete_archive(
id:str,
user: UserState = Depends(get_user_state),
db: Session = Depends(get_db_dependency)
) -> schemas.DeleteResponse:
logger.info(f"deleting url archive task {id} request by {user.email}")
return JSONResponse({
"id": id,
"deleted": crud.soft_delete_archive(db, id, user.email)
})

186
app/web/events.py Normal file
View File

@@ -0,0 +1,186 @@
import asyncio
from collections import defaultdict
import datetime
import logging
import alembic.config
from fastapi import FastAPI
from contextlib import asynccontextmanager
from fastapi_utils.tasks import repeat_every
from loguru import logger
from fastapi_mail import FastMail, MessageSchema, MessageType
from app.shared.db import models
from app.shared.db.database import get_db, get_db_async, make_engine, wal_checkpoint
from app.shared import schemas
from app.shared.settings import get_settings
from app.shared.task_messaging import get_celery
from app.web.db import crud
from app.web.middleware import increase_exceptions_counter
from app.web.utils.metrics import measure_regular_metrics, redis_subscribe_worker_exceptions
celery = get_celery()
@asynccontextmanager
async def lifespan(app: FastAPI):
# see https://fastapi.tiangolo.com/advanced/events/#lifespan
# STARTUP
engine = make_engine(get_settings().DATABASE_PATH)
models.Base.metadata.create_all(bind=engine)
alembic.config.main(prog="alembic", argv=['--raiseerr', 'upgrade', 'head'])
logging.getLogger("uvicorn.access").disabled = True # loguru
asyncio.create_task(redis_subscribe_worker_exceptions(get_settings().REDIS_EXCEPTIONS_CHANNEL))
asyncio.create_task(repeat_measure_regular_metrics())
with get_db() as db:
crud.upsert_user_groups(db)
# setup archive cronjobs
if get_settings().CRON_ARCHIVE_SHEETS:
asyncio.create_task(archive_hourly_sheets_cronjob())
asyncio.create_task(archive_daily_sheets_cronjob())
else:
logger.warning("[CRON] Sheet archive cronjobs are disabled.")
if get_settings().CRON_DELETE_STALE_SHEETS:
asyncio.create_task(delete_stale_sheets())
else:
logger.warning("[CRON] Delete stale sheets cronjob is disabled.")
if get_settings().CRON_DELETE_SCHEDULED_ARCHIVES:
asyncio.create_task(notify_about_expired_archives())
else:
logger.warning("[CRON] Delete scheduled archives cronjob is disabled.")
wal_checkpoint()
yield # separates startup from shutdown instructions
# SHUTDOWN
logger.info("shutting down")
# CRON JOBS
@repeat_every(seconds=get_settings().REPEAT_COUNT_METRICS_SECONDS, on_exception=increase_exceptions_counter)
async def repeat_measure_regular_metrics():
await measure_regular_metrics(get_settings().DATABASE_PATH, get_settings().REPEAT_COUNT_METRICS_SECONDS)
@repeat_every(seconds=60, wait_first=120, on_exception=increase_exceptions_counter)
async def archive_hourly_sheets_cronjob():
await archive_sheets_cronjob("hourly", 60, datetime.datetime.now().minute)
@repeat_every(seconds=3600, wait_first=120, on_exception=increase_exceptions_counter)
async def archive_daily_sheets_cronjob():
await archive_sheets_cronjob("daily", 24, datetime.datetime.now().hour)
async def archive_sheets_cronjob(frequency: str, interval: int, current_time_unit: int):
triggered_jobs = []
async with get_db_async() as db:
sheets = await crud.get_sheets_by_id_hash(db, frequency, interval, current_time_unit)
for s in sheets:
group_queue = await crud.get_group_priority_async(db, s.group_id)
task = celery.signature("create_sheet_task", args=[schemas.SubmitSheet(sheet_id=s.id, author_id=s.author_id, group_id=s.group_id).model_dump_json()]).apply_async(**group_queue)
triggered_jobs.append({"sheet_id": s.id, "task_id": task.id})
logger.debug(f"[CRON {frequency.upper()}:{current_time_unit}] Triggered {len(triggered_jobs)} sheet tasks: {triggered_jobs}")
# TODO: on exception should logerror but also prometheus counter
DELETE_WINDOW = get_settings().DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS * 24 * 60 * 60
@repeat_every(seconds=DELETE_WINDOW, wait_first=180, on_exception=increase_exceptions_counter)
async def notify_about_expired_archives():
notify_from = datetime.datetime.now() + datetime.timedelta(days=get_settings().DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS)
async with get_db_async() as db:
scheduled_deletions = await crud.find_by_store_until(db, notify_from)
user_archives = defaultdict(list)
for archive in scheduled_deletions:
user_archives[archive.author_id].append(archive)
if user_archives:
fastmail = FastMail(get_settings().MAIL_CONFIG)
# notify users
for email in user_archives:
list_of_archives = "\n".join([f'{a.url}, {a.id}, {a.store_until.isoformat()}<br/>' for a in user_archives[email]])
# TODO: how can users download them in bulk?
message = MessageSchema(
subject="Auto Archiver: Archives Scheduled for Deletion",
recipients=[email],
body=f"""
<html>
<body>
<p>Hi {email},</p>
<p>Some of your archives will be deleted in the next {get_settings().DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS} days, as they are reaching their expiration date according to our retention policy for their groups.</p>
<p>If you want to preserve any, make sure to download them now.</p>
<p>Here is a CSV list of URLs:</p>
<code>
url,archive_id,time_of_deletion<br/>
{list_of_archives}
</code>
<p>Best,<br>The Auto Archiver team</p>
</body>
</html>
""",
subtype=MessageType.html
)
await fastmail.send_message(message)
logger.debug(f"[CRON] Email sent to {email} about {len(user_archives[email])} scheduled archives deletion.")
# now schedule the deletion event
asyncio.create_task(delete_expired_archives())
@repeat_every(max_repetitions=1, wait_first=10, seconds=0, on_exception=increase_exceptions_counter)
async def delete_expired_archives():
async with get_db_async() as db:
count_deleted = await crud.soft_delete_expired_archives(db)
if count_deleted:
logger.debug(f"[CRON] Deleted {count_deleted} archives.")
@repeat_every(seconds=86400, wait_first=150, on_exception=increase_exceptions_counter)
async def delete_stale_sheets():
STALE_DAYS = get_settings().DELETE_STALE_SHEETS_DAYS
logger.debug(f"[CRON] Deleting stale sheets older than {STALE_DAYS} days.")
async with get_db_async() as db:
user_sheets = await crud.delete_stale_sheets(db, STALE_DAYS)
if not user_sheets: return
fastmail = FastMail(get_settings().MAIL_CONFIG)
# notify users
for email in user_sheets:
list_of_sheets = "\n".join([f'<li><a href="https://docs.google.com/spreadsheets/d/{s.id}">{s.name}</a></li>' for s in user_sheets[email]])
message = MessageSchema(
subject="Auto Archiver: Stale Sheets Removed",
recipients=[email],
body=f"""
<html>
<body>
<p>Hi {email},</p>
<p>Your stale sheets have been removed from our system as no new URL was archived in the past {STALE_DAYS} days:</p>
<ul>
{list_of_sheets}
</ul>
<p>You can always re-add them at https://auto-archiver.bellingcat.com/.</p>
<p>Best,<br>The Auto Archiver team</p>
</body>
</html>
""",
subtype=MessageType.html
)
await fastmail.send_message(message)
logger.debug(f"[CRON] Email sent to {email} about stale sheets deletion.")
# @repeat_at
async def generate_users_export_csv():
#TODO: implement a cronjob that regularly requested user data to a CSV file
# see https://colab.research.google.com/drive/1QDbo3QXHPBdiTuANlA1AWVvN-rqxuCPa?authuser=0#scrollTo=4nPXeSdK8RBT
pass

60
app/web/main.py Normal file
View File

@@ -0,0 +1,60 @@
import os
from fastapi import FastAPI, Depends
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from prometheus_fastapi_instrumentator import Instrumentator
from loguru import logger
from app.web.middleware import logging_middleware
from app.shared.task_messaging import get_celery
from app.web.security import token_api_key_auth
from app.web.config import VERSION, API_DESCRIPTION
from app.web.events import lifespan
from app.shared.settings import get_settings
from app.web.endpoints.default import default_router
from app.web.endpoints.url import url_router
from app.web.endpoints.sheet import sheet_router
from app.web.endpoints.task import task_router
from app.web.endpoints.interoperability import interoperability_router
celery = get_celery()
def app_factory(settings = get_settings()):
app = FastAPI(
title="Auto-Archiver API",
description=API_DESCRIPTION,
version=VERSION,
contact={"name": "GitHub", "url": "https://github.com/bellingcat/auto-archiver-api"},
lifespan=lifespan
)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.middleware("http")(logging_middleware)
app.include_router(default_router)
app.include_router(url_router)
app.include_router(sheet_router)
app.include_router(task_router)
app.include_router(interoperability_router)
# prometheus exposed in /metrics with authentication
Instrumentator(should_group_status_codes=False, excluded_handlers=["/metrics", "/health", "/openapi.json", "/favicon.ico"]).instrument(app).expose(app, dependencies=[Depends(token_api_key_auth)])
if settings.SERVE_LOCAL_ARCHIVE:
local_dir = settings.SERVE_LOCAL_ARCHIVE
if not os.path.isdir(local_dir) and os.path.isdir(local_dir.replace("/app", ".")):
local_dir = local_dir.replace("/app", ".")
if len(settings.SERVE_LOCAL_ARCHIVE) > 1 and os.path.isdir(local_dir):
logger.warning(f"MOUNTing local archive, use this in development only {settings.SERVE_LOCAL_ARCHIVE}")
app.mount(settings.SERVE_LOCAL_ARCHIVE, StaticFiles(directory=local_dir), name=settings.SERVE_LOCAL_ARCHIVE)
return app

31
app/web/middleware.py Normal file
View File

@@ -0,0 +1,31 @@
import traceback
from loguru import logger
from fastapi import Request
from app.shared.log import log_error
from app.web.utils.metrics import EXCEPTION_COUNTER
async def logging_middleware(request: Request, call_next):
try:
response = await call_next(request)
#TODO: use Origin to have summary prometheus metrics on where requests come from
# origin = request.headers.get("origin")
logger.info(f"{request.client.host}:{request.client.port} {request.method} {request.url._url} - HTTP {response.status_code}")
return response
except Exception as e:
location = f"{request.method} {request.url._url}"
await increase_exceptions_counter(e, location)
logger.info(f"{request.client.host}:{request.client.port} {location} - {e.__class__.__name__} {e}")
raise e
async def increase_exceptions_counter(e: Exception, location:str="cronjob"):
if location == "cronjob":
try:
last_trace = traceback.extract_tb(e.__traceback__)[-1]
_file, _line, func_name, _text = last_trace
location = func_name
except Exception as e:
logger.error(f"Unable to get function name from cronjob exception traceback: {e}")
EXCEPTION_COUNTER.labels(type=e.__class__.__name__, location=location).inc()
log_error(e)

View File

@@ -2,8 +2,12 @@ from loguru import logger
import requests, secrets
from fastapi import HTTPException, status, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from core.config import ALLOW_ANY_EMAIL
from shared.settings import get_settings
from sqlalchemy.orm import Session
from app.web.config import ALLOW_ANY_EMAIL
from app.shared.settings import get_settings
from app.shared.db.database import get_db_dependency
from app.web.db.user_state import UserState
settings = get_settings()
bearer_security = HTTPBearer()
@@ -45,7 +49,7 @@ async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bear
# validates the Bearer token in the case that it requires it
valid_user, info = authenticate_user(credentials.credentials)
if valid_user:
return info
return info.lower()
logger.debug(f"TOKEN FAILURE: {valid_user=} {info=}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -69,7 +73,11 @@ def authenticate_user(access_token):
return False, f"email '{j.get('email')}' not verified"
if int(j.get("expires_in", -1)) <= 0:
return False, "Token expired"
return True, j.get('email')
return True, j.get('email').lower()
except Exception as e:
logger.warning(f"AUTH EXCEPTION occurred: {e}")
return False, "exception occurred"
def get_user_state(email:str=Depends(get_user_auth), db:Session=Depends(get_db_dependency)):
return UserState(db, email)

View File

Before

Width:  |  Height:  |  Size: 93 KiB

After

Width:  |  Height:  |  Size: 93 KiB

View File

@@ -3,23 +3,23 @@ import json
import os
import shutil
from prometheus_client import Counter, Gauge
import redis
from db import crud
from db.database import get_db
from core.logging import log_error
from app.web.db import crud
from app.shared.db.database import get_db
from app.shared.log import log_error
from app.shared.task_messaging import get_redis
# Custom metrics
EXCEPTION_COUNTER = Counter(
"exceptions",
"Number of times a certain exception has occurred.",
labelnames=["types"]
labelnames=["type", "location"]
)
WORKER_EXCEPTION = Counter(
"worker_exceptions_total",
"Number of times a certain exception has occurred on the worker.",
labelnames=["types", "exception", "task", "traceback"]
labelnames=["type", "exception", "task", "traceback"]
)
DISK_UTILIZATION = Gauge(
"disk_utilization",
@@ -38,16 +38,16 @@ DATABASE_METRICS_COUNTER = Counter(
)
async def redis_subscribe_worker_exceptions(REDIS_EXCEPTIONS_CHANNEL, CELERY_BROKER_URL):
async def redis_subscribe_worker_exceptions(REDIS_EXCEPTIONS_CHANNEL: str):
# Subscribe to Redis channel and increment the counter for each exception with info on the exception and task
Rdis = redis.Redis.from_url(CELERY_BROKER_URL)
PubSubExceptions = Rdis.pubsub()
Redis = get_redis()
PubSubExceptions = Redis.pubsub()
PubSubExceptions.subscribe(REDIS_EXCEPTIONS_CHANNEL)
while True:
message = PubSubExceptions.get_message()
if message and message["type"] == "message":
data = json.loads(message["data"].decode("utf-8"))
WORKER_EXCEPTION.labels(types=type(data["exception"]).__name__, exception=data["exception"], task=data["task"], traceback=data["traceback"]).inc()
WORKER_EXCEPTION.labels(type=data["type"], exception=data["exception"], task=data["task"], traceback=data["traceback"]).inc()
await asyncio.sleep(1)

15
app/web/utils/misc.py Normal file
View File

@@ -0,0 +1,15 @@
import base64
from fastapi.encoders import jsonable_encoder
def custom_jsonable_encoder(obj):
if isinstance(obj, bytes):
return base64.b64encode(obj).decode('utf-8')
return jsonable_encoder(obj)
def convert_priority_to_queue_dict(priority: str) -> dict:
return {
"priority": 0 if priority == "high" else 10,
"queue": f"{priority}_priority"
}

0
app/worker/__init__.py Normal file
View File

147
app/worker/main.py Normal file
View File

@@ -0,0 +1,147 @@
import json
import traceback, datetime
from celery.signals import task_failure
from loguru import logger
from sqlalchemy import exc
from auto_archiver.core.orchestrator import ArchivingOrchestrator
from app.shared.db import models
from app.shared.db.database import get_db
from app.shared import business_logic, schemas
from app.shared.task_messaging import get_celery, get_redis
from app.shared.settings import get_settings
from app.shared.log import log_error
from app.shared.aa_utils import get_all_urls
from app.shared.db import worker_crud
settings = get_settings()
celery = get_celery("worker")
Redis = get_redis()
USER_GROUPS_FILENAME = settings.USER_GROUPS_FILENAME
# TODO: these are temporary PATCHES for new aa's functionality
# logger.add("app/worker/worker_log.log", level="DEBUG")
logger.remove = lambda x: print(f"logger.remove({x})")
# TODO: after release, as it requires updating past entries with sheet_id where tag is used, drop tags
@celery.task(name="create_archive_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 1})
def create_archive_task(self, archive_json: str):
archive = schemas.ArchiveCreate.model_validate_json(archive_json)
# call auto-archiver
args = get_orchestrator_args(archive.group_id, False, [archive.url])
try:
orchestrator = ArchivingOrchestrator()
orchestrator.setup(args)
result = next(orchestrator.feed())
except SystemExit as e:
log_error(e, f"create_archive_task: SystemExit from AA")
except Exception as e:
log_error(e, f"create_archive_task")
raise e
assert result, f"UNABLE TO archive: {archive.url}"
# prepare and insert in DB
archive.store_until = get_store_until(archive.group_id)
archive.id = self.request.id
archive.urls = get_all_urls(result)
archive.result = json.loads(result.to_json())
insert_result_into_db(archive)
return archive.result
@celery.task(name="create_sheet_task", bind=True)
def create_sheet_task(self, sheet_json: str):
sheet = schemas.SubmitSheet.model_validate_json(sheet_json)
queue_name = (create_sheet_task.request.delivery_info or {}).get('routing_key', 'unknown')
logger.info(f"[queue={queue_name}] SHEET START {sheet=}")
args = get_orchestrator_args(sheet.group_id, True, ["--gsheet_feeder.sheet_id", sheet.sheet_id])
orchestrator = ArchivingOrchestrator()
orchestrator.setup(args)
stats = {"archived": 0, "failed": 0, "errors": []}
try:
for result in orchestrator.feed():
try:
assert result, f"ERROR archiving URL for sheet {sheet.sheet_id}"
archive = schemas.ArchiveCreate(
author_id=sheet.author_id,
url=result.get_url(),
group_id=sheet.group_id,
tags=sheet.tags,
id=models.generate_uuid(),
result=json.loads(result.to_json()),
sheet_id=sheet.sheet_id,
urls=get_all_urls(result),
store_until=get_store_until(sheet.group_id)
)
insert_result_into_db(archive)
stats["archived"] += 1
except exc.IntegrityError as e:
logger.warning(f"cached result detected: {e}")
except Exception as e:
log_error(e, extra=f"{self.name}: {sheet_json}")
redis_publish_exception(e, self.name, traceback.format_exc())
stats["failed"] += 1
stats["errors"].append(str(e))
except SystemExit as e:
log_error(e, f"create_sheet_task: SystemExit from AA")
if stats["archived"] > 0:
with get_db() as session:
worker_crud.update_sheet_last_url_archived_at(session, sheet.sheet_id)
logger.info(f"SHEET DONE {sheet=}")
# TODO: is this used anywhere? maybe drop it
return schemas.CelerySheetTask(success=True, sheet_id=sheet.sheet_id, time=datetime.datetime.now().isoformat(), stats=stats).model_dump()
def get_orchestrator_args(group_id: str, orchestrator_for_sheet: bool, cli_args: list = []) -> list:
aa_configs = []
with get_db() as session:
group = worker_crud.get_group(session, group_id)
if orchestrator_for_sheet:
orchestrator_fn = group.orchestrator_sheet
else:
orchestrator_fn = worker_crud.get_group(session, group_id).orchestrator
assert orchestrator_fn, f"no orchestrator found for {group_id}"
aa_configs.extend(["--config", orchestrator_fn])
aa_configs.extend(cli_args)
return aa_configs
def insert_result_into_db(archive: schemas.ArchiveCreate) -> str:
with get_db() as session:
db_archive = worker_crud.store_archived_url(session, archive)
logger.debug(f"[ARCHIVE STORED] {db_archive.author_id} {db_archive.url}")
return db_archive.id
def get_store_until(group_id: str) -> datetime.datetime:
with get_db() as session:
return business_logic.get_store_archive_until(session, group_id)
def redis_publish_exception(exception, task_name, traceback: str = ""):
REDIS_EXCEPTIONS_CHANNEL = settings.REDIS_EXCEPTIONS_CHANNEL
try:
exception_data = {"task": task_name, "type": exception.__class__.__name__, "exception": exception, "traceback": traceback}
Redis.publish(REDIS_EXCEPTIONS_CHANNEL, json.dumps(exception_data, default=str))
except Exception as e:
log_error(e, f"[CRITICAL] Could not publish to {REDIS_EXCEPTIONS_CHANNEL}")
@task_failure.connect(sender=create_sheet_task)
@task_failure.connect(sender=create_archive_task)
def task_failure_notifier(sender, **kwargs):
# automatically capture exceptions in the worker tasks
logger.warning(f"⚠️ worker task failed: {sender.name}")
traceback_msg = "\n".join(traceback.format_list(traceback.extract_tb(kwargs['traceback'])))
log_error(kwargs['exception'], traceback_msg, f"task_failure: {sender.name}")
redis_publish_exception(kwargs['exception'], sender.name, traceback_msg)

0
database/.gitkeep Normal file
View File

View File

@@ -1,19 +1,30 @@
services:
web:
command: uvicorn app.web:app --factory --host 0.0.0.0 --reload
restart: "no"
env_file: src/.env.dev
env_file: .env.dev
volumes:
- ./app/web:/aa-api/app/web # for --reload to work
- ./app/shared:/aa-api/app/shared # for --reload to work
environment:
- SERVE_LOCAL_ARCHIVE=/app/local_archive # See orchestration.yaml local_storage.save_to
- ALLOWED_ORIGINS=http://localhost:8004,chrome-extension://ojcimmjndnlmmlgnjaeojoebaceokpdp
- USER_GROUPS_FILENAME=user-groups.dev.yaml
- DATABASE_PATH=sqlite:////app/auto-archiver.db
- ENVIRONMENT_FILE=.env.dev
- SERVE_LOCAL_ARCHIVE=/aa-api/app/local_archive # See orchestration.yaml local_storage.save_to
- ALLOWED_ORIGINS=["http://localhost:8000","http://localhost:8004","http://localhost:8081","chrome-extension://ojcimmjndnlmmlgnjaeojoebaceokpdp"]
- USER_GROUPS_FILENAME=/aa-api/app/user-groups.dev.yaml
- DATABASE_PATH=sqlite:////aa-api/database/auto-archiver.db
worker:
# command: watchmedo auto-restart --patterns="*.py" --recursive --ignore-directories -- celery -- --app=app.worker.main.celery worker --loglevel=debug --logfile=/aa-api/logs/celery.log -Q high_priority,low_priority --concurrency=${CONCURRENCY}
command: celery --app=app.worker.main.celery worker --loglevel=debug --logfile=/aa-api/logs/celery.log -Q high_priority,low_priority --concurrency=${CONCURRENCY}
restart: "no"
env_file: src/.env.dev
env_file: .env.dev
volumes:
- ./app/worker:/aa-api/app/worker # for watchmedo to work
- ./app/shared:/aa-api/app/shared # for watchmedo to work
redis:
restart: "no"
env_file: src/.env.dev
env_file: .env.dev
ports:
- 6379:6379

View File

@@ -1,13 +1,3 @@
# reusable YAML variables
x-broker-url: &broker-url "redis://:${REDIS_PASSWORD}@redis:6379/0"
x-base-setup: &base-setup
build: ./src
restart: always
env_file: src/.env.prod
environment:
CELERY_BROKER_URL: *broker-url
CELERY_RESULT_BACKEND: *broker-url
volumes:
crawls:
@@ -15,31 +5,45 @@ volumes:
name: "auto-archiver-api"
services:
web:
<<: *base-setup
build:
context: .
dockerfile: web.Dockerfile
restart: always
env_file: .env.prod
environment:
ENVIRONMENT_FILE: .env.prod
REDIS_HOSTNAME: redis
ports:
- "127.0.0.1:8004:8000"
command: uvicorn web:app --factory --host 0.0.0.0 --reload
command: uvicorn app.web:app --factory --host 0.0.0.0
volumes:
- ./src:/app
- ./logs:/aa-api/logs
- ./database:/aa-api/database
- ./secrets:/aa-api/secrets
depends_on:
- redis
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
test: ["CMD", "python3", "-c", 'import sys, urllib.request; sys.exit(urllib.request.urlopen("http://localhost:8000/health").getcode() != 200)']
interval: 30s
timeout: 10s
retries: 3
worker:
<<: *base-setup
command: celery --app=worker.main.celery worker --loglevel=info --logfile=logs/celery.log
build:
context: .
dockerfile: worker.Dockerfile
restart: always
env_file: .env.prod
command: celery --app=app.worker.main.celery worker --loglevel=warning --logfile=/aa-api/logs/celery.log -Q high_priority,low_priority --concurrency=${CONCURRENCY}
volumes:
- ./src:/app
- ./logs:/aa-api/logs
- ./database:/aa-api/database
- ./secrets:/aa-api/secrets
- /var/run/docker.sock:/var/run/docker.sock
- crawls:/crawls # BROWSERTRIX_HOME_HOST:BROWSERTRIX_HOME_CONTAINER, do not change /crawls
environment:
# celery broker-url needs to be duplicated here, do not remove
CELERY_BROKER_URL: *broker-url
CELERY_RESULT_BACKEND: *broker-url
REDIS_HOSTNAME: redis
ENVIRONMENT_FILE: .env.prod
WACZ_ENABLE_DOCKER: 1 # Enable calling docker from this container
BROWSERTRIX_HOME_HOST: auto-archiver-api_crawls
BROWSERTRIX_HOME_CONTAINER: /crawls
@@ -47,7 +51,7 @@ services:
- web
- redis
healthcheck:
test: ["CMD", "pipenv", "run", "celery", "-A", "worker.main.celery", "status"]
test: ["CMD-SHELL", "./poetry-venv/bin/poetry run celery -A app.worker.main.celery inspect ping || exit 1"]
interval: 30s
timeout: 10s
retries: 3
@@ -55,10 +59,11 @@ services:
redis:
image: redis:6-alpine
restart: always
env_file: .env.prod
command: redis-server /conf/redis.conf --requirepass ${REDIS_PASSWORD}
volumes:
- "./redis/data:/data"
- "./redis/config:/conf"
- ./redis/data:/data
- ./redis/config:/conf
healthcheck:
test: ["CMD", "redis-cli", "-a", "${REDIS_PASSWORD}", "ping"]
interval: 30s

0
logs/.gitkeep Normal file
View File

3690
poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

55
pyproject.toml Normal file
View File

@@ -0,0 +1,55 @@
[tool.poetry]
package-mode = false
[project]
name = "auto-archiver-api"
description = "API wrapper for Bellingcat's Auto Archiver, supports users, groups, sheet and url archives."
authors = [
{ name = "Bellingcat", email = "contact-tech@bellingcat.com" },
]
license = {text = "MIT"}
readme = "README.md"
keywords = ["archive", "oosi", "osint", "scraping"]
classifiers = [
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3"
]
requires-python = ">=3.10,<3.13"
dependencies = [
"auto-archiver (>=0.13.1)",
"oscrypto @ git+https://github.com/wbond/oscrypto.git@d5f3437ed24257895ae1edd9e503cfb352e635a8",
"celery (>=5.0)",
"redis (==3.5.3)",
"loguru (>=0.7.3,<0.8.0)",
"pydantic-settings (>=2.7.1,<3.0.0)",
"sqlalchemy (>=2.0.38,<3.0.0)",
"requests (>=2.25.1)",
"pyopenssl (>=23.3.0)",
]
[tool.poetry.group.worker.dependencies]
watchdog = ">=6.0.0,<7.0.0"
setuptools = "^75.8.0"
[tool.poetry.group.web.dependencies]
fastapi = ">=0.115.8,<0.116.0"
requests = ">=2.32.3,<3.0.0"
aiosqlite = ">=0.21.0,<0.22.0"
alembic = ">=1.14.1,<2.0.0"
fastapi-utils = ">=0.8.0,<0.9.0"
prometheus-fastapi-instrumentator = ">=7.0.2,<8.0.0"
fastapi-mail = ">=1.4.2,<2.0.0"
uvicorn = ">=0.13.4"
pyyaml = "^6.0.2"
[tool.poetry.group.dev.dependencies]
pytest = ">=8.3.4,<9.0.0"
httpx = ">=0.28.1,<0.29.0"
coverage = ">=7.6.11,<8.0.0"
pytest-asyncio = ">=0.25.3,<0.26.0"

View File

@@ -1,7 +0,0 @@
DATABASE_PATH="sqlite:///./auto-archiver.db"
USER_GROUPS_FILENAME=user-groups.yaml
CHROME_APP_IDS=000000000000000000000000000000000000000000000.apps.googleusercontent.com,000000000000000000000000000000000000000000001.apps.googleusercontent.com
#ALLOWED_ORIGINS="http://localhost:8004" # dev only
API_BEARER_TOKEN=TODO

View File

@@ -1,22 +0,0 @@
# From python:3.10
FROM bellingcat/auto-archiver
# set work directory
WORKDIR /app
RUN curl -fsSL https://get.docker.com -o get-docker.sh && \
sh get-docker.sh
# set environment variables
ENV PYTHONUNBUFFERED=1
ENV PYTHONDONTWRITEBYTECODE=1
# install dependencies
RUN pip install --upgrade pip && \
apt-get update
COPY Pipfile* ./
RUN pipenv install
# copy src code over
COPY . .
ENTRYPOINT ["pipenv", "run"]

View File

@@ -1,32 +0,0 @@
[[source]]
url = "https://pypi.org/simple"
verify_ssl = true
name = "pypi"
[packages]
aiofiles = "==0.6.0"
celery = ">=5.0"
fastapi = "*"
jinja2 = "*"
redis = "==3.5.3"
requests = ">=2.25.1"
uvicorn = ">=0.13.4"
aiosqlite = "*"
python-dotenv = "*"
loguru = "*"
sqlalchemy = "*"
alembic = "*"
fastapi-utils = "*"
prometheus-fastapi-instrumentator = "*"
auto-archiver = "*"
pydantic-settings = "*"
[dev-packages]
watchdog = "*"
pytest = "*"
httpx = "*"
coverage = "*"
pytest-asyncio = "*"
[requires]
python_version = "3.10"

3517
src/Pipfile.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,41 +0,0 @@
import asyncio
import logging
import alembic.config
from fastapi import FastAPI
from contextlib import asynccontextmanager
from fastapi_utils.tasks import repeat_every
from loguru import logger
from db import crud, models
from db.database import get_db, make_engine
from shared.settings import get_settings
from utils.metrics import measure_regular_metrics, redis_subscribe_worker_exceptions
@asynccontextmanager
async def lifespan(app: FastAPI):
# see https://fastapi.tiangolo.com/advanced/events/#lifespan
# STARTUP
engine = make_engine(get_settings().DATABASE_PATH)
models.Base.metadata.create_all(bind=engine)
alembic.config.main(argv=['--raiseerr', 'upgrade', 'head'])
# disabling uvicorn logger since we use loguru in logging_middleware
logging.getLogger("uvicorn.access").disabled = True
asyncio.create_task(redis_subscribe_worker_exceptions(get_settings().REDIS_EXCEPTIONS_CHANNEL, get_settings().CELERY_BROKER_URL))
asyncio.create_task(repeat_measure_regular_metrics())
with get_db() as db:
crud.upsert_user_groups(db)
yield # separates startup from shutdown instructions
# SHUTDOWN
logger.info("shutting down")
# CRON JOBS
@repeat_every(seconds=get_settings().REPEAT_COUNT_METRICS_SECONDS)
async def repeat_measure_regular_metrics():
await measure_regular_metrics(get_settings().DATABASE_PATH, get_settings().REPEAT_COUNT_METRICS_SECONDS)

View File

@@ -1,26 +0,0 @@
import traceback
from loguru import logger
from fastapi import Request
# logging configurations
logger.add("logs/api_logs.log", retention="30 days", rotation="3 days")
logger.add("logs/error_logs.log", retention="30 days", level="ERROR")
def log_error(e: Exception, traceback_str: str = None, extra:str = ""):
# EXCEPTION_COUNTER.labels(type(e).__name__).inc()
if not traceback_str: traceback_str = traceback.format_exc()
if extra: extra = f"{extra}\n"
logger.error(f"{extra}{e.__class__.__name__}: {e}\n{traceback_str}")
async def logging_middleware(request: Request, call_next):
try:
response = await call_next(request)
logger.info(f"{request.client.host}:{request.client.port} {request.method} {request.url._url} - HTTP {response.status_code}")
return response
except Exception as e:
from utils.metrics import EXCEPTION_COUNTER
EXCEPTION_COUNTER.labels(type(e).__name__).inc()
log_error(e)
raise e

View File

@@ -1 +0,0 @@
based on https://fastapi-users.github.io/fastapi-users/10.4/configuration/oauth/

View File

@@ -1 +0,0 @@
# https://fastapi.tiangolo.com/tutorial/sql-databases/#review-all-the-files

View File

@@ -1,259 +0,0 @@
from collections import defaultdict
from functools import cache
from sqlalchemy.orm import Session, load_only
from sqlalchemy import Column, or_, func
from loguru import logger
from datetime import datetime, timedelta
from core.config import ALLOW_ANY_EMAIL
from shared.settings import get_settings
from . import models, schemas
import yaml
DATABASE_QUERY_LIMIT = get_settings().DATABASE_QUERY_LIMIT
# --------------- TASK = Archive
def get_limit(user_limit: int):
return max(1, min(user_limit, DATABASE_QUERY_LIMIT))
def get_archive(db: Session, id: str, email: str):
email = email.lower()
query = base_query(db).filter(models.Archive.id == id)
if email != ALLOW_ANY_EMAIL:
groups = get_user_groups(db, email)
query = query.filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups)))
return query.first()
def search_archives_by_url(db: Session, url: str, email: str, skip: int = 0, limit: int = 100, archived_after: datetime = None, archived_before: datetime = None, absolute_search: bool = False):
# searches for partial URLs, if email is * no ownership filtering happens
query = base_query(db)
if email != ALLOW_ANY_EMAIL:
email = email.lower()
groups = get_user_groups(db, email)
query = query.filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups)))
if absolute_search:
query = query.filter(models.Archive.url == url)
else:
query = query.filter(models.Archive.url.like(f'%{url}%'))
if archived_after:
query = query.filter(models.Archive.created_at > archived_after)
if archived_before:
query = query.filter(models.Archive.created_at < archived_before)
return query.order_by(models.Archive.created_at.desc()).offset(skip).limit(get_limit(limit)).all()
def search_archives_by_email(db: Session, email: str, skip: int = 0, limit: int = 100):
email = email.lower()
return base_query(db).filter(models.Archive.author_id == email).order_by(models.Archive.created_at.desc()).offset(skip).limit(get_limit(limit)).all()
def create_task(db: Session, task: schemas.ArchiveCreate, tags: list[models.Tag], urls: list[models.ArchiveUrl]):
db_task = models.Archive(id=task.id, url=task.url, result=task.result, public=task.public, author_id=task.author_id, group_id=task.group_id)
db_task.tags = tags
db_task.urls = urls
db.add(db_task)
db.commit()
db.refresh(db_task)
return db_task
def soft_delete_task(db: Session, task_id: str, email: str) -> bool:
# TODO: implement hard-delete with cronjob that deletes from S3
db_task = db.query(models.Archive).filter(models.Archive.id == task_id, models.Archive.author_id == email, models.Archive.deleted == False).first()
if db_task:
db_task.deleted = True
db.commit()
return db_task is not None
def count_archives(db: Session):
return db.query(func.count(models.Archive.id)).scalar()
def count_archive_urls(db: Session):
return db.query(func.count(models.ArchiveUrl.url)).scalar()
def count_users(db: Session):
return db.query(func.count(models.User.email)).scalar()
def count_by_user_since(db: Session, seconds_delta: int = 15):
time_threshold = datetime.now() - timedelta(seconds=seconds_delta)
return db.query(models.Archive.author_id, func.count().label('total'))\
.filter(models.Archive.created_at >= time_threshold)\
.group_by(models.Archive.author_id)\
.order_by(func.count().desc())\
.limit(500).all()
def base_query(db: Session):
# NOTE: load_only is for optimization and not obfuscation, use .with_entities() if needed
return db.query(models.Archive)\
.filter(models.Archive.deleted == False)\
.options(load_only(models.Archive.id, models.Archive.created_at, models.Archive.url, models.Archive.result))
# --------------- TAG
def create_tag(db: Session, tag: str):
db_tag = db.query(models.Tag).filter(models.Tag.id == tag).first()
if not db_tag:
db_tag = models.Tag(id=tag)
db.add(db_tag)
db.commit()
db.refresh(db_tag)
return db_tag
def is_active_user(db: Session, email: str) -> bool:
email = email.lower()
if not email or not len(email) or "@" not in email: return False
domain = email.split('@')[1]
explicitly_active = db.query(models.User).filter(models.User.email == email, models.User.is_active == True).first() is not None
if explicitly_active: return True
return db.query(models.Group).filter(models.Group.domains.contains(domain)).first() is not None
def is_user_in_group(db: Session, group_name: str, email: str) -> models.Group:
if email == ALLOW_ANY_EMAIL: return True
return len(group_name) and len(email) and group_name in get_user_groups(db, email)
def get_user_groups(db: Session, email: str):
"""
given an email retrieves the user groups from the DB and then the email-domain groups from a global variable, the email does not need to belong to an existing user. User does not need to be active.
"""
if not email or not len(email) or "@" not in email: return []
email = email.lower()
# get user groups
user_groups = db.query(models.association_table_user_groups).filter_by(user_id=email).with_entities(Column("group_id")).all()
user_level_groups = [g[0] for g in user_groups]
# get domain groups
domain = email.split('@')[1]
domain_level_groups = db.query(models.Group.id).filter(models.Group.domains.contains(domain)).with_entities(Column("id")).all()
domain_level_groups = [g[0] for g in domain_level_groups]
# combine and return
return list(set(user_level_groups + domain_level_groups))
# --------------- INIT User-Groups
def create_or_get_user(db: Session, author_id: str, is_active: bool = models.User.is_active.default.arg) -> models.User:
if type(author_id) == str: author_id = author_id.lower()
db_user = db.query(models.User).filter(models.User.email == author_id).first()
if not db_user:
db_user = models.User(email=author_id, is_active=is_active)
db.add(db_user)
db.commit()
db.refresh(db_user)
return db_user
def upsert_group(db: Session, group_name: str, description: str, orchestrator: str, orchestrator_sheet: str, permissions: dict, domains: list) -> models.Group:
db_group = db.query(models.Group).filter(models.Group.id == group_name).first()
if db_group is None:
db_group = models.Group(id=group_name, description=description, orchestrator=orchestrator, orchestrator_sheet=orchestrator_sheet, permissions=permissions, domains=domains)
db.add(db_group)
else:
db_group.description = description
db_group.orchestrator = orchestrator
db_group.orchestrator_sheet = orchestrator_sheet
db_group.permissions = permissions
db_group.domains = domains
db.commit()
db.refresh(db_group)
return db_group
def upsert_user(db: Session, email: str, active: bool):
db_user = db.query(models.User).filter(models.User.email == email).first()
if db_user is None:
db_user = models.User(email=email, is_active=active)
db.add(db_user)
else:
db_user.is_active = active
db.commit()
return db_user
def upsert_user_groups(db: Session):
def display_email_pii(email: str):
return f"'{email[0:3]}...@{email.split('@')[1]}'"
"""
reads the user_groups yaml file and inserts any new users, groups,
along with new participation of users in groups
"""
logger.debug("Updating user-groups configuration.")
filename = get_settings().USER_GROUPS_FILENAME
# read yaml safely
try:
with open(filename) as inf:
user_groups_yaml = yaml.safe_load(inf)
except Exception as e:
logger.error(f"could not open user groups filename {filename}: {e}")
raise e
# delete all user-groups relationships
db.query(models.association_table_user_groups).delete()
# set all users to inactive
db.query(models.User).update({models.User.is_active: False})
# create a map of group_id -> domains and another of domain -> groups
group_domains = defaultdict(set)
domain_groups = defaultdict(list)
for domain, explicit_groups in user_groups_yaml.get("domains", {}).items():
domain_groups[domain] = list(set(explicit_groups))
for group in explicit_groups:
group_domains[group].add(domain)
# upsert groups and save a map of groupid -> dbobject
for group_id, g in user_groups_yaml.get("groups", {}).items():
upsert_group(db, group_id, g["description"], g["orchestrator"], g["orchestrator_sheet"], g["permissions"], list(group_domains.get(group_id, [])))
db_groups: dict[str, models.Group] = {g.id: g for g in db.query(models.Group).all()}
# integrity checks
for group_in_domains in group_domains:
if group_in_domains not in db_groups:
logger.error(f"[CONFIG] Group '{group_in_domains}' does not exist in the database: domains setting will not work.")
if group_in_domains not in user_groups_yaml.get("groups", {}):
logger.error(f"[CONFIG] Group '{group_in_domains}' does not exist in the config file: domains setting will not work.")
# reinsert users in their EXPLICITLY DEFINED groups
# domain groups are check live, as there may be new users that are not explicitly registered but belong to a domain
for email, explicit_groups in user_groups_yaml.get("users", {}).items():
explicit_groups = explicit_groups or []
email = email.lower().strip()
if '@' not in email:
logger.error(f'[CONFIG] Invalid user email {email}, skipping.')
continue
logger.info(f"{display_email_pii(email)} => {explicit_groups}")
# upsert active user
db_user = upsert_user(db, email, active=True)
# connect users to groups
for group_id in explicit_groups:
if group_id not in db_groups:
logger.error(f"[CONFIG] Group {group_id} does not exist in config file, skipping for email={display_email_pii(email)}.")
continue
db_groups[group_id].users.append(db_user)
db.commit()
count_user_groups = db.query(models.association_table_user_groups).count()
count_groups = db.query(func.count(models.Group.id)).scalar()
logger.success(f"[CONFIG] DONE: [users={count_users(db)}, groups={count_groups}, explicit user groups={count_user_groups}].")

View File

@@ -1,36 +0,0 @@
from functools import lru_cache
from sqlalchemy import Engine, create_engine, event
from sqlalchemy.orm import sessionmaker
from shared.settings import get_settings
from contextlib import contextmanager
@lru_cache
def make_engine(database_url: str):
engine = create_engine(database_url, connect_args={"check_same_thread": False})
@event.listens_for(engine, "connect")
def set_sqlite_pragma(conn, _) -> None:
cursor = conn.cursor()
cursor.execute("PRAGMA journal_mode=WAL")
cursor.close()
return engine
def make_session_local(engine: Engine):
session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
return session_local
@contextmanager
def get_db():
session = make_session_local(make_engine(get_settings().DATABASE_PATH))()
try: yield session
finally: session.close()
def get_db_dependency():
# to use with Depends and ensure proper session closing
with get_db() as db:
yield db

View File

@@ -1,66 +0,0 @@
from pydantic import BaseModel
from datetime import datetime
class Tag(BaseModel):
id: str
created_at: datetime
model_config = { "from_attributes": True }
__hash__ = object.__hash__
class ArchiveCreate(BaseModel):
id: str | None = None
url: str
result: dict | None = None
public: bool = True
author_id: str | None = None
group_id: str | None = None
tags: set[Tag] | None = set()
rearchive: bool = True
# urls: list = []
class Archive(ArchiveCreate):
created_at: datetime
updated_at: datetime | None
deleted: bool
model_config = { "from_attributes": True }
class SubmitSheet(BaseModel):
sheet_name: str | None = None
sheet_id: str | None = None
header: int = 1
public: bool = False
author_id: str | None = None
group_id: str | None = None
tags: set[str] | None = set()
columns: dict | None = {} # TODO: implement
class SubmitManual(BaseModel):
result: str # should be a Metadata.to_json()
public: bool = False
author_id: str | None = None
group_id: str | None = None
tags: set[str] | None = set()
# API RESPONSES BELOW
class ArchiveResult(BaseModel):
id: str
url: str
result: dict
created_at: datetime
class Task(BaseModel):
id: str
class TaskResult(Task):
status: str
result: str
class TaskDelete(Task):
deleted: bool
class ActiveUser(BaseModel):
active: bool

View File

@@ -1,5 +0,0 @@
from endpoints.default import default_router
from endpoints.url import url_router
from endpoints.task import task_router
from endpoints.interoperability import interoperability_router
from endpoints.sheet import sheet_router

View File

@@ -1,45 +0,0 @@
from fastapi import APIRouter, Depends, Request, HTTPException
from fastapi.responses import FileResponse, JSONResponse
from sqlalchemy.orm import Session
from core.config import VERSION, BREAKING_CHANGES
from core.logging import log_error
from db import crud, schemas
from db.database import get_db_dependency, get_db
from web.security import get_user_auth, bearer_security
default_router = APIRouter()
@default_router.get("/")
async def home(request: Request):
# TODO: maybe split into 2 routes: one non authenticated and one authenticated for the groups info only
status = {"version": VERSION, "breakingChanges": BREAKING_CHANGES}
try:
email = await get_user_auth(await bearer_security(request))
with get_db() as db:
status["groups"] = crud.get_user_groups(db, email)
except HTTPException: pass # not authenticated is fine
except Exception as e: log_error(e)
return JSONResponse(status)
@default_router.get("/health")
async def health():
return JSONResponse({"status": "ok"})
@default_router.get("/user/active", summary="Check if the user is active and can use the tool.")
async def active(db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)) -> schemas.ActiveUser:
return {"active": crud.is_active_user(db, email)}
@default_router.get("/groups")
def get_user_groups(db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)) -> list[str]:
return crud.get_user_groups(db, email)
@default_router.get('/favicon.ico', include_in_schema=False)
async def favicon():
return FileResponse("static/favicon.ico")

View File

@@ -1,27 +0,0 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse
from auto_archiver import Metadata
from loguru import logger
import sqlalchemy
from web.security import token_api_key_auth
from db import models, schemas
from worker.main import insert_result_into_db
from core.logging import log_error
interoperability_router = APIRouter(prefix="/interop", tags=["Interoperability endpoints."])
# ----- endpoint to submit data archived elsewhere
@interoperability_router.post("/submit-archive", status_code=201, summary="Submit a manual archive entry, for data that was archived elsewhere.")
def submit_manual_archive(manual: schemas.SubmitManual, auth=Depends(token_api_key_auth)):
result = Metadata.from_json(manual.result)
logger.info(f"MANUAL SUBMIT {result.get_url()} {manual.author_id}")
manual.tags.add("manual")
try:
archive_id = insert_result_into_db(result, manual.tags, manual.public, manual.group_id, manual.author_id, models.generate_uuid())
except sqlalchemy.exc.IntegrityError as e:
log_error(e)
raise HTTPException(status_code=422, detail=f"Cannot insert into DB due to integrity error")
return JSONResponse({"id": archive_id}, status_code=201)

View File

@@ -1,24 +0,0 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse
from loguru import logger
from core.config import ALLOW_ANY_EMAIL
from web.security import get_token_or_user_auth
from db import schemas
from worker.main import create_sheet_task
sheet_router = APIRouter(prefix="/sheet", tags=["Google Spreadsheet operations"])
@sheet_router.post("/archive", status_code=201, summary="Submit a Google Sheet archive request, starts a sheet archiving task.", response_description="task_id for the archiving task.")
def archive_sheet(sheet:schemas.SubmitSheet, email = Depends(get_token_or_user_auth)) -> schemas.Task:
logger.info(f"SHEET TASK for {sheet=}")
if email == ALLOW_ANY_EMAIL:
email = sheet.author_id or "api-endpoint"
sheet.author_id = email
if not sheet.sheet_name and not sheet.sheet_id:
raise HTTPException(status_code=422, detail=f"sheet name or id is required")
task = create_sheet_task.delay(sheet.model_dump_json())
return JSONResponse({"id": task.id}, status_code=201)

View File

@@ -1,60 +0,0 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse
from datetime import datetime
from loguru import logger
from web.security import get_user_auth, get_token_or_user_auth
from sqlalchemy.orm import Session
from db import crud, schemas
from db.database import get_db_dependency
from worker.main import create_archive_task
url_router = APIRouter(prefix="/url", tags=["Single URL operations"])
@url_router.post("/archive", status_code=201, summary="Submit a single URL archive request, starts an archiving task.", response_description="task_id for the archiving task, will match the archive id.")
def archive_url(archive: schemas.ArchiveCreate, email=Depends(get_token_or_user_auth)) -> schemas.Task:
archive.author_id = email
url = archive.url
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {url}")
if type(url) != str or len(url) <= 5:
raise HTTPException(status_code=422, detail=f"Invalid URL received: {url}")
logger.info("creating task")
task = create_archive_task.delay(archive.model_dump_json())
task_response = schemas.Task(id=task.id)
return JSONResponse(task_response.model_dump(), status_code=201)
@url_router.get("/search", summary="Search for archive entries by URL.")
def search_by_url(
url: str, skip: int = 0, limit: int = 25,
archived_after: datetime = None, archived_before: datetime = None,
db: Session = Depends(get_db_dependency),
email=Depends(get_token_or_user_auth)
) -> list[schemas.ArchiveResult]:
return crud.search_archives_by_url(db, url.strip(), email, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before)
@url_router.get("/latest", summary="Fetch latest URL archives for the authenticated user.")
def latest(skip: int = 0, limit: int = 25, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)) -> list[schemas.ArchiveResult]:
return crud.search_archives_by_email(db, email, skip=skip, limit=limit)
@url_router.get("/{id}", summary="Fetch a single URL archive by the associated id.")
def lookup(id, db: Session = Depends(get_db_dependency), email=Depends(get_token_or_user_auth)) -> schemas.ArchiveResult:
archive = crud.get_archive(db, id, email)
if archive is None:
raise HTTPException(status_code=404, detail="Archive not found")
return archive
@url_router.delete("/{id}", summary="Delete a single URL archive by id.")
def delete_task(id, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)) -> schemas.TaskDelete:
logger.info(f"deleting url archive task {id} request by {email}")
return JSONResponse({
"id": id,
"deleted": crud.soft_delete_task(db, id, email)
})

View File

@@ -1,18 +0,0 @@
# email-level group access
users:
email1@example.com:
- group1
- group2
email2@example.com:
- group2
email3@example-no-group.com:
# domain-level group access (taken from the emails)
domains:
example.com:
- group3
orchestrators:
group1: secrets/orchestration-group1.yaml
group2: secrets/orchestration-group2.yaml
default: secrets/orchestration-default.yaml

View File

@@ -1 +0,0 @@
Generic single-database configuration.

Some files were not shown because too many files have changed in this diff Show More