mirror of
https://github.com/bellingcat/auto-archiver-api.git
synced 2026-06-08 03:28:35 +03:00
3
.coveragerc
Normal file
3
.coveragerc
Normal file
@@ -0,0 +1,3 @@
|
||||
[run]
|
||||
omit =
|
||||
app/migrations/*
|
||||
@@ -1,5 +1,5 @@
|
||||
CHROME_APP_IDS='["1234567890"]'
|
||||
ALLOWED_ORIGINS='["allowed"]'
|
||||
BLOCKED_EMAILS='[]'
|
||||
DATABASE_PATH="sqlite:///./auto-archiver.db"
|
||||
DATABASE_PATH="sqlite:///./database/auto-archiver.db"
|
||||
API_BEARER_TOKEN=THIS_API_TOKEN_SHOULD_NEVER_BE_USED
|
||||
38
.env.example
Normal file
38
.env.example
Normal file
@@ -0,0 +1,38 @@
|
||||
# main settings
|
||||
USER_GROUPS_FILENAME=app/user-groups.yaml
|
||||
# database
|
||||
DATABASE_PATH="sqlite:///./database/auto-archiver.db"
|
||||
DATABASE_QUERY_LIMIT=100
|
||||
|
||||
# security settings
|
||||
API_BEARER_TOKEN=TODO-MODIFY-THIS-API-TOKEN
|
||||
ALLOWED_ORIGINS='["http://localhost:8000","http://localhost:8004","http://localhost:8081","https://auto-archiver.bellingcat.com"]'
|
||||
CHROME_APP_IDS='[PROJECT_ID.apps.googleusercontent.com"]'
|
||||
BLOCKED_EMAILS='[]'
|
||||
# redis configuration
|
||||
REDIS_PASSWORD=TODO-MODIFY-THIS-REDIS-PASSWORD
|
||||
REDIS_HOSTNAME="localhost"
|
||||
|
||||
# cronjobs management, enable as needed
|
||||
CRON_ARCHIVE_SHEETS=true
|
||||
CRON_DELETE_STALE_SHEETS=true
|
||||
DELETE_STALE_SHEETS_DAYS=7
|
||||
CRON_DELETE_SCHEDULED_ARCHIVES=false
|
||||
DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS=14
|
||||
|
||||
# observability for prometheus
|
||||
REPEAT_COUNT_METRICS_SECONDS=30
|
||||
|
||||
# mail service settings, if you want to email users
|
||||
MAIL_FROM="noreply@auto-archiver.com"
|
||||
MAIL_FROM_NAME="My Auto Archiver deployment"
|
||||
MAIL_USERNAME="USERNAME"
|
||||
MAIL_PASSWORD="PASSWORD"
|
||||
MAIL_SERVER="mail.server.com"
|
||||
MAIL_PORT=587
|
||||
MAIL_STARTTLS=False
|
||||
MAIL_SSL_TLS=True
|
||||
|
||||
|
||||
# celery workers config
|
||||
CONCURRENCY=2
|
||||
@@ -5,5 +5,4 @@ BLOCKED_EMAILS='["blocked@example.com"]'
|
||||
|
||||
DATABASE_PATH="sqlite:///auto-archiver.test.db"
|
||||
API_BEARER_TOKEN=this_is_the_test_api_token
|
||||
USER_GROUPS_FILENAME=tests/user-groups.test.yaml
|
||||
SHEET_ORCHESTRATION_YAML=tests/orchestration.test.yaml
|
||||
USER_GROUPS_FILENAME=app/tests/user-groups.test.yaml
|
||||
@@ -1 +0,0 @@
|
||||
REDIS_PASSWORD=TODO
|
||||
21
.github/workflows/ci.yml
vendored
21
.github/workflows/ci.yml
vendored
@@ -21,25 +21,24 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install pipenv
|
||||
run: pip install pipenv
|
||||
working-directory: src
|
||||
- name: Install Poetry
|
||||
run: pipx install poetry
|
||||
|
||||
- name: Install dependencies
|
||||
run: pipenv install --dev
|
||||
working-directory: src
|
||||
run: poetry install --no-interaction --with dev
|
||||
|
||||
- name: Set dev environment variable
|
||||
run: echo "ENVIRONMENT_FILE=.env.test" >> $GITHUB_ENV
|
||||
|
||||
- name: Run tests with coverage
|
||||
run: PYTHONPATH=. PIPENV_DOTENV_LOCATION=.env.test pipenv run coverage run -m pytest -v --color=yes tests/
|
||||
working-directory: src
|
||||
run: poetry run coverage run -m pytest -v -ra --color=yes app/tests/
|
||||
|
||||
- name: Report coverage
|
||||
run: pipenv run coverage report
|
||||
working-directory: src
|
||||
run: poetry run coverage report
|
||||
22
.gitignore
vendored
22
.gitignore
vendored
@@ -1,26 +1,32 @@
|
||||
user-groups.dev.yaml
|
||||
user-groups.yaml
|
||||
orchestration.yaml
|
||||
my-archives
|
||||
*.pyc
|
||||
.DS_Store
|
||||
secrets
|
||||
secrets/*
|
||||
*.log
|
||||
__pycache
|
||||
.pytest_cach
|
||||
__pycache__
|
||||
.pytest_cache
|
||||
.env
|
||||
.env.dev
|
||||
.env.prod
|
||||
*.db
|
||||
redis/data/*
|
||||
.ipynb_checkpoints*
|
||||
src/user-groups.yaml
|
||||
src/user-groups.dev.yaml
|
||||
app/user-groups.yaml
|
||||
app/user-groups.dev.yaml
|
||||
wit*
|
||||
src/crawls
|
||||
app/crawls
|
||||
.coverage
|
||||
.pytest_cache/*
|
||||
.pytest_cache/
|
||||
htmlcov
|
||||
local_archive
|
||||
local_archive_test
|
||||
*db-wal
|
||||
*db-shm
|
||||
copy-files.sh
|
||||
copy-files.sh
|
||||
temp/
|
||||
.python-version
|
||||
orchestration2.yaml
|
||||
database
|
||||
2
LICENSE
2
LICENSE
@@ -1,6 +1,6 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Stichting Bellingcat
|
||||
Copyright (c) 2025 Stichting Bellingcat
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
||||
13
Makefile
13
Makefile
@@ -3,15 +3,20 @@ clean-dev:
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml down --volumes --remove-orphans
|
||||
|
||||
dev:
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml build
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up --remove-orphans
|
||||
docker compose --env-file .env.dev -f docker-compose.yml -f docker-compose.dev.yml build
|
||||
docker compose --env-file .env.dev -f docker-compose.yml -f docker-compose.dev.yml up --remove-orphans
|
||||
|
||||
|
||||
dev-redis-only:
|
||||
docker compose --env-file .env.dev -f docker-compose.yml -f docker-compose.dev.yml build redis
|
||||
docker compose --env-file .env.dev -f docker-compose.yml -f docker-compose.dev.yml up --remove-orphans redis
|
||||
|
||||
stop-dev:
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml down --volumes
|
||||
|
||||
prod:
|
||||
docker compose build
|
||||
docker compose up -d --remove-orphans
|
||||
docker compose --env-file .env.prod build
|
||||
docker compose --env-file .env.prod up -d --remove-orphans
|
||||
docker buildx prune --keep-storage 20gb -f
|
||||
docker image prune -f
|
||||
docker system df
|
||||
|
||||
163
README.md
163
README.md
@@ -1,89 +1,120 @@
|
||||
# Auto Archiver API
|
||||
|
||||
An api that uses celery workers to process URL archive requests via [bellingcat/auto-archiver](https://github.com/bellingcat/auto-archiver), it allows authentication via Google OAuth Apps and enables CORS, everything runs on docker but development can be done without docker (except for redis).
|
||||
[](https://github.com/bellingcat/auto-archiver-api/actions/workflows/ci.yaml)
|
||||
|
||||
A web API that uses celery workers to process URL archive requests via [bellingcat/auto-archiver](https://github.com/bellingcat/auto-archiver), it allows authentication via Google OAuth Apps and enables CORS, everything runs on docker but development can be done without docker (except for redis).
|
||||
|
||||

|
||||
|
||||
## setup
|
||||
To properly set up the API you need to install `docker` and to edit 3 files:
|
||||
1. a `.env.prod` and `.env.dev` to configure the API, stays at the root level
|
||||
2. a `user-groups.yaml` to manage user permissions
|
||||
1. note that all local files referenced in `user-groups.yaml` and any orchestration.yaml files should be relative to the home directory so if your service account is in `secrets/orchestration.yaml` use that path and not just `orchestration.yaml`.
|
||||
|
||||
Do not commit those files, they are .gitignored by default.
|
||||
We also advise you to keep any sensitive files in the `secrets/` folder which is pinned and gitignored.
|
||||
|
||||
|
||||
## Development
|
||||
http://localhost:8004
|
||||
We have examples for both of those files (`.env.example` and `user-groups.example.yaml`), and here's how to set them up whether you're in development or production:
|
||||
|
||||
TODO: update .env file instructions, should use .env.prod and .env.dev and only use .env for always overwriting dev/prod settings.
|
||||
### setup for DEVELOPMENT
|
||||
```bash
|
||||
# copy and modify the .env.dev file according to your needs
|
||||
cp .env.example .env.dev
|
||||
# copy the user-groups.example.yaml and modify it accordingly
|
||||
cp user-groups.example.yaml user-groups.dev.yaml
|
||||
# run the APP, make sure VPNs are off
|
||||
make dev
|
||||
# check it's running by calling the health endpoint
|
||||
curl 'http://localhost:8004/health'
|
||||
# > {"status":"ok"}
|
||||
```
|
||||
now go to http://localhost:8004/docs#/ and you should see the API documentation
|
||||
|
||||
requires `src/.env`
|
||||
### setup for PRODUCTION
|
||||
```bash
|
||||
# copy and modify the .env.prod file according to your needs
|
||||
cp .env.example .env.prod
|
||||
# copy the user-groups.example.yaml and modify it accordingly
|
||||
cp user-groups.example.yaml user-groups.yaml
|
||||
# deploy the app
|
||||
make prod
|
||||
# check it's running by calling the health endpoint
|
||||
curl 'http://localhost:8004/health'
|
||||
# > {"status":"ok"}
|
||||
```
|
||||
now go to http://localhost:8004/docs#/ and you should see the API documentation
|
||||
|
||||
## User, Domains, Groups, and permissions management
|
||||
there are 2 ways to access the API
|
||||
1. via an API token which has full control/privileges to archive/search
|
||||
2. via a Google Auth token which goes through the user access model
|
||||
|
||||
#### User access model
|
||||
The permissions are defined solely via the `user-groups.yaml` file
|
||||
- users belong to groups which determine their access level/quotas/orchestration setup
|
||||
- users are assigned to groups explicitly (via email)
|
||||
- users are assigned to groups implicitly (via email domains) as domains can be associated to groups
|
||||
- users that are not explicitly or implicitly in the system belong to the `default` group, restrict their permissions if you do not wish them to be able to search/archive
|
||||
- if a user is assigned to one group which is not explicitly defined, a warning will be thrown, it may be necessary to do that if you discontinue a given group but the database still has entries for it and so
|
||||
- groups determine
|
||||
- which orchestrator to use for single URL archives and for spreadsheet archives see [GroupPermissions](app/shared/user_groups.py)
|
||||
- a set of permissions
|
||||
- `read` can be [`all`], [] or a comma separated list of group names, meaning people in this group can access either all, none, or those belonging to explicitly listed groups.
|
||||
- the group itself must be included in the list, otherwise the user cannot search archives of that group
|
||||
- `read_public` a boolean that enables the user to search public archives
|
||||
- `archive_url` a boolean that enables the user to archive links in this group
|
||||
- `archive_sheet` a boolean that enables the user to archive spreadsheets
|
||||
- `manually_trigger_sheet` a boolean that enables the user to manually trigger a sheet archive for sheets in this group
|
||||
- `sheet_frequency` a list of options for the sheet archiving frequency, currently max permissions is `["hourly", "daily"]`
|
||||
- `max_sheets` defines the maximum amount of spreadsheets someone can have in total (`-1` means no limit)
|
||||
- `max_archive_lifespan_months` defines the lifespan of an archive before being deleted from S3, users will be notified 1 month in advance with instructions to download TODO
|
||||
- `max_monthly_urls` how many total URLs someone can archive per month (`-1` means no limit)
|
||||
- `max_monthly_mbs` how many MBs of data someone can archive per month (`-1` means no limit)
|
||||
- `priority` one of `high` or `low`, this will be used to give archiving priority
|
||||
- group names are all lower-case
|
||||
|
||||
|
||||
## development of web/worker without docker
|
||||
|
||||
cd /src
|
||||
<!-- * `pipenv install --editable ../../auto-archiver` -->
|
||||
* console 1 - `docker compose up redis` optionally add `web` if not running uvicorn locally
|
||||
* console 2 - `pipenv shell` + `celery worker --app=worker.celery --loglevel=info --logfile=logs/celery_dev.log`
|
||||
* `celery --app=worker.celery worker --loglevel=info --logfile=logs/celery_dev.log` celery 5
|
||||
* or with watchdog for dev auto-reload `watchmedo auto-restart -d ./ -- celery --app=worker.celery worker --loglevel=info --logfile=logs/celery_dev.log`
|
||||
* console 3 - `pipenv shell` + `uvicorn main:app --host 0.0.0.0 --reload`
|
||||
orchestration must be from the console(?)
|
||||
* turn off VPNs if connection to docker is not working
|
||||
We advise you to use `make prod` but you can also spin up redis and run the API (uvicorn) and worker (celery) individually like so:
|
||||
* console 1 - `make dev-redis-only` to spin up redis, turn off any VPNs
|
||||
* console 2 - `export ENVIRONMENT_FILE=.env.dev` then `poetry run celery --app=app.worker.main.celery worker --loglevel=debug --logfile=/aa-api/logs/celery.log -Q high_priority,low_priority --concurrency=1`
|
||||
* or with watchdog for dev auto-reload `watchmedo auto-restart --patterns="*.py" --recursive --ignore-directories -- celery -- --app=app.worker.main.celery worker --loglevel=debug --logfile=/aa-api/logs/celery.log -Q high_priority,low_priority --concurrency=1`
|
||||
* console 3 - `export ENVIRONMENT_FILE=.env.dev` then `poetry run uvicorn main:app --host 0.0.0.0 --reload`
|
||||
|
||||
## User management
|
||||
Copy [example.user-groups.yaml](src/example.user-groups.yaml) into a new file and set the environment variable `USER_GROUPS_FILENAME` to that filename (defaults to `user-groups.yaml`).
|
||||
|
||||
This file contains 2 parts user-groups specifications. Each user can archive URLs publicly, privately, or privately for a group so long as they are declared as part of that group. In the example bellow `email1` has 2 groups while `email3` has none.
|
||||
```yaml
|
||||
users:
|
||||
email1@example.com:
|
||||
- group1
|
||||
- group2
|
||||
email2@example.com:
|
||||
- group2
|
||||
email3@example-no-group.com:
|
||||
```
|
||||
|
||||
Auto-archiver orchestrator files configurations. For each archiving task an orchestrator is chosen, either from a specified group (if group-level visibility) or the first group the user is assigned to in the above file or the `default` orchestration file which is a required config.
|
||||
```yaml
|
||||
orchestrators:
|
||||
group1: secrets/orchestration-group1.yaml
|
||||
group2: secrets/orchestration-group2.yaml
|
||||
default: secrets/orchestration-default:orchestration.yaml
|
||||
```
|
||||
|
||||
## Database migrations
|
||||
check https://alembic.sqlalchemy.org/en/latest/tutorial.html#the-migration-environment
|
||||
|
||||
* create migrations with `alembic revision -m "create account table"`
|
||||
* if running in the normal pipenv environment use `PIPENV_DOTENV_LOCATION=.env.alembic pipenv run` followed by:
|
||||
* migrate to most recent with `alembic upgrade head`
|
||||
* downgrade with `alembic downgrade -1`
|
||||
```bash
|
||||
# set the env variables
|
||||
export ENVIRONMENT_FILE=.env.alembic
|
||||
# create a new migration with description in app/migrations
|
||||
poetry run alembic revision -m "create account table"
|
||||
# perform all migrations
|
||||
poetry run alembic upgrade head
|
||||
# downgrade by one migration
|
||||
poetry run alembic downgrade -1
|
||||
```
|
||||
|
||||
## Release
|
||||
Update `main.py:VERSION`.
|
||||
Update the version in [config.py](app/web/config.py)
|
||||
|
||||
Copy `.env` and `src/.env` to deployment, along with the contents of `secrets/` including `secrets/orchestration.yaml`.
|
||||
Make sure environment and user-groups files are up to date.
|
||||
|
||||
Then `make prod`.
|
||||
|
||||
#### updating packages/app/access
|
||||
If pipenv packages are updated: `make prod` to build images with new packages.
|
||||
|
||||
New users should be added to the `src/.env` file `ALLOWED_EMAILS` prop.
|
||||
|
||||
Run `pipenv update auto-archiver` inside `src` to update the auto-archiver version being used, then test with `make dev`.
|
||||
|
||||
|
||||
```bash
|
||||
# CALL /sheet POST endpoint
|
||||
curl -XPOST -H "Authorization: Bearer GOOGLE_OAUTH_TOKEN" -H "Content-type: application/json" -d '{"sheet_id": "SHEET_ID", "header": 1}' 'http://localhost:8004/sheet'
|
||||
|
||||
```
|
||||
|
||||
|
||||
### Testing
|
||||
```bash
|
||||
# can be done from top level but let's do it from the src folder for consistency with CI etc
|
||||
cd src
|
||||
# set the testing environment variables
|
||||
export ENVIRONMENT_FILE=.env.test
|
||||
# run tests and generate coverage
|
||||
PYTHONPATH=. PIPENV_DOTENV_LOCATION=.env.test pipenv run coverage run -m pytest -vv --disable-warnings --color=yes tests/ && pipenv run coverage html
|
||||
|
||||
poetry run coverage run -m pytest -vv --disable-warnings --color=yes app/tests/
|
||||
# get coverage report in command line
|
||||
pipenv run coverage report
|
||||
|
||||
# get coverage HTML
|
||||
pipenv run coverage html
|
||||
|
||||
# > open/run server on htmlcov/index.html to navigate through line coverage
|
||||
```
|
||||
poetry run coverage report
|
||||
# get coverage report in HTML format
|
||||
poetry run coverage html
|
||||
```
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
script_location = migrations
|
||||
script_location = app/migrations
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
@@ -1,11 +1,10 @@
|
||||
from logging.config import fileConfig
|
||||
import os
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
|
||||
from alembic import context
|
||||
|
||||
from shared.settings import get_settings
|
||||
from app.shared.settings import get_settings
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
@@ -0,0 +1,34 @@
|
||||
"""create archives.store_until column
|
||||
|
||||
Revision ID: 02b2f6d17ed0
|
||||
Revises: 1636724ec4b1
|
||||
Create Date: 2025-02-08 15:22:20.392522
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '02b2f6d17ed0'
|
||||
down_revision = '1636724ec4b1'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
STORE_UNTIL_COL = "store_until"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('archives')]
|
||||
|
||||
if STORE_UNTIL_COL not in columns:
|
||||
op.add_column('archives', sa.Column(STORE_UNTIL_COL, sa.DateTime(), nullable=True, default=None))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('archives')]
|
||||
if STORE_UNTIL_COL in columns:
|
||||
op.drop_column('archives', STORE_UNTIL_COL)
|
||||
@@ -0,0 +1,32 @@
|
||||
"""rename sheets last_archived col
|
||||
|
||||
Revision ID: 1636724ec4b1
|
||||
Revises: a23aaf3ae930
|
||||
Create Date: 2025-02-05 19:19:01.984396
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '1636724ec4b1'
|
||||
down_revision = 'a23aaf3ae930'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('sheets')]
|
||||
if 'last_archived_at' in columns:
|
||||
op.alter_column('sheets', 'last_archived_at', new_column_name='last_url_archived_at')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('sheets')]
|
||||
if 'last_url_archived_at' in columns:
|
||||
op.alter_column('sheets', 'last_url_archived_at', new_column_name='last_archived_at')
|
||||
@@ -0,0 +1,36 @@
|
||||
"""add new service_account_email column to groups table
|
||||
|
||||
Revision ID: 63ac79df4ad0
|
||||
Revises: 02b2f6d17ed0
|
||||
Create Date: 2025-02-11 21:53:23.293274
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '63ac79df4ad0'
|
||||
down_revision = '02b2f6d17ed0'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
NEW_COL = "service_account_email"
|
||||
TABLE = "groups"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns(TABLE)]
|
||||
|
||||
if NEW_COL not in columns:
|
||||
op.add_column(TABLE, sa.Column(NEW_COL, sa.String, nullable=True, default=None))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns(TABLE)]
|
||||
if NEW_COL in columns:
|
||||
op.drop_column(TABLE, NEW_COL)
|
||||
@@ -0,0 +1,42 @@
|
||||
"""add sheet_id to archive table
|
||||
|
||||
Revision ID: 89121d2c96d8
|
||||
Revises: fa012ec405b8
|
||||
Create Date: 2024-11-04 11:12:30.237299
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '89121d2c96d8'
|
||||
down_revision = 'fa012ec405b8'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('archives')]
|
||||
|
||||
if 'sheet_id' not in columns:
|
||||
with op.batch_alter_table('archives') as batch_op:
|
||||
batch_op.add_column(sa.Column('sheet_id', sa.String(), nullable=True, default=None))
|
||||
batch_op.create_foreign_key('fk_sheet_id', 'sheets', ['sheet_id'], ['id'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
foreign_keys = [fk['name'] for fk in inspector.get_foreign_keys('archives')]
|
||||
columns = [col['name'] for col in inspector.get_columns('archives')]
|
||||
|
||||
with op.batch_alter_table('archives') as batch_op:
|
||||
if 'fk_sheet_id' in foreign_keys:
|
||||
batch_op.drop_constraint('fk_sheet_id', type_='foreignkey')
|
||||
|
||||
if 'sheet_id' in columns:
|
||||
batch_op.drop_column('sheet_id')
|
||||
34
app/migrations/versions/a23aaf3ae930_drop_active_column.py
Normal file
34
app/migrations/versions/a23aaf3ae930_drop_active_column.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""drop active column
|
||||
|
||||
Revision ID: a23aaf3ae930
|
||||
Revises: 89121d2c96d8
|
||||
Create Date: 2025-02-04 12:19:20.753570
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'a23aaf3ae930'
|
||||
down_revision = '89121d2c96d8'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('users')]
|
||||
|
||||
if 'is_active' in columns:
|
||||
op.drop_column('users', 'is_active')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('users')]
|
||||
|
||||
if 'is_active' not in columns:
|
||||
op.add_column('users', sa.Column('is_active', sa.Boolean(), nullable=False, server_default=sa.false()))
|
||||
@@ -19,7 +19,7 @@ depends_on = None
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = Inspector.from_engine(conn)
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('groups')]
|
||||
|
||||
if 'description' not in columns:
|
||||
@@ -35,8 +35,11 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('groups', 'description')
|
||||
op.drop_column('groups', 'orchestrator')
|
||||
op.drop_column('groups', 'orchestrator_sheet')
|
||||
op.drop_column('groups', 'permissions')
|
||||
op.drop_column('groups', 'domains')
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('groups')]
|
||||
|
||||
column_names = ['description', 'orchestrator', 'orchestrator_sheet', 'permissions', 'domains']
|
||||
for column_name in column_names:
|
||||
if column_name in columns:
|
||||
op.drop_column('groups', column_name)
|
||||
32
app/shared/aa_utils.py
Normal file
32
app/shared/aa_utils.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# TODO: code in this file should eventually be moved to the auto-archiver code base
|
||||
|
||||
from typing import List
|
||||
from loguru import logger
|
||||
from auto_archiver.core import Media, Metadata
|
||||
|
||||
from app.shared.db import models
|
||||
|
||||
def get_all_urls(result: Metadata) -> List[models.ArchiveUrl]:
|
||||
db_urls = []
|
||||
for m in result.media:
|
||||
for i, url in enumerate(m.urls): db_urls.append(models.ArchiveUrl(url=url, key=m.get("id", f"media_{i}")))
|
||||
for k, prop in m.properties.items():
|
||||
if prop_converted := convert_if_media(prop):
|
||||
for i, url in enumerate(prop_converted.urls): db_urls.append(models.ArchiveUrl(url=url, key=prop_converted.get("id", f"{k}_{i}")))
|
||||
if isinstance(prop, list):
|
||||
for i, prop_media in enumerate(prop):
|
||||
if prop_media := convert_if_media(prop_media):
|
||||
for j, url in enumerate(prop_media.urls):
|
||||
db_urls.append(models.ArchiveUrl(url=url, key=prop_media.get("id", f"{k}{prop_media.key}_{i}.{j}")))
|
||||
return db_urls
|
||||
|
||||
|
||||
|
||||
def convert_if_media(media):
|
||||
if isinstance(media, Media): return media
|
||||
elif isinstance(media, dict):
|
||||
try: return Media.from_dict(media)
|
||||
except Exception as e:
|
||||
logger.debug(f"error parsing {media} : {e}")
|
||||
return False
|
||||
|
||||
16
app/shared/business_logic.py
Normal file
16
app/shared/business_logic.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# TODO: temporary file for this code, maybe other code belongs here, maybe not. do decide
|
||||
|
||||
|
||||
import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.shared.db import worker_crud
|
||||
|
||||
|
||||
def get_store_archive_until(db: Session, group_id: str) -> datetime.datetime:
|
||||
group = worker_crud.get_group(db, group_id)
|
||||
assert group, f"Group {group_id} not found."
|
||||
max_lifespan = group.permissions.get("max_archive_lifespan_months", -1)
|
||||
if max_lifespan == -1: return None
|
||||
|
||||
return datetime.datetime.now() + datetime.timedelta(days=30 * max_lifespan)
|
||||
74
app/shared/db/database.py
Normal file
74
app/shared/db/database.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from functools import lru_cache
|
||||
from sqlalchemy import Engine, create_engine, event, text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, AsyncEngine, async_sessionmaker
|
||||
|
||||
from app.shared.settings import get_settings
|
||||
|
||||
|
||||
@lru_cache
|
||||
def make_engine(database_url: str):
|
||||
engine = create_engine(
|
||||
database_url,
|
||||
connect_args={"check_same_thread": False},
|
||||
pool_size=15, # Increase pool size
|
||||
max_overflow=20, # Allow more temporary connections
|
||||
pool_recycle=1800 # Recycle connections every 30 minutes
|
||||
)
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def set_sqlite_pragma(conn, _) -> None:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.close()
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
def make_session_local(engine: Engine):
|
||||
session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
return session_local
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db():
|
||||
session = make_session_local(make_engine(get_settings().DATABASE_PATH))()
|
||||
try: yield session
|
||||
finally: session.close()
|
||||
|
||||
|
||||
def get_db_dependency():
|
||||
# to use with Depends and ensure proper session closing
|
||||
with get_db() as db:
|
||||
yield db
|
||||
|
||||
|
||||
def wal_checkpoint():
|
||||
# WAL checkpointing, make sure the .sqlite file receives the latest changes
|
||||
# to be called at startup as it halts writes
|
||||
with get_db() as db:
|
||||
db.execute(text("PRAGMA wal_checkpoint(TRUNCATE)"))
|
||||
|
||||
|
||||
# ASYNC connections
|
||||
async def make_async_engine(database_url: str) -> AsyncEngine:
|
||||
engine = create_async_engine(database_url, connect_args={"check_same_thread": False})
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(lambda sync_conn: sync_conn.execute(text("PRAGMA journal_mode=WAL;")))
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
async def make_async_session_local(engine: AsyncEngine) -> AsyncSession:
|
||||
return async_sessionmaker(engine, expire_on_commit=False, autoflush=False, autocommit=False)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db_async():
|
||||
engine = await make_async_engine(get_settings().ASYNC_DATABASE_PATH)
|
||||
async_session = await make_async_session_local(engine)
|
||||
async with async_session() as session:
|
||||
try: yield session
|
||||
finally: await engine.dispose()
|
||||
@@ -6,9 +6,11 @@ import uuid
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def generate_uuid():
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
# many to many association tables
|
||||
association_table_archive_tags = Table(
|
||||
"mtm_archives_tags",
|
||||
@@ -23,6 +25,7 @@ association_table_user_groups = Table(
|
||||
Column("group_id", ForeignKey("groups.id")),
|
||||
)
|
||||
|
||||
|
||||
# data model tables
|
||||
class Archive(Base):
|
||||
__tablename__ = "archives"
|
||||
@@ -30,18 +33,22 @@ class Archive(Base):
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
url = Column(String, index=True)
|
||||
result = Column(JSON, default=None)
|
||||
public = Column(Boolean, default=True) # if public=false, access to group and author
|
||||
public = Column(Boolean, default=True) # if public=false, access by group and author
|
||||
deleted = Column(Boolean, default=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
store_until = Column(DateTime(timezone=True), default=None)
|
||||
|
||||
group_id = Column(String, ForeignKey("groups.id"), default=None)
|
||||
author_id = Column(String, ForeignKey("users.email"))
|
||||
sheet_id = Column(String, ForeignKey("sheets.id"), default=None)
|
||||
|
||||
tags = relationship("Tag", back_populates="archives", secondary=association_table_archive_tags)
|
||||
group = relationship("Group", back_populates="archives")
|
||||
author = relationship("User", back_populates="archives")
|
||||
urls = relationship("ArchiveUrl", back_populates="archive")
|
||||
sheet = relationship("Sheet", back_populates="archives")
|
||||
|
||||
|
||||
class ArchiveUrl(Base):
|
||||
__tablename__ = "archive_urls"
|
||||
@@ -61,15 +68,17 @@ class Tag(Base):
|
||||
|
||||
archives = relationship("Archive", back_populates="tags", secondary=association_table_archive_tags)
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
email = Column(String, primary_key=True, index=True)
|
||||
is_active = Column(Boolean, default=False)
|
||||
|
||||
archives = relationship("Archive", back_populates="author")
|
||||
sheets = relationship("Sheet", back_populates="author")
|
||||
groups = relationship("Group", back_populates="users", secondary=association_table_user_groups)
|
||||
|
||||
|
||||
class Group(Base):
|
||||
__tablename__ = "groups"
|
||||
|
||||
@@ -77,8 +86,29 @@ class Group(Base):
|
||||
description = Column(String, default=None)
|
||||
orchestrator = Column(String, default=None)
|
||||
orchestrator_sheet = Column(String, default=None)
|
||||
permissions = Column(JSON, default=None)
|
||||
permissions = Column(JSON, default={})
|
||||
service_account_email = Column(String, default=None)
|
||||
domains = Column(JSON, default=[])
|
||||
|
||||
archives = relationship("Archive", back_populates="group")
|
||||
users = relationship("User", back_populates="groups", secondary=association_table_user_groups)
|
||||
sheets = relationship("Sheet", back_populates="group")
|
||||
users = relationship("User", back_populates="groups", secondary=association_table_user_groups)
|
||||
|
||||
|
||||
class Sheet(Base):
|
||||
__tablename__ = "sheets"
|
||||
|
||||
id = Column(String, primary_key=True, index=True, doc="Google Sheet ID")
|
||||
name = Column(String, default=None)
|
||||
author_id = Column(String, ForeignKey("users.email"))
|
||||
group_id = Column(String, ForeignKey("groups.id"), doc="Group ID, user must be in a group to create a sheet.")
|
||||
frequency = Column(String, default="daily", doc="Frequency of archiving: hourly, daily, weekly.")
|
||||
# TODO: stats is not being used, consider removing
|
||||
stats = Column(JSON, default={}, doc="Sheet statistics like total links, total rows, ...")
|
||||
last_url_archived_at = Column(DateTime(timezone=True), server_default=func.now(), doc="Last time a new link was archived.")
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
group = relationship("Group", back_populates="sheets")
|
||||
author = relationship("User", back_populates="sheets")
|
||||
archives = relationship("Archive", back_populates="sheet")
|
||||
60
app/shared/db/worker_crud.py
Normal file
60
app/shared/db/worker_crud.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from datetime import datetime
|
||||
|
||||
from app.shared.db import models
|
||||
from app.shared import schemas
|
||||
|
||||
# TODO: isolate database operations away from worker and into WEB
|
||||
# ONLY WORKER
|
||||
def update_sheet_last_url_archived_at(db: Session, sheet_id: str):
|
||||
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id).first()
|
||||
if db_sheet:
|
||||
db_sheet.last_url_archived_at = datetime.now()
|
||||
db.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# ONLY WORKER and INTEROP
|
||||
|
||||
def get_group(db: Session, group_name: str) -> models.Group:
|
||||
return db.query(models.Group).filter(models.Group.id == group_name).first()
|
||||
|
||||
def create_or_get_user(db: Session, author_id: str) -> models.User:
|
||||
if type(author_id) == str: author_id = author_id.lower()
|
||||
db_user = db.query(models.User).filter(models.User.email == author_id).first()
|
||||
if not db_user:
|
||||
db_user = models.User(email=author_id)
|
||||
db.add(db_user)
|
||||
db.commit()
|
||||
db.refresh(db_user)
|
||||
return db_user
|
||||
|
||||
|
||||
def create_tag(db: Session, tag: str) -> models.Tag:
|
||||
db_tag = db.query(models.Tag).filter(models.Tag.id == tag).first()
|
||||
if not db_tag:
|
||||
db_tag = models.Tag(id=tag)
|
||||
db.add(db_tag)
|
||||
db.commit()
|
||||
db.refresh(db_tag)
|
||||
return db_tag
|
||||
|
||||
|
||||
def create_archive(db: Session, archive: schemas.ArchiveCreate, tags: list[models.Tag], urls: list[models.ArchiveUrl]) -> models.Archive:
|
||||
db_archive = models.Archive(id=archive.id, url=archive.url, result=archive.result, public=archive.public, author_id=archive.author_id, group_id=archive.group_id, sheet_id=archive.sheet_id, store_until=archive.store_until)
|
||||
db_archive.tags = tags
|
||||
db_archive.urls = urls
|
||||
db.add(db_archive)
|
||||
db.commit()
|
||||
db.refresh(db_archive)
|
||||
return db_archive
|
||||
|
||||
|
||||
def store_archived_url(db: Session, archive: schemas.ArchiveCreate) -> models.Archive:
|
||||
# create and load user, tags, if needed
|
||||
create_or_get_user(db, archive.author_id)
|
||||
db_tags = [create_tag(db, tag) for tag in (archive.tags or [])]
|
||||
# insert everything
|
||||
db_archive = create_archive(db, archive=archive, tags=db_tags, urls=archive.urls)
|
||||
return db_archive
|
||||
13
app/shared/log.py
Normal file
13
app/shared/log.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import traceback
|
||||
from loguru import logger
|
||||
|
||||
|
||||
# logging configurations
|
||||
logger.add("logs/api_logs.log", retention="30 days")
|
||||
logger.add("logs/error_logs.log", retention="30 days", level="ERROR")
|
||||
|
||||
|
||||
def log_error(e: Exception, traceback_str: str = None, extra:str = ""):
|
||||
if not traceback_str: traceback_str = traceback.format_exc()
|
||||
if extra: extra = f"{extra}\n"
|
||||
logger.error(f"{extra}{e.__class__.__name__}: {e}\n{traceback_str}")
|
||||
100
app/shared/schemas.py
Normal file
100
app/shared/schemas.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from typing import Annotated
|
||||
from annotated_types import Len
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class SubmitSheet(BaseModel):
|
||||
sheet_id: str | None
|
||||
author_id: str | None = None
|
||||
group_id: str = "default"
|
||||
tags: set[str] | None = set()
|
||||
|
||||
class ArchiveUrl(BaseModel):
|
||||
url: str
|
||||
public: bool = False
|
||||
author_id: str | None
|
||||
group_id: str | None
|
||||
tags: set[str] | None = set()
|
||||
|
||||
class ArchiveResult(BaseModel):
|
||||
id: str
|
||||
url: str
|
||||
result: dict
|
||||
created_at: datetime
|
||||
store_until: datetime | None
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
class TaskResult(Task):
|
||||
status: str
|
||||
result: str
|
||||
|
||||
|
||||
class DeleteResponse(Task):
|
||||
deleted: bool
|
||||
|
||||
|
||||
class ActiveUser(BaseModel):
|
||||
active: bool
|
||||
|
||||
|
||||
class SheetAdd(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
group_id: str
|
||||
frequency: str
|
||||
|
||||
|
||||
class SheetResponse(SheetAdd):
|
||||
author_id: str
|
||||
created_at: datetime
|
||||
last_url_archived_at: datetime | None
|
||||
|
||||
|
||||
class ArchiveTrigger(BaseModel):
|
||||
author_id: str | None = None
|
||||
url: Annotated[str, Len(min_length=5)]
|
||||
public: bool = False
|
||||
group_id: Annotated[str, Len(min_length=1)] = "default"
|
||||
tags: set[str] | None = None
|
||||
|
||||
|
||||
class ArchiveCreate(ArchiveTrigger):
|
||||
id: str | None = None
|
||||
result: dict | None = None
|
||||
sheet_id: str | None = None
|
||||
urls: list | None = None
|
||||
store_until: datetime | None = None
|
||||
|
||||
|
||||
class Archive(ArchiveCreate):
|
||||
created_at: datetime
|
||||
updated_at: datetime | None
|
||||
deleted: bool
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
monthly_urls: int = 0
|
||||
monthly_mbs: int = 0
|
||||
total_sheets: int = 0
|
||||
|
||||
|
||||
class UsageResponse(Usage):
|
||||
groups: dict[str, Usage]
|
||||
|
||||
|
||||
class CelerySheetTask(BaseModel):
|
||||
success: bool
|
||||
sheet_id: str
|
||||
time: datetime
|
||||
stats: dict
|
||||
|
||||
|
||||
class SubmitManualArchive(ArchiveTrigger):
|
||||
result: str # should be a Metadata.to_json()
|
||||
76
app/shared/settings.py
Normal file
76
app/shared/settings.py
Normal file
@@ -0,0 +1,76 @@
|
||||
|
||||
from functools import lru_cache
|
||||
import os
|
||||
from fastapi_mail import ConnectionConfig
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from typing import Annotated, Set
|
||||
from annotated_types import Len
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
|
||||
model_config = SettingsConfigDict(env_file=os.environ.get("ENVIRONMENT_FILE") , env_file_encoding='utf-8', extra='ignore', str_strip_whitespace=True)
|
||||
|
||||
# general
|
||||
SERVE_LOCAL_ARCHIVE: str | None = None
|
||||
USER_GROUPS_FILENAME: str = "app/user-groups.yaml"
|
||||
|
||||
# database
|
||||
DATABASE_PATH: str
|
||||
DATABASE_QUERY_LIMIT: int = 100
|
||||
@property
|
||||
def ASYNC_DATABASE_PATH(self) -> str:
|
||||
return self.DATABASE_PATH.replace("sqlite://", "sqlite+aiosqlite://")
|
||||
|
||||
# security
|
||||
API_BEARER_TOKEN: Annotated[str, Len(min_length=20)]
|
||||
ALLOWED_ORIGINS: Annotated[Set[str], Len(min_length=1)]
|
||||
CHROME_APP_IDS: Annotated[Set[Annotated[str, Len(min_length=10)]], Len(min_length=1)]
|
||||
BLOCKED_EMAILS: Annotated[Set[str], Len(min_length=0)] = set()
|
||||
|
||||
# redis
|
||||
REDIS_PASSWORD: str = ""
|
||||
REDIS_HOSTNAME: str = "localhost"
|
||||
REDIS_EXCEPTIONS_CHANNEL: str = "exceptions-channel"
|
||||
@property
|
||||
def CELERY_BROKER_URL(self)-> str:
|
||||
if self.REDIS_PASSWORD:
|
||||
return f"redis://:{self.REDIS_PASSWORD}@{self.REDIS_HOSTNAME}:6379"
|
||||
return f"redis://{self.REDIS_HOSTNAME}:6379"
|
||||
|
||||
# cronjobs
|
||||
CRON_ARCHIVE_SHEETS: bool = False
|
||||
CRON_DELETE_STALE_SHEETS: bool = False
|
||||
DELETE_STALE_SHEETS_DAYS: int = 14
|
||||
CRON_DELETE_SCHEDULED_ARCHIVES: bool = False
|
||||
DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS: int = 7
|
||||
|
||||
# observability
|
||||
REPEAT_COUNT_METRICS_SECONDS: int = 30
|
||||
|
||||
# email configuration, if needed
|
||||
MAIL_FROM: str = "noreply@bellingcat.com"
|
||||
MAIL_FROM_NAME: str = "Bellingcat's Auto Archiver"
|
||||
MAIL_USERNAME: str = ""
|
||||
MAIL_PASSWORD: str = ""
|
||||
MAIL_SERVER: str = ""
|
||||
MAIL_PORT: int = 587
|
||||
MAIL_STARTTLS: bool = False
|
||||
MAIL_SSL_TLS: bool = True
|
||||
@property
|
||||
def MAIL_CONFIG(self) -> str:
|
||||
return ConnectionConfig(
|
||||
MAIL_FROM=self.MAIL_FROM,
|
||||
MAIL_FROM_NAME=self.MAIL_FROM_NAME,
|
||||
MAIL_USERNAME=self.MAIL_USERNAME,
|
||||
MAIL_PASSWORD=self.MAIL_PASSWORD,
|
||||
MAIL_SERVER=self.MAIL_SERVER,
|
||||
MAIL_PORT=self.MAIL_PORT,
|
||||
MAIL_STARTTLS=self.MAIL_STARTTLS,
|
||||
MAIL_SSL_TLS=self.MAIL_SSL_TLS,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings():
|
||||
return Settings()
|
||||
23
app/shared/task_messaging.py
Normal file
23
app/shared/task_messaging.py
Normal file
@@ -0,0 +1,23 @@
|
||||
|
||||
from functools import lru_cache
|
||||
from celery import Celery
|
||||
import redis
|
||||
|
||||
from app.shared.settings import get_settings
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_celery(name: str = "") -> Celery:
|
||||
return Celery(
|
||||
name,
|
||||
broker_url=get_settings().CELERY_BROKER_URL,
|
||||
result_backend=get_settings().CELERY_BROKER_URL,
|
||||
broker_connection_retry_on_startup=False,
|
||||
broker_transport_options={
|
||||
'queue_order_strategy': 'priority',
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def get_redis() -> redis.Redis:
|
||||
return redis.Redis.from_url(get_settings().CELERY_BROKER_URL)
|
||||
169
app/shared/user_groups.py
Normal file
169
app/shared/user_groups.py
Normal file
@@ -0,0 +1,169 @@
|
||||
import json
|
||||
import os
|
||||
import yaml
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, computed_field, field_validator, Field, model_validator
|
||||
from typing import Dict, List, Set
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class UserGroups:
|
||||
def __init__(self, filename):
|
||||
user_groups = self.read_yaml(filename)
|
||||
self.validate_and_load(user_groups)
|
||||
|
||||
def read_yaml(self, user_groups_filename):
|
||||
# read yaml safely
|
||||
with open(user_groups_filename) as inf:
|
||||
try:
|
||||
return yaml.safe_load(inf)
|
||||
except yaml.YAMLError as e:
|
||||
logger.error(f"could not open user groups filename {user_groups_filename}: {e}")
|
||||
raise e
|
||||
|
||||
def validate_and_load(self, user_groups):
|
||||
try:
|
||||
configs = UserGroupModel(**user_groups)
|
||||
self.users = configs.users
|
||||
self.domains = configs.domains
|
||||
self.groups = configs.groups
|
||||
except Exception as e:
|
||||
logger.error(f"Validation error: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
class GroupPermissions(BaseModel):
|
||||
read: Set[str] | bool = Field(default_factory=list)
|
||||
read_public: bool = False
|
||||
archive_url: bool = False
|
||||
archive_sheet: bool = False
|
||||
manually_trigger_sheet: bool = False
|
||||
sheet_frequency: Set[str] = Field(default_factory=list)
|
||||
max_sheets: int = 0
|
||||
max_archive_lifespan_months: int = 12
|
||||
max_monthly_urls: int = 0
|
||||
max_monthly_mbs: int = 0
|
||||
priority: str = "low"
|
||||
|
||||
@field_validator('max_sheets', 'max_archive_lifespan_months', 'max_monthly_urls', 'max_monthly_mbs', mode='before')
|
||||
def validate_max_values(cls, v):
|
||||
if v < -1:
|
||||
raise ValueError("max_* values should be positive integers or -1 (for no limit).")
|
||||
return v
|
||||
|
||||
@field_validator('sheet_frequency', mode='before')
|
||||
def validate_sheet_frequency(cls, v):
|
||||
if not v: return []
|
||||
allowed = ["daily", "hourly"]
|
||||
for k in v:
|
||||
if k not in allowed:
|
||||
raise ValueError(f"Invalid sheet frequency: '{k}', expected one of {allowed}")
|
||||
return v
|
||||
|
||||
@field_validator('priority', mode='before')
|
||||
def validate_priority(cls, v):
|
||||
v = v.lower()
|
||||
if v not in ["low", "high"]:
|
||||
raise ValueError("priority must be either 'low' or 'high'.")
|
||||
return v
|
||||
|
||||
|
||||
class GroupModel(BaseModel):
|
||||
description: str
|
||||
orchestrator: str
|
||||
orchestrator_sheet: str
|
||||
permissions: GroupPermissions
|
||||
|
||||
@field_validator('orchestrator', 'orchestrator_sheet', mode='before')
|
||||
def validate_orchestrator(cls, v):
|
||||
if not os.path.exists(v):
|
||||
raise ValueError(f"Orchestrator file not found with this path: {v}")
|
||||
return v
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def service_account_email(self) -> str:
|
||||
if hasattr(self, "_service_account_email"):
|
||||
return self._service_account_email
|
||||
orch = yaml.safe_load(open(self.orchestrator_sheet))
|
||||
|
||||
def find_service_account_email(d):
|
||||
for k, v in d.items():
|
||||
if k == "service_account":
|
||||
return v
|
||||
if isinstance(v, dict):
|
||||
if result := find_service_account_email(v):
|
||||
return result
|
||||
return False
|
||||
|
||||
service_account_json = find_service_account_email(orch)
|
||||
if not service_account_json:
|
||||
raise ValueError(f"service_account key not found in orchestrator sheet file: {self.orchestrator_sheet}.")
|
||||
|
||||
with open(service_account_json) as f:
|
||||
self._service_account_email = json.load(f).get("client_email")
|
||||
|
||||
if not self._service_account_email:
|
||||
raise ValueError(f"Service account email not found in {service_account_json}.")
|
||||
|
||||
return self._service_account_email
|
||||
|
||||
|
||||
class UserGroupModel(BaseModel):
|
||||
users: Dict[str, List[str]] = Field(default_factory=dict)
|
||||
domains: Dict[str, List[str]] = Field(default_factory=dict)
|
||||
groups: Dict[str, GroupModel] = Field(default_factory=dict)
|
||||
|
||||
@field_validator('users', mode='before')
|
||||
@classmethod
|
||||
def validate_emails(cls, v):
|
||||
for email in v.keys():
|
||||
if '@' not in email:
|
||||
raise ValueError(f"Invalid user, it should be an address: {email}")
|
||||
if not v[email]:
|
||||
raise ValueError(f"User {email} has no explicitly listed groups, only include them here if they should be in a group.")
|
||||
# all users belong to the default group
|
||||
return {k.lower().strip(): list(set(["default"] + [g.lower().strip() for g in v])) for k, v in v.items()}
|
||||
|
||||
@field_validator('domains', mode='before')
|
||||
@classmethod
|
||||
def validate_domains(cls, v):
|
||||
for domain, members in v.items():
|
||||
if '.' not in domain:
|
||||
raise ValueError(f"Invalid domain, it should contain a dot: {domain}")
|
||||
if not members:
|
||||
raise ValueError(f"Domain {domain} should have at least one member.")
|
||||
return {k.lower().strip(): list(set([g.lower().strip() for g in v])) for k, v in v.items()}
|
||||
|
||||
@field_validator('groups', mode='before')
|
||||
@classmethod
|
||||
def validate_groups(cls, v):
|
||||
if "default" not in v.keys():
|
||||
raise ValueError("Please include a 'default' group.")
|
||||
if "all" in v.keys():
|
||||
raise ValueError("'all' is a reserved group name.")
|
||||
for group in v.keys():
|
||||
if not group == group.lower():
|
||||
raise ValueError(f"Group names should be lowercase: {group}")
|
||||
return v
|
||||
|
||||
@model_validator(mode='after')
|
||||
def check_groups_consistency(self) -> Self:
|
||||
groups_in_domains = set([g for domain in self.domains for g in self.domains[domain]])
|
||||
groups_in_users = set([g for user in self.users for g in self.users[user]])
|
||||
configured_groups = set(self.groups.keys())
|
||||
|
||||
# groups mentioned in domains and users should be defined, but this is not a ValueError since historical DB data may require it
|
||||
if groups_in_domains - configured_groups:
|
||||
logger.warning(f"These groups are associated to DOMAINS but not defined in the GROUPS section, the domains settings may not work as expected: {groups_in_domains - configured_groups}")
|
||||
if groups_in_users - configured_groups:
|
||||
logger.warning(f"These groups are associated to USERS but not defined in the GROUPS section, the users settings may not work as expected: {groups_in_users - configured_groups}")
|
||||
|
||||
return self
|
||||
|
||||
# for the API return values
|
||||
|
||||
|
||||
class GroupInfo(GroupPermissions):
|
||||
description: str = ""
|
||||
service_account_email: str = ""
|
||||
10
app/shared/utils/misc.py
Normal file
10
app/shared/utils/misc.py
Normal file
@@ -0,0 +1,10 @@
|
||||
|
||||
def fnv1a_hash_mod(s: str, modulo:int) -> int:
|
||||
# receives a string and returns a number in [0:modulo-1], ensures an even distribution over the modulo range
|
||||
hash = 0x811c9dc5 # FNV offset basis
|
||||
fnv_prime = 0x01000193 # FNV prime
|
||||
for char in s:
|
||||
hash ^= ord(char)
|
||||
hash *= fnv_prime
|
||||
hash &= 0xFFFFFFFF # Keep it 32-bit
|
||||
return (hash if hash < 0x80000000 else hash - 0x100000000) % modulo
|
||||
333
app/test.ipynb
Normal file
333
app/test.ipynb
Normal file
File diff suppressed because one or more lines are too long
160
app/tests/conftest.py
Normal file
160
app/tests/conftest.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import os
|
||||
from typing import AsyncGenerator
|
||||
from fastapi.testclient import TestClient
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
import pytest_asyncio
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
from app.shared.settings import Settings
|
||||
from app.web.db.user_state import UserState
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_logger_add():
|
||||
"""Fixture to mock loguru.logger.add for all tests."""
|
||||
with patch('loguru.logger.add') as mock_add:
|
||||
yield mock_add # This makes the mock available to tests
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def get_settings():
|
||||
return Settings(_env_file=".env.test")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_settings():
|
||||
with patch('app.shared.settings.Settings', return_value=Settings(_env_file=".env.test")) as mock_settings:
|
||||
yield mock_settings
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_db(get_settings: Settings):
|
||||
from app.shared.db import models
|
||||
from app.shared.db.database import make_engine
|
||||
from app.web.db.crud import get_user_group_names
|
||||
|
||||
get_user_group_names.cache_clear()
|
||||
make_engine.cache_clear()
|
||||
engine = make_engine(get_settings.DATABASE_PATH)
|
||||
|
||||
fs = get_settings.DATABASE_PATH.replace("sqlite:///", "")
|
||||
if not os.path.exists(fs):
|
||||
open(fs, 'w').close()
|
||||
|
||||
models.Base.metadata.create_all(engine)
|
||||
|
||||
connection = engine.connect()
|
||||
yield connection
|
||||
connection.close()
|
||||
|
||||
models.Base.metadata.drop_all(bind=engine)
|
||||
for suffix in ["", "-wal", "-shm"]:
|
||||
new_fs = fs + suffix
|
||||
if os.path.exists(new_fs):
|
||||
os.remove(new_fs)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def db_session(test_db):
|
||||
from app.shared.db.database import make_session_local
|
||||
session_local = make_session_local(test_db)
|
||||
with session_local() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest_asyncio.fixture()
|
||||
async def async_test_db(get_settings: Settings):
|
||||
from app.shared.db import models
|
||||
from app.shared.db.database import make_async_engine
|
||||
from app.web.db.crud import get_user_group_names
|
||||
import asyncio
|
||||
|
||||
get_user_group_names.cache_clear()
|
||||
engine = await make_async_engine(get_settings.ASYNC_DATABASE_PATH)
|
||||
|
||||
fs = get_settings.ASYNC_DATABASE_PATH.replace("sqlite+aiosqlite:///", "")
|
||||
if not os.path.exists(fs):
|
||||
open(fs, 'w').close()
|
||||
|
||||
async def create_all():
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(models.Base.metadata.create_all)
|
||||
|
||||
await create_all()
|
||||
|
||||
yield engine
|
||||
|
||||
async def drop_all():
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(models.Base.metadata.drop_all)
|
||||
|
||||
await drop_all()
|
||||
|
||||
engine.dispose()
|
||||
for suffix in ["", "-wal", "-shm"]:
|
||||
new_fs = fs + suffix
|
||||
if os.path.exists(new_fs):
|
||||
os.remove(new_fs)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture()
|
||||
async def async_db_session(async_test_db: AsyncEngine) -> AsyncGenerator[AsyncSession, None]:
|
||||
from app.shared.db.database import make_async_session_local
|
||||
session_local = await make_async_session_local(async_test_db)
|
||||
async with session_local() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def app(db_session):
|
||||
from app.web.main import app_factory
|
||||
from app.web.db import crud
|
||||
app = app_factory()
|
||||
crud.upsert_user_groups(db_session)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client(app):
|
||||
client = TestClient(app)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def app_with_auth(app, db_session):
|
||||
from app.web.security import get_token_or_user_auth, get_user_auth, get_user_state
|
||||
app.dependency_overrides[get_token_or_user_auth] = lambda: "rick@example.com"
|
||||
app.dependency_overrides[get_user_auth] = lambda: "morty@example.com"
|
||||
app.dependency_overrides[get_user_state] = lambda: UserState(db_session, "MORTY@example.com")
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client_with_auth(app_with_auth):
|
||||
client = TestClient(app_with_auth)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def app_with_token(app):
|
||||
from app.web.security import token_api_key_auth, get_token_or_user_auth
|
||||
app.dependency_overrides[token_api_key_auth] = lambda: ALLOW_ANY_EMAIL
|
||||
app.dependency_overrides[get_token_or_user_auth] = lambda: ALLOW_ANY_EMAIL
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client_with_token(app_with_token):
|
||||
client = TestClient(app_with_token)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_no_auth():
|
||||
# reusable code to ensure a method/endpoint combination is unauthorized
|
||||
def no_auth(http_method, endpoint):
|
||||
response = http_method(endpoint)
|
||||
assert response.status_code == 403
|
||||
assert response.json() == {"detail": "Not authenticated"}
|
||||
return no_auth
|
||||
3
app/tests/fake_service_account.json
Normal file
3
app/tests/fake_service_account.json
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"client_email": "fake_service_account@fake_service_account.iam.gserviceaccount.com"
|
||||
}
|
||||
@@ -12,6 +12,8 @@ steps:
|
||||
- console_db
|
||||
|
||||
configurations:
|
||||
gsheet_feeder:
|
||||
service_account: "app/tests/fake_service_account.json"
|
||||
cli_feeder:
|
||||
urls:
|
||||
- "url1"
|
||||
@@ -1,5 +1,5 @@
|
||||
def test_generate_uuid():
|
||||
from db.models import generate_uuid
|
||||
from app.shared.db.models import generate_uuid
|
||||
|
||||
assert generate_uuid() != generate_uuid()
|
||||
assert len(generate_uuid()) == 36
|
||||
117
app/tests/shared/db/test_worker_crud.py
Normal file
117
app/tests/shared/db/test_worker_crud.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from app.shared.db import models
|
||||
from app.shared.db import worker_crud, models
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
from app.tests.web.db.test_crud import test_data
|
||||
|
||||
def test_update_sheet_last_url_archived_at(db_session):
|
||||
|
||||
# Create test sheet
|
||||
test_sheet = models.Sheet(id="sheet-123")
|
||||
db_session.add(test_sheet)
|
||||
db_session.commit()
|
||||
|
||||
# Test updating existing sheet
|
||||
assert isinstance(test_sheet.last_url_archived_at, datetime)
|
||||
before = test_sheet.last_url_archived_at
|
||||
assert worker_crud.update_sheet_last_url_archived_at(db_session, "sheet-123") is True
|
||||
db_session.refresh(test_sheet)
|
||||
assert isinstance(test_sheet.last_url_archived_at, datetime)
|
||||
assert test_sheet.last_url_archived_at > before
|
||||
|
||||
# Test non-existent sheet
|
||||
assert worker_crud.update_sheet_last_url_archived_at(db_session, "non-existent-sheet") is False
|
||||
|
||||
def test_get_group(test_data, db_session):
|
||||
from app.shared.db import worker_crud
|
||||
|
||||
assert worker_crud.get_group(db_session, "spaceship") is not None
|
||||
assert worker_crud.get_group(db_session, "interdimensional") is not None
|
||||
assert worker_crud.get_group(db_session, "animated-characters") is not None
|
||||
assert worker_crud.get_group(db_session, "non-existent!@#!%!") is None
|
||||
|
||||
|
||||
def test_create_or_get_user(test_data, db_session):
|
||||
from app.shared.db import worker_crud
|
||||
|
||||
assert db_session.query(models.User).count() == 3
|
||||
|
||||
# already exists
|
||||
assert (u1 := worker_crud.create_or_get_user(db_session, "rick@example.com")) is not None
|
||||
assert u1.email == "rick@example.com"
|
||||
|
||||
# new user
|
||||
assert (u2 := worker_crud.create_or_get_user(db_session, "beth@example.com")) is not None
|
||||
assert u2.email == "beth@example.com"
|
||||
|
||||
assert db_session.query(models.User).count() == 4
|
||||
|
||||
|
||||
def test_create_tag(db_session):
|
||||
from app.shared.db import worker_crud
|
||||
|
||||
assert db_session.query(models.Tag).count() == 0
|
||||
|
||||
# create first
|
||||
create_tag = worker_crud.create_tag(db_session, "tag-101")
|
||||
assert create_tag is not None
|
||||
assert create_tag.id == "tag-101"
|
||||
assert db_session.query(models.Tag).count() == 1
|
||||
assert db_session.query(models.Tag).filter(models.Tag.id == "tag-101").first() == create_tag
|
||||
|
||||
# same id does not add new db entry
|
||||
existing_tag = worker_crud.create_tag(db_session, "tag-101")
|
||||
assert existing_tag == create_tag
|
||||
assert db_session.query(models.Tag).count() == 1
|
||||
|
||||
# create second
|
||||
second_tag = worker_crud.create_tag(db_session, "tag-102")
|
||||
assert second_tag is not None
|
||||
assert second_tag.id == "tag-102"
|
||||
assert db_session.query(models.Tag).count() == 2
|
||||
|
||||
|
||||
def test_create_task(db_session):
|
||||
from app.shared.db import worker_crud
|
||||
from app.shared import schemas
|
||||
|
||||
task = schemas.ArchiveCreate(
|
||||
id="archive-id-456-101",
|
||||
url="https://example-0.com",
|
||||
result={},
|
||||
public=False,
|
||||
author_id="rick@example.com",
|
||||
group_id="spaceship",
|
||||
tags=[],
|
||||
urls=[]
|
||||
)
|
||||
|
||||
# with tags and urls
|
||||
nt = worker_crud.create_archive(db_session, task, [models.Tag(id="tag-101")], [models.ArchiveUrl(url="https://example-0.com/0", key="media_0")])
|
||||
|
||||
assert nt is not None
|
||||
assert nt.id == "archive-id-456-101"
|
||||
assert nt.url == "https://example-0.com"
|
||||
assert nt.author_id == "rick@example.com"
|
||||
assert nt.public == False
|
||||
assert nt.group_id == "spaceship"
|
||||
assert len(nt.tags) == 1
|
||||
assert nt.tags[0].id == "tag-101"
|
||||
assert len(nt.urls) == 1
|
||||
assert nt.urls[0].url == "https://example-0.com/0"
|
||||
assert nt.urls[0].key == "media_0"
|
||||
assert nt.created_at is not None
|
||||
|
||||
# without tags and urls
|
||||
task.id = "archive-id-456-102"
|
||||
nt = worker_crud.create_archive(db_session, task, [], [])
|
||||
assert nt is not None
|
||||
assert nt.id == "archive-id-456-102"
|
||||
assert nt.url == "https://example-0.com"
|
||||
assert nt.author_id == "rick@example.com"
|
||||
assert nt.public == False
|
||||
assert nt.group_id == "spaceship"
|
||||
assert len(nt.tags) == 0
|
||||
assert len(nt.urls) == 0
|
||||
assert nt.created_at is not None
|
||||
36
app/tests/shared/test_business_logic.py
Normal file
36
app/tests/shared/test_business_logic.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from app.shared.business_logic import get_store_archive_until
|
||||
|
||||
class Test_get_store_archive_until:
|
||||
GROUP_ID = "test-group"
|
||||
|
||||
def test_group_not_found(self, db_session):
|
||||
with pytest.raises(AssertionError) as exc:
|
||||
get_store_archive_until(db_session, self.GROUP_ID)
|
||||
assert str(exc.value) == f"Group {self.GROUP_ID} not found."
|
||||
|
||||
@patch("app.shared.db.worker_crud.get_group")
|
||||
def test_no_max_lifespan(self, mock_get_group, db_session):
|
||||
group = MagicMock()
|
||||
group.permissions = {"max_archive_lifespan_months": -1}
|
||||
mock_get_group.return_value = group
|
||||
|
||||
result = get_store_archive_until(db_session, self.GROUP_ID)
|
||||
assert result is None
|
||||
mock_get_group.assert_called_once_with(db_session, self.GROUP_ID)
|
||||
|
||||
@patch("app.shared.db.worker_crud.get_group")
|
||||
def test_with_max_lifespan(self, mock_get_group, db_session):
|
||||
group = MagicMock()
|
||||
group.permissions = {"max_archive_lifespan_months": 6}
|
||||
mock_get_group.return_value = group
|
||||
|
||||
result = get_store_archive_until(db_session, self.GROUP_ID)
|
||||
expected = datetime.now() + timedelta(days=180) # 6 months
|
||||
|
||||
assert isinstance(result, datetime)
|
||||
# Allow 1 second difference due to execution time
|
||||
assert abs(result - expected) < timedelta(seconds=1)
|
||||
mock_get_group.assert_called_once_with(db_session, self.GROUP_ID)
|
||||
31
app/tests/shared/utils/test_misc.py
Normal file
31
app/tests/shared/utils/test_misc.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from app.shared.utils.misc import fnv1a_hash_mod
|
||||
|
||||
|
||||
def test_fnv1a_hash_mod():
|
||||
# Test basic string hashing
|
||||
assert fnv1a_hash_mod("test", 10) == fnv1a_hash_mod("test", 10)
|
||||
assert 0 <= fnv1a_hash_mod("test", 10) < 10
|
||||
|
||||
# Test different strings give different hashes
|
||||
assert fnv1a_hash_mod("test1", 100) != fnv1a_hash_mod("test2", 100)
|
||||
|
||||
# Test different modulos
|
||||
hash1 = fnv1a_hash_mod("test", 5)
|
||||
hash2 = fnv1a_hash_mod("test", 10)
|
||||
assert 0 <= hash1 < 5
|
||||
assert 0 <= hash2 < 10
|
||||
|
||||
# Test empty string
|
||||
assert isinstance(fnv1a_hash_mod("", 10), int)
|
||||
assert 0 <= fnv1a_hash_mod("", 10) < 10
|
||||
|
||||
# Test long string
|
||||
long_str = "a" * 1000
|
||||
assert 0 <= fnv1a_hash_mod(long_str, 20) < 20
|
||||
|
||||
# Test unicode string
|
||||
assert isinstance(fnv1a_hash_mod("测试", 10), int)
|
||||
assert 0 <= fnv1a_hash_mod("测试", 10) < 10
|
||||
|
||||
# Test modulo = 1 edge case
|
||||
assert fnv1a_hash_mod("test", 1) == 0
|
||||
87
app/tests/user-groups.test.yaml
Normal file
87
app/tests/user-groups.test.yaml
Normal file
@@ -0,0 +1,87 @@
|
||||
# NOTE: all emails should be lower-cased
|
||||
users:
|
||||
rick@example.com:
|
||||
- spaceship
|
||||
- interdimensional
|
||||
morty@example.com:
|
||||
- spaceship
|
||||
jerry@example.com:
|
||||
- the-jerrys-club
|
||||
# summer@herself.com:
|
||||
# badyemail.com:
|
||||
|
||||
domains:
|
||||
example.com:
|
||||
- animated-characters
|
||||
birdy.com:
|
||||
- animated-characters
|
||||
- this-does-not-exist
|
||||
|
||||
|
||||
orchestrators:
|
||||
spaceship: app/tests/orchestration.test.yaml
|
||||
interdimensional: app/tests/orchestration.test.yaml
|
||||
default: app/tests/orchestration.test.yaml
|
||||
|
||||
default_orchestrator: app/tests/orchestration.test.yaml
|
||||
|
||||
groups:
|
||||
spaceship:
|
||||
description: "The spaceship crew"
|
||||
orchestrator: app/tests/orchestration.test.yaml
|
||||
orchestrator_sheet: app/tests/orchestration.test.yaml
|
||||
permissions:
|
||||
read: ["all"]
|
||||
archive_url: true
|
||||
archive_sheet: true
|
||||
manually_trigger_sheet: true
|
||||
sheet_frequency: ["hourly", "daily"]
|
||||
max_sheets: -1
|
||||
max_archive_lifespan_months: -1
|
||||
max_monthly_urls: -1
|
||||
max_monthly_mbs: -1
|
||||
priority: "high"
|
||||
interdimensional:
|
||||
description: "Interdimensional travelers"
|
||||
orchestrator: app/tests/orchestration.test.yaml
|
||||
orchestrator_sheet: app/tests/orchestration.test.yaml
|
||||
permissions:
|
||||
read: ["interdimensional", "animated-characters"]
|
||||
archive_url: true
|
||||
archive_sheet: true
|
||||
manually_trigger_sheet: true
|
||||
sheet_frequency: ["hourly", "daily"]
|
||||
max_sheets: 5
|
||||
max_archive_lifespan_months: 12
|
||||
max_monthly_urls: 1000
|
||||
max_monthly_mbs: 1000
|
||||
priority: "high"
|
||||
animated-characters:
|
||||
description: "Animated characters"
|
||||
orchestrator: app/tests/orchestration.test.yaml
|
||||
orchestrator_sheet: app/tests/orchestration.test.yaml
|
||||
permissions:
|
||||
read: ["animated-characters"]
|
||||
archive_url: true
|
||||
archive_sheet: true
|
||||
sheet_frequency: ["daily"]
|
||||
max_sheets: 1
|
||||
max_archive_lifespan_months: 12
|
||||
max_monthly_urls: 2
|
||||
max_monthly_mbs: 10
|
||||
priority: "low"
|
||||
default:
|
||||
description: "Public access"
|
||||
orchestrator: app/tests/orchestration.test.yaml
|
||||
orchestrator_sheet: app/tests/orchestration.test.yaml
|
||||
permissions:
|
||||
# read: []
|
||||
archive_url: true
|
||||
# manually_trigger_sheet: false
|
||||
# archive_sheet: false
|
||||
# sheet_frequency: []
|
||||
# max_sheets: 0
|
||||
# max_archive_lifespan_months: 12
|
||||
max_monthly_urls: 1
|
||||
# max_monthly_mbs: 50
|
||||
priority: "low"
|
||||
438
app/tests/web/db/test_crud.py
Normal file
438
app/tests/web/db/test_crud.py
Normal file
@@ -0,0 +1,438 @@
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from app.shared.db import models
|
||||
from app.shared.settings import Settings
|
||||
|
||||
from app.web.db import crud
|
||||
authors = ["rick@example.com", "morty@example.com", "jerry@example.com"]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_data(db_session):
|
||||
|
||||
# creates 3 users
|
||||
for email in authors:
|
||||
db_session.add(models.User(email=email))
|
||||
db_session.commit()
|
||||
assert db_session.query(models.User).count() == 3
|
||||
|
||||
# creates 100 archives for 3 users over 2 months with repeating URLs
|
||||
for i in range(100):
|
||||
author = authors[i % 3]
|
||||
archive = models.Archive(
|
||||
id=f"archive-id-456-{i}",
|
||||
url=f"https://example-{i%3}.com",
|
||||
result={},
|
||||
public=author == "jerry@example.com",
|
||||
author_id=author,
|
||||
group_id="spaceship" if author == "morty@example.com" and i % 2 == 0 else None,
|
||||
created_at=datetime(2021, (i % 2) + 1, (i % 25) + 1)
|
||||
)
|
||||
if i % 5 == 0:
|
||||
archive.tags.append(models.Tag(id=f"tag-{i}"))
|
||||
if i % 10 == 0:
|
||||
archive.tags.append(models.Tag(id=f"tag-second-{i}"))
|
||||
if i % 4 == 0:
|
||||
archive.tags.append(models.Tag(id=f"tag-third-{i}"))
|
||||
for j in range(10):
|
||||
archive.urls.append(models.ArchiveUrl(url=f"https://example-{i}.com/{j}", key=f"media_{j}"))
|
||||
db_session.add(archive)
|
||||
|
||||
# creates a sheet for each user
|
||||
for i, email in enumerate(authors):
|
||||
db_session.add(models.Sheet(id=f"sheet-{i}", name=f"sheet-{i}", author_id=email, group_id=None, frequency="daily"))
|
||||
if email == "rick@example.com":
|
||||
db_session.add(models.Sheet(id=f"sheet-{i}-2", name=f"sheet-{i}-2", author_id=email, group_id="spaceship", frequency="hourly"))
|
||||
|
||||
db_session.commit()
|
||||
|
||||
assert db_session.query(models.Archive).count() == 100
|
||||
assert db_session.query(models.Tag).count() == 20 + 10 + 25
|
||||
assert db_session.query(models.ArchiveUrl).count() == 1000
|
||||
assert db_session.query(models.ArchiveUrl).filter(models.ArchiveUrl.archive_id == "archive-id-456-0").count() == 10
|
||||
|
||||
# setup groups
|
||||
assert db_session.query(models.Group).count() == 0
|
||||
from app.web.db import crud
|
||||
crud.upsert_user_groups(db_session)
|
||||
assert db_session.query(models.Group).count() == 4
|
||||
assert db_session.query(models.User).count() == 3
|
||||
|
||||
|
||||
def test_search_archives_by_url(test_data, db_session):
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
|
||||
# rick's archives are private
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "rick@example.com", True, False)) == 34
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "rick@example.com", [], False)) == 34
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "rick@example.com", [], True)) == 34
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", ALLOW_ANY_EMAIL, [], False)) == 34
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", ALLOW_ANY_EMAIL, True, False)) == 34
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "morty@example.com", [], False)) == 0
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "morty@example.com", [], True)) == 0
|
||||
|
||||
# morty's archives are public but half are in spaceship group
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-1.com", "rick@example.com", ["spaceship"], False)) == 16
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-1.com", "rick@example.com", True, False)) == 16
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-1.com", "jerry@example.com", True, True)) == 16
|
||||
|
||||
# jerry's archives are public
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-2.com", "jerry@example.com", [], True)) == 33
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-2.com", "rick@example.com", [], True)) == 33
|
||||
|
||||
# fuzzy search
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False)) == 100
|
||||
assert len(crud.search_archives_by_url(db_session, "https://EXAMPLE", ALLOW_ANY_EMAIL, False, False)) == 100
|
||||
assert len(crud.search_archives_by_url(db_session, "2.com", ALLOW_ANY_EMAIL, False, False)) == 33
|
||||
|
||||
# absolute search
|
||||
assert len(crud.search_archives_by_url(db_session, "example-2.com", ALLOW_ANY_EMAIL, [], False, absolute_search=True)) == 0
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-2.com", ALLOW_ANY_EMAIL, [], False, absolute_search=True)) == 33
|
||||
|
||||
# archived_after
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, True, True, archived_after=datetime(2010, 1, 1))) == 100
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_after=datetime(2021, 1, 15))) == 70
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_after=datetime(2031, 1, 1))) == 0
|
||||
|
||||
# archived before
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_before=datetime(2010, 1, 1))) == 0
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_before=datetime(2021, 1, 15))) == 28
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_before=datetime(2031, 1, 1))) == 100
|
||||
|
||||
# archived before and after
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_after=datetime(2001, 1, 1), archived_before=datetime(2031, 1, 11))) == 100
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_after=datetime(2021, 1, 14), archived_before=datetime(2021, 1, 16))) == 2
|
||||
|
||||
# limit
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, limit=10)) == 10
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, limit=-1)) == 1
|
||||
|
||||
# skip
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, skip=10)) == 90
|
||||
|
||||
|
||||
def test_search_archives_by_email(test_data, db_session):
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
|
||||
# lower/upper case
|
||||
assert len(crud.search_archives_by_email(db_session, "rick@example.com")) == 34
|
||||
|
||||
# ALLOW_ANY_EMAIL is not a user
|
||||
assert len(crud.search_archives_by_email(db_session, ALLOW_ANY_EMAIL)) == 0
|
||||
|
||||
# most recent first
|
||||
a1 = crud.search_archives_by_email(db_session, "rick@example.com", limit=1)
|
||||
assert len(a1) == 1
|
||||
assert a1[0].created_at == datetime(2021, 2, 25)
|
||||
|
||||
# earliest is the last
|
||||
a2 = crud.search_archives_by_email(db_session, "rick@example.com", skip=33)
|
||||
assert len(a2) == 1
|
||||
assert a2[0].created_at == datetime(2021, 1, 1)
|
||||
|
||||
|
||||
@patch("app.web.db.crud.DATABASE_QUERY_LIMIT", new=25)
|
||||
def test_max_query_limit(test_data, db_session):
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, [], False)) == 25
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, True, True, limit=1000)) == 25
|
||||
|
||||
assert len(crud.search_archives_by_email(db_session, "rick@example.com")) == 25
|
||||
assert len(crud.search_archives_by_email(db_session, "rick@example.com", limit=1000)) == 25
|
||||
|
||||
|
||||
def test_soft_delete(test_data, db_session):
|
||||
# none deleted yet
|
||||
db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").first() is not None
|
||||
assert db_session.query(models.Archive).filter(models.Archive.deleted == True).count() == 0
|
||||
|
||||
# delete
|
||||
assert crud.soft_delete_archive(db_session, "archive-id-456-0", "rick@example.com") == True
|
||||
|
||||
# ensure soft delete
|
||||
assert db_session.query(models.Archive).filter(models.Archive.deleted == True).count() == 1
|
||||
db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").first() is None
|
||||
|
||||
# already deleted
|
||||
assert crud.soft_delete_archive(db_session, "archive-id-456-0", "rick@example.com") == False
|
||||
|
||||
|
||||
def test_count_archives(test_data, db_session):
|
||||
assert crud.count_archives(db_session) == 100
|
||||
db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").delete()
|
||||
db_session.commit()
|
||||
assert crud.count_archives(db_session) == 99
|
||||
|
||||
|
||||
def test_count_archive_urls(test_data, db_session):
|
||||
assert crud.count_archive_urls(db_session) == 1000
|
||||
db_session.query(models.ArchiveUrl).filter(models.ArchiveUrl.url == "https://example-0.com/0").delete()
|
||||
db_session.commit()
|
||||
assert crud.count_archive_urls(db_session) == 999
|
||||
|
||||
db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").delete()
|
||||
db_session.commit()
|
||||
# no Cascade is enabled
|
||||
assert crud.count_archives(db_session) == 99
|
||||
assert crud.count_archive_urls(db_session) == 999
|
||||
|
||||
|
||||
def test_count_users(test_data, db_session):
|
||||
assert crud.count_users(db_session) == 3
|
||||
db_session.query(models.User).filter(models.User.email == "rick@example.com").delete()
|
||||
db_session.commit()
|
||||
assert crud.count_users(db_session) == 2
|
||||
|
||||
|
||||
def test_count_by_users_since(test_data, db_session):
|
||||
from app.web.db import crud
|
||||
|
||||
# 100y window
|
||||
assert len(cu := crud.count_by_user_since(db_session, 60 * 60 * 24 * 31 * 12 * 100)) == 3
|
||||
assert cu[0].total == 34
|
||||
assert cu[1].total == 33
|
||||
assert cu[2].total == 33
|
||||
|
||||
|
||||
def test_upsert_group(test_data, db_session):
|
||||
assert db_session.query(models.Group).count() == 4
|
||||
|
||||
repeatable_params = ["desc 1", "orch.yaml", "sheet.yaml", "service_account_email@example.com", {"read": ["all"]}, ["example.com"]]
|
||||
|
||||
assert (g1 := crud.upsert_group(db_session, "spaceship", *repeatable_params)) is not None
|
||||
assert g1.id == "spaceship"
|
||||
assert g1.description == "desc 1"
|
||||
assert g1.orchestrator == "orch.yaml"
|
||||
assert g1.orchestrator_sheet == "sheet.yaml"
|
||||
assert g1.service_account_email == "service_account_email@example.com"
|
||||
assert g1.permissions == {"read": ["all"]}
|
||||
assert g1.domains == ["example.com"]
|
||||
assert len(g1.users) == 2
|
||||
assert [u.email for u in g1.users] == ["rick@example.com", "morty@example.com"]
|
||||
|
||||
assert (g2 := crud.upsert_group(db_session, "interdimensional", *repeatable_params)) is not None
|
||||
assert g2.id == "interdimensional"
|
||||
assert len(g2.users) == 1
|
||||
assert [u.email for u in g2.users] == ["rick@example.com"]
|
||||
|
||||
assert (g3 := crud.upsert_group(db_session, "this-is-a-new-group", *repeatable_params)) is not None
|
||||
assert g3.id == "this-is-a-new-group"
|
||||
assert len(g3.users) == 0
|
||||
|
||||
assert db_session.query(models.Group).count() == 5
|
||||
|
||||
|
||||
def test_upsert_user_groups(db_session):
|
||||
@patch('app.web.db.crud.get_settings', new=lambda: bad_setings)
|
||||
def test_missing_yaml(db_session):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
crud.upsert_user_groups(db_session)
|
||||
|
||||
@patch('app.web.db.crud.get_settings', new=lambda: bad_setings)
|
||||
def test_broken_yaml(db_session):
|
||||
with pytest.raises(yaml.YAMLError):
|
||||
crud.upsert_user_groups(db_session)
|
||||
|
||||
bad_setings = Settings(_env_file=".env.test")
|
||||
|
||||
bad_setings.USER_GROUPS_FILENAME = "app/tests/user-groups.test.missing.yaml"
|
||||
test_missing_yaml(db_session)
|
||||
|
||||
bad_setings.USER_GROUPS_FILENAME = "app/tests/user-groups.test.broken.yaml"
|
||||
test_broken_yaml(db_session)
|
||||
|
||||
|
||||
def test_create_sheet(db_session):
|
||||
assert db_session.query(models.Sheet).count() == 0
|
||||
|
||||
s = crud.create_sheet(db_session, "sheet-id-123", "sheet name", "email@example.com", "group-id", "hourly")
|
||||
assert s is not None
|
||||
assert s.id == "sheet-id-123"
|
||||
assert s.name == "sheet name"
|
||||
assert s.author_id == "email@example.com"
|
||||
assert s.group_id == "group-id"
|
||||
assert s.frequency == "hourly"
|
||||
|
||||
assert db_session.query(models.Sheet).count() == 1
|
||||
|
||||
# duplicate id
|
||||
import sqlalchemy
|
||||
with pytest.raises(sqlalchemy.exc.IntegrityError):
|
||||
crud.create_sheet(db_session, "sheet-id-123", "I thought this was another sheet", "email", "group-id", "hourly")
|
||||
|
||||
|
||||
def test_get_user_sheet(test_data, db_session):
|
||||
assert crud.get_user_sheet(db_session, "", "sheet-0") is None
|
||||
assert crud.get_user_sheet(db_session, "morty@example.com", "sheet-0") is None
|
||||
|
||||
assert crud.get_user_sheet(db_session, "rick@example.com", "sheet-0") is not None
|
||||
assert crud.get_user_sheet(db_session, "rick@example.com", "sheet-0-2") is not None
|
||||
assert crud.get_user_sheet(db_session, "morty@example.com", "sheet-1") is not None
|
||||
|
||||
|
||||
def test_get_user_sheets(test_data, db_session):
|
||||
assert len(crud.get_user_sheets(db_session, "")) == 0
|
||||
rick_sheets = crud.get_user_sheets(db_session, "rick@example.com")
|
||||
assert len(rick_sheets) == 2
|
||||
assert [s.id for s in rick_sheets] == ["sheet-0", "sheet-0-2"]
|
||||
assert len(crud.get_user_sheets(db_session, "morty@example.com")) == 1
|
||||
|
||||
|
||||
def test_delete_sheet(test_data, db_session):
|
||||
assert crud.delete_sheet(db_session, "sheet-0", "") == False
|
||||
assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") == True
|
||||
assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") == False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_by_store_until(async_db_session):
|
||||
# Add archives with different store_until dates
|
||||
now = datetime.now()
|
||||
archive1 = models.Archive(
|
||||
id="archive-expired-1",
|
||||
url="https://example-expired-1.com",
|
||||
result={},
|
||||
author_id="rick@example.com",
|
||||
store_until=now - timedelta(days=1)
|
||||
)
|
||||
archive2 = models.Archive(
|
||||
id="archive-expired-2",
|
||||
url="https://example-expired-2.com",
|
||||
result={},
|
||||
author_id="rick@example.com",
|
||||
store_until=now - timedelta(hours=1)
|
||||
)
|
||||
archive3 = models.Archive(
|
||||
id="archive-active",
|
||||
url="https://example-active.com",
|
||||
result={},
|
||||
author_id="rick@example.com",
|
||||
store_until=now + timedelta(days=1)
|
||||
)
|
||||
async_db_session.add_all([archive1, archive2, archive3])
|
||||
await async_db_session.commit()
|
||||
|
||||
# Should find 2 expired archives
|
||||
expired = await crud.find_by_store_until(async_db_session, now)
|
||||
assert len(list(expired)) == 2
|
||||
|
||||
# Should find 1 archive expired before 2 hours ago
|
||||
expired = await crud.find_by_store_until(async_db_session, now - timedelta(hours=2))
|
||||
assert len(list(expired)) == 1
|
||||
|
||||
# Should find no archives expired before 2 days ago
|
||||
expired = await crud.find_by_store_until(async_db_session, now - timedelta(days=2))
|
||||
assert len(list(expired)) == 0
|
||||
|
||||
# Should not find deleted archives
|
||||
archive1.deleted = True
|
||||
await async_db_session.commit()
|
||||
expired = await crud.find_by_store_until(async_db_session, now)
|
||||
assert len(list(expired)) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_sheets_by_id_hash(async_db_session):
|
||||
# Add test data
|
||||
authors = ["rick@example.com", "morty@example.com", "jerry@example.com"]
|
||||
sheets = [
|
||||
models.Sheet(id="sheet-0", name="sheet-0", author_id=authors[0], group_id=None, frequency="daily"),
|
||||
models.Sheet(id="sheet-0-2", name="sheet-0-2", author_id=authors[0], group_id="spaceship", frequency="hourly"),
|
||||
models.Sheet(id="sheet-1", name="sheet-1", author_id=authors[1], group_id=None, frequency="daily"),
|
||||
models.Sheet(id="sheet-2", name="sheet-2", author_id=authors[2], group_id=None, frequency="daily")
|
||||
]
|
||||
async_db_session.add_all(sheets)
|
||||
await async_db_session.commit()
|
||||
|
||||
with patch("app.web.db.crud.fnv1a_hash_mod", return_value=1):
|
||||
# Test retrieving hourly sheets
|
||||
hourly_sheets = await crud.get_sheets_by_id_hash(async_db_session, "hourly", 4, 1)
|
||||
assert len(hourly_sheets) == 1
|
||||
assert hourly_sheets[0].id == "sheet-0-2"
|
||||
assert hourly_sheets[0].frequency == "hourly"
|
||||
|
||||
# Test retrieving daily sheets
|
||||
daily_sheets = await crud.get_sheets_by_id_hash(async_db_session, "daily", 4, 1)
|
||||
assert len(daily_sheets) == 3
|
||||
assert all(sheet.frequency == "daily" for sheet in daily_sheets)
|
||||
assert {sheet.id for sheet in daily_sheets} == {"sheet-0", "sheet-1", "sheet-2"}
|
||||
|
||||
# Test with non-matching hash
|
||||
no_sheets = await crud.get_sheets_by_id_hash(async_db_session, "daily", 4, 3)
|
||||
assert len(no_sheets) == 0
|
||||
|
||||
# Test with non-existent frequency
|
||||
weekly_sheets = await crud.get_sheets_by_id_hash(async_db_session, "weekly", 4, 1)
|
||||
assert len(weekly_sheets) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_stale_sheets(async_db_session):
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.sql import select
|
||||
|
||||
now = datetime.now()
|
||||
active_date = now - timedelta(days=5)
|
||||
stale_date = now - timedelta(days=15)
|
||||
|
||||
# Create test sheets with different last_url_archived_at dates
|
||||
sheets = [
|
||||
models.Sheet(
|
||||
id="sheet-active-1",
|
||||
name="Active Sheet 1",
|
||||
author_id="rick@example.com",
|
||||
frequency="daily",
|
||||
last_url_archived_at=active_date
|
||||
),
|
||||
models.Sheet(
|
||||
id="sheet-active-2",
|
||||
name="Active Sheet 2",
|
||||
author_id="morty@example.com",
|
||||
frequency="hourly",
|
||||
last_url_archived_at=active_date
|
||||
),
|
||||
models.Sheet(
|
||||
id="sheet-stale-1",
|
||||
name="Stale Sheet 1",
|
||||
author_id="rick@example.com",
|
||||
frequency="daily",
|
||||
last_url_archived_at=stale_date
|
||||
),
|
||||
models.Sheet(
|
||||
id="sheet-stale-2",
|
||||
name="Stale Sheet 2",
|
||||
author_id="morty@example.com",
|
||||
frequency="daily",
|
||||
last_url_archived_at=stale_date
|
||||
)
|
||||
]
|
||||
async_db_session.add_all(sheets)
|
||||
await async_db_session.commit()
|
||||
|
||||
# Should not delete sheets with 20 days inactivity threshold
|
||||
deleted = await crud.delete_stale_sheets(async_db_session, 20)
|
||||
assert len(deleted) == 0 # No sheets should be deleted
|
||||
result = await async_db_session.execute(select(models.Sheet))
|
||||
assert len(list(result.scalars())) == 4 # All sheets should remain
|
||||
|
||||
# Should delete sheets with 7 days inactivity threshold
|
||||
deleted = await crud.delete_stale_sheets(async_db_session, 7)
|
||||
assert len(deleted) == 2 # Two authors affected
|
||||
assert len(deleted["rick@example.com"]) == 1 # One sheet deleted for Rick
|
||||
assert len(deleted["morty@example.com"]) == 1 # One sheet deleted for Morty
|
||||
assert deleted["rick@example.com"][0].id == "sheet-stale-1"
|
||||
assert deleted["morty@example.com"][0].id == "sheet-stale-2"
|
||||
|
||||
# Verify only active sheets remain
|
||||
result = await async_db_session.execute(select(models.Sheet))
|
||||
remaining = list(result.scalars())
|
||||
assert len(remaining) == 2
|
||||
assert {s.id for s in remaining} == {"sheet-active-1", "sheet-active-2"}
|
||||
|
||||
# Running again should not delete anything
|
||||
deleted = await crud.delete_stale_sheets(async_db_session, 7)
|
||||
assert len(deleted) == 0
|
||||
433
app/tests/web/db/test_user_state.py
Normal file
433
app/tests/web/db/test_user_state.py
Normal file
@@ -0,0 +1,433 @@
|
||||
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
import pytest
|
||||
|
||||
from app.shared.db import models
|
||||
from app.shared.user_groups import GroupInfo, GroupPermissions
|
||||
from app.web.db.user_state import UserState
|
||||
|
||||
|
||||
def fresh_user_state():
|
||||
return UserState(None, email="test@example.com")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_state():
|
||||
return fresh_user_state()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_state_with_groups(user_state):
|
||||
user_groups = [
|
||||
models.Group(id="no-permissions", permissions={}),
|
||||
models.Group(id="group1", description="this is g1", service_account_email="sa1@example.com", permissions={"read": ["group1", "no-permissions"], "read_public": True, "archive_url": True, "archive_sheet": True, "max_archive_lifespan_months": 24, "max_monthly_urls": 100, "max_monthly_mbs": 1000, "priority": "high"}),
|
||||
models.Group(id="group2", description="this is g2", service_account_email="sa2@example.com", permissions={"read": ["all"], "read_public": True, "archive_url": False, "archive_sheet": False, "max_archive_lifespan_months": -1, "max_monthly_urls": -1, "max_monthly_mbs": -1, "priority": "low", "sheet_frequency": {"daily"}}),
|
||||
]
|
||||
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=user_groups):
|
||||
yield user_state
|
||||
|
||||
|
||||
def test_permissions(user_state_with_groups):
|
||||
permissions = user_state_with_groups.permissions
|
||||
|
||||
assert permissions["all"].read == True
|
||||
assert permissions["all"].read_public == True
|
||||
assert permissions["all"].archive_url == True
|
||||
assert permissions["all"].archive_sheet == True
|
||||
assert permissions["all"].max_archive_lifespan_months == -1
|
||||
assert permissions["all"].max_monthly_urls == -1
|
||||
assert permissions["all"].max_monthly_mbs == -1
|
||||
assert permissions["all"].priority == "high"
|
||||
|
||||
assert permissions["group1"].read == set(["group1", "no-permissions"])
|
||||
assert permissions["group1"].read_public == True
|
||||
assert permissions["group1"].archive_url == True
|
||||
assert permissions["group1"].archive_sheet == True
|
||||
assert permissions["group1"].max_archive_lifespan_months == 24
|
||||
assert permissions["group1"].max_monthly_urls == 100
|
||||
assert permissions["group1"].max_monthly_mbs == 1000
|
||||
assert permissions["group1"].priority == "high"
|
||||
|
||||
assert permissions["group2"].read == set(["all"])
|
||||
assert permissions["group2"].read_public == True
|
||||
assert permissions["group2"].archive_url == False
|
||||
assert permissions["group2"].archive_sheet == False
|
||||
assert permissions["group2"].max_archive_lifespan_months == -1
|
||||
assert permissions["group2"].max_monthly_urls == -1
|
||||
assert permissions["group2"].max_monthly_mbs == -1
|
||||
assert permissions["group2"].priority == "low"
|
||||
|
||||
assert len(permissions) == 3
|
||||
|
||||
|
||||
def test_user_groups_names(user_state):
|
||||
with patch('app.web.db.crud.get_user_group_names', return_value=["group1", "group2"]) as mock:
|
||||
assert user_state.user_groups_names == ["group1", "group2", "default"]
|
||||
mock.assert_called_once_with(None, "test@example.com")
|
||||
|
||||
|
||||
def test_user_groups(user_state):
|
||||
with patch('app.web.db.crud.get_user_groups_by_name', return_value=[MagicMock(), MagicMock()]) as mock:
|
||||
user_state._user_groups_names = ["group1", "group2"]
|
||||
assert len(user_state.user_groups) == 2
|
||||
mock.assert_called_once_with(None, ["group1", "group2"])
|
||||
|
||||
|
||||
def test_read():
|
||||
us = fresh_user_state()
|
||||
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
|
||||
assert not hasattr(us, "_read")
|
||||
assert us.read == set()
|
||||
assert us._read == set()
|
||||
mock.assert_called_once()
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"read": ["group1", "no-permissions"]})]):
|
||||
assert us.read == set(["group1", "no-permissions"])
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"read": ["all"]})]):
|
||||
assert us.read == True
|
||||
|
||||
|
||||
def test_read_public():
|
||||
us = fresh_user_state()
|
||||
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
|
||||
assert not hasattr(us, "_read_public")
|
||||
assert us.read_public == False
|
||||
assert us._read_public == False
|
||||
mock.assert_called_once()
|
||||
# no new calls
|
||||
assert us.read_public == False
|
||||
mock.assert_called_once()
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"read_public": True})]):
|
||||
assert us.read_public == True
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"read_public": False})]):
|
||||
assert us.read_public == False
|
||||
|
||||
|
||||
def test_archive_url():
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
|
||||
assert not hasattr(us, "_archive_url")
|
||||
assert us.archive_url == False
|
||||
assert us._archive_url == False
|
||||
mock.assert_called_once()
|
||||
# no new calls
|
||||
assert us.archive_url == False
|
||||
mock.assert_called_once()
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"archive_url": False})]):
|
||||
assert us.archive_url == False
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"archive_url": True})]):
|
||||
assert us.archive_url == True
|
||||
|
||||
|
||||
def test_archive_sheet():
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
|
||||
assert not hasattr(us, "_archive_sheet")
|
||||
assert us.archive_sheet == False
|
||||
assert us._archive_sheet == False
|
||||
mock.assert_called_once()
|
||||
# no new calls
|
||||
assert us.archive_sheet == False
|
||||
mock.assert_called_once()
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"archive_sheet": False})]):
|
||||
assert us.archive_sheet == False
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"archive_sheet": True})]):
|
||||
assert us.archive_sheet == True
|
||||
|
||||
|
||||
def test_sheet_frequency():
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
|
||||
assert not hasattr(us, "_sheet_frequency")
|
||||
assert us.sheet_frequency == set()
|
||||
assert us._sheet_frequency == set()
|
||||
mock.assert_called_once()
|
||||
# no new calls
|
||||
assert us.sheet_frequency == set()
|
||||
mock.assert_called_once()
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"sheet_frequency": ["daily", "hourly"]})]):
|
||||
assert us.sheet_frequency == {"daily", "hourly"}
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"sheet_frequency": []})]):
|
||||
assert us.sheet_frequency == set()
|
||||
|
||||
|
||||
def test_max_archive_lifespan_months():
|
||||
us = fresh_user_state()
|
||||
default = GroupPermissions.model_fields["max_archive_lifespan_months"].default
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
|
||||
assert not hasattr(us, "_max_archive_lifespan_months")
|
||||
assert us.max_archive_lifespan_months == default
|
||||
assert us._max_archive_lifespan_months == default
|
||||
mock.assert_called_once()
|
||||
# no new calls
|
||||
assert us.max_archive_lifespan_months == default
|
||||
mock.assert_called_once()
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_archive_lifespan_months": 24})]):
|
||||
assert us.max_archive_lifespan_months == 24
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_archive_lifespan_months": 150}), models.Group(id="group2", permissions={"max_archive_lifespan_months": -1})]):
|
||||
assert us.max_archive_lifespan_months == -1
|
||||
|
||||
|
||||
def test_max_monthly_urls():
|
||||
us = fresh_user_state()
|
||||
default = GroupPermissions.model_fields["max_monthly_urls"].default
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
|
||||
assert not hasattr(us, "_max_monthly_urls")
|
||||
assert us.max_monthly_urls == default
|
||||
assert us._max_monthly_urls == default
|
||||
mock.assert_called_once()
|
||||
# no new calls
|
||||
assert us.max_monthly_urls == default
|
||||
mock.assert_called_once()
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_monthly_urls": 100})]):
|
||||
assert us.max_monthly_urls == 100
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_monthly_urls": 150}), models.Group(id="group2", permissions={"max_monthly_urls": -1})]):
|
||||
assert us.max_monthly_urls == -1
|
||||
|
||||
|
||||
def test_max_monthly_mbs():
|
||||
us = fresh_user_state()
|
||||
default = GroupPermissions.model_fields["max_monthly_mbs"].default
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
|
||||
assert not hasattr(us, "_max_monthly_mbs")
|
||||
assert us.max_monthly_mbs == default
|
||||
assert us._max_monthly_mbs == default
|
||||
mock.assert_called_once()
|
||||
# no new calls
|
||||
assert us.max_monthly_mbs == default
|
||||
mock.assert_called_once()
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_monthly_mbs": 1000})]):
|
||||
assert us.max_monthly_mbs == 1000
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_monthly_mbs": 1500}), models.Group(id="group2", permissions={"max_monthly_mbs": -1})]):
|
||||
assert us.max_monthly_mbs == -1
|
||||
|
||||
|
||||
def test_priority(user_state):
|
||||
default = GroupPermissions.model_fields["priority"].default
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
|
||||
assert not hasattr(user_state, "_priority")
|
||||
assert user_state.priority == default
|
||||
assert user_state._priority == default
|
||||
mock.assert_called_once()
|
||||
# no new calls
|
||||
assert user_state.priority == default
|
||||
mock.assert_called_once()
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"priority": "high"})]):
|
||||
assert us.priority == "high"
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"priority": "low"}), models.Group(id="group2", permissions={"priority": "medium"})]):
|
||||
assert us.priority == "low"
|
||||
|
||||
|
||||
def test_active():
|
||||
for read, read_public, archive_url, archive_sheet, is_active in [
|
||||
(False, False, False, False, False),
|
||||
(True, False, False, False, True),
|
||||
(False, True, False, False, True),
|
||||
(False, False, True, False, True),
|
||||
(False, False, False, True, True)
|
||||
]:
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'read', new_callable=PropertyMock, return_value=read), \
|
||||
patch.object(UserState, 'read_public', new_callable=PropertyMock, return_value=read_public), \
|
||||
patch.object(UserState, 'archive_url', new_callable=PropertyMock, return_value=archive_url), \
|
||||
patch.object(UserState, 'archive_sheet', new_callable=PropertyMock, return_value=archive_sheet):
|
||||
assert us.active == is_active
|
||||
|
||||
|
||||
def test_in_group(user_state):
|
||||
with patch.object(UserState, 'user_groups_names', new_callable=PropertyMock, return_value=["group1", "group2"]):
|
||||
assert user_state.in_group("group1") == True
|
||||
assert user_state.in_group("group2") == True
|
||||
assert user_state.in_group("group3") == False
|
||||
|
||||
|
||||
def test_usage(db_session):
|
||||
user_state = UserState(db_session, email="test@example.com")
|
||||
user_sheets = [
|
||||
MagicMock(group_id="group1", sheet_count=5),
|
||||
MagicMock(group_id="group2", sheet_count=10),
|
||||
MagicMock(group_id="group3", sheet_count=100),
|
||||
]
|
||||
bytes = [1000000, 2000000, 3000000]
|
||||
urls_by_group = [
|
||||
MagicMock(group_id="group1", url_count=50, total_bytes=bytes[0]),
|
||||
MagicMock(group_id="group2", url_count=100, total_bytes=bytes[1]),
|
||||
MagicMock(group_id="group4", url_count=5, total_bytes=bytes[2]),
|
||||
]
|
||||
megabytes = int(sum(bytes) / 1024 / 1024)
|
||||
|
||||
with patch.object(db_session, 'query', side_effect=[
|
||||
MagicMock(filter=MagicMock(return_value=MagicMock(group_by=MagicMock(return_value=MagicMock(all=MagicMock(return_value=user_sheets)))))),
|
||||
MagicMock(filter=MagicMock(return_value=MagicMock(group_by=MagicMock(return_value=MagicMock(all=MagicMock(return_value=urls_by_group))))))
|
||||
]):
|
||||
usage_response = user_state.usage()
|
||||
|
||||
assert usage_response.monthly_urls == 155
|
||||
assert usage_response.monthly_mbs == megabytes
|
||||
assert usage_response.total_sheets == 115
|
||||
|
||||
assert usage_response.groups["group1"].monthly_urls == 50
|
||||
assert usage_response.groups["group1"].monthly_mbs == int(bytes[0] / 1024 / 1024)
|
||||
assert usage_response.groups["group1"].total_sheets == 5
|
||||
|
||||
assert usage_response.groups["group2"].monthly_urls == 100
|
||||
assert usage_response.groups["group2"].monthly_mbs == int(bytes[1] / 1024 / 1024)
|
||||
assert usage_response.groups["group2"].total_sheets == 10
|
||||
|
||||
assert usage_response.groups["group3"].monthly_urls == 0
|
||||
assert usage_response.groups["group3"].monthly_mbs == 0
|
||||
assert usage_response.groups["group3"].total_sheets == 100
|
||||
|
||||
assert usage_response.groups["group4"].monthly_urls == 5
|
||||
assert usage_response.groups["group4"].monthly_mbs == int(bytes[2] / 1024 / 1024)
|
||||
assert usage_response.groups["group4"].total_sheets == 0
|
||||
|
||||
|
||||
def test_has_quota_monthly_sheets(db_session):
|
||||
us = UserState(db_session, email="test@example.com")
|
||||
|
||||
test_cases = [
|
||||
({"unkonwn": GroupInfo(max_sheets=5)}, 1, False),
|
||||
({"group1": GroupInfo(max_sheets=-1)}, 1000, True),
|
||||
({"group1": GroupInfo(max_sheets=5)}, 3, True),
|
||||
({"group1": GroupInfo(max_sheets=5)}, 5, False),
|
||||
({"group1": GroupInfo(max_sheets=5)}, 6, False),
|
||||
]
|
||||
|
||||
for permissions, count, expected in test_cases:
|
||||
with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions):
|
||||
with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(count=MagicMock(return_value=count))))):
|
||||
assert us.has_quota_monthly_sheets("group1") == expected
|
||||
|
||||
|
||||
def test_has_quota_max_monthly_urls(db_session):
|
||||
us = UserState(db_session, email="test@example.com")
|
||||
|
||||
test_cases = [
|
||||
({"group1": GroupInfo(max_monthly_urls=-1)}, 1000, True),
|
||||
({"group1": GroupInfo(max_monthly_urls=100)}, 50, True),
|
||||
({"group1": GroupInfo(max_monthly_urls=100)}, 100, False),
|
||||
({"group1": GroupInfo(max_monthly_urls=100)}, 150, False),
|
||||
]
|
||||
|
||||
for permissions, count, expected in test_cases:
|
||||
with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions):
|
||||
with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(count=MagicMock(return_value=count))))):
|
||||
assert us.has_quota_max_monthly_urls("group1") == expected
|
||||
test_cases = [
|
||||
(-1, 1000, True),
|
||||
(100, 50, True),
|
||||
(100, 100, False),
|
||||
(100, 150, False),
|
||||
]
|
||||
|
||||
for max_urls, count, expected in test_cases:
|
||||
with patch.object(UserState, 'max_monthly_urls', new_callable=PropertyMock, return_value=max_urls):
|
||||
with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(count=MagicMock(return_value=count))))):
|
||||
assert us.has_quota_max_monthly_urls("") == expected
|
||||
|
||||
|
||||
def test_has_quota_max_monthly_mbs(db_session):
|
||||
us = UserState(db_session, email="test@example.com")
|
||||
|
||||
test_cases = [
|
||||
({"group1": GroupInfo(max_monthly_mbs=-1)}, 1000, True),
|
||||
({"group1": GroupInfo(max_monthly_mbs=100)}, 50, True),
|
||||
({"group1": GroupInfo(max_monthly_mbs=100)}, 100, False),
|
||||
({"group1": GroupInfo(max_monthly_mbs=100)}, 150, False),
|
||||
]
|
||||
|
||||
for permissions, mbs, expected in test_cases:
|
||||
with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions):
|
||||
with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(with_entities=MagicMock(return_value=MagicMock(scalar=MagicMock(return_value=mbs * 1024 * 1024))))))):
|
||||
assert us.has_quota_max_monthly_mbs("group1") == expected
|
||||
|
||||
test_cases = [
|
||||
(-1, 1000, True),
|
||||
(100, 50, True),
|
||||
(100, 100, False),
|
||||
(100, 150, False),
|
||||
]
|
||||
|
||||
for max_mbs, mbs, expected in test_cases:
|
||||
with patch.object(UserState, 'max_monthly_mbs', new_callable=PropertyMock, return_value=max_mbs):
|
||||
with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(with_entities=MagicMock(return_value=MagicMock(scalar=MagicMock(return_value=mbs * 1024 * 1024))))))):
|
||||
assert us.has_quota_max_monthly_mbs("") == expected
|
||||
|
||||
|
||||
def test_can_manually_trigger(user_state):
|
||||
permissions = {
|
||||
"group1": GroupInfo(manually_trigger_sheet=True),
|
||||
"group2": GroupInfo(manually_trigger_sheet=False),
|
||||
}
|
||||
|
||||
with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions):
|
||||
assert user_state.can_manually_trigger("group1") == True
|
||||
assert user_state.can_manually_trigger("group2") == False
|
||||
assert user_state.can_manually_trigger("group3") == False
|
||||
|
||||
|
||||
def test_is_sheet_frequency_allowed(user_state):
|
||||
permissions = {
|
||||
"group1": GroupInfo(sheet_frequency={"daily", "hourly"}),
|
||||
"group2": GroupInfo(sheet_frequency={"daily"}),
|
||||
}
|
||||
|
||||
with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions):
|
||||
assert user_state.is_sheet_frequency_allowed("group1", "daily") == True
|
||||
assert user_state.is_sheet_frequency_allowed("group1", "hourly") == True
|
||||
assert user_state.is_sheet_frequency_allowed("group1", "weekly") == False
|
||||
assert user_state.is_sheet_frequency_allowed("group2", "hourly") == False
|
||||
assert user_state.is_sheet_frequency_allowed("group2", "daily") == True
|
||||
assert user_state.is_sheet_frequency_allowed("group3", "daily") == False
|
||||
|
||||
|
||||
def test_priority_group(user_state):
|
||||
from app.web.utils.misc import convert_priority_to_queue_dict
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[
|
||||
models.Group(id="group1", permissions={"priority": "high"}),
|
||||
models.Group(id="group2", permissions={"priority": "medium"}),
|
||||
models.Group(id="group3", permissions={"priority": "low"}),
|
||||
]):
|
||||
assert user_state.priority_group("group1") == convert_priority_to_queue_dict("high")
|
||||
assert user_state.priority_group("group2") == convert_priority_to_queue_dict("medium")
|
||||
assert user_state.priority_group("group3") == convert_priority_to_queue_dict("low")
|
||||
assert user_state.priority_group("group4") == convert_priority_to_queue_dict("low")
|
||||
175
app/tests/web/endpoints/test_default.py
Normal file
175
app/tests/web/endpoints/test_default.py
Normal file
@@ -0,0 +1,175 @@
|
||||
from unittest.mock import MagicMock
|
||||
from fastapi.testclient import TestClient
|
||||
import pytest
|
||||
from app.shared.schemas import Usage, UsageResponse
|
||||
from app.shared.user_groups import GroupInfo
|
||||
from app.web.config import VERSION
|
||||
from app.tests.web.db.test_crud import test_data
|
||||
|
||||
|
||||
def test_endpoint_home(client_with_auth):
|
||||
r = client_with_auth.get("/")
|
||||
assert r.status_code == 200
|
||||
j = r.json()
|
||||
assert "version" in j and j["version"] == VERSION
|
||||
assert "breakingChanges" in j
|
||||
assert "groups" not in j
|
||||
|
||||
|
||||
def test_endpoint_health(client_with_auth):
|
||||
r = client_with_auth.get("/health")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"status": "ok"}
|
||||
|
||||
|
||||
def test_endpoint_active_no_auth(client, test_no_auth):
|
||||
test_no_auth(client.get, "/user/active")
|
||||
|
||||
|
||||
def test_endpoint_active(app):
|
||||
m_user_state = MagicMock()
|
||||
|
||||
from app.web.security import get_user_state
|
||||
app.dependency_overrides[get_user_state] = lambda: m_user_state
|
||||
|
||||
# inactive user
|
||||
m_user_state.active = False
|
||||
client = TestClient(app)
|
||||
r = client.get("/user/active")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"active": False}
|
||||
|
||||
# active user
|
||||
m_user_state.active = True
|
||||
client = TestClient(app)
|
||||
r = client.get("/user/active")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"active": True}
|
||||
|
||||
|
||||
def test_no_serve_local_archive_by_default(client_with_auth):
|
||||
r = client_with_auth.get("/app/local_archive_test/temp.txt")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_favicon(client_with_auth):
|
||||
r = client_with_auth.get("/favicon.ico")
|
||||
assert r.status_code == 200
|
||||
assert r.headers["content-type"] == "image/vnd.microsoft.icon"
|
||||
|
||||
|
||||
def test_endpoint_test_prometheus_no_auth(client, test_no_auth):
|
||||
test_no_auth(client.get, "/metrics")
|
||||
|
||||
|
||||
def test_endpoint_test_prometheus_no_user_auth(client_with_auth, test_no_auth):
|
||||
test_no_auth(client_with_auth.get, "/metrics")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prometheus_metrics(test_data, client_with_token, get_settings):
|
||||
# before metrics calculation
|
||||
r = client_with_token.get("/metrics")
|
||||
assert r.status_code == 200
|
||||
assert r.headers["content-type"] == "text/plain; version=0.0.4; charset=utf-8"
|
||||
assert "disk_utilization" in r.text
|
||||
assert "database_metrics" in r.text
|
||||
assert "exceptions" in r.text
|
||||
assert "worker_exceptions_total" in r.text
|
||||
assert 'disk_utilization{type="used"}' not in r.text
|
||||
|
||||
# after metrics calculation
|
||||
from app.web.utils.metrics import measure_regular_metrics
|
||||
await measure_regular_metrics(get_settings.DATABASE_PATH, 60 * 60 * 24 * 31 * 12 * 100)
|
||||
r2 = client_with_token.get("/metrics")
|
||||
assert 'disk_utilization{type="used"}' in r2.text
|
||||
assert 'disk_utilization{type="free"}' in r2.text
|
||||
assert 'disk_utilization{type="database"}' in r2.text
|
||||
assert 'database_metrics{query="count_archives"} 100.0' in r2.text
|
||||
assert 'database_metrics{query="count_archive_urls"} 1000.0' in r2.text
|
||||
assert 'database_metrics{query="count_users"} 3.0' in r2.text
|
||||
assert 'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0' in r2.text
|
||||
assert 'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0' in r2.text
|
||||
assert 'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0' in r2.text
|
||||
|
||||
# 30s window, should not change the gauges nor the total in the counters
|
||||
from app.web.utils.metrics import measure_regular_metrics
|
||||
await measure_regular_metrics(get_settings.DATABASE_PATH, 30)
|
||||
r3 = client_with_token.get("/metrics")
|
||||
assert 'database_metrics{query="count_archives"} 100.0' in r3.text
|
||||
assert 'database_metrics{query="count_archive_urls"} 1000.0' in r3.text
|
||||
assert 'database_metrics{query="count_users"} 3.0' in r3.text
|
||||
assert 'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0' in r3.text
|
||||
assert 'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0' in r3.text
|
||||
assert 'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0' in r3.text
|
||||
|
||||
|
||||
def test_endpoint_get_user_permissions_no_user_auth(client, test_no_auth):
|
||||
test_no_auth(client.get, "/user/permissions")
|
||||
|
||||
|
||||
def test_endpoint_get_user_permissions(app):
|
||||
from app.web.security import get_user_state
|
||||
|
||||
m_user_state = MagicMock()
|
||||
rv = {
|
||||
"all": GroupInfo(read=True),
|
||||
"group1": GroupInfo(archive_url=True),
|
||||
}
|
||||
from loguru import logger
|
||||
logger.info(rv)
|
||||
m_user_state.permissions = rv
|
||||
|
||||
app.dependency_overrides[get_user_state] = lambda: m_user_state
|
||||
|
||||
client = TestClient(app)
|
||||
r = client.get("/user/permissions")
|
||||
assert r.status_code == 200
|
||||
response = r.json()
|
||||
assert response.keys() == {"all", "group1"}
|
||||
assert response["all"]["read"]
|
||||
assert response["group1"]["read"] == []
|
||||
assert response["group1"]["archive_url"]
|
||||
assert response["all"]["archive_url"] == False
|
||||
|
||||
|
||||
def test_endpoint_get_user_usage_no_user_auth(client, test_no_auth):
|
||||
test_no_auth(client.get, "/user/usage")
|
||||
|
||||
|
||||
def test_endpoint_get_user_usage_inactive(app):
|
||||
from app.web.security import get_user_state
|
||||
|
||||
m_user_state = MagicMock()
|
||||
m_user_state.active = False
|
||||
|
||||
app.dependency_overrides[get_user_state] = lambda: m_user_state
|
||||
|
||||
client = TestClient(app)
|
||||
r = client.get("/user/usage")
|
||||
assert r.status_code == 403
|
||||
assert r.json() == {"detail": "User is not active."}
|
||||
|
||||
|
||||
def test_endpoint_get_user_usage_active(app):
|
||||
from app.web.security import get_user_state
|
||||
|
||||
m_user_state = MagicMock()
|
||||
m_user_state.active = True
|
||||
mock_usage = UsageResponse(
|
||||
monthly_urls=1,
|
||||
monthly_mbs=2,
|
||||
total_sheets=3,
|
||||
groups={
|
||||
"group1": Usage(monthly_urls=4, monthly_mbs=5, total_sheets=6),
|
||||
"group2": Usage(monthly_urls=7, monthly_mbs=8, total_sheets=9)
|
||||
}
|
||||
)
|
||||
m_user_state.usage.return_value = mock_usage
|
||||
|
||||
app.dependency_overrides[get_user_state] = lambda: m_user_state
|
||||
|
||||
client = TestClient(app)
|
||||
r = client.get("/user/usage")
|
||||
assert r.status_code == 200
|
||||
assert UsageResponse(**r.json()) == mock_usage
|
||||
56
app/tests/web/endpoints/test_interoperability.py
Normal file
56
app/tests/web/endpoints/test_interoperability.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from datetime import datetime
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from app.shared.db import models
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
from app.web.db import crud
|
||||
|
||||
|
||||
def test_submit_manual_archive_unauthenticated(client, test_no_auth):
|
||||
test_no_auth(client.post, "/interop/submit-archive")
|
||||
|
||||
|
||||
def test_submit_manual_archive_not_user_auth(client_with_auth, test_no_auth):
|
||||
test_no_auth(client_with_auth.post, "/interop/submit-archive")
|
||||
|
||||
|
||||
@patch("app.web.endpoints.interoperability.business_logic", return_value=MagicMock(get_store_archive_until=MagicMock(return_value=datetime)))
|
||||
def test_submit_manual_archive(m1, client_with_token, db_session):
|
||||
# normal workflow
|
||||
aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.s3.com"]}]})
|
||||
r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": True, "author_id": "jerry@gmail.com", "group_id": "spaceship", "tags": ["test"], "url": "http://example.com"})
|
||||
assert r.status_code == 201
|
||||
assert "id" in r.json()
|
||||
|
||||
inserted = db_session.query(models.Archive).filter(models.Archive.id == r.json()["id"]).first()
|
||||
assert inserted.url == "http://example.com"
|
||||
assert inserted.group_id == "spaceship"
|
||||
assert inserted.author_id == "jerry@gmail.com"
|
||||
assert sorted([t.id for t in inserted.tags]) == sorted(["test", "manual"])
|
||||
assert inserted.public
|
||||
assert type(inserted.result) == dict
|
||||
assert [u.url for u in inserted.urls] == ["http://example.s3.com"]
|
||||
assert type(inserted.store_until) == datetime
|
||||
|
||||
# cannot have the same URL twice
|
||||
aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.com", "http://example.com"]}]})
|
||||
r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": False, "author_id": "jerry@gmail.com", "tags": ["test"], "url": "http://example.com"})
|
||||
assert r.status_code == 422
|
||||
assert r.json() == {"detail": "Cannot insert into DB due to integrity error, likely duplicate urls."}
|
||||
|
||||
|
||||
# test with invalid JSON
|
||||
def test_submit_manual_archive_invalid_json(client_with_token):
|
||||
r = client_with_token.post("/interop/submit-archive", json={"result": "invalid json", "public": False, "author_id": "jer", "tags": ["test"], "url": "http://example.com"})
|
||||
assert r.status_code == 422
|
||||
assert r.json() == {"detail": "Invalid JSON in result field."}
|
||||
|
||||
|
||||
@patch("app.web.endpoints.interoperability.business_logic")
|
||||
def test_submit_manual_archive_no_store_until(m_b, client_with_token, db_session):
|
||||
m_b.get_store_archive_until.side_effect = AssertionError("AssertionError")
|
||||
aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.s3.com"]}]})
|
||||
r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": True, "author_id": "jerry@gmail.com", "group_id": "spaceship", "tags": ["test"], "url": "http://example.com"})
|
||||
assert r.status_code == 422
|
||||
assert r.json() == {"detail": "AssertionError"}
|
||||
193
app/tests/web/endpoints/test_sheet.py
Normal file
193
app/tests/web/endpoints/test_sheet.py
Normal file
@@ -0,0 +1,193 @@
|
||||
from datetime import datetime
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.shared.schemas import TaskResult
|
||||
|
||||
|
||||
def test_endpoints_no_auth(client, test_no_auth):
|
||||
test_no_auth(client.post, "/sheet/create")
|
||||
test_no_auth(client.get, "/sheet/mine")
|
||||
test_no_auth(client.delete, "/sheet/123-sheet-id")
|
||||
test_no_auth(client.post, "/sheet/123-sheet-id/archive")
|
||||
|
||||
|
||||
def test_create_sheet_endpoint(app_with_auth, db_session):
|
||||
client_with_auth = TestClient(app_with_auth)
|
||||
good_data = {
|
||||
"id": "123-sheet-id",
|
||||
"name": "Test Sheet",
|
||||
"group_id": "spaceship",
|
||||
"frequency": "daily"
|
||||
}
|
||||
|
||||
# with good data
|
||||
response = client_with_auth.post("/sheet/create", json=good_data)
|
||||
assert response.status_code == 201
|
||||
j = response.json()
|
||||
assert datetime.fromisoformat(j.pop("created_at"))
|
||||
assert datetime.fromisoformat(j.pop("last_url_archived_at"))
|
||||
assert j.pop("author_id") == 'morty@example.com'
|
||||
assert j == good_data
|
||||
|
||||
# already exists
|
||||
response = client_with_auth.post("/sheet/create", json=good_data)
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {"detail": "Sheet with this ID is already being archived."}
|
||||
|
||||
# bad group
|
||||
bad_data = good_data.copy()
|
||||
bad_data["group_id"] = "not a group"
|
||||
response = client_with_auth.post("/sheet/create", json=bad_data)
|
||||
assert response.status_code == 403
|
||||
assert response.json() == {"detail": "User does not have access to this group."}
|
||||
|
||||
# switch to jerry who's got less quota/permissions
|
||||
from app.web.security import get_user_state
|
||||
from app.web.db.user_state import UserState
|
||||
app_with_auth.dependency_overrides[get_user_state] = lambda: UserState(db_session, "jerry@example.com")
|
||||
client_jerry = TestClient(app_with_auth)
|
||||
|
||||
# frequency not allowed
|
||||
jerry_data = good_data.copy()
|
||||
jerry_data["group_id"] = "animated-characters"
|
||||
jerry_data["frequency"] = "hourly"
|
||||
jerry_data["id"] = "jerry-sheet-id"
|
||||
response = client_jerry.post("/sheet/create", json=jerry_data)
|
||||
assert response.status_code == 422
|
||||
assert response.json() == {"detail": "Invalid frequency selected for this group."}
|
||||
|
||||
jerry_data["frequency"] = "daily"
|
||||
# success for the first sheet, bad quota on second
|
||||
response = client_jerry.post("/sheet/create", json=jerry_data)
|
||||
assert response.status_code == 201
|
||||
|
||||
response = client_jerry.post("/sheet/create", json=jerry_data)
|
||||
assert response.status_code == 429
|
||||
assert response.json() == {"detail": "User has reached their sheet quota for this group."}
|
||||
|
||||
|
||||
def test_get_user_sheets_endpoint(client_with_auth, db_session):
|
||||
# no data
|
||||
response = client_with_auth.get("/sheet/mine")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
# with data
|
||||
from app.shared.db import models
|
||||
db_session.add(
|
||||
models.Sheet(id="123", name="Test Sheet 1", author_id="morty@example.com", group_id="spaceship", frequency="hourly")
|
||||
)
|
||||
db_session.commit()
|
||||
db_session.add_all([
|
||||
models.Sheet(id="456", name="Test Sheet 2", author_id="morty@example.com", group_id="interdimensional", frequency="daily"),
|
||||
models.Sheet(id="789", name="Test Sheet 3", author_id="rick@example.com", group_id="interdimensional", frequency="hourly"),
|
||||
])
|
||||
db_session.commit()
|
||||
|
||||
response = client_with_auth.get("/sheet/mine")
|
||||
assert response.status_code == 200
|
||||
r = response.json()
|
||||
assert isinstance(r, list)
|
||||
assert len(r) == 2
|
||||
assert datetime.fromisoformat(r[0].pop("created_at"))
|
||||
assert datetime.fromisoformat(r[0].pop("last_url_archived_at"))
|
||||
assert datetime.fromisoformat(r[1].pop("created_at"))
|
||||
assert datetime.fromisoformat(r[1].pop("last_url_archived_at"))
|
||||
assert r[0] == {
|
||||
'id': '123',
|
||||
'author_id': 'morty@example.com',
|
||||
'frequency': 'hourly',
|
||||
'group_id': 'spaceship',
|
||||
'name': 'Test Sheet 1',
|
||||
}
|
||||
assert r[1] == {
|
||||
'id': '456',
|
||||
'author_id': 'morty@example.com',
|
||||
'frequency': 'daily',
|
||||
'group_id': 'interdimensional',
|
||||
'name': 'Test Sheet 2',
|
||||
}
|
||||
|
||||
|
||||
def test_delete_sheet_endpoint(client_with_auth, db_session):
|
||||
# missing sheet
|
||||
response = client_with_auth.delete("/sheet/123-sheet-id")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"id": "123-sheet-id",
|
||||
"deleted": False
|
||||
}
|
||||
|
||||
# add sheets for deletion
|
||||
from app.shared.db import models
|
||||
db_session.add_all([
|
||||
models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="interdimensional", frequency="daily"),
|
||||
models.Sheet(id="456-sheet-id", name="Test Sheet 2", author_id="rick@example.com", group_id="spaceship", frequency="hourly"),
|
||||
])
|
||||
db_session.commit()
|
||||
|
||||
# morty can delete his
|
||||
response = client_with_auth.delete("/sheet/123-sheet-id")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"id": "123-sheet-id", "deleted": True}
|
||||
# but only once
|
||||
response = client_with_auth.delete("/sheet/123-sheet-id")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"id": "123-sheet-id", "deleted": False}
|
||||
# and not rick's
|
||||
response = client_with_auth.delete("/sheet/456-sheet-id")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"id": "456-sheet-id", "deleted": False}
|
||||
|
||||
|
||||
class TestArchiveUserSheetEndpoint:
|
||||
@patch("app.web.endpoints.sheet.celery", return_value=MagicMock())
|
||||
def test_normal_flow(self, m_celery, client_with_auth, db_session):
|
||||
from app.shared.db import models
|
||||
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="spaceship", frequency="hourly"))
|
||||
db_session.commit()
|
||||
|
||||
m_signature = MagicMock()
|
||||
m_signature.apply_async.return_value = TaskResult(id="123-taskid", status="PENDING", result="")
|
||||
m_celery.signature.return_value = m_signature
|
||||
|
||||
r = client_with_auth.post("/sheet/123-sheet-id/archive")
|
||||
assert r.status_code == 201
|
||||
assert r.json() == {"id": "123-taskid"}
|
||||
m_celery.signature.assert_called_once()
|
||||
m_signature.apply_async.assert_called_once()
|
||||
|
||||
def test_token_auth(self, client_with_token, test_no_auth):
|
||||
test_no_auth(client_with_token.post, "/sheet/123-sheet-id/archive")
|
||||
|
||||
def test_missing_data(self, client_with_auth):
|
||||
r = client_with_auth.post("/sheet/123-sheet-id/archive")
|
||||
assert r.status_code == 403
|
||||
assert r.json() == {"detail": "No access to this sheet."}
|
||||
|
||||
def test_no_access(self, client_with_auth, db_session):
|
||||
from app.shared.db import models
|
||||
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="rick@example.com", group_id="spaceship", frequency="hourly"))
|
||||
db_session.commit()
|
||||
r = client_with_auth.post("/sheet/123-sheet-id/archive")
|
||||
assert r.status_code == 403
|
||||
assert r.json() == {"detail": "No access to this sheet."}
|
||||
|
||||
def test_user_not_in_group(self, client_with_auth, db_session):
|
||||
from app.shared.db import models
|
||||
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="interdimensional", frequency="hourly"))
|
||||
db_session.commit()
|
||||
r = client_with_auth.post("/sheet/123-sheet-id/archive")
|
||||
assert r.status_code == 403
|
||||
assert r.json() == {"detail": "User does not have access to this group."}
|
||||
|
||||
def test_user_cannot_manually_trigger(self, client_with_auth, db_session):
|
||||
from app.shared.db import models
|
||||
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="default", frequency="hourly"))
|
||||
db_session.commit()
|
||||
r = client_with_auth.post("/sheet/123-sheet-id/archive")
|
||||
assert r.status_code == 429
|
||||
assert r.json() == {"detail": "User cannot manually trigger sheet archiving in this group."}
|
||||
@@ -5,7 +5,7 @@ def test_endpoint_task_status_no_auth(client, test_no_auth):
|
||||
test_no_auth(client.get, "/task/test-task-id")
|
||||
|
||||
|
||||
@patch("endpoints.task.AsyncResult")
|
||||
@patch("app.web.endpoints.task.AsyncResult")
|
||||
def test_get_status_success(mock_async_result, client_with_auth):
|
||||
mock_async_result.return_value.status = "SUCCESS"
|
||||
mock_async_result.return_value.result = {"data": "some result"}
|
||||
@@ -20,7 +20,7 @@ def test_get_status_success(mock_async_result, client_with_auth):
|
||||
}
|
||||
|
||||
|
||||
@patch("endpoints.task.AsyncResult")
|
||||
@patch("app.web.endpoints.task.AsyncResult")
|
||||
def test_get_status_failure(mock_async_result, client_with_auth):
|
||||
|
||||
mock_async_result.return_value.status = "FAILURE"
|
||||
@@ -36,7 +36,7 @@ def test_get_status_failure(mock_async_result, client_with_auth):
|
||||
}
|
||||
|
||||
|
||||
@patch("endpoints.task.AsyncResult")
|
||||
@patch("app.web.endpoints.task.AsyncResult")
|
||||
def test_get_status_pending(mock_async_result, client_with_auth):
|
||||
mock_async_result.return_value.status = "PENDING"
|
||||
mock_async_result.return_value.result = None
|
||||
193
app/tests/web/endpoints/test_url.py
Normal file
193
app/tests/web/endpoints/test_url.py
Normal file
@@ -0,0 +1,193 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from app.shared.schemas import ArchiveCreate, TaskResult
|
||||
|
||||
|
||||
def test_archive_url_unauthenticated(client, test_no_auth):
|
||||
test_no_auth(client.post, "/url/archive")
|
||||
|
||||
|
||||
@patch("app.web.endpoints.url.UserState")
|
||||
@patch("app.web.endpoints.url.celery", return_value=MagicMock())
|
||||
def test_archive_url(m_celery, m2, client_with_auth):
|
||||
m_signature = MagicMock()
|
||||
m_signature.apply_async.return_value = TaskResult(id="123-456-789", status="PENDING", result="")
|
||||
m_celery.signature.return_value = m_signature
|
||||
|
||||
m_user_state = MagicMock()
|
||||
m2.return_value = m_user_state
|
||||
|
||||
# url is too short
|
||||
response = client_with_auth.post("/url/archive", json={"url": "bad"})
|
||||
assert response.status_code == 422
|
||||
assert response.json()["detail"][0]["msg"] == 'String should have at least 5 characters'
|
||||
m_celery.signature.assert_not_called()
|
||||
|
||||
# url is invalid
|
||||
response = client_with_auth.post("/url/archive", json={"url": "example.com"})
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "Invalid URL received."
|
||||
|
||||
# valid request
|
||||
m_user_state.has_quota_max_monthly_urls.return_value = True
|
||||
m_user_state.has_quota_max_monthly_mbs.return_value = True
|
||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {'id': '123-456-789'}
|
||||
m_celery.signature.assert_called_once()
|
||||
m_signature.apply_async.assert_called_once()
|
||||
called_val = m_celery.signature.call_args
|
||||
assert called_val[0][0] == "create_archive_task"
|
||||
assert json.loads(called_val[1]['args'][0]) == {"id": None, "url": "https://example.com", "result": None, "public": False, "author_id": "rick@example.com", "group_id": "default", "tags": None, "sheet_id": None, "store_until": None, "urls": None}
|
||||
m_user_state.has_quota_max_monthly_urls.assert_called_once()
|
||||
m_user_state.has_quota_max_monthly_mbs.assert_called_once()
|
||||
m_user_state.in_group.assert_called_once_with("default")
|
||||
|
||||
# user is not in group
|
||||
m_user_state.in_group.return_value = False
|
||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "new-group"})
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "User does not have access to this group."
|
||||
m_user_state.in_group.assert_called_with("new-group")
|
||||
|
||||
# user is in group
|
||||
m_user_state.in_group.return_value = True
|
||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spaceship"})
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {'id': '123-456-789'}
|
||||
assert m_celery.signature.call_count == 2
|
||||
assert m_signature.apply_async.call_count == 2
|
||||
called_val = m_celery.signature.call_args
|
||||
assert json.loads(called_val[1]['args'][0])["group_id"] == "spaceship"
|
||||
m_user_state.in_group.assert_called_with("spaceship")
|
||||
|
||||
# user is over monthly URL quota
|
||||
m_user_state.has_quota_max_monthly_urls.return_value = False
|
||||
m_user_state.has_quota_max_monthly_mbs.return_value = True
|
||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spaceship"})
|
||||
assert response.status_code == 429
|
||||
assert response.json()["detail"] == "User has reached their monthly URL quota."
|
||||
m_user_state.has_quota_max_monthly_urls.assert_called_with("spaceship")
|
||||
|
||||
# user is over monthly MB quota
|
||||
m_user_state.has_quota_max_monthly_urls.return_value = True
|
||||
m_user_state.has_quota_max_monthly_mbs.return_value = False
|
||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spacesuit"})
|
||||
assert response.status_code == 429
|
||||
assert response.json()["detail"] == "User has reached their monthly MB quota."
|
||||
m_user_state.has_quota_max_monthly_mbs.assert_called_with("spacesuit")
|
||||
assert m_celery.signature.call_count == 2
|
||||
assert m_signature.apply_async.call_count == 2
|
||||
|
||||
|
||||
@patch("app.web.endpoints.url.UserState")
|
||||
def test_archive_url_quotas(m1, client_with_auth):
|
||||
m_user_state = MagicMock()
|
||||
m1.return_value = m_user_state
|
||||
|
||||
# misses on monthly URLs quota
|
||||
m_user_state.has_quota_max_monthly_urls.return_value = False
|
||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
|
||||
assert response.status_code == 429
|
||||
assert response.json()["detail"] == "User has reached their monthly URL quota."
|
||||
m_user_state.has_quota_max_monthly_urls.assert_called_once()
|
||||
|
||||
# misses on monthly MBs quota
|
||||
m_user_state.has_quota_max_monthly_urls.return_value = True
|
||||
m_user_state.has_quota_max_monthly_mbs.return_value = False
|
||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
|
||||
assert response.status_code == 429
|
||||
assert response.json()["detail"] == "User has reached their monthly MB quota."
|
||||
m_user_state.has_quota_max_monthly_mbs.assert_called_once()
|
||||
|
||||
|
||||
@patch("app.web.endpoints.url.celery", return_value=MagicMock())
|
||||
def test_archive_url_with_api_token(m_celery, client_with_token):
|
||||
m_signature = MagicMock()
|
||||
m_signature.apply_async.return_value = TaskResult(id="123-456-789", status="PENDING", result="")
|
||||
m_celery.signature.return_value = m_signature
|
||||
response = client_with_token.post("/url/archive", json={"url": "https://example.com"})
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {'id': '123-456-789'}
|
||||
m_celery.signature.assert_called_once()
|
||||
m_signature.apply_async.assert_called_once()
|
||||
called_val = m_celery.signature.call_args
|
||||
assert called_val[0][0] == "create_archive_task"
|
||||
|
||||
|
||||
def test_search_by_url_unauthenticated(client, test_no_auth):
|
||||
test_no_auth(client.get, "/url/search")
|
||||
|
||||
|
||||
def test_search_by_url(client_with_auth, client_with_token, db_session):
|
||||
# tests the search endpoint, including through some db data for the endpoint params
|
||||
response = client_with_auth.get("/url/search")
|
||||
assert response.status_code == 422
|
||||
assert response.json()["detail"][0]["msg"] == "Field required"
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
from app.shared import schemas
|
||||
from app.shared.db import worker_crud
|
||||
for i in range(11):
|
||||
worker_crud.create_archive(db_session, ArchiveCreate(id=f"url-456-{i}", url="https://example.com" if i < 10 else "https://something-else.com", result={}, public=True, author_id="rick@example.com"), [], [])
|
||||
# NB: this insertion is too fast for the ordering to be correct as they are within the same second
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com")
|
||||
assert response.status_code == 200
|
||||
assert len(j := response.json()) == 10
|
||||
assert "url-456-0" in [i["id"] for i in j]
|
||||
assert "url-456-9" in [i["id"] for i in j]
|
||||
assert "url-456-10" not in [i["id"] for i in j]
|
||||
assert j[0].keys() == schemas.ArchiveResult.model_fields.keys()
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com&limit=5")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 5
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com&skip=5&limit=2")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 2
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com&archived_before=2010-01-01")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 0
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com&archived_after=2010-01-01")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 10
|
||||
|
||||
# API token will also work
|
||||
response = client_with_token.get("/url/search?url=https://example.com&archived_after=2010-01-01")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 10
|
||||
|
||||
|
||||
@patch("app.web.endpoints.url.UserState")
|
||||
def test_search_no_read_access(mock_user_state, client_with_auth):
|
||||
mock_user_state.return_value.read = False
|
||||
mock_user_state.return_value.read_public = False
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com")
|
||||
assert response.status_code == 403
|
||||
assert response.json() == {"detail": "User does not have read access."}
|
||||
|
||||
|
||||
def test_delete_task_unauthenticated(client, test_no_auth):
|
||||
test_no_auth(client.delete, "/url/123-456-789")
|
||||
|
||||
|
||||
def test_delete_task(client_with_auth, db_session):
|
||||
response = client_with_auth.delete("/url/delete-123-456-789")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"id": "delete-123-456-789", "deleted": False}
|
||||
|
||||
from app.shared.db import worker_crud
|
||||
worker_crud.create_archive(db_session, ArchiveCreate(id="delete-123-456-789", url="https://example.com", result={}, public=True, author_id="morty@example.com"), [], [])
|
||||
|
||||
response = client_with_auth.delete("/url/delete-123-456-789")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"id": "delete-123-456-789", "deleted": True}
|
||||
@@ -17,12 +17,12 @@ def test_alembic(db_session):
|
||||
alembic.config.main(argv=['--raiseerr', 'upgrade', 'head'])
|
||||
alembic.config.main(argv=['--raiseerr', 'downgrade', 'base'])
|
||||
|
||||
@patch("endpoints.default.crud.get_user_groups", side_effect=Exception('mocked error'))
|
||||
@patch("app.web.endpoints.url.crud.soft_delete_archive", side_effect=Exception('mocked error'))
|
||||
def test_logging_middleware(m1, client_with_auth):
|
||||
from utils.metrics import EXCEPTION_COUNTER
|
||||
from app.web.utils.metrics import EXCEPTION_COUNTER
|
||||
assert len(EXCEPTION_COUNTER.collect()[0].samples) == 0
|
||||
with pytest.raises(Exception, match="mocked error"):
|
||||
client_with_auth.get("/groups")
|
||||
client_with_auth.delete("/url/123")
|
||||
# creates one empty and one from above
|
||||
assert len(EXCEPTION_COUNTER.collect()[0].samples) == 2
|
||||
|
||||
@@ -36,7 +36,7 @@ def test_serve_local_archive_logic(get_settings):
|
||||
try:
|
||||
# modify the settings
|
||||
get_settings.SERVE_LOCAL_ARCHIVE = "/app/local_archive_test"
|
||||
from web.main import app_factory
|
||||
from app.web.main import app_factory
|
||||
app = app_factory(get_settings)
|
||||
|
||||
# test
|
||||
@@ -1,14 +1,14 @@
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
import pytest
|
||||
|
||||
from core.config import ALLOW_ANY_EMAIL
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
|
||||
|
||||
def test_secure_compare():
|
||||
from web.security import secure_compare
|
||||
from app.web.security import secure_compare
|
||||
|
||||
assert secure_compare("test", "test")
|
||||
assert not secure_compare("test", "test2")
|
||||
@@ -16,14 +16,14 @@ def test_secure_compare():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_token_or_user_auth_with_api():
|
||||
from web.security import get_token_or_user_auth
|
||||
from app.web.security import get_token_or_user_auth
|
||||
mock_api = HTTPAuthorizationCredentials(scheme="lorem", credentials="this_is_the_test_api_token")
|
||||
assert await get_token_or_user_auth(mock_api) == ALLOW_ANY_EMAIL
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_token_or_user_auth_with_user():
|
||||
from web.security import get_token_or_user_auth
|
||||
from app.web.security import get_token_or_user_auth
|
||||
bad_user = HTTPAuthorizationCredentials(scheme="ipsum", credentials="invalid")
|
||||
e: pytest.ExceptionInfo = None
|
||||
with pytest.raises(HTTPException) as e:
|
||||
@@ -32,18 +32,18 @@ async def test_get_token_or_user_auth_with_user():
|
||||
assert e.value.detail == "invalid access_token"
|
||||
|
||||
|
||||
@patch("web.security.authenticate_user", return_value=(True, "summer@example.com"))
|
||||
@patch("app.web.security.authenticate_user", return_value=(True, "summer@example.com"))
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_auth(m1):
|
||||
from web.security import get_user_auth
|
||||
bad_user = HTTPAuthorizationCredentials(scheme="ipsum", credentials="valid-and-good")
|
||||
assert await get_user_auth(bad_user) == "summer@example.com"
|
||||
from app.web.security import get_user_auth
|
||||
good_user = HTTPAuthorizationCredentials(scheme="ipsum", credentials="valid-and-good")
|
||||
assert await get_user_auth(good_user) == "summer@example.com"
|
||||
|
||||
|
||||
@patch("web.security.secure_compare", return_value=False)
|
||||
@patch("app.web.security.secure_compare", return_value=False)
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_api_key_auth_exception(m1):
|
||||
from web.security import token_api_key_auth
|
||||
from app.web.security import token_api_key_auth
|
||||
|
||||
e: pytest.ExceptionInfo = None
|
||||
with pytest.raises(HTTPException) as e:
|
||||
@@ -54,12 +54,12 @@ async def test_token_api_key_auth_exception(m1):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_user():
|
||||
from web.security import authenticate_user
|
||||
from app.web.security import authenticate_user
|
||||
|
||||
assert authenticate_user("test") == (False, "invalid access_token")
|
||||
assert authenticate_user(123) == (False, "invalid access_token")
|
||||
|
||||
with patch("web.security.requests.get") as mock_get:
|
||||
with patch("app.web.security.requests.get") as mock_get:
|
||||
# bad response from oauth2
|
||||
mock_get.return_value.status_code = 403
|
||||
assert authenticate_user("this-will-call-requests") == (False, "invalid token")
|
||||
@@ -100,9 +100,22 @@ async def test_authenticate_user():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_user_exception():
|
||||
from web.security import authenticate_user
|
||||
|
||||
with patch("web.security.requests.get") as mock_get:
|
||||
from app.web.security import authenticate_user
|
||||
with patch("app.web.security.requests.get") as mock_get:
|
||||
mock_get.return_value.status_code = 200
|
||||
mock_get.return_value.json.side_effect = Exception("mocked error")
|
||||
assert authenticate_user("this-will-call-requests") == (False, "exception occurred")
|
||||
|
||||
|
||||
def test_get_user_state():
|
||||
from app.web.security import get_user_state
|
||||
from app.web.db.user_state import UserState
|
||||
|
||||
mock_session = Mock()
|
||||
test_email = "test@example.com"
|
||||
|
||||
state = get_user_state(test_email, mock_session)
|
||||
|
||||
assert isinstance(state, UserState)
|
||||
assert state.email == test_email
|
||||
assert state.db == mock_session
|
||||
137
app/tests/worker/test_worker_main.py
Normal file
137
app/tests/worker/test_worker_main.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from datetime import datetime
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.shared.db import models
|
||||
from app.shared import schemas
|
||||
from auto_archiver.core import Media, Metadata
|
||||
|
||||
|
||||
class Test_create_archive_task():
|
||||
URL = "https://example-live.com"
|
||||
archive = schemas.ArchiveCreate(url=URL, tags=["tag-celery"], public=True, author_id="rick@example.com", group_id="interstellar")
|
||||
|
||||
@patch("app.worker.main.ArchivingOrchestrator")
|
||||
@patch("app.worker.main.get_all_urls", return_value=[])
|
||||
@patch("app.worker.main.insert_result_into_db")
|
||||
@patch("app.worker.main.get_store_until", return_value=datetime.now())
|
||||
@patch("app.worker.main.get_orchestrator_args", return_value=["arg1", "arg2"])
|
||||
@patch("celery.app.task.Task.request")
|
||||
def test_success(self, m_req, m_args, m_store, m_insert, m_urls, m_orchestrator, db_session):
|
||||
from app.worker.main import create_archive_task
|
||||
|
||||
m_req.id = "this-just-in"
|
||||
m_orchestrator.return_value.feed.return_value = iter([Metadata().set_url(self.URL).success()])
|
||||
|
||||
task = create_archive_task(self.archive.model_dump_json())
|
||||
|
||||
m_args.assert_called_once()
|
||||
m_store.assert_called_once_with("interstellar")
|
||||
m_insert.assert_called_once()
|
||||
m_urls.assert_called_once()
|
||||
m_orchestrator.return_value.feed.assert_called_once()
|
||||
m_orchestrator.return_value.setup.assert_called_once()
|
||||
|
||||
assert task["status"] == "success"
|
||||
assert task["metadata"]["url"] == self.URL
|
||||
assert len(task["media"]) == 0
|
||||
|
||||
def test_raise_invalid(self):
|
||||
from app.worker.main import create_archive_task
|
||||
with pytest.raises(Exception):
|
||||
create_archive_task(self.archive.model_dump_json())
|
||||
|
||||
@patch("app.worker.main.ArchivingOrchestrator")
|
||||
@patch("app.worker.main.get_orchestrator_args")
|
||||
def test_raise_db_error(self, m_args, m_orchestrator):
|
||||
from app.worker.main import create_archive_task
|
||||
m_orchestrator.return_value.feed.side_effect = Exception("Orchestrator failed")
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
create_archive_task(self.archive.model_dump_json())
|
||||
assert str(e.value) == "Orchestrator failed"
|
||||
m_args.assert_called_once()
|
||||
m_orchestrator.return_value.feed.assert_called_once()
|
||||
|
||||
@patch("app.worker.main.ArchivingOrchestrator")
|
||||
@patch("app.worker.main.insert_result_into_db", return_value=None)
|
||||
@patch("app.worker.main.get_orchestrator_args")
|
||||
def test_raise_empty_result(self, m_args, m_insert, m_orchestrator):
|
||||
from app.worker.main import create_archive_task
|
||||
m_orchestrator.return_value.feed.return_value = iter([None])
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
create_archive_task(self.archive.model_dump_json())
|
||||
assert str(e.value) == "UNABLE TO archive: https://example-live.com"
|
||||
m_orchestrator.return_value.feed.assert_called_once()
|
||||
|
||||
|
||||
class Test_create_sheet_task():
|
||||
URL = "https://example-live.com"
|
||||
sheet = schemas.SubmitSheet(sheet_id="123", author_id="rick@example.com", group_id="interstellar", tags=["spaceship"])
|
||||
|
||||
@patch("app.worker.main.get_all_urls", return_value=[])
|
||||
@patch("app.worker.main.ArchivingOrchestrator")
|
||||
@patch("app.worker.main.models.generate_uuid", return_value="constant-uuid")
|
||||
@patch("app.worker.main.get_store_until", return_value=datetime.now())
|
||||
@patch("app.worker.main.get_orchestrator_args")
|
||||
def test_success(self, m_args, m_store, m_uuid, m_orchestrator, m_urls, db_session):
|
||||
from app.worker.main import create_sheet_task
|
||||
|
||||
assert db_session.query(models.Archive).filter(models.Archive.url == self.URL).count() == 0
|
||||
|
||||
mock_metadata = Metadata().set_url(self.URL).success()
|
||||
mock_metadata.add_media(Media("fn1.txt", urls=["outcome1.com"]))
|
||||
|
||||
m_orchestrator.return_value.feed.return_value = iter([False, mock_metadata, mock_metadata])
|
||||
|
||||
res = create_sheet_task(self.sheet.model_dump_json())
|
||||
|
||||
m_args.assert_called_once_with("interstellar", True, ["--gsheet_feeder.sheet_id", "123"])
|
||||
m_orchestrator.return_value.setup.assert_called_once()
|
||||
m_orchestrator.return_value.feed.assert_called_once()
|
||||
m_store.assert_called_with("interstellar")
|
||||
m_store.call_count == 2
|
||||
m_uuid.call_count == 2
|
||||
assert type(res) == dict
|
||||
assert res["stats"]["archived"] == 1
|
||||
assert res["stats"]["failed"] == 1
|
||||
assert len(res["stats"]["errors"]) == 1
|
||||
assert res["sheet_id"] == "123"
|
||||
assert res["success"]
|
||||
assert type(res["time"]) == datetime
|
||||
|
||||
# query created archive entry
|
||||
inserted = db_session.query(models.Archive).filter(models.Archive.url == self.URL).one()
|
||||
assert inserted is not None
|
||||
assert inserted.url == self.URL
|
||||
assert len(inserted.tags) == 1
|
||||
assert inserted.tags[0].id == "spaceship"
|
||||
assert inserted.group_id == "interstellar"
|
||||
assert inserted.author_id == "rick@example.com"
|
||||
assert inserted.public == False
|
||||
|
||||
|
||||
def test_get_all_urls(db_session):
|
||||
from app.worker.main import get_all_urls
|
||||
|
||||
meta = Metadata().set_url("https://example.com")
|
||||
m1 = meta.add_media(Media("fn1.txt", urls=["outcome1.com"]))
|
||||
m2 = meta.add_media(Media("fn2.txt", urls=["outcome2.com"]))
|
||||
m3 = meta.add_media(Media("fn3.txt", urls=["outcome3.com"]))
|
||||
m1.set("screenshot", Media("screenshot.png", urls=["screenshot.com"]))
|
||||
m2.set("thumbnails", [Media("thumb1.png", urls=["thumb1.com"]), Media("thumb2.png", urls=["thumb2.com"])])
|
||||
m3.set("ssl_data", Media("ssl_data.txt", urls=["ssl_data.com"]).to_dict())
|
||||
m3.set("bad_data", {"bad": "dict is ignored"})
|
||||
|
||||
urls = [u.url for u in get_all_urls(meta)]
|
||||
assert len(urls) == 7
|
||||
assert "outcome1.com" in urls
|
||||
assert "outcome2.com" in urls
|
||||
assert "outcome3.com" in urls
|
||||
assert "screenshot.com" in urls
|
||||
assert "thumb1.com" in urls
|
||||
assert "thumb2.com" in urls
|
||||
assert "ssl_data.com" in urls
|
||||
3
app/web/__init__.py
Normal file
3
app/web/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from app.web.main import app_factory
|
||||
|
||||
app = app_factory
|
||||
@@ -1,4 +1,5 @@
|
||||
VERSION = "0.8.0"
|
||||
VERSION = "0.9.0"
|
||||
|
||||
API_DESCRIPTION = """
|
||||
#### API for the Auto-Archiver project, a tool to archive web pages and Google Sheets.
|
||||
|
||||
@@ -7,7 +8,7 @@ API_DESCRIPTION = """
|
||||
- You can use this API to archive single URLs or entire Google Sheets.
|
||||
- Once you submit a URL or Sheet for archiving, the API will return a task_id that you can use to check the status of the archiving process. It works asynchronously.
|
||||
"""
|
||||
BREAKING_CHANGES = {"minVersion": "0.3.1", "message": "The latest update has breaking changes, please update the extension to the most recent version."}
|
||||
BREAKING_CHANGES = {"minVersion": "0.4.0", "message": "The latest update has breaking changes, please update the extension to the most recent version."}
|
||||
|
||||
# changing this will corrupt the database logic
|
||||
ALLOW_ANY_EMAIL = "*"
|
||||
273
app/web/db/crud.py
Normal file
273
app/web/db/crud.py
Normal file
@@ -0,0 +1,273 @@
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache
|
||||
from sqlalchemy.orm import Session, load_only
|
||||
from sqlalchemy import Column, or_, func, select
|
||||
from loguru import logger
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from cachetools import LRUCache, cached
|
||||
from cachetools.keys import hashkey
|
||||
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
from app.shared.db import models
|
||||
from app.shared.settings import get_settings
|
||||
from app.shared.user_groups import UserGroups
|
||||
from app.shared.utils.misc import fnv1a_hash_mod
|
||||
from app.web.utils.misc import convert_priority_to_queue_dict
|
||||
|
||||
|
||||
DATABASE_QUERY_LIMIT = get_settings().DATABASE_QUERY_LIMIT
|
||||
|
||||
|
||||
def get_limit(user_limit: int):
|
||||
return max(1, min(user_limit, DATABASE_QUERY_LIMIT))
|
||||
|
||||
# --------------- TASK = Archive
|
||||
|
||||
|
||||
def base_query(db: Session):
|
||||
# NOTE: load_only is for optimization and not obfuscation, use .with_entities() if needed
|
||||
return db.query(models.Archive)\
|
||||
.filter(models.Archive.deleted == False)\
|
||||
.options(load_only(models.Archive.id, models.Archive.created_at, models.Archive.url, models.Archive.result, models.Archive.store_until))
|
||||
|
||||
|
||||
def search_archives_by_url(db: Session, url: str, email: str, read_groups: bool | set[str], read_public: bool, skip: int = 0, limit: int = 100, archived_after: datetime = None, archived_before: datetime = None, absolute_search: bool = False) -> list[models.Archive]:
|
||||
# searches for partial URLs, if email is * no ownership (or read/read_public) filtering happens
|
||||
query = base_query(db)
|
||||
if email != ALLOW_ANY_EMAIL:
|
||||
or_filters = [models.Archive.author_id == email]
|
||||
if read_public:
|
||||
or_filters.append(models.Archive.public == True)
|
||||
if read_groups == True:
|
||||
or_filters.append(models.Archive.group_id.isnot(None))
|
||||
else:
|
||||
or_filters.append(models.Archive.group_id.in_(read_groups))
|
||||
query = query.filter(or_(*or_filters))
|
||||
if absolute_search:
|
||||
query = query.filter(models.Archive.url == url)
|
||||
else:
|
||||
query = query.filter(models.Archive.url.like(f'%{url}%'))
|
||||
if archived_after:
|
||||
query = query.filter(models.Archive.created_at > archived_after)
|
||||
if archived_before:
|
||||
query = query.filter(models.Archive.created_at < archived_before)
|
||||
return query.order_by(models.Archive.created_at.desc()).offset(skip).limit(get_limit(limit)).all()
|
||||
|
||||
|
||||
def search_archives_by_email(db: Session, email: str, skip: int = 0, limit: int = 100):
|
||||
return base_query(db).filter(models.Archive.author_id == email).order_by(models.Archive.created_at.desc()).offset(skip).limit(get_limit(limit)).all()
|
||||
|
||||
|
||||
def soft_delete_archive(db: Session, id: str, email: str) -> bool:
|
||||
# TODO: implement hard-delete with cronjob that deletes from S3
|
||||
db_archive = db.query(models.Archive).filter(models.Archive.id == id, models.Archive.author_id == email, models.Archive.deleted == False).first()
|
||||
if db_archive:
|
||||
db_archive.deleted = True
|
||||
db.commit()
|
||||
return db_archive is not None
|
||||
|
||||
|
||||
def count_archives(db: Session):
|
||||
return db.query(func.count(models.Archive.id)).scalar()
|
||||
|
||||
|
||||
def count_archive_urls(db: Session):
|
||||
return db.query(func.count(models.ArchiveUrl.url)).scalar()
|
||||
|
||||
|
||||
def count_users(db: Session):
|
||||
return db.query(func.count(models.User.email)).scalar()
|
||||
|
||||
|
||||
def count_by_user_since(db: Session, seconds_delta: int = 15):
|
||||
time_threshold = datetime.now() - timedelta(seconds=seconds_delta)
|
||||
return db.query(models.Archive.author_id, func.count().label('total'))\
|
||||
.filter(models.Archive.created_at >= time_threshold)\
|
||||
.group_by(models.Archive.author_id)\
|
||||
.order_by(func.count().desc())\
|
||||
.limit(500).all()
|
||||
|
||||
|
||||
async def find_by_store_until(db: AsyncSession, store_until_is_before: datetime) -> list[models.Archive]:
|
||||
res = await db.execute(
|
||||
select(models.Archive)
|
||||
.filter(models.Archive.deleted == False, models.Archive.store_until < store_until_is_before)
|
||||
)
|
||||
return res.scalars()
|
||||
|
||||
|
||||
async def soft_delete_expired_archives(db: AsyncSession) -> dict:
|
||||
to_delete = await find_by_store_until(db, datetime.now())
|
||||
counter = 0
|
||||
for archive in to_delete:
|
||||
archive.deleted = True
|
||||
counter += 1
|
||||
await db.commit()
|
||||
return counter
|
||||
# --------------- TAG
|
||||
|
||||
|
||||
async def get_group_priority_async(db: AsyncSession, group_id: str) -> dict:
|
||||
db_group = await db.get(models.Group, group_id)
|
||||
priority = db_group.permissions.get("priority", "low") if db_group else "low"
|
||||
return convert_priority_to_queue_dict(priority)
|
||||
|
||||
|
||||
@cached(cache=LRUCache(maxsize=128), key=lambda db, email: hashkey(email))
|
||||
def get_user_group_names(db: Session, email: str) -> list[str]:
|
||||
"""
|
||||
given an email retrieves the user groups from the DB and then the email-domain groups from a global variable, the email does not need to belong to an existing user.
|
||||
"""
|
||||
# TODO: the read: [group1, group2] permissions don't currently work
|
||||
if not email or not len(email) or "@" not in email: return []
|
||||
|
||||
# get user groups
|
||||
user_groups = db.query(models.association_table_user_groups).filter_by(user_id=email).with_entities(Column("group_id")).all()
|
||||
user_level_groups_names = [g[0] for g in user_groups]
|
||||
|
||||
# get domain groups
|
||||
domain = email.split('@')[1]
|
||||
domain_level_groups = db.query(models.Group.id).filter(models.Group.domains.contains(domain)).with_entities(Column("id")).all()
|
||||
domain_level_groups_names = [g[0] for g in domain_level_groups]
|
||||
|
||||
return list(set(user_level_groups_names + domain_level_groups_names))
|
||||
|
||||
|
||||
def get_user_groups_by_name(db: Session, groups: list[str]) -> list[models.Group]:
|
||||
return db.query(models.Group).filter(
|
||||
models.Group.id.in_(groups)
|
||||
).all()
|
||||
|
||||
# --------------- INIT User-Groups
|
||||
|
||||
|
||||
def upsert_group(db: Session, group_name: str, description: str, orchestrator: str, orchestrator_sheet: str, service_account_email: str, permissions: dict, domains: list) -> models.Group:
|
||||
db_group = db.query(models.Group).filter(models.Group.id == group_name).first()
|
||||
if db_group is None:
|
||||
db_group = models.Group(id=group_name, description=description, orchestrator=orchestrator, orchestrator_sheet=orchestrator_sheet, service_account_email=service_account_email, permissions=permissions, domains=domains)
|
||||
db.add(db_group)
|
||||
else:
|
||||
db_group.description = description
|
||||
db_group.orchestrator = orchestrator
|
||||
db_group.orchestrator_sheet = orchestrator_sheet
|
||||
db_group.service_account_email = service_account_email
|
||||
db_group.permissions = permissions
|
||||
db_group.domains = domains
|
||||
db.commit()
|
||||
db.refresh(db_group)
|
||||
return db_group
|
||||
|
||||
|
||||
def upsert_user(db: Session, email: str):
|
||||
email = email.lower()
|
||||
db_user = db.query(models.User).filter(models.User.email == email).first()
|
||||
if db_user is None:
|
||||
db_user = models.User(email=email)
|
||||
db.add(db_user)
|
||||
db.commit()
|
||||
return db_user
|
||||
|
||||
|
||||
def upsert_user_groups(db: Session):
|
||||
def display_email_pii(email: str):
|
||||
return f"'{email[0:3]}...@{email.split('@')[1]}'"
|
||||
"""
|
||||
reads the user_groups yaml file and inserts any new users, groups,
|
||||
along with new participation of users in groups
|
||||
"""
|
||||
filename = get_settings().USER_GROUPS_FILENAME
|
||||
logger.debug(f"Updating user-groups configuration with file {filename}.")
|
||||
|
||||
ug = UserGroups(filename)
|
||||
|
||||
# delete all user-groups relationships
|
||||
db.query(models.association_table_user_groups).delete()
|
||||
|
||||
# create a map of group_id -> domains and another of domain -> groups
|
||||
group_domains = defaultdict(set)
|
||||
domain_groups = defaultdict(list)
|
||||
for domain, explicit_groups in ug.domains.items():
|
||||
domain_groups[domain] = list(set(explicit_groups))
|
||||
for group in explicit_groups:
|
||||
group_domains[group].add(domain)
|
||||
import json
|
||||
# upsert groups and save a map of groupid -> dbobject
|
||||
for group_id, g in ug.groups.items():
|
||||
upsert_group(db, group_id, g.description, g.orchestrator, g.orchestrator_sheet, g.service_account_email, json.loads(g.permissions.model_dump_json()), list(group_domains.get(group_id, [])))
|
||||
db_groups: dict[str, models.Group] = {g.id: g for g in db.query(models.Group).all()}
|
||||
|
||||
# integrity checks
|
||||
for group_in_domains in group_domains:
|
||||
if group_in_domains not in db_groups:
|
||||
logger.warning(f"[CONFIG] Group '{group_in_domains}' does not exist in the database: domains setting will not work.")
|
||||
|
||||
# reinsert users in their EXPLICITLY DEFINED groups
|
||||
# domain groups are check live, as there may be new users that are not explicitly registered but belong to a domain
|
||||
for email, explicit_groups in ug.users.items():
|
||||
explicit_groups = explicit_groups or []
|
||||
logger.info(f"EXPLICIT {display_email_pii(email)} => {explicit_groups}")
|
||||
|
||||
db_user = upsert_user(db, email)
|
||||
|
||||
# connect users to groups
|
||||
for group_id in explicit_groups:
|
||||
if group_id not in db_groups:
|
||||
logger.warning(f"[CONFIG] Group {group_id} does not exist in config file, skipping for email={display_email_pii(email)}.")
|
||||
continue
|
||||
db_groups[group_id].users.append(db_user)
|
||||
|
||||
db.commit()
|
||||
count_user_groups = db.query(models.association_table_user_groups).count()
|
||||
count_groups = db.query(func.count(models.Group.id)).scalar()
|
||||
|
||||
logger.success(f"[CONFIG] DONE: [users={count_users(db)}, groups={count_groups}, explicit user groups={count_user_groups}].")
|
||||
|
||||
|
||||
# --------------- SHEET
|
||||
def create_sheet(db: Session, sheet_id: str, name: str, email: str, group_id: str, frequency: str):
|
||||
db_sheet = models.Sheet(id=sheet_id, name=name, author_id=email, group_id=group_id, frequency=frequency)
|
||||
db.add(db_sheet)
|
||||
db.commit()
|
||||
db.refresh(db_sheet)
|
||||
return db_sheet
|
||||
|
||||
|
||||
def get_user_sheet(db: Session, email: str, sheet_id: str) -> models.Sheet:
|
||||
return db.query(models.Sheet).filter(models.Sheet.author_id == email, models.Sheet.id == sheet_id).first()
|
||||
|
||||
|
||||
def get_user_sheets(db: Session, email: str) -> list[models.Sheet]:
|
||||
return db.query(models.Sheet).filter(models.Sheet.author_id == email).order_by(models.Sheet.last_url_archived_at.desc()).all()
|
||||
|
||||
|
||||
async def get_sheets_by_id_hash(db: AsyncSession, frequency: str, modulo: str, id_hash: int) -> list[models.Sheet]:
|
||||
result = await db.execute(
|
||||
select(models.Sheet).filter(models.Sheet.frequency == frequency)
|
||||
)
|
||||
filtered = []
|
||||
for sheet in result.scalars():
|
||||
if fnv1a_hash_mod(sheet.id, modulo) == id_hash:
|
||||
filtered.append(sheet)
|
||||
return filtered
|
||||
|
||||
|
||||
async def delete_stale_sheets(db: AsyncSession, inactivity_days: int) -> dict:
|
||||
time_threshold = datetime.now() - timedelta(days=inactivity_days)
|
||||
result = await db.execute(
|
||||
select(models.Sheet).filter(models.Sheet.last_url_archived_at < time_threshold)
|
||||
)
|
||||
deleted = defaultdict(list)
|
||||
for sheet in result.scalars():
|
||||
await db.delete(sheet)
|
||||
deleted[sheet.author_id].append(sheet)
|
||||
await db.commit()
|
||||
return dict(deleted)
|
||||
|
||||
|
||||
def delete_sheet(db: Session, sheet_id: str, email: str) -> bool:
|
||||
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id, models.Sheet.author_id == email).first()
|
||||
if db_sheet:
|
||||
db.delete(db_sheet)
|
||||
db.commit()
|
||||
return db_sheet is not None
|
||||
341
app/web/db/user_state.py
Normal file
341
app/web/db/user_state.py
Normal file
@@ -0,0 +1,341 @@
|
||||
|
||||
from typing import Dict, Set
|
||||
import sqlalchemy
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
from datetime import datetime
|
||||
|
||||
from app.shared.db import models
|
||||
from app.shared.user_groups import GroupInfo, GroupPermissions
|
||||
from app.shared.schemas import Usage, UsageResponse
|
||||
from app.web.db import crud
|
||||
from app.web.utils.misc import convert_priority_to_queue_dict
|
||||
|
||||
|
||||
class UserState:
|
||||
"""
|
||||
Manage a user's state and permissions
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session, email: str):
|
||||
self.db = db
|
||||
self.email = email.lower()
|
||||
|
||||
@property
|
||||
def permissions(self) -> Dict[str, GroupInfo]:
|
||||
"""
|
||||
Returns a dict of all group permissions and a special {"all": read/archive_url/archive_sheet} key
|
||||
"""
|
||||
if not hasattr(self, '_permissions'):
|
||||
self._permissions = {}
|
||||
self._permissions["all"] = GroupInfo(
|
||||
read=self.read,
|
||||
read_public=self.read_public,
|
||||
archive_url=self.archive_url,
|
||||
archive_sheet=self.archive_sheet,
|
||||
# below are relevant only for /url endpoints
|
||||
max_archive_lifespan_months=self.max_archive_lifespan_months,
|
||||
max_monthly_urls=self.max_monthly_urls,
|
||||
max_monthly_mbs=self.max_monthly_mbs,
|
||||
priority=self.priority
|
||||
)
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
self._permissions[group.id] = GroupInfo(**group.permissions, description=group.description, service_account_email=group.service_account_email)
|
||||
return self._permissions
|
||||
|
||||
@property
|
||||
def user_groups_names(self):
|
||||
if not hasattr(self, '_user_groups_names'):
|
||||
self._user_groups_names = crud.get_user_group_names(self.db, self.email) + ["default"]
|
||||
return self._user_groups_names
|
||||
|
||||
@property
|
||||
def user_groups(self):
|
||||
if not hasattr(self, '_user_groups'):
|
||||
self._user_groups = crud.get_user_groups_by_name(self.db, self.user_groups_names)
|
||||
return self._user_groups
|
||||
|
||||
@property
|
||||
def read(self) -> Set[str] | bool:
|
||||
"""
|
||||
Read can be a list of group names or True, if all can be read.
|
||||
"""
|
||||
if not hasattr(self, '_read'):
|
||||
self._read = set()
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
group_read_permissions = group.permissions.get("read", [])
|
||||
if "all" in group_read_permissions:
|
||||
self._read = True
|
||||
return self._read
|
||||
else:
|
||||
self._read.update(group.permissions.get("read", []))
|
||||
return self._read
|
||||
|
||||
@property
|
||||
def read_public(self) -> bool:
|
||||
"""
|
||||
Read public permission
|
||||
"""
|
||||
if not hasattr(self, '_read_public'):
|
||||
self._read_public = False
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
if group.permissions.get("read_public", False):
|
||||
self._read_public = True
|
||||
return self._read_public
|
||||
return self._read_public
|
||||
|
||||
@property
|
||||
def archive_url(self) -> bool:
|
||||
"""
|
||||
Archive URL permission
|
||||
"""
|
||||
if not hasattr(self, '_archive_url'):
|
||||
self._archive_url = False
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
if group.permissions.get("archive_url", False):
|
||||
self._archive_url = True
|
||||
return self._archive_url
|
||||
return self._archive_url
|
||||
|
||||
@property
|
||||
def archive_sheet(self) -> bool:
|
||||
"""
|
||||
Archive sheet permission
|
||||
"""
|
||||
if not hasattr(self, '_archive_sheet'):
|
||||
self._archive_sheet = False
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
if group.permissions.get("archive_sheet", False):
|
||||
self._archive_sheet = True
|
||||
return self._archive_sheet
|
||||
return self._archive_sheet
|
||||
|
||||
@property
|
||||
def sheet_frequency(self):
|
||||
if not hasattr(self, '_sheet_frequency'):
|
||||
self._sheet_frequency = set()
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
self._sheet_frequency.update(group.permissions.get("sheet_frequency", None))
|
||||
return self._sheet_frequency
|
||||
|
||||
@property
|
||||
def max_archive_lifespan_months(self) -> int:
|
||||
if not hasattr(self, '_max_archive_lifespan_months'):
|
||||
self._max_archive_lifespan_months = self._helper_for_grouping_max_numerical_permissions("max_archive_lifespan_months")
|
||||
return self._max_archive_lifespan_months
|
||||
|
||||
@property
|
||||
def max_monthly_urls(self) -> int:
|
||||
if not hasattr(self, '_max_monthly_urls'):
|
||||
self._max_monthly_urls = self._helper_for_grouping_max_numerical_permissions("max_monthly_urls")
|
||||
return self._max_monthly_urls
|
||||
|
||||
@property
|
||||
def max_monthly_mbs(self) -> int:
|
||||
if not hasattr(self, '_max_monthly_mbs'):
|
||||
self._max_monthly_mbs = self._helper_for_grouping_max_numerical_permissions("max_monthly_mbs")
|
||||
return self._max_monthly_mbs
|
||||
|
||||
@property
|
||||
def priority(self) -> str:
|
||||
if not hasattr(self, '_priority'):
|
||||
self._priority = "low"
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
if group.permissions.get("priority", self._priority) == "high":
|
||||
self._priority = "high"
|
||||
break
|
||||
return self._priority
|
||||
|
||||
@property
|
||||
def active(self) -> bool:
|
||||
"""
|
||||
A user is active if they can read/archive anything
|
||||
"""
|
||||
if not hasattr(self, '_active'):
|
||||
self._active = bool(self.read or self.read_public or self.archive_url or self.archive_sheet)
|
||||
return self._active
|
||||
|
||||
def _helper_for_grouping_max_numerical_permissions(self, permission_name: str) -> int:
|
||||
"""
|
||||
Iterates one of the numerical permissions where -1 means no restrictions and returns either -1 or the maximum value, defaults according to GroupPermissions
|
||||
"""
|
||||
default = GroupPermissions.model_fields[permission_name].default
|
||||
max_value = default
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
group_value = group.permissions.get(permission_name, default)
|
||||
if group_value == -1:
|
||||
max_value = -1
|
||||
return max_value
|
||||
max_value = max(max_value, group_value)
|
||||
return max_value
|
||||
|
||||
def in_group(self, group_id: str) -> bool:
|
||||
return group_id in self.user_groups_names
|
||||
|
||||
def usage(self) -> Dict:
|
||||
"""
|
||||
returns the monthly quotas for the URLs/MBs and the totals for Sheets
|
||||
"""
|
||||
current_month = datetime.now().month
|
||||
current_year = datetime.now().year
|
||||
|
||||
# find and sum all user sheets over this month
|
||||
user_sheets = self.db.query(
|
||||
models.Sheet.group_id,
|
||||
func.count(models.Sheet.id).label('sheet_count')
|
||||
).filter(models.Sheet.author_id == self.email).group_by(models.Sheet.group_id).all()
|
||||
|
||||
sheets_by_group = {sheet.group_id: sheet.sheet_count for sheet in user_sheets}
|
||||
|
||||
# find and sum all user urls over this month
|
||||
urls_by_group = self.db.query(
|
||||
models.Archive.group_id,
|
||||
func.count(models.Archive.id).label('url_count'),
|
||||
func.coalesce(func.sum(
|
||||
func.coalesce(
|
||||
func.cast(
|
||||
func.json_extract(models.Archive.result, '$.metadata.total_bytes'),
|
||||
sqlalchemy.Integer
|
||||
), 0
|
||||
)
|
||||
), 0).label('total_bytes')
|
||||
).filter(
|
||||
models.Archive.author_id == self.email,
|
||||
func.extract('month', models.Archive.created_at) == current_month,
|
||||
func.extract('year', models.Archive.created_at) == current_year
|
||||
).group_by(models.Archive.group_id).all()
|
||||
|
||||
# merge the two queries
|
||||
usage_by_group: Dict[str, Usage] = {
|
||||
(url.group_id or ""):
|
||||
Usage(monthly_urls=url.url_count, monthly_mbs=int(url.total_bytes / 1024 / 1024))
|
||||
for url in urls_by_group
|
||||
}
|
||||
for group_id, sheet_count in sheets_by_group.items():
|
||||
group_id = group_id or ""
|
||||
if group_id in usage_by_group:
|
||||
usage_by_group[group_id].total_sheets = sheet_count
|
||||
else:
|
||||
usage_by_group[group_id] = Usage(total_sheets=sheet_count)
|
||||
|
||||
# calculate totals
|
||||
total_sheets = sum([sheet.sheet_count for sheet in user_sheets])
|
||||
total_bytes = sum([url.total_bytes for url in urls_by_group])
|
||||
total_urls = sum([url.url_count for url in urls_by_group])
|
||||
|
||||
return UsageResponse(
|
||||
monthly_urls=total_urls,
|
||||
monthly_mbs=int(total_bytes / 1024 / 1024),
|
||||
total_sheets=total_sheets,
|
||||
groups=usage_by_group
|
||||
)
|
||||
|
||||
def has_quota_monthly_sheets(self, group_id: str) -> bool:
|
||||
"""
|
||||
checks if a user has reached their sheet quota for a given group
|
||||
"""
|
||||
if group_id not in self.permissions:
|
||||
return False
|
||||
|
||||
user_sheets = self.db.query(models.Sheet).filter(models.Sheet.author_id == self.email, models.Sheet.group_id == group_id).count()
|
||||
|
||||
sheet_quota = self.permissions[group_id].max_sheets
|
||||
if sheet_quota == -1:
|
||||
return True
|
||||
return user_sheets < sheet_quota
|
||||
|
||||
def has_quota_max_monthly_urls(self, group_id: str) -> bool:
|
||||
"""
|
||||
checks if a user has reached their monthly url quota for a group, if global then group should be empty string
|
||||
"""
|
||||
quota = 0
|
||||
if not group_id:
|
||||
quota = self.max_monthly_urls
|
||||
else:
|
||||
if group_id not in self.permissions: return False
|
||||
quota = self.permissions[group_id].max_monthly_urls
|
||||
|
||||
if quota == -1:
|
||||
return True
|
||||
|
||||
current_month = datetime.now().month
|
||||
current_year = datetime.now().year
|
||||
user_urls = self.db.query(models.Archive).filter(
|
||||
models.Archive.author_id == self.email,
|
||||
models.Archive.group_id == group_id,
|
||||
func.extract('month', models.Archive.created_at) == current_month,
|
||||
func.extract('year', models.Archive.created_at) == current_year
|
||||
).count()
|
||||
|
||||
return user_urls < quota
|
||||
|
||||
def has_quota_max_monthly_mbs(self, group_id: str) -> bool:
|
||||
"""
|
||||
checks if a user has reached their monthly MBs quota for a group, if global then group should be empty string
|
||||
"""
|
||||
quota = 0
|
||||
if not group_id:
|
||||
quota = self.max_monthly_mbs
|
||||
else:
|
||||
if group_id not in self.permissions: return False
|
||||
quota = self.permissions[group_id].max_monthly_mbs
|
||||
|
||||
if quota == -1:
|
||||
return True
|
||||
|
||||
current_month = datetime.now().month
|
||||
current_year = datetime.now().year
|
||||
|
||||
# find and sum all user bytes over this month
|
||||
user_bytes = self.db.query(models.Archive).filter(
|
||||
models.Archive.author_id == self.email,
|
||||
models.Archive.group_id == group_id,
|
||||
func.extract('month', models.Archive.created_at) == current_month,
|
||||
func.extract('year', models.Archive.created_at) == current_year
|
||||
).with_entities(func.coalesce(func.sum(
|
||||
func.coalesce(
|
||||
func.cast(
|
||||
func.json_extract(models.Archive.result, '$.metadata.total_bytes'),
|
||||
sqlalchemy.Integer
|
||||
), 0
|
||||
)
|
||||
), 0).label('total')).scalar()
|
||||
|
||||
# convert bytes to mb
|
||||
user_mbs = int(user_bytes / 1024 / 1024)
|
||||
return user_mbs < quota
|
||||
|
||||
def can_manually_trigger(self, group_id: str) -> bool:
|
||||
"""
|
||||
checks if a user is allowed to manually trigger a sheet
|
||||
"""
|
||||
if group_id not in self.permissions:
|
||||
return False
|
||||
|
||||
return self.permissions[group_id].manually_trigger_sheet
|
||||
|
||||
def is_sheet_frequency_allowed(self, group_id: str, frequency: str) -> bool:
|
||||
"""
|
||||
checks if a user is allowed to create a sheet with this frequency for this group
|
||||
"""
|
||||
if group_id not in self.permissions:
|
||||
return False
|
||||
|
||||
return frequency in self.permissions[group_id].sheet_frequency
|
||||
|
||||
def priority_group(self, group_id: str) -> str:
|
||||
priority = "low"
|
||||
for group in self.user_groups:
|
||||
if group.id != group_id: continue
|
||||
if not group.permissions: continue
|
||||
priority = group.permissions.get("priority", priority)
|
||||
break
|
||||
return convert_priority_to_queue_dict(priority)
|
||||
50
app/web/endpoints/default.py
Normal file
50
app/web/endpoints/default.py
Normal file
@@ -0,0 +1,50 @@
|
||||
|
||||
from typing import Dict
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
|
||||
from app.web.config import VERSION, BREAKING_CHANGES
|
||||
from app.shared.schemas import ActiveUser, UsageResponse
|
||||
from app.web.db.user_state import UserState
|
||||
from app.web.security import get_user_state
|
||||
from app.shared.user_groups import GroupInfo
|
||||
|
||||
default_router = APIRouter()
|
||||
|
||||
|
||||
@default_router.get("/")
|
||||
async def home():
|
||||
return JSONResponse({"version": VERSION, "breakingChanges": BREAKING_CHANGES})
|
||||
|
||||
|
||||
@default_router.get("/health")
|
||||
async def health():
|
||||
return JSONResponse({"status": "ok"})
|
||||
|
||||
|
||||
@default_router.get("/user/active", summary="Check if the user is active and can use the tool.")
|
||||
async def active(
|
||||
user: UserState = Depends(get_user_state),
|
||||
) -> ActiveUser:
|
||||
return {"active": user.active}
|
||||
|
||||
|
||||
@default_router.get("/user/permissions", summary="Get the user's global 'all' permissions and the permissions for each group they belong to.")
|
||||
def get_user_permissions(
|
||||
user: UserState = Depends(get_user_state),
|
||||
) -> Dict[str, GroupInfo]:
|
||||
return user.permissions
|
||||
|
||||
@default_router.get("/user/usage", summary="Get the user's monthly URLs/MBs usage along with the total active sheets, breakdown by group.")
|
||||
def get_user_usage(
|
||||
user: UserState = Depends(get_user_state),
|
||||
) -> UsageResponse:
|
||||
if not user.active:
|
||||
raise HTTPException(status_code=403, detail="User is not active.")
|
||||
return user.usage()
|
||||
|
||||
|
||||
|
||||
@default_router.get('/favicon.ico', include_in_schema=False)
|
||||
async def favicon() -> FileResponse:
|
||||
return FileResponse("app/web/static/favicon.ico")
|
||||
61
app/web/endpoints/interoperability.py
Normal file
61
app/web/endpoints/interoperability.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import json
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from loguru import logger
|
||||
import sqlalchemy
|
||||
from auto_archiver.core import Metadata
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.shared.aa_utils import get_all_urls
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
from app.shared import business_logic, schemas
|
||||
from app.shared.db import worker_crud
|
||||
from app.shared.db.database import get_db_dependency
|
||||
from app.web.security import token_api_key_auth
|
||||
from app.shared.db import models
|
||||
from app.shared.log import log_error
|
||||
|
||||
|
||||
interoperability_router = APIRouter(prefix="/interop", tags=["Interoperability endpoints."])
|
||||
|
||||
|
||||
# ----- endpoint to submit data archived elsewhere
|
||||
@interoperability_router.post("/submit-archive", status_code=201, summary="Submit a manual archive entry, for data that was archived elsewhere.")
|
||||
def submit_manual_archive(
|
||||
manual: schemas.SubmitManualArchive,
|
||||
auth=Depends(token_api_key_auth),
|
||||
db: Session = Depends(get_db_dependency)
|
||||
):
|
||||
try:
|
||||
result: Metadata = Metadata.from_json(manual.result)
|
||||
except json.JSONDecodeError as e:
|
||||
log_error(e)
|
||||
raise HTTPException(status_code=422, detail="Invalid JSON in result field.")
|
||||
manual.author_id = manual.author_id or ALLOW_ANY_EMAIL
|
||||
manual.tags.add("manual")
|
||||
|
||||
try:
|
||||
store_until=business_logic.get_store_archive_until(db, manual.group_id)
|
||||
except AssertionError as e:
|
||||
log_error(e)
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
|
||||
try:
|
||||
archive = schemas.ArchiveCreate(
|
||||
author_id=manual.author_id,
|
||||
url=result.get_url(),
|
||||
public=manual.public,
|
||||
group_id=manual.group_id,
|
||||
tags=manual.tags,
|
||||
id=models.generate_uuid(),
|
||||
result=json.loads(result.to_json()),
|
||||
urls=get_all_urls(result),
|
||||
store_until=store_until,
|
||||
)
|
||||
|
||||
db_archive = worker_crud.store_archived_url(db, archive)
|
||||
logger.debug(f"[MANUAL ARCHIVE STORED] {db_archive.author_id} {db_archive.url}")
|
||||
return JSONResponse({"id": db_archive.id}, status_code=201)
|
||||
except sqlalchemy.exc.IntegrityError as e:
|
||||
log_error(e)
|
||||
raise HTTPException(status_code=422, detail=f"Cannot insert into DB due to integrity error, likely duplicate urls.")
|
||||
81
app/web/endpoints/sheet.py
Normal file
81
app/web/endpoints/sheet.py
Normal file
@@ -0,0 +1,81 @@
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from sqlalchemy import exc
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.web.db.user_state import UserState
|
||||
from app.shared import schemas
|
||||
from app.shared.task_messaging import get_celery
|
||||
from app.web.security import get_user_state
|
||||
from app.web.db import crud
|
||||
from app.shared.db.database import get_db_dependency
|
||||
|
||||
sheet_router = APIRouter(prefix="/sheet", tags=["Google Spreadsheet operations"])
|
||||
|
||||
celery = get_celery()
|
||||
|
||||
@sheet_router.post("/create", status_code=201, summary="Store a new Google Sheet for regular archiving.")
|
||||
def create_sheet(
|
||||
sheet: schemas.SheetAdd,
|
||||
user: UserState = Depends(get_user_state),
|
||||
db: Session = Depends(get_db_dependency),
|
||||
) -> schemas.SheetResponse:
|
||||
|
||||
if not user.in_group(sheet.group_id):
|
||||
raise HTTPException(status_code=403, detail="User does not have access to this group.")
|
||||
|
||||
if not user.has_quota_monthly_sheets(sheet.group_id):
|
||||
raise HTTPException(status_code=429, detail="User has reached their sheet quota for this group.")
|
||||
|
||||
if not user.is_sheet_frequency_allowed(sheet.group_id, sheet.frequency):
|
||||
raise HTTPException(status_code=422, detail="Invalid frequency selected for this group.")
|
||||
|
||||
try:
|
||||
return crud.create_sheet(db, sheet.id, sheet.name, user.email, sheet.group_id, sheet.frequency)
|
||||
except exc.IntegrityError as e:
|
||||
raise HTTPException(status_code=400, detail="Sheet with this ID is already being archived.") from e
|
||||
|
||||
|
||||
@sheet_router.get("/mine", status_code=200, summary="Get the authenticated user's Google Sheets.")
|
||||
def get_user_sheets(
|
||||
user: UserState = Depends(get_user_state),
|
||||
db: Session = Depends(get_db_dependency)
|
||||
) -> list[schemas.SheetResponse]:
|
||||
return crud.get_user_sheets(db, user.email)
|
||||
|
||||
|
||||
@sheet_router.delete("/{id}", summary="Delete a Google Sheet by ID.")
|
||||
def delete_sheet(
|
||||
id: str,
|
||||
user: UserState = Depends(get_user_state),
|
||||
db: Session = Depends(get_db_dependency),
|
||||
) -> schemas.DeleteResponse:
|
||||
return JSONResponse({
|
||||
"id": id,
|
||||
"deleted": crud.delete_sheet(db, id, user.email)
|
||||
})
|
||||
|
||||
|
||||
@sheet_router.post("/{id}/archive", status_code=201, summary="Trigger an archiving task for a GSheet you own.", response_description="task_id for the archiving task.")
|
||||
def archive_user_sheet(
|
||||
id: str,
|
||||
user: UserState = Depends(get_user_state),
|
||||
db: Session = Depends(get_db_dependency),
|
||||
) -> schemas.Task:
|
||||
|
||||
sheet = crud.get_user_sheet(db, user.email, sheet_id=id)
|
||||
if not sheet:
|
||||
raise HTTPException(status_code=403, detail="No access to this sheet.")
|
||||
|
||||
if not user.in_group(sheet.group_id):
|
||||
raise HTTPException(status_code=403, detail="User does not have access to this group.")
|
||||
|
||||
if not user.can_manually_trigger(sheet.group_id):
|
||||
raise HTTPException(status_code=429, detail="User cannot manually trigger sheet archiving in this group.")
|
||||
|
||||
group_queue = user.priority_group(sheet.group_id)
|
||||
task = celery.signature("create_sheet_task", args=[schemas.SubmitSheet(sheet_id=id, author_id=user.email, group_id=sheet.group_id).model_dump_json()]).apply_async(**group_queue)
|
||||
|
||||
return JSONResponse({"id": task.id}, status_code=201)
|
||||
@@ -3,21 +3,19 @@ from fastapi import APIRouter, Depends
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from loguru import logger
|
||||
from web.security import get_token_or_user_auth
|
||||
|
||||
from db import schemas
|
||||
from core.logging import log_error
|
||||
from worker.main import celery
|
||||
from utils.mics import custom_jsonable_encoder
|
||||
from app.shared.task_messaging import get_celery
|
||||
from app.web.security import get_token_or_user_auth
|
||||
from app.shared import schemas
|
||||
from app.shared.log import log_error
|
||||
from app.web.utils.misc import custom_jsonable_encoder
|
||||
|
||||
|
||||
task_router = APIRouter(prefix="/task", tags=["Async task operations"])
|
||||
|
||||
celery = get_celery()
|
||||
|
||||
@task_router.get("/{task_id}", summary="Check the status of an async task by its id, works for URLs and Sheet tasks.")
|
||||
def get_status(task_id, email=Depends(get_token_or_user_auth)) -> schemas.TaskResult:
|
||||
logger.info(f"status check for user {email} task {task_id}")
|
||||
task = AsyncResult(task_id, app=celery)
|
||||
try:
|
||||
if task.status == "FAILURE":
|
||||
84
app/web/endpoints/url.py
Normal file
84
app/web/endpoints/url.py
Normal file
@@ -0,0 +1,84 @@
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
from app.shared import schemas
|
||||
from app.shared.task_messaging import get_celery
|
||||
from app.web.security import get_token_or_user_auth, get_user_state
|
||||
from app.web.db import crud
|
||||
from app.web.db.user_state import UserState
|
||||
from app.shared.db.database import get_db_dependency
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from app.web.utils.misc import convert_priority_to_queue_dict
|
||||
|
||||
url_router = APIRouter(prefix="/url", tags=["Single URL operations"])
|
||||
|
||||
celery = get_celery()
|
||||
|
||||
@url_router.post("/archive", status_code=201, summary="Submit a single URL archive request, starts an archiving task.", response_description="task_id for the archiving task, will match the archive id.")
|
||||
def archive_url(
|
||||
archive: schemas.ArchiveTrigger,
|
||||
email=Depends(get_token_or_user_auth),
|
||||
db: Session = Depends(get_db_dependency)
|
||||
) -> schemas.Task:
|
||||
archive.author_id = email
|
||||
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {archive.url}")
|
||||
|
||||
parsed_url = urlparse(archive.url)
|
||||
if not all([parsed_url.scheme, parsed_url.netloc]):
|
||||
raise HTTPException(status_code=400, detail="Invalid URL received.")
|
||||
|
||||
archive_create = schemas.ArchiveCreate(**archive.model_dump())
|
||||
if email != ALLOW_ANY_EMAIL:
|
||||
user = UserState(db, email)
|
||||
if archive.group_id and not user.in_group(archive.group_id):
|
||||
raise HTTPException(status_code=403, detail="User does not have access to this group.")
|
||||
if not user.has_quota_max_monthly_urls(archive.group_id):
|
||||
raise HTTPException(status_code=429, detail="User has reached their monthly URL quota.")
|
||||
if not user.has_quota_max_monthly_mbs(archive.group_id):
|
||||
raise HTTPException(status_code=429, detail="User has reached their monthly MB quota.")
|
||||
group_queue = user.priority_group(archive_create.group_id)
|
||||
else:
|
||||
group_queue = convert_priority_to_queue_dict("high")
|
||||
|
||||
|
||||
task = celery.signature("create_archive_task", args=[archive_create.model_dump_json()]).apply_async(**group_queue)
|
||||
task_response = schemas.Task(id=task.id)
|
||||
return JSONResponse(task_response.model_dump(), status_code=201)
|
||||
|
||||
|
||||
@url_router.get("/search", summary="Search for archive entries by URL.")
|
||||
def search_by_url(
|
||||
url: str, skip: int = 0, limit: int = 25,
|
||||
archived_after: datetime = None, archived_before: datetime = None,
|
||||
db: Session = Depends(get_db_dependency),
|
||||
email: str = Depends(get_token_or_user_auth)
|
||||
) -> list[schemas.ArchiveResult]:
|
||||
|
||||
read_groups, read_public = False, False
|
||||
if email != ALLOW_ANY_EMAIL:
|
||||
user = UserState(db, email)
|
||||
if not user.read and not user.read_public:
|
||||
raise HTTPException(status_code=403, detail="User does not have read access.")
|
||||
read_groups = user.read
|
||||
read_public = user.read_public
|
||||
return crud.search_archives_by_url(db, url.strip(), email, read_groups, read_public, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before)
|
||||
|
||||
|
||||
@url_router.delete("/{id}", summary="Delete a single URL archive by id.")
|
||||
def delete_archive(
|
||||
id:str,
|
||||
user: UserState = Depends(get_user_state),
|
||||
db: Session = Depends(get_db_dependency)
|
||||
) -> schemas.DeleteResponse:
|
||||
logger.info(f"deleting url archive task {id} request by {user.email}")
|
||||
return JSONResponse({
|
||||
"id": id,
|
||||
"deleted": crud.soft_delete_archive(db, id, user.email)
|
||||
})
|
||||
186
app/web/events.py
Normal file
186
app/web/events.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
import datetime
|
||||
import logging
|
||||
import alembic.config
|
||||
from fastapi import FastAPI
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi_utils.tasks import repeat_every
|
||||
from loguru import logger
|
||||
from fastapi_mail import FastMail, MessageSchema, MessageType
|
||||
|
||||
from app.shared.db import models
|
||||
from app.shared.db.database import get_db, get_db_async, make_engine, wal_checkpoint
|
||||
from app.shared import schemas
|
||||
from app.shared.settings import get_settings
|
||||
from app.shared.task_messaging import get_celery
|
||||
from app.web.db import crud
|
||||
from app.web.middleware import increase_exceptions_counter
|
||||
from app.web.utils.metrics import measure_regular_metrics, redis_subscribe_worker_exceptions
|
||||
|
||||
celery = get_celery()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# see https://fastapi.tiangolo.com/advanced/events/#lifespan
|
||||
|
||||
# STARTUP
|
||||
engine = make_engine(get_settings().DATABASE_PATH)
|
||||
models.Base.metadata.create_all(bind=engine)
|
||||
alembic.config.main(prog="alembic", argv=['--raiseerr', 'upgrade', 'head'])
|
||||
logging.getLogger("uvicorn.access").disabled = True # loguru
|
||||
asyncio.create_task(redis_subscribe_worker_exceptions(get_settings().REDIS_EXCEPTIONS_CHANNEL))
|
||||
asyncio.create_task(repeat_measure_regular_metrics())
|
||||
with get_db() as db:
|
||||
crud.upsert_user_groups(db)
|
||||
|
||||
# setup archive cronjobs
|
||||
if get_settings().CRON_ARCHIVE_SHEETS:
|
||||
asyncio.create_task(archive_hourly_sheets_cronjob())
|
||||
asyncio.create_task(archive_daily_sheets_cronjob())
|
||||
else:
|
||||
logger.warning("[CRON] Sheet archive cronjobs are disabled.")
|
||||
|
||||
if get_settings().CRON_DELETE_STALE_SHEETS:
|
||||
asyncio.create_task(delete_stale_sheets())
|
||||
else:
|
||||
logger.warning("[CRON] Delete stale sheets cronjob is disabled.")
|
||||
|
||||
if get_settings().CRON_DELETE_SCHEDULED_ARCHIVES:
|
||||
asyncio.create_task(notify_about_expired_archives())
|
||||
else:
|
||||
logger.warning("[CRON] Delete scheduled archives cronjob is disabled.")
|
||||
|
||||
wal_checkpoint()
|
||||
|
||||
yield # separates startup from shutdown instructions
|
||||
|
||||
# SHUTDOWN
|
||||
logger.info("shutting down")
|
||||
|
||||
|
||||
# CRON JOBS
|
||||
@repeat_every(seconds=get_settings().REPEAT_COUNT_METRICS_SECONDS, on_exception=increase_exceptions_counter)
|
||||
async def repeat_measure_regular_metrics():
|
||||
await measure_regular_metrics(get_settings().DATABASE_PATH, get_settings().REPEAT_COUNT_METRICS_SECONDS)
|
||||
|
||||
|
||||
@repeat_every(seconds=60, wait_first=120, on_exception=increase_exceptions_counter)
|
||||
async def archive_hourly_sheets_cronjob():
|
||||
await archive_sheets_cronjob("hourly", 60, datetime.datetime.now().minute)
|
||||
|
||||
|
||||
@repeat_every(seconds=3600, wait_first=120, on_exception=increase_exceptions_counter)
|
||||
async def archive_daily_sheets_cronjob():
|
||||
await archive_sheets_cronjob("daily", 24, datetime.datetime.now().hour)
|
||||
|
||||
|
||||
async def archive_sheets_cronjob(frequency: str, interval: int, current_time_unit: int):
|
||||
triggered_jobs = []
|
||||
|
||||
async with get_db_async() as db:
|
||||
sheets = await crud.get_sheets_by_id_hash(db, frequency, interval, current_time_unit)
|
||||
for s in sheets:
|
||||
group_queue = await crud.get_group_priority_async(db, s.group_id)
|
||||
task = celery.signature("create_sheet_task", args=[schemas.SubmitSheet(sheet_id=s.id, author_id=s.author_id, group_id=s.group_id).model_dump_json()]).apply_async(**group_queue)
|
||||
|
||||
triggered_jobs.append({"sheet_id": s.id, "task_id": task.id})
|
||||
logger.debug(f"[CRON {frequency.upper()}:{current_time_unit}] Triggered {len(triggered_jobs)} sheet tasks: {triggered_jobs}")
|
||||
|
||||
|
||||
# TODO: on exception should logerror but also prometheus counter
|
||||
DELETE_WINDOW = get_settings().DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS * 24 * 60 * 60
|
||||
|
||||
|
||||
@repeat_every(seconds=DELETE_WINDOW, wait_first=180, on_exception=increase_exceptions_counter)
|
||||
async def notify_about_expired_archives():
|
||||
notify_from = datetime.datetime.now() + datetime.timedelta(days=get_settings().DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS)
|
||||
async with get_db_async() as db:
|
||||
scheduled_deletions = await crud.find_by_store_until(db, notify_from)
|
||||
|
||||
user_archives = defaultdict(list)
|
||||
for archive in scheduled_deletions:
|
||||
user_archives[archive.author_id].append(archive)
|
||||
|
||||
if user_archives:
|
||||
fastmail = FastMail(get_settings().MAIL_CONFIG)
|
||||
# notify users
|
||||
for email in user_archives:
|
||||
list_of_archives = "\n".join([f'{a.url}, {a.id}, {a.store_until.isoformat()}<br/>' for a in user_archives[email]])
|
||||
# TODO: how can users download them in bulk?
|
||||
message = MessageSchema(
|
||||
subject="Auto Archiver: Archives Scheduled for Deletion",
|
||||
recipients=[email],
|
||||
body=f"""
|
||||
<html>
|
||||
<body>
|
||||
<p>Hi {email},</p>
|
||||
<p>Some of your archives will be deleted in the next {get_settings().DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS} days, as they are reaching their expiration date according to our retention policy for their groups.</p>
|
||||
<p>If you want to preserve any, make sure to download them now.</p>
|
||||
<p>Here is a CSV list of URLs:</p>
|
||||
<code>
|
||||
url,archive_id,time_of_deletion<br/>
|
||||
{list_of_archives}
|
||||
</code>
|
||||
<p>Best,<br>The Auto Archiver team</p>
|
||||
</body>
|
||||
</html>
|
||||
""",
|
||||
subtype=MessageType.html
|
||||
)
|
||||
await fastmail.send_message(message)
|
||||
logger.debug(f"[CRON] Email sent to {email} about {len(user_archives[email])} scheduled archives deletion.")
|
||||
|
||||
# now schedule the deletion event
|
||||
asyncio.create_task(delete_expired_archives())
|
||||
|
||||
|
||||
@repeat_every(max_repetitions=1, wait_first=10, seconds=0, on_exception=increase_exceptions_counter)
|
||||
async def delete_expired_archives():
|
||||
async with get_db_async() as db:
|
||||
count_deleted = await crud.soft_delete_expired_archives(db)
|
||||
if count_deleted:
|
||||
logger.debug(f"[CRON] Deleted {count_deleted} archives.")
|
||||
|
||||
|
||||
@repeat_every(seconds=86400, wait_first=150, on_exception=increase_exceptions_counter)
|
||||
async def delete_stale_sheets():
|
||||
STALE_DAYS = get_settings().DELETE_STALE_SHEETS_DAYS
|
||||
logger.debug(f"[CRON] Deleting stale sheets older than {STALE_DAYS} days.")
|
||||
async with get_db_async() as db:
|
||||
user_sheets = await crud.delete_stale_sheets(db, STALE_DAYS)
|
||||
|
||||
if not user_sheets: return
|
||||
|
||||
fastmail = FastMail(get_settings().MAIL_CONFIG)
|
||||
# notify users
|
||||
for email in user_sheets:
|
||||
list_of_sheets = "\n".join([f'<li><a href="https://docs.google.com/spreadsheets/d/{s.id}">{s.name}</a></li>' for s in user_sheets[email]])
|
||||
message = MessageSchema(
|
||||
subject="Auto Archiver: Stale Sheets Removed",
|
||||
recipients=[email],
|
||||
body=f"""
|
||||
<html>
|
||||
<body>
|
||||
<p>Hi {email},</p>
|
||||
<p>Your stale sheets have been removed from our system as no new URL was archived in the past {STALE_DAYS} days:</p>
|
||||
<ul>
|
||||
{list_of_sheets}
|
||||
</ul>
|
||||
<p>You can always re-add them at https://auto-archiver.bellingcat.com/.</p>
|
||||
<p>Best,<br>The Auto Archiver team</p>
|
||||
</body>
|
||||
</html>
|
||||
""",
|
||||
subtype=MessageType.html
|
||||
)
|
||||
await fastmail.send_message(message)
|
||||
logger.debug(f"[CRON] Email sent to {email} about stale sheets deletion.")
|
||||
|
||||
|
||||
# @repeat_at
|
||||
async def generate_users_export_csv():
|
||||
#TODO: implement a cronjob that regularly requested user data to a CSV file
|
||||
# see https://colab.research.google.com/drive/1QDbo3QXHPBdiTuANlA1AWVvN-rqxuCPa?authuser=0#scrollTo=4nPXeSdK8RBT
|
||||
pass
|
||||
60
app/web/main.py
Normal file
60
app/web/main.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import os
|
||||
from fastapi import FastAPI, Depends
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from loguru import logger
|
||||
|
||||
from app.web.middleware import logging_middleware
|
||||
from app.shared.task_messaging import get_celery
|
||||
|
||||
from app.web.security import token_api_key_auth
|
||||
from app.web.config import VERSION, API_DESCRIPTION
|
||||
from app.web.events import lifespan
|
||||
from app.shared.settings import get_settings
|
||||
|
||||
|
||||
from app.web.endpoints.default import default_router
|
||||
from app.web.endpoints.url import url_router
|
||||
from app.web.endpoints.sheet import sheet_router
|
||||
from app.web.endpoints.task import task_router
|
||||
from app.web.endpoints.interoperability import interoperability_router
|
||||
|
||||
celery = get_celery()
|
||||
|
||||
def app_factory(settings = get_settings()):
|
||||
app = FastAPI(
|
||||
title="Auto-Archiver API",
|
||||
description=API_DESCRIPTION,
|
||||
version=VERSION,
|
||||
contact={"name": "GitHub", "url": "https://github.com/bellingcat/auto-archiver-api"},
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
app.middleware("http")(logging_middleware)
|
||||
|
||||
app.include_router(default_router)
|
||||
app.include_router(url_router)
|
||||
app.include_router(sheet_router)
|
||||
app.include_router(task_router)
|
||||
app.include_router(interoperability_router)
|
||||
|
||||
# prometheus exposed in /metrics with authentication
|
||||
Instrumentator(should_group_status_codes=False, excluded_handlers=["/metrics", "/health", "/openapi.json", "/favicon.ico"]).instrument(app).expose(app, dependencies=[Depends(token_api_key_auth)])
|
||||
|
||||
if settings.SERVE_LOCAL_ARCHIVE:
|
||||
local_dir = settings.SERVE_LOCAL_ARCHIVE
|
||||
if not os.path.isdir(local_dir) and os.path.isdir(local_dir.replace("/app", ".")):
|
||||
local_dir = local_dir.replace("/app", ".")
|
||||
if len(settings.SERVE_LOCAL_ARCHIVE) > 1 and os.path.isdir(local_dir):
|
||||
logger.warning(f"MOUNTing local archive, use this in development only {settings.SERVE_LOCAL_ARCHIVE}")
|
||||
app.mount(settings.SERVE_LOCAL_ARCHIVE, StaticFiles(directory=local_dir), name=settings.SERVE_LOCAL_ARCHIVE)
|
||||
|
||||
return app
|
||||
31
app/web/middleware.py
Normal file
31
app/web/middleware.py
Normal file
@@ -0,0 +1,31 @@
|
||||
|
||||
import traceback
|
||||
from loguru import logger
|
||||
from fastapi import Request
|
||||
from app.shared.log import log_error
|
||||
from app.web.utils.metrics import EXCEPTION_COUNTER
|
||||
|
||||
|
||||
async def logging_middleware(request: Request, call_next):
|
||||
try:
|
||||
response = await call_next(request)
|
||||
#TODO: use Origin to have summary prometheus metrics on where requests come from
|
||||
# origin = request.headers.get("origin")
|
||||
logger.info(f"{request.client.host}:{request.client.port} {request.method} {request.url._url} - HTTP {response.status_code}")
|
||||
return response
|
||||
except Exception as e:
|
||||
location = f"{request.method} {request.url._url}"
|
||||
await increase_exceptions_counter(e, location)
|
||||
logger.info(f"{request.client.host}:{request.client.port} {location} - {e.__class__.__name__} {e}")
|
||||
raise e
|
||||
|
||||
async def increase_exceptions_counter(e: Exception, location:str="cronjob"):
|
||||
if location == "cronjob":
|
||||
try:
|
||||
last_trace = traceback.extract_tb(e.__traceback__)[-1]
|
||||
_file, _line, func_name, _text = last_trace
|
||||
location = func_name
|
||||
except Exception as e:
|
||||
logger.error(f"Unable to get function name from cronjob exception traceback: {e}")
|
||||
EXCEPTION_COUNTER.labels(type=e.__class__.__name__, location=location).inc()
|
||||
log_error(e)
|
||||
@@ -2,8 +2,12 @@ from loguru import logger
|
||||
import requests, secrets
|
||||
from fastapi import HTTPException, status, Depends
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from core.config import ALLOW_ANY_EMAIL
|
||||
from shared.settings import get_settings
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
from app.shared.settings import get_settings
|
||||
from app.shared.db.database import get_db_dependency
|
||||
from app.web.db.user_state import UserState
|
||||
|
||||
settings = get_settings()
|
||||
bearer_security = HTTPBearer()
|
||||
@@ -45,7 +49,7 @@ async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bear
|
||||
# validates the Bearer token in the case that it requires it
|
||||
valid_user, info = authenticate_user(credentials.credentials)
|
||||
if valid_user:
|
||||
return info
|
||||
return info.lower()
|
||||
logger.debug(f"TOKEN FAILURE: {valid_user=} {info=}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@@ -69,7 +73,11 @@ def authenticate_user(access_token):
|
||||
return False, f"email '{j.get('email')}' not verified"
|
||||
if int(j.get("expires_in", -1)) <= 0:
|
||||
return False, "Token expired"
|
||||
return True, j.get('email')
|
||||
return True, j.get('email').lower()
|
||||
except Exception as e:
|
||||
logger.warning(f"AUTH EXCEPTION occurred: {e}")
|
||||
return False, "exception occurred"
|
||||
|
||||
|
||||
def get_user_state(email:str=Depends(get_user_auth), db:Session=Depends(get_db_dependency)):
|
||||
return UserState(db, email)
|
||||
|
Before Width: | Height: | Size: 93 KiB After Width: | Height: | Size: 93 KiB |
@@ -3,23 +3,23 @@ import json
|
||||
import os
|
||||
import shutil
|
||||
from prometheus_client import Counter, Gauge
|
||||
import redis
|
||||
|
||||
from db import crud
|
||||
from db.database import get_db
|
||||
from core.logging import log_error
|
||||
from app.web.db import crud
|
||||
from app.shared.db.database import get_db
|
||||
from app.shared.log import log_error
|
||||
from app.shared.task_messaging import get_redis
|
||||
|
||||
|
||||
# Custom metrics
|
||||
EXCEPTION_COUNTER = Counter(
|
||||
"exceptions",
|
||||
"Number of times a certain exception has occurred.",
|
||||
labelnames=["types"]
|
||||
labelnames=["type", "location"]
|
||||
)
|
||||
WORKER_EXCEPTION = Counter(
|
||||
"worker_exceptions_total",
|
||||
"Number of times a certain exception has occurred on the worker.",
|
||||
labelnames=["types", "exception", "task", "traceback"]
|
||||
labelnames=["type", "exception", "task", "traceback"]
|
||||
)
|
||||
DISK_UTILIZATION = Gauge(
|
||||
"disk_utilization",
|
||||
@@ -38,16 +38,16 @@ DATABASE_METRICS_COUNTER = Counter(
|
||||
)
|
||||
|
||||
|
||||
async def redis_subscribe_worker_exceptions(REDIS_EXCEPTIONS_CHANNEL, CELERY_BROKER_URL):
|
||||
async def redis_subscribe_worker_exceptions(REDIS_EXCEPTIONS_CHANNEL: str):
|
||||
# Subscribe to Redis channel and increment the counter for each exception with info on the exception and task
|
||||
Rdis = redis.Redis.from_url(CELERY_BROKER_URL)
|
||||
PubSubExceptions = Rdis.pubsub()
|
||||
Redis = get_redis()
|
||||
PubSubExceptions = Redis.pubsub()
|
||||
PubSubExceptions.subscribe(REDIS_EXCEPTIONS_CHANNEL)
|
||||
while True:
|
||||
message = PubSubExceptions.get_message()
|
||||
if message and message["type"] == "message":
|
||||
data = json.loads(message["data"].decode("utf-8"))
|
||||
WORKER_EXCEPTION.labels(types=type(data["exception"]).__name__, exception=data["exception"], task=data["task"], traceback=data["traceback"]).inc()
|
||||
WORKER_EXCEPTION.labels(type=data["type"], exception=data["exception"], task=data["task"], traceback=data["traceback"]).inc()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
15
app/web/utils/misc.py
Normal file
15
app/web/utils/misc.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import base64
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
|
||||
|
||||
def custom_jsonable_encoder(obj):
|
||||
if isinstance(obj, bytes):
|
||||
return base64.b64encode(obj).decode('utf-8')
|
||||
return jsonable_encoder(obj)
|
||||
|
||||
|
||||
def convert_priority_to_queue_dict(priority: str) -> dict:
|
||||
return {
|
||||
"priority": 0 if priority == "high" else 10,
|
||||
"queue": f"{priority}_priority"
|
||||
}
|
||||
0
app/worker/__init__.py
Normal file
0
app/worker/__init__.py
Normal file
147
app/worker/main.py
Normal file
147
app/worker/main.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import json
|
||||
|
||||
import traceback, datetime
|
||||
from celery.signals import task_failure
|
||||
from loguru import logger
|
||||
from sqlalchemy import exc
|
||||
from auto_archiver.core.orchestrator import ArchivingOrchestrator
|
||||
|
||||
from app.shared.db import models
|
||||
from app.shared.db.database import get_db
|
||||
from app.shared import business_logic, schemas
|
||||
from app.shared.task_messaging import get_celery, get_redis
|
||||
from app.shared.settings import get_settings
|
||||
from app.shared.log import log_error
|
||||
from app.shared.aa_utils import get_all_urls
|
||||
from app.shared.db import worker_crud
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
celery = get_celery("worker")
|
||||
Redis = get_redis()
|
||||
|
||||
USER_GROUPS_FILENAME = settings.USER_GROUPS_FILENAME
|
||||
|
||||
# TODO: these are temporary PATCHES for new aa's functionality
|
||||
# logger.add("app/worker/worker_log.log", level="DEBUG")
|
||||
logger.remove = lambda x: print(f"logger.remove({x})")
|
||||
|
||||
# TODO: after release, as it requires updating past entries with sheet_id where tag is used, drop tags
|
||||
@celery.task(name="create_archive_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 1})
|
||||
def create_archive_task(self, archive_json: str):
|
||||
archive = schemas.ArchiveCreate.model_validate_json(archive_json)
|
||||
|
||||
# call auto-archiver
|
||||
args = get_orchestrator_args(archive.group_id, False, [archive.url])
|
||||
try:
|
||||
orchestrator = ArchivingOrchestrator()
|
||||
orchestrator.setup(args)
|
||||
result = next(orchestrator.feed())
|
||||
except SystemExit as e:
|
||||
log_error(e, f"create_archive_task: SystemExit from AA")
|
||||
except Exception as e:
|
||||
log_error(e, f"create_archive_task")
|
||||
raise e
|
||||
assert result, f"UNABLE TO archive: {archive.url}"
|
||||
|
||||
# prepare and insert in DB
|
||||
archive.store_until = get_store_until(archive.group_id)
|
||||
archive.id = self.request.id
|
||||
archive.urls = get_all_urls(result)
|
||||
archive.result = json.loads(result.to_json())
|
||||
insert_result_into_db(archive)
|
||||
|
||||
return archive.result
|
||||
|
||||
|
||||
@celery.task(name="create_sheet_task", bind=True)
|
||||
def create_sheet_task(self, sheet_json: str):
|
||||
sheet = schemas.SubmitSheet.model_validate_json(sheet_json)
|
||||
queue_name = (create_sheet_task.request.delivery_info or {}).get('routing_key', 'unknown')
|
||||
logger.info(f"[queue={queue_name}] SHEET START {sheet=}")
|
||||
|
||||
args = get_orchestrator_args(sheet.group_id, True, ["--gsheet_feeder.sheet_id", sheet.sheet_id])
|
||||
orchestrator = ArchivingOrchestrator()
|
||||
orchestrator.setup(args)
|
||||
|
||||
stats = {"archived": 0, "failed": 0, "errors": []}
|
||||
try:
|
||||
for result in orchestrator.feed():
|
||||
try:
|
||||
assert result, f"ERROR archiving URL for sheet {sheet.sheet_id}"
|
||||
archive = schemas.ArchiveCreate(
|
||||
author_id=sheet.author_id,
|
||||
url=result.get_url(),
|
||||
group_id=sheet.group_id,
|
||||
tags=sheet.tags,
|
||||
id=models.generate_uuid(),
|
||||
result=json.loads(result.to_json()),
|
||||
sheet_id=sheet.sheet_id,
|
||||
urls=get_all_urls(result),
|
||||
store_until=get_store_until(sheet.group_id)
|
||||
)
|
||||
insert_result_into_db(archive)
|
||||
stats["archived"] += 1
|
||||
except exc.IntegrityError as e:
|
||||
logger.warning(f"cached result detected: {e}")
|
||||
except Exception as e:
|
||||
log_error(e, extra=f"{self.name}: {sheet_json}")
|
||||
redis_publish_exception(e, self.name, traceback.format_exc())
|
||||
stats["failed"] += 1
|
||||
stats["errors"].append(str(e))
|
||||
|
||||
except SystemExit as e:
|
||||
log_error(e, f"create_sheet_task: SystemExit from AA")
|
||||
|
||||
if stats["archived"] > 0:
|
||||
with get_db() as session:
|
||||
worker_crud.update_sheet_last_url_archived_at(session, sheet.sheet_id)
|
||||
|
||||
logger.info(f"SHEET DONE {sheet=}")
|
||||
# TODO: is this used anywhere? maybe drop it
|
||||
return schemas.CelerySheetTask(success=True, sheet_id=sheet.sheet_id, time=datetime.datetime.now().isoformat(), stats=stats).model_dump()
|
||||
|
||||
|
||||
def get_orchestrator_args(group_id: str, orchestrator_for_sheet: bool, cli_args: list = []) -> list:
|
||||
aa_configs = []
|
||||
with get_db() as session:
|
||||
group = worker_crud.get_group(session, group_id)
|
||||
if orchestrator_for_sheet:
|
||||
orchestrator_fn = group.orchestrator_sheet
|
||||
else:
|
||||
orchestrator_fn = worker_crud.get_group(session, group_id).orchestrator
|
||||
assert orchestrator_fn, f"no orchestrator found for {group_id}"
|
||||
aa_configs.extend(["--config", orchestrator_fn])
|
||||
aa_configs.extend(cli_args)
|
||||
return aa_configs
|
||||
|
||||
|
||||
def insert_result_into_db(archive: schemas.ArchiveCreate) -> str:
|
||||
with get_db() as session:
|
||||
db_archive = worker_crud.store_archived_url(session, archive)
|
||||
logger.debug(f"[ARCHIVE STORED] {db_archive.author_id} {db_archive.url}")
|
||||
return db_archive.id
|
||||
|
||||
|
||||
def get_store_until(group_id: str) -> datetime.datetime:
|
||||
with get_db() as session:
|
||||
return business_logic.get_store_archive_until(session, group_id)
|
||||
|
||||
|
||||
def redis_publish_exception(exception, task_name, traceback: str = ""):
|
||||
REDIS_EXCEPTIONS_CHANNEL = settings.REDIS_EXCEPTIONS_CHANNEL
|
||||
try:
|
||||
exception_data = {"task": task_name, "type": exception.__class__.__name__, "exception": exception, "traceback": traceback}
|
||||
Redis.publish(REDIS_EXCEPTIONS_CHANNEL, json.dumps(exception_data, default=str))
|
||||
except Exception as e:
|
||||
log_error(e, f"[CRITICAL] Could not publish to {REDIS_EXCEPTIONS_CHANNEL}")
|
||||
|
||||
|
||||
@task_failure.connect(sender=create_sheet_task)
|
||||
@task_failure.connect(sender=create_archive_task)
|
||||
def task_failure_notifier(sender, **kwargs):
|
||||
# automatically capture exceptions in the worker tasks
|
||||
logger.warning(f"⚠️ worker task failed: {sender.name}")
|
||||
traceback_msg = "\n".join(traceback.format_list(traceback.extract_tb(kwargs['traceback'])))
|
||||
log_error(kwargs['exception'], traceback_msg, f"task_failure: {sender.name}")
|
||||
redis_publish_exception(kwargs['exception'], sender.name, traceback_msg)
|
||||
0
database/.gitkeep
Normal file
0
database/.gitkeep
Normal file
@@ -1,19 +1,30 @@
|
||||
services:
|
||||
web:
|
||||
command: uvicorn app.web:app --factory --host 0.0.0.0 --reload
|
||||
restart: "no"
|
||||
env_file: src/.env.dev
|
||||
env_file: .env.dev
|
||||
volumes:
|
||||
- ./app/web:/aa-api/app/web # for --reload to work
|
||||
- ./app/shared:/aa-api/app/shared # for --reload to work
|
||||
environment:
|
||||
- SERVE_LOCAL_ARCHIVE=/app/local_archive # See orchestration.yaml local_storage.save_to
|
||||
- ALLOWED_ORIGINS=http://localhost:8004,chrome-extension://ojcimmjndnlmmlgnjaeojoebaceokpdp
|
||||
- USER_GROUPS_FILENAME=user-groups.dev.yaml
|
||||
- DATABASE_PATH=sqlite:////app/auto-archiver.db
|
||||
- ENVIRONMENT_FILE=.env.dev
|
||||
- SERVE_LOCAL_ARCHIVE=/aa-api/app/local_archive # See orchestration.yaml local_storage.save_to
|
||||
- ALLOWED_ORIGINS=["http://localhost:8000","http://localhost:8004","http://localhost:8081","chrome-extension://ojcimmjndnlmmlgnjaeojoebaceokpdp"]
|
||||
- USER_GROUPS_FILENAME=/aa-api/app/user-groups.dev.yaml
|
||||
- DATABASE_PATH=sqlite:////aa-api/database/auto-archiver.db
|
||||
|
||||
|
||||
worker:
|
||||
# command: watchmedo auto-restart --patterns="*.py" --recursive --ignore-directories -- celery -- --app=app.worker.main.celery worker --loglevel=debug --logfile=/aa-api/logs/celery.log -Q high_priority,low_priority --concurrency=${CONCURRENCY}
|
||||
command: celery --app=app.worker.main.celery worker --loglevel=debug --logfile=/aa-api/logs/celery.log -Q high_priority,low_priority --concurrency=${CONCURRENCY}
|
||||
restart: "no"
|
||||
env_file: src/.env.dev
|
||||
env_file: .env.dev
|
||||
volumes:
|
||||
- ./app/worker:/aa-api/app/worker # for watchmedo to work
|
||||
- ./app/shared:/aa-api/app/shared # for watchmedo to work
|
||||
|
||||
redis:
|
||||
restart: "no"
|
||||
env_file: src/.env.dev
|
||||
env_file: .env.dev
|
||||
ports:
|
||||
- 6379:6379
|
||||
|
||||
@@ -1,13 +1,3 @@
|
||||
# reusable YAML variables
|
||||
x-broker-url: &broker-url "redis://:${REDIS_PASSWORD}@redis:6379/0"
|
||||
|
||||
x-base-setup: &base-setup
|
||||
build: ./src
|
||||
restart: always
|
||||
env_file: src/.env.prod
|
||||
environment:
|
||||
CELERY_BROKER_URL: *broker-url
|
||||
CELERY_RESULT_BACKEND: *broker-url
|
||||
|
||||
volumes:
|
||||
crawls:
|
||||
@@ -15,31 +5,45 @@ volumes:
|
||||
name: "auto-archiver-api"
|
||||
services:
|
||||
web:
|
||||
<<: *base-setup
|
||||
build:
|
||||
context: .
|
||||
dockerfile: web.Dockerfile
|
||||
restart: always
|
||||
env_file: .env.prod
|
||||
environment:
|
||||
ENVIRONMENT_FILE: .env.prod
|
||||
REDIS_HOSTNAME: redis
|
||||
ports:
|
||||
- "127.0.0.1:8004:8000"
|
||||
command: uvicorn web:app --factory --host 0.0.0.0 --reload
|
||||
command: uvicorn app.web:app --factory --host 0.0.0.0
|
||||
volumes:
|
||||
- ./src:/app
|
||||
- ./logs:/aa-api/logs
|
||||
- ./database:/aa-api/database
|
||||
- ./secrets:/aa-api/secrets
|
||||
depends_on:
|
||||
- redis
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||
test: ["CMD", "python3", "-c", 'import sys, urllib.request; sys.exit(urllib.request.urlopen("http://localhost:8000/health").getcode() != 200)']
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
|
||||
worker:
|
||||
<<: *base-setup
|
||||
command: celery --app=worker.main.celery worker --loglevel=info --logfile=logs/celery.log
|
||||
build:
|
||||
context: .
|
||||
dockerfile: worker.Dockerfile
|
||||
restart: always
|
||||
env_file: .env.prod
|
||||
command: celery --app=app.worker.main.celery worker --loglevel=warning --logfile=/aa-api/logs/celery.log -Q high_priority,low_priority --concurrency=${CONCURRENCY}
|
||||
volumes:
|
||||
- ./src:/app
|
||||
- ./logs:/aa-api/logs
|
||||
- ./database:/aa-api/database
|
||||
- ./secrets:/aa-api/secrets
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
- crawls:/crawls # BROWSERTRIX_HOME_HOST:BROWSERTRIX_HOME_CONTAINER, do not change /crawls
|
||||
environment:
|
||||
# celery broker-url needs to be duplicated here, do not remove
|
||||
CELERY_BROKER_URL: *broker-url
|
||||
CELERY_RESULT_BACKEND: *broker-url
|
||||
REDIS_HOSTNAME: redis
|
||||
ENVIRONMENT_FILE: .env.prod
|
||||
WACZ_ENABLE_DOCKER: 1 # Enable calling docker from this container
|
||||
BROWSERTRIX_HOME_HOST: auto-archiver-api_crawls
|
||||
BROWSERTRIX_HOME_CONTAINER: /crawls
|
||||
@@ -47,7 +51,7 @@ services:
|
||||
- web
|
||||
- redis
|
||||
healthcheck:
|
||||
test: ["CMD", "pipenv", "run", "celery", "-A", "worker.main.celery", "status"]
|
||||
test: ["CMD-SHELL", "./poetry-venv/bin/poetry run celery -A app.worker.main.celery inspect ping || exit 1"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
@@ -55,10 +59,11 @@ services:
|
||||
redis:
|
||||
image: redis:6-alpine
|
||||
restart: always
|
||||
env_file: .env.prod
|
||||
command: redis-server /conf/redis.conf --requirepass ${REDIS_PASSWORD}
|
||||
volumes:
|
||||
- "./redis/data:/data"
|
||||
- "./redis/config:/conf"
|
||||
- ./redis/data:/data
|
||||
- ./redis/config:/conf
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "-a", "${REDIS_PASSWORD}", "ping"]
|
||||
interval: 30s
|
||||
|
||||
0
logs/.gitkeep
Normal file
0
logs/.gitkeep
Normal file
3690
poetry.lock
generated
Normal file
3690
poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
55
pyproject.toml
Normal file
55
pyproject.toml
Normal file
@@ -0,0 +1,55 @@
|
||||
[tool.poetry]
|
||||
package-mode = false
|
||||
|
||||
[project]
|
||||
name = "auto-archiver-api"
|
||||
description = "API wrapper for Bellingcat's Auto Archiver, supports users, groups, sheet and url archives."
|
||||
authors = [
|
||||
{ name = "Bellingcat", email = "contact-tech@bellingcat.com" },
|
||||
]
|
||||
license = {text = "MIT"}
|
||||
readme = "README.md"
|
||||
keywords = ["archive", "oosi", "osint", "scraping"]
|
||||
classifiers = [
|
||||
"Intended Audience :: Developers",
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3"
|
||||
]
|
||||
|
||||
requires-python = ">=3.10,<3.13"
|
||||
|
||||
|
||||
dependencies = [
|
||||
"auto-archiver (>=0.13.1)",
|
||||
"oscrypto @ git+https://github.com/wbond/oscrypto.git@d5f3437ed24257895ae1edd9e503cfb352e635a8",
|
||||
"celery (>=5.0)",
|
||||
"redis (==3.5.3)",
|
||||
"loguru (>=0.7.3,<0.8.0)",
|
||||
"pydantic-settings (>=2.7.1,<3.0.0)",
|
||||
"sqlalchemy (>=2.0.38,<3.0.0)",
|
||||
"requests (>=2.25.1)",
|
||||
"pyopenssl (>=23.3.0)",
|
||||
]
|
||||
[tool.poetry.group.worker.dependencies]
|
||||
watchdog = ">=6.0.0,<7.0.0"
|
||||
setuptools = "^75.8.0"
|
||||
|
||||
[tool.poetry.group.web.dependencies]
|
||||
fastapi = ">=0.115.8,<0.116.0"
|
||||
requests = ">=2.32.3,<3.0.0"
|
||||
aiosqlite = ">=0.21.0,<0.22.0"
|
||||
alembic = ">=1.14.1,<2.0.0"
|
||||
fastapi-utils = ">=0.8.0,<0.9.0"
|
||||
prometheus-fastapi-instrumentator = ">=7.0.2,<8.0.0"
|
||||
fastapi-mail = ">=1.4.2,<2.0.0"
|
||||
uvicorn = ">=0.13.4"
|
||||
pyyaml = "^6.0.2"
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pytest = ">=8.3.4,<9.0.0"
|
||||
httpx = ">=0.28.1,<0.29.0"
|
||||
coverage = ">=7.6.11,<8.0.0"
|
||||
pytest-asyncio = ">=0.25.3,<0.26.0"
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
DATABASE_PATH="sqlite:///./auto-archiver.db"
|
||||
USER_GROUPS_FILENAME=user-groups.yaml
|
||||
CHROME_APP_IDS=000000000000000000000000000000000000000000000.apps.googleusercontent.com,000000000000000000000000000000000000000000001.apps.googleusercontent.com
|
||||
#ALLOWED_ORIGINS="http://localhost:8004" # dev only
|
||||
|
||||
|
||||
API_BEARER_TOKEN=TODO
|
||||
@@ -1,22 +0,0 @@
|
||||
# From python:3.10
|
||||
FROM bellingcat/auto-archiver
|
||||
|
||||
# set work directory
|
||||
WORKDIR /app
|
||||
|
||||
RUN curl -fsSL https://get.docker.com -o get-docker.sh && \
|
||||
sh get-docker.sh
|
||||
# set environment variables
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
|
||||
# install dependencies
|
||||
RUN pip install --upgrade pip && \
|
||||
apt-get update
|
||||
COPY Pipfile* ./
|
||||
RUN pipenv install
|
||||
|
||||
# copy src code over
|
||||
COPY . .
|
||||
|
||||
ENTRYPOINT ["pipenv", "run"]
|
||||
32
src/Pipfile
32
src/Pipfile
@@ -1,32 +0,0 @@
|
||||
[[source]]
|
||||
url = "https://pypi.org/simple"
|
||||
verify_ssl = true
|
||||
name = "pypi"
|
||||
|
||||
[packages]
|
||||
aiofiles = "==0.6.0"
|
||||
celery = ">=5.0"
|
||||
fastapi = "*"
|
||||
jinja2 = "*"
|
||||
redis = "==3.5.3"
|
||||
requests = ">=2.25.1"
|
||||
uvicorn = ">=0.13.4"
|
||||
aiosqlite = "*"
|
||||
python-dotenv = "*"
|
||||
loguru = "*"
|
||||
sqlalchemy = "*"
|
||||
alembic = "*"
|
||||
fastapi-utils = "*"
|
||||
prometheus-fastapi-instrumentator = "*"
|
||||
auto-archiver = "*"
|
||||
pydantic-settings = "*"
|
||||
|
||||
[dev-packages]
|
||||
watchdog = "*"
|
||||
pytest = "*"
|
||||
httpx = "*"
|
||||
coverage = "*"
|
||||
pytest-asyncio = "*"
|
||||
|
||||
[requires]
|
||||
python_version = "3.10"
|
||||
3517
src/Pipfile.lock
generated
3517
src/Pipfile.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -1,41 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import alembic.config
|
||||
from fastapi import FastAPI
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi_utils.tasks import repeat_every
|
||||
from loguru import logger
|
||||
|
||||
from db import crud, models
|
||||
from db.database import get_db, make_engine
|
||||
from shared.settings import get_settings
|
||||
from utils.metrics import measure_regular_metrics, redis_subscribe_worker_exceptions
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# see https://fastapi.tiangolo.com/advanced/events/#lifespan
|
||||
|
||||
# STARTUP
|
||||
engine = make_engine(get_settings().DATABASE_PATH)
|
||||
models.Base.metadata.create_all(bind=engine)
|
||||
alembic.config.main(argv=['--raiseerr', 'upgrade', 'head'])
|
||||
# disabling uvicorn logger since we use loguru in logging_middleware
|
||||
logging.getLogger("uvicorn.access").disabled = True
|
||||
asyncio.create_task(redis_subscribe_worker_exceptions(get_settings().REDIS_EXCEPTIONS_CHANNEL, get_settings().CELERY_BROKER_URL))
|
||||
asyncio.create_task(repeat_measure_regular_metrics())
|
||||
with get_db() as db:
|
||||
crud.upsert_user_groups(db)
|
||||
|
||||
yield # separates startup from shutdown instructions
|
||||
|
||||
# SHUTDOWN
|
||||
logger.info("shutting down")
|
||||
|
||||
|
||||
# CRON JOBS
|
||||
|
||||
|
||||
@repeat_every(seconds=get_settings().REPEAT_COUNT_METRICS_SECONDS)
|
||||
async def repeat_measure_regular_metrics():
|
||||
await measure_regular_metrics(get_settings().DATABASE_PATH, get_settings().REPEAT_COUNT_METRICS_SECONDS)
|
||||
@@ -1,26 +0,0 @@
|
||||
import traceback
|
||||
from loguru import logger
|
||||
from fastapi import Request
|
||||
|
||||
|
||||
# logging configurations
|
||||
logger.add("logs/api_logs.log", retention="30 days", rotation="3 days")
|
||||
logger.add("logs/error_logs.log", retention="30 days", level="ERROR")
|
||||
|
||||
|
||||
def log_error(e: Exception, traceback_str: str = None, extra:str = ""):
|
||||
# EXCEPTION_COUNTER.labels(type(e).__name__).inc()
|
||||
if not traceback_str: traceback_str = traceback.format_exc()
|
||||
if extra: extra = f"{extra}\n"
|
||||
logger.error(f"{extra}{e.__class__.__name__}: {e}\n{traceback_str}")
|
||||
|
||||
async def logging_middleware(request: Request, call_next):
|
||||
try:
|
||||
response = await call_next(request)
|
||||
logger.info(f"{request.client.host}:{request.client.port} {request.method} {request.url._url} - HTTP {response.status_code}")
|
||||
return response
|
||||
except Exception as e:
|
||||
from utils.metrics import EXCEPTION_COUNTER
|
||||
EXCEPTION_COUNTER.labels(type(e).__name__).inc()
|
||||
log_error(e)
|
||||
raise e
|
||||
@@ -1 +0,0 @@
|
||||
based on https://fastapi-users.github.io/fastapi-users/10.4/configuration/oauth/
|
||||
@@ -1 +0,0 @@
|
||||
# https://fastapi.tiangolo.com/tutorial/sql-databases/#review-all-the-files
|
||||
259
src/db/crud.py
259
src/db/crud.py
@@ -1,259 +0,0 @@
|
||||
from collections import defaultdict
|
||||
from functools import cache
|
||||
from sqlalchemy.orm import Session, load_only
|
||||
from sqlalchemy import Column, or_, func
|
||||
from loguru import logger
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from core.config import ALLOW_ANY_EMAIL
|
||||
from shared.settings import get_settings
|
||||
from . import models, schemas
|
||||
import yaml
|
||||
|
||||
DATABASE_QUERY_LIMIT = get_settings().DATABASE_QUERY_LIMIT
|
||||
|
||||
# --------------- TASK = Archive
|
||||
|
||||
|
||||
def get_limit(user_limit: int):
|
||||
return max(1, min(user_limit, DATABASE_QUERY_LIMIT))
|
||||
|
||||
|
||||
def get_archive(db: Session, id: str, email: str):
|
||||
email = email.lower()
|
||||
query = base_query(db).filter(models.Archive.id == id)
|
||||
if email != ALLOW_ANY_EMAIL:
|
||||
groups = get_user_groups(db, email)
|
||||
query = query.filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups)))
|
||||
return query.first()
|
||||
|
||||
|
||||
def search_archives_by_url(db: Session, url: str, email: str, skip: int = 0, limit: int = 100, archived_after: datetime = None, archived_before: datetime = None, absolute_search: bool = False):
|
||||
# searches for partial URLs, if email is * no ownership filtering happens
|
||||
query = base_query(db)
|
||||
if email != ALLOW_ANY_EMAIL:
|
||||
email = email.lower()
|
||||
groups = get_user_groups(db, email)
|
||||
query = query.filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups)))
|
||||
if absolute_search:
|
||||
query = query.filter(models.Archive.url == url)
|
||||
else:
|
||||
query = query.filter(models.Archive.url.like(f'%{url}%'))
|
||||
if archived_after:
|
||||
query = query.filter(models.Archive.created_at > archived_after)
|
||||
if archived_before:
|
||||
query = query.filter(models.Archive.created_at < archived_before)
|
||||
return query.order_by(models.Archive.created_at.desc()).offset(skip).limit(get_limit(limit)).all()
|
||||
|
||||
|
||||
def search_archives_by_email(db: Session, email: str, skip: int = 0, limit: int = 100):
|
||||
email = email.lower()
|
||||
return base_query(db).filter(models.Archive.author_id == email).order_by(models.Archive.created_at.desc()).offset(skip).limit(get_limit(limit)).all()
|
||||
|
||||
|
||||
def create_task(db: Session, task: schemas.ArchiveCreate, tags: list[models.Tag], urls: list[models.ArchiveUrl]):
|
||||
db_task = models.Archive(id=task.id, url=task.url, result=task.result, public=task.public, author_id=task.author_id, group_id=task.group_id)
|
||||
db_task.tags = tags
|
||||
db_task.urls = urls
|
||||
db.add(db_task)
|
||||
db.commit()
|
||||
db.refresh(db_task)
|
||||
return db_task
|
||||
|
||||
|
||||
def soft_delete_task(db: Session, task_id: str, email: str) -> bool:
|
||||
# TODO: implement hard-delete with cronjob that deletes from S3
|
||||
db_task = db.query(models.Archive).filter(models.Archive.id == task_id, models.Archive.author_id == email, models.Archive.deleted == False).first()
|
||||
if db_task:
|
||||
db_task.deleted = True
|
||||
db.commit()
|
||||
return db_task is not None
|
||||
|
||||
|
||||
def count_archives(db: Session):
|
||||
return db.query(func.count(models.Archive.id)).scalar()
|
||||
|
||||
|
||||
def count_archive_urls(db: Session):
|
||||
return db.query(func.count(models.ArchiveUrl.url)).scalar()
|
||||
|
||||
|
||||
def count_users(db: Session):
|
||||
return db.query(func.count(models.User.email)).scalar()
|
||||
|
||||
|
||||
def count_by_user_since(db: Session, seconds_delta: int = 15):
|
||||
time_threshold = datetime.now() - timedelta(seconds=seconds_delta)
|
||||
return db.query(models.Archive.author_id, func.count().label('total'))\
|
||||
.filter(models.Archive.created_at >= time_threshold)\
|
||||
.group_by(models.Archive.author_id)\
|
||||
.order_by(func.count().desc())\
|
||||
.limit(500).all()
|
||||
|
||||
|
||||
def base_query(db: Session):
|
||||
# NOTE: load_only is for optimization and not obfuscation, use .with_entities() if needed
|
||||
return db.query(models.Archive)\
|
||||
.filter(models.Archive.deleted == False)\
|
||||
.options(load_only(models.Archive.id, models.Archive.created_at, models.Archive.url, models.Archive.result))
|
||||
|
||||
# --------------- TAG
|
||||
|
||||
|
||||
def create_tag(db: Session, tag: str):
|
||||
db_tag = db.query(models.Tag).filter(models.Tag.id == tag).first()
|
||||
if not db_tag:
|
||||
db_tag = models.Tag(id=tag)
|
||||
db.add(db_tag)
|
||||
db.commit()
|
||||
db.refresh(db_tag)
|
||||
return db_tag
|
||||
|
||||
|
||||
def is_active_user(db: Session, email: str) -> bool:
|
||||
email = email.lower()
|
||||
if not email or not len(email) or "@" not in email: return False
|
||||
domain = email.split('@')[1]
|
||||
|
||||
explicitly_active = db.query(models.User).filter(models.User.email == email, models.User.is_active == True).first() is not None
|
||||
if explicitly_active: return True
|
||||
|
||||
return db.query(models.Group).filter(models.Group.domains.contains(domain)).first() is not None
|
||||
|
||||
|
||||
def is_user_in_group(db: Session, group_name: str, email: str) -> models.Group:
|
||||
if email == ALLOW_ANY_EMAIL: return True
|
||||
return len(group_name) and len(email) and group_name in get_user_groups(db, email)
|
||||
|
||||
|
||||
def get_user_groups(db: Session, email: str):
|
||||
"""
|
||||
given an email retrieves the user groups from the DB and then the email-domain groups from a global variable, the email does not need to belong to an existing user. User does not need to be active.
|
||||
"""
|
||||
if not email or not len(email) or "@" not in email: return []
|
||||
email = email.lower()
|
||||
|
||||
# get user groups
|
||||
user_groups = db.query(models.association_table_user_groups).filter_by(user_id=email).with_entities(Column("group_id")).all()
|
||||
user_level_groups = [g[0] for g in user_groups]
|
||||
|
||||
# get domain groups
|
||||
domain = email.split('@')[1]
|
||||
domain_level_groups = db.query(models.Group.id).filter(models.Group.domains.contains(domain)).with_entities(Column("id")).all()
|
||||
domain_level_groups = [g[0] for g in domain_level_groups]
|
||||
|
||||
# combine and return
|
||||
return list(set(user_level_groups + domain_level_groups))
|
||||
|
||||
|
||||
# --------------- INIT User-Groups
|
||||
|
||||
|
||||
def create_or_get_user(db: Session, author_id: str, is_active: bool = models.User.is_active.default.arg) -> models.User:
|
||||
if type(author_id) == str: author_id = author_id.lower()
|
||||
db_user = db.query(models.User).filter(models.User.email == author_id).first()
|
||||
if not db_user:
|
||||
db_user = models.User(email=author_id, is_active=is_active)
|
||||
db.add(db_user)
|
||||
db.commit()
|
||||
db.refresh(db_user)
|
||||
return db_user
|
||||
|
||||
|
||||
def upsert_group(db: Session, group_name: str, description: str, orchestrator: str, orchestrator_sheet: str, permissions: dict, domains: list) -> models.Group:
|
||||
db_group = db.query(models.Group).filter(models.Group.id == group_name).first()
|
||||
if db_group is None:
|
||||
db_group = models.Group(id=group_name, description=description, orchestrator=orchestrator, orchestrator_sheet=orchestrator_sheet, permissions=permissions, domains=domains)
|
||||
db.add(db_group)
|
||||
else:
|
||||
db_group.description = description
|
||||
db_group.orchestrator = orchestrator
|
||||
db_group.orchestrator_sheet = orchestrator_sheet
|
||||
db_group.permissions = permissions
|
||||
db_group.domains = domains
|
||||
db.commit()
|
||||
db.refresh(db_group)
|
||||
return db_group
|
||||
|
||||
|
||||
def upsert_user(db: Session, email: str, active: bool):
|
||||
db_user = db.query(models.User).filter(models.User.email == email).first()
|
||||
if db_user is None:
|
||||
db_user = models.User(email=email, is_active=active)
|
||||
db.add(db_user)
|
||||
else:
|
||||
db_user.is_active = active
|
||||
db.commit()
|
||||
return db_user
|
||||
|
||||
|
||||
def upsert_user_groups(db: Session):
|
||||
def display_email_pii(email: str):
|
||||
return f"'{email[0:3]}...@{email.split('@')[1]}'"
|
||||
"""
|
||||
reads the user_groups yaml file and inserts any new users, groups,
|
||||
along with new participation of users in groups
|
||||
"""
|
||||
logger.debug("Updating user-groups configuration.")
|
||||
filename = get_settings().USER_GROUPS_FILENAME
|
||||
|
||||
# read yaml safely
|
||||
try:
|
||||
with open(filename) as inf:
|
||||
user_groups_yaml = yaml.safe_load(inf)
|
||||
except Exception as e:
|
||||
logger.error(f"could not open user groups filename {filename}: {e}")
|
||||
raise e
|
||||
|
||||
# delete all user-groups relationships
|
||||
db.query(models.association_table_user_groups).delete()
|
||||
|
||||
# set all users to inactive
|
||||
db.query(models.User).update({models.User.is_active: False})
|
||||
|
||||
# create a map of group_id -> domains and another of domain -> groups
|
||||
group_domains = defaultdict(set)
|
||||
domain_groups = defaultdict(list)
|
||||
for domain, explicit_groups in user_groups_yaml.get("domains", {}).items():
|
||||
domain_groups[domain] = list(set(explicit_groups))
|
||||
for group in explicit_groups:
|
||||
group_domains[group].add(domain)
|
||||
|
||||
# upsert groups and save a map of groupid -> dbobject
|
||||
for group_id, g in user_groups_yaml.get("groups", {}).items():
|
||||
upsert_group(db, group_id, g["description"], g["orchestrator"], g["orchestrator_sheet"], g["permissions"], list(group_domains.get(group_id, [])))
|
||||
db_groups: dict[str, models.Group] = {g.id: g for g in db.query(models.Group).all()}
|
||||
|
||||
# integrity checks
|
||||
for group_in_domains in group_domains:
|
||||
if group_in_domains not in db_groups:
|
||||
logger.error(f"[CONFIG] Group '{group_in_domains}' does not exist in the database: domains setting will not work.")
|
||||
if group_in_domains not in user_groups_yaml.get("groups", {}):
|
||||
logger.error(f"[CONFIG] Group '{group_in_domains}' does not exist in the config file: domains setting will not work.")
|
||||
|
||||
# reinsert users in their EXPLICITLY DEFINED groups
|
||||
# domain groups are check live, as there may be new users that are not explicitly registered but belong to a domain
|
||||
for email, explicit_groups in user_groups_yaml.get("users", {}).items():
|
||||
explicit_groups = explicit_groups or []
|
||||
email = email.lower().strip()
|
||||
if '@' not in email:
|
||||
logger.error(f'[CONFIG] Invalid user email {email}, skipping.')
|
||||
continue
|
||||
|
||||
logger.info(f"{display_email_pii(email)} => {explicit_groups}")
|
||||
|
||||
# upsert active user
|
||||
db_user = upsert_user(db, email, active=True)
|
||||
|
||||
# connect users to groups
|
||||
for group_id in explicit_groups:
|
||||
if group_id not in db_groups:
|
||||
logger.error(f"[CONFIG] Group {group_id} does not exist in config file, skipping for email={display_email_pii(email)}.")
|
||||
continue
|
||||
db_groups[group_id].users.append(db_user)
|
||||
|
||||
db.commit()
|
||||
count_user_groups = db.query(models.association_table_user_groups).count()
|
||||
count_groups = db.query(func.count(models.Group.id)).scalar()
|
||||
|
||||
logger.success(f"[CONFIG] DONE: [users={count_users(db)}, groups={count_groups}, explicit user groups={count_user_groups}].")
|
||||
@@ -1,36 +0,0 @@
|
||||
from functools import lru_cache
|
||||
from sqlalchemy import Engine, create_engine, event
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from shared.settings import get_settings
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
@lru_cache
|
||||
def make_engine(database_url: str):
|
||||
engine = create_engine(database_url, connect_args={"check_same_thread": False})
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def set_sqlite_pragma(conn, _) -> None:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.close()
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
def make_session_local(engine: Engine):
|
||||
session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
return session_local
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db():
|
||||
session = make_session_local(make_engine(get_settings().DATABASE_PATH))()
|
||||
try: yield session
|
||||
finally: session.close()
|
||||
|
||||
|
||||
def get_db_dependency():
|
||||
# to use with Depends and ensure proper session closing
|
||||
with get_db() as db:
|
||||
yield db
|
||||
@@ -1,66 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class Tag(BaseModel):
|
||||
id: str
|
||||
created_at: datetime
|
||||
|
||||
model_config = { "from_attributes": True }
|
||||
__hash__ = object.__hash__
|
||||
|
||||
class ArchiveCreate(BaseModel):
|
||||
id: str | None = None
|
||||
url: str
|
||||
result: dict | None = None
|
||||
public: bool = True
|
||||
author_id: str | None = None
|
||||
group_id: str | None = None
|
||||
tags: set[Tag] | None = set()
|
||||
rearchive: bool = True
|
||||
# urls: list = []
|
||||
|
||||
|
||||
class Archive(ArchiveCreate):
|
||||
created_at: datetime
|
||||
updated_at: datetime | None
|
||||
deleted: bool
|
||||
|
||||
model_config = { "from_attributes": True }
|
||||
|
||||
class SubmitSheet(BaseModel):
|
||||
sheet_name: str | None = None
|
||||
sheet_id: str | None = None
|
||||
header: int = 1
|
||||
public: bool = False
|
||||
author_id: str | None = None
|
||||
group_id: str | None = None
|
||||
tags: set[str] | None = set()
|
||||
columns: dict | None = {} # TODO: implement
|
||||
|
||||
class SubmitManual(BaseModel):
|
||||
result: str # should be a Metadata.to_json()
|
||||
public: bool = False
|
||||
author_id: str | None = None
|
||||
group_id: str | None = None
|
||||
tags: set[str] | None = set()
|
||||
|
||||
# API RESPONSES BELOW
|
||||
class ArchiveResult(BaseModel):
|
||||
id: str
|
||||
url: str
|
||||
result: dict
|
||||
created_at: datetime
|
||||
|
||||
class Task(BaseModel):
|
||||
id: str
|
||||
|
||||
class TaskResult(Task):
|
||||
status: str
|
||||
result: str
|
||||
|
||||
class TaskDelete(Task):
|
||||
deleted: bool
|
||||
|
||||
class ActiveUser(BaseModel):
|
||||
active: bool
|
||||
@@ -1,5 +0,0 @@
|
||||
from endpoints.default import default_router
|
||||
from endpoints.url import url_router
|
||||
from endpoints.task import task_router
|
||||
from endpoints.interoperability import interoperability_router
|
||||
from endpoints.sheet import sheet_router
|
||||
@@ -1,45 +0,0 @@
|
||||
|
||||
from fastapi import APIRouter, Depends, Request, HTTPException
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.config import VERSION, BREAKING_CHANGES
|
||||
from core.logging import log_error
|
||||
from db import crud, schemas
|
||||
from db.database import get_db_dependency, get_db
|
||||
from web.security import get_user_auth, bearer_security
|
||||
|
||||
default_router = APIRouter()
|
||||
|
||||
|
||||
@default_router.get("/")
|
||||
async def home(request: Request):
|
||||
# TODO: maybe split into 2 routes: one non authenticated and one authenticated for the groups info only
|
||||
status = {"version": VERSION, "breakingChanges": BREAKING_CHANGES}
|
||||
try:
|
||||
email = await get_user_auth(await bearer_security(request))
|
||||
with get_db() as db:
|
||||
status["groups"] = crud.get_user_groups(db, email)
|
||||
except HTTPException: pass # not authenticated is fine
|
||||
except Exception as e: log_error(e)
|
||||
return JSONResponse(status)
|
||||
|
||||
|
||||
@default_router.get("/health")
|
||||
async def health():
|
||||
return JSONResponse({"status": "ok"})
|
||||
|
||||
|
||||
@default_router.get("/user/active", summary="Check if the user is active and can use the tool.")
|
||||
async def active(db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)) -> schemas.ActiveUser:
|
||||
return {"active": crud.is_active_user(db, email)}
|
||||
|
||||
|
||||
@default_router.get("/groups")
|
||||
def get_user_groups(db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)) -> list[str]:
|
||||
return crud.get_user_groups(db, email)
|
||||
|
||||
|
||||
@default_router.get('/favicon.ico', include_in_schema=False)
|
||||
async def favicon():
|
||||
return FileResponse("static/favicon.ico")
|
||||
@@ -1,27 +0,0 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from auto_archiver import Metadata
|
||||
from loguru import logger
|
||||
import sqlalchemy
|
||||
|
||||
from web.security import token_api_key_auth
|
||||
from db import models, schemas
|
||||
from worker.main import insert_result_into_db
|
||||
from core.logging import log_error
|
||||
|
||||
|
||||
interoperability_router = APIRouter(prefix="/interop", tags=["Interoperability endpoints."])
|
||||
|
||||
|
||||
# ----- endpoint to submit data archived elsewhere
|
||||
@interoperability_router.post("/submit-archive", status_code=201, summary="Submit a manual archive entry, for data that was archived elsewhere.")
|
||||
def submit_manual_archive(manual: schemas.SubmitManual, auth=Depends(token_api_key_auth)):
|
||||
result = Metadata.from_json(manual.result)
|
||||
logger.info(f"MANUAL SUBMIT {result.get_url()} {manual.author_id}")
|
||||
manual.tags.add("manual")
|
||||
try:
|
||||
archive_id = insert_result_into_db(result, manual.tags, manual.public, manual.group_id, manual.author_id, models.generate_uuid())
|
||||
except sqlalchemy.exc.IntegrityError as e:
|
||||
log_error(e)
|
||||
raise HTTPException(status_code=422, detail=f"Cannot insert into DB due to integrity error")
|
||||
return JSONResponse({"id": archive_id}, status_code=201)
|
||||
@@ -1,24 +0,0 @@
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from core.config import ALLOW_ANY_EMAIL
|
||||
from web.security import get_token_or_user_auth
|
||||
from db import schemas
|
||||
from worker.main import create_sheet_task
|
||||
|
||||
sheet_router = APIRouter(prefix="/sheet", tags=["Google Spreadsheet operations"])
|
||||
|
||||
|
||||
@sheet_router.post("/archive", status_code=201, summary="Submit a Google Sheet archive request, starts a sheet archiving task.", response_description="task_id for the archiving task.")
|
||||
def archive_sheet(sheet:schemas.SubmitSheet, email = Depends(get_token_or_user_auth)) -> schemas.Task:
|
||||
logger.info(f"SHEET TASK for {sheet=}")
|
||||
if email == ALLOW_ANY_EMAIL:
|
||||
email = sheet.author_id or "api-endpoint"
|
||||
sheet.author_id = email
|
||||
if not sheet.sheet_name and not sheet.sheet_id:
|
||||
raise HTTPException(status_code=422, detail=f"sheet name or id is required")
|
||||
task = create_sheet_task.delay(sheet.model_dump_json())
|
||||
return JSONResponse({"id": task.id}, status_code=201)
|
||||
@@ -1,60 +0,0 @@
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from datetime import datetime
|
||||
|
||||
from loguru import logger
|
||||
from web.security import get_user_auth, get_token_or_user_auth
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from db import crud, schemas
|
||||
from db.database import get_db_dependency
|
||||
|
||||
from worker.main import create_archive_task
|
||||
|
||||
url_router = APIRouter(prefix="/url", tags=["Single URL operations"])
|
||||
|
||||
|
||||
@url_router.post("/archive", status_code=201, summary="Submit a single URL archive request, starts an archiving task.", response_description="task_id for the archiving task, will match the archive id.")
|
||||
def archive_url(archive: schemas.ArchiveCreate, email=Depends(get_token_or_user_auth)) -> schemas.Task:
|
||||
archive.author_id = email
|
||||
url = archive.url
|
||||
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {url}")
|
||||
if type(url) != str or len(url) <= 5:
|
||||
raise HTTPException(status_code=422, detail=f"Invalid URL received: {url}")
|
||||
logger.info("creating task")
|
||||
task = create_archive_task.delay(archive.model_dump_json())
|
||||
task_response = schemas.Task(id=task.id)
|
||||
return JSONResponse(task_response.model_dump(), status_code=201)
|
||||
|
||||
|
||||
@url_router.get("/search", summary="Search for archive entries by URL.")
|
||||
def search_by_url(
|
||||
url: str, skip: int = 0, limit: int = 25,
|
||||
archived_after: datetime = None, archived_before: datetime = None,
|
||||
db: Session = Depends(get_db_dependency),
|
||||
email=Depends(get_token_or_user_auth)
|
||||
) -> list[schemas.ArchiveResult]:
|
||||
return crud.search_archives_by_url(db, url.strip(), email, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before)
|
||||
|
||||
|
||||
@url_router.get("/latest", summary="Fetch latest URL archives for the authenticated user.")
|
||||
def latest(skip: int = 0, limit: int = 25, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)) -> list[schemas.ArchiveResult]:
|
||||
return crud.search_archives_by_email(db, email, skip=skip, limit=limit)
|
||||
|
||||
|
||||
@url_router.get("/{id}", summary="Fetch a single URL archive by the associated id.")
|
||||
def lookup(id, db: Session = Depends(get_db_dependency), email=Depends(get_token_or_user_auth)) -> schemas.ArchiveResult:
|
||||
archive = crud.get_archive(db, id, email)
|
||||
if archive is None:
|
||||
raise HTTPException(status_code=404, detail="Archive not found")
|
||||
return archive
|
||||
|
||||
|
||||
@url_router.delete("/{id}", summary="Delete a single URL archive by id.")
|
||||
def delete_task(id, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)) -> schemas.TaskDelete:
|
||||
logger.info(f"deleting url archive task {id} request by {email}")
|
||||
return JSONResponse({
|
||||
"id": id,
|
||||
"deleted": crud.soft_delete_task(db, id, email)
|
||||
})
|
||||
@@ -1,18 +0,0 @@
|
||||
# email-level group access
|
||||
users:
|
||||
email1@example.com:
|
||||
- group1
|
||||
- group2
|
||||
email2@example.com:
|
||||
- group2
|
||||
email3@example-no-group.com:
|
||||
|
||||
# domain-level group access (taken from the emails)
|
||||
domains:
|
||||
example.com:
|
||||
- group3
|
||||
|
||||
orchestrators:
|
||||
group1: secrets/orchestration-group1.yaml
|
||||
group2: secrets/orchestration-group2.yaml
|
||||
default: secrets/orchestration-default.yaml
|
||||
@@ -1 +0,0 @@
|
||||
Generic single-database configuration.
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user