Merge pull request #72 from bellingcat/dev

Community contributions, code standardization, and AA v1.0.0
This commit is contained in:
Miguel Sozinho Ramalho
2025-04-03 19:51:56 +01:00
committed by GitHub
88 changed files with 5245 additions and 2550 deletions

View File

@@ -1,3 +0,0 @@
[run]
omit =
app/migrations/*

23
.github/pull_request_template.md vendored Normal file
View File

@@ -0,0 +1,23 @@
<!---
Please write your PR name in the present imperative tense. Examples of that tense are:
"Fix issue in the dispatcher where…", "Improve our handling of…", etc.
For more information on Pull Requests, you can reference here:
https://success.vanillaforums.com/kb/articles/228-using-pull-requests-to-contribute
-->
## Describe your changes
## Non-obvious technical information
## Checklist before requesting a review
<!---
These are suggested things you could add, but what you add will be dependent on
your repository's standards.
--->
- [ ] The code runs successfully.
```commandline
HERE IS SOME COMMAND LINE OUTPUT
```

View File

@@ -1,16 +1,12 @@
name: CI name: CI
on: on:
push: push:
branches: branches: [ main, dev ]
- main
- dev
pull_request: pull_request:
branches: branches: [ main, dev ]
- main
- dev
jobs: jobs:
test: test-with-coverage:
runs-on: ubuntu-latest runs-on: ubuntu-latest
services: services:

16
.github/workflows/format-and-fail.yml vendored Normal file
View File

@@ -0,0 +1,16 @@
name: Format and Fail
on:
push:
branches: [ main, dev ]
pull_request:
branches: [ main, dev ]
jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: pre-commit/action@v3.0.0

40
.github/workflows/test.yml vendored Normal file
View File

@@ -0,0 +1,40 @@
name: Run Tests
on:
push:
branches: [ main, dev ]
pull_request:
branches: [ main, dev ]
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.10', '3.11', '3.12']
services:
redis:
image: redis:6-alpine
ports:
- 6379:6379
steps:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Checkout code
uses: actions/checkout@v4
- name: Install Poetry
run: pipx install poetry
- name: Install dependencies
run: poetry install --no-interaction --with dev
- name: Set dev environment variable
run: echo "ENVIRONMENT_FILE=.env.test" >> $GITHUB_ENV
- name: Run tests
run: poetry run pytest app/tests

139
.gitignore vendored
View File

@@ -1,16 +1,10 @@
# Misc.
user-groups.dev.yaml user-groups.dev.yaml
user-groups.yaml user-groups.yaml
orchestration.yaml orchestration.yaml
my-archives my-archives
*.pyc *.pyc
.DS_Store
secrets/* secrets/*
*.log
__pycache__
.pytest_cache
.env
.env.dev
.env.prod
*.db *.db
redis/data/* redis/data/*
.ipynb_checkpoints* .ipynb_checkpoints*
@@ -18,8 +12,6 @@ app/user-groups.yaml
app/user-groups.dev.yaml app/user-groups.dev.yaml
wit* wit*
app/crawls app/crawls
.coverage
.pytest_cache/
htmlcov htmlcov
local_archive local_archive
local_archive_test local_archive_test
@@ -27,6 +19,133 @@ local_archive_test
*db-shm *db-shm
copy-files.sh copy-files.sh
temp/ temp/
.python-version
orchestration2.yaml orchestration2.yaml
database database
# IDE files
.idea
.vscode
**/.DS_Store
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env*
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/

78
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,78 @@
repos:
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.9.1
hooks:
- id: nbqa-ruff
args:
- --fix
- --target-version=py310
- --ignore=E721,E722
- --line-length=80
- id: nbqa-black
args:
- --line-length=80
- id: nbqa-isort
args:
- --float-to-top
- --profile=black
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: check-docstring-first
- id: check-executables-have-shebangs
- id: check-json
- id: check-case-conflict
- id: check-toml
- id: check-merge-conflict
- id: check-xml
- id: check-yaml
exclude: app/tests/user-groups.test.broken.yaml
- id: end-of-file-fixer
- id: check-symlinks
- id: mixed-line-ending
- id: sort-simple-yaml
- id: fix-encoding-pragma
args:
- --remove
- id: pretty-format-json
args:
- --autofix
- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.10.0
hooks:
- id: python-check-blanket-noqa
- id: python-check-mock-methods
- id: python-no-eval
- id: python-no-log-warn
- repo: https://github.com/PyCQA/isort
rev: 6.0.1
hooks:
- id: isort
name: Run isort to sort imports
files: \.py$
# To keep consistent with the global isort skip config defined in setup.cfg
exclude: ^build/.*$|^.tox/.*$|^venv/.*$
args:
- --lines-after-imports=2
- --profile=black
- --line-length=80
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.7
hooks:
- id: ruff
types_or: [python,pyi]
args:
- --fix
- --select=B,C,E,F,W,B9
- --line-length=80
- --ignore=E203,E402,E501,E261
- id: ruff-format
types_or: [ python,pyi]
args:
- --target-version=py310
- --line-length=80

1
CODEOWNERS Normal file
View File

@@ -0,0 +1 @@
* @msramalho

View File

@@ -1,19 +1,33 @@
.PHONY: lint
lint:
poetry run pre-commit run --all-files
.PHONY: test
test:
export ENVIRONMENT_FILE=.env.test
poetry run coverage run -m pytest -v --disable-warnings --color=yes app/tests/
poetry run coverage report
.PHONY: clean-dev
clean-dev: clean-dev:
@echo -n "Are you sure? [yes/N] (this will delete volumes) " && read ans && [ $${ans:-N} = yes ] @echo -n "Are you sure? [yes/N] (this will delete volumes) " && read ans && [ $${ans:-N} = yes ]
docker compose -f docker-compose.yml -f docker-compose.dev.yml down --volumes --remove-orphans docker compose -f docker-compose.yml -f docker-compose.dev.yml down --volumes --remove-orphans
.PHONY: dev
dev: dev:
docker compose --env-file .env.dev -f docker-compose.yml -f docker-compose.dev.yml build docker compose --env-file .env.dev -f docker-compose.yml -f docker-compose.dev.yml build
docker compose --env-file .env.dev -f docker-compose.yml -f docker-compose.dev.yml up --remove-orphans docker compose --env-file .env.dev -f docker-compose.yml -f docker-compose.dev.yml up --remove-orphans
.PHONY: dev-redis-only
dev-redis-only: dev-redis-only:
docker compose --env-file .env.dev -f docker-compose.yml -f docker-compose.dev.yml build redis docker compose --env-file .env.dev -f docker-compose.yml -f docker-compose.dev.yml build redis
docker compose --env-file .env.dev -f docker-compose.yml -f docker-compose.dev.yml up --remove-orphans redis docker compose --env-file .env.dev -f docker-compose.yml -f docker-compose.dev.yml up --remove-orphans redis
.PHONY: stop-dev
stop-dev: stop-dev:
docker compose -f docker-compose.yml -f docker-compose.dev.yml down --volumes docker compose -f docker-compose.yml -f docker-compose.dev.yml down --volumes
.PHONY: prod
prod: prod:
docker compose --env-file .env.prod build docker compose --env-file .env.prod build
docker compose --env-file .env.prod up -d --remove-orphans docker compose --env-file .env.prod up -d --remove-orphans
@@ -21,5 +35,6 @@ prod:
docker image prune -f docker image prune -f
docker system df docker system df
.PHONY: stop-prod
stop-prod: stop-prod:
docker compose down docker compose down

View File

@@ -12,7 +12,7 @@ To properly set up the API you need to install `docker` and to have these files,
2. a `user-groups.yaml` to manage user permissions 2. a `user-groups.yaml` to manage user permissions
1. note that all local files referenced in `user-groups.yaml` and any orchestration.yaml files should be relative to the home directory so if your service account is in `secrets/orchestration.yaml` use that path and not just `orchestration.yaml`. 1. note that all local files referenced in `user-groups.yaml` and any orchestration.yaml files should be relative to the home directory so if your service account is in `secrets/orchestration.yaml` use that path and not just `orchestration.yaml`.
2. go through the example file and configure it according to your needs. 2. go through the example file and configure it according to your needs.
3. you will need to create and reference at least one `secrets/orchestration.yaml` file, you can do so by following the instructions in the [auto-archiver](https://github.com/bellingcat/auto-archiver#installation) that automatically generates one for you. If you use the archive sheets feature you will need to create a `orchestrationsheets-sheets.yaml` file as well that should have the `gsheet_feeder` and `gsheet_db` enabled and configured, the auto-archiver has [extensive documentation](https://auto-archiver.readthedocs.io/en/latest/) on how to set this up. 3. you will need to create and reference at least one `secrets/orchestration.yaml` file, you can do so by following the instructions in the [auto-archiver](https://github.com/bellingcat/auto-archiver#installation) that automatically generates one for you. If you use the archive sheets feature you will need to create a `orchestrationsheets-sheets.yaml` file as well that should have the `gsheet_feeder_db` feeder and database enabled and configured, the auto-archiver has [extensive documentation](https://auto-archiver.readthedocs.io/en/latest/) on how to set this up.
Do not commit those files, they are .gitignored by default. Do not commit those files, they are .gitignored by default.
We also advise you to keep any sensitive files in the `secrets/` folder which is pinned and gitignored. We also advise you to keep any sensitive files in the `secrets/` folder which is pinned and gitignored.
@@ -108,6 +108,27 @@ Make sure environment and user-groups files are up to date.
Then `make prod`. Then `make prod`.
## Development
```bash
# make sure all development dependencies are installed
poetry install --with dev
# this project uses pre-commit to enforce code style and formatting, set that up locally
poetry run pre-commit install
# you can test pre-commit with
poetry run pre-commit run --all-files
# this means pre-commit will always run with git commit, to skip it use
git commit --no-verify
# see the Makefile for more commands, but linting and formatting can be done with
make lint
# run all tests
make test
```
### Testing ### Testing
```bash ```bash
# set the testing environment variables # set the testing environment variables

View File

@@ -2,7 +2,7 @@
[alembic] [alembic]
# path to migration scripts # path to migration scripts
script_location = app/migrations script_location = ./app/migrations
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s # template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time # Uncomment the line below if you want the files to be prepended with date and time
@@ -10,10 +10,6 @@ script_location = app/migrations
# for all available tokens # for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s # file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file # timezone to use when rendering the date within the migration file
# as well as the filename. # as well as the filename.
# If specified, requires the python-dateutil library that can be # If specified, requires the python-dateutil library that can be

View File

@@ -1,19 +1,20 @@
from logging.config import fileConfig from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context from alembic import context
from sqlalchemy import engine_from_config, pool
from app.shared.settings import get_settings from app.shared.settings import get_settings
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides
# access to the values within the .ini file in use. # access to the values within the .ini file in use.
config = context.config config = context.config
config.set_main_option('sqlalchemy.url', get_settings().DATABASE_PATH) config.set_main_option("sqlalchemy.url", get_settings().DATABASE_PATH)
# Interpret the config file for Python logging. # Interpret the config file for Python logging.
# This line sets up loggers basically. # This line sets up loggers basically.
if config.config_file_name is not None: if config.config_file_name is not None:
fileConfig(config.config_file_name, disable_existing_loggers=False) # disable_existing_loggers prevents loguru disabling # disable_existing_loggers prevents loguru disabling
fileConfig(config.config_file_name, disable_existing_loggers=False)
# add your model's MetaData object here # add your model's MetaData object here
# for 'autogenerate' support # for 'autogenerate' support

View File

@@ -5,13 +5,14 @@ Revises: 1636724ec4b1
Create Date: 2025-02-08 15:22:20.392522 Create Date: 2025-02-08 15:22:20.392522
""" """
from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '02b2f6d17ed0' revision = "02b2f6d17ed0"
down_revision = '1636724ec4b1' down_revision = "1636724ec4b1"
branch_labels = None branch_labels = None
depends_on = None depends_on = None
STORE_UNTIL_COL = "store_until" STORE_UNTIL_COL = "store_until"
@@ -20,15 +21,20 @@ STORE_UNTIL_COL = "store_until"
def upgrade() -> None: def upgrade() -> None:
conn = op.get_bind() conn = op.get_bind()
inspector = sa.inspect(conn) inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns('archives')] columns = [col["name"] for col in inspector.get_columns("archives")]
if STORE_UNTIL_COL not in columns: if STORE_UNTIL_COL not in columns:
op.add_column('archives', sa.Column(STORE_UNTIL_COL, sa.DateTime(), nullable=True, default=None)) op.add_column(
"archives",
sa.Column(
STORE_UNTIL_COL, sa.DateTime(), nullable=True, default=None
),
)
def downgrade() -> None: def downgrade() -> None:
conn = op.get_bind() conn = op.get_bind()
inspector = sa.inspect(conn) inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns('archives')] columns = [col["name"] for col in inspector.get_columns("archives")]
if STORE_UNTIL_COL in columns: if STORE_UNTIL_COL in columns:
op.drop_column('archives', STORE_UNTIL_COL) op.drop_column("archives", STORE_UNTIL_COL)

View File

@@ -5,13 +5,14 @@ Revises: a23aaf3ae930
Create Date: 2025-02-05 19:19:01.984396 Create Date: 2025-02-05 19:19:01.984396
""" """
from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '1636724ec4b1' revision = "1636724ec4b1"
down_revision = 'a23aaf3ae930' down_revision = "a23aaf3ae930"
branch_labels = None branch_labels = None
depends_on = None depends_on = None
@@ -19,14 +20,18 @@ depends_on = None
def upgrade() -> None: def upgrade() -> None:
conn = op.get_bind() conn = op.get_bind()
inspector = sa.inspect(conn) inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns('sheets')] columns = [col["name"] for col in inspector.get_columns("sheets")]
if 'last_archived_at' in columns: if "last_archived_at" in columns:
op.alter_column('sheets', 'last_archived_at', new_column_name='last_url_archived_at') op.alter_column(
"sheets", "last_archived_at", new_column_name="last_url_archived_at"
)
def downgrade() -> None: def downgrade() -> None:
conn = op.get_bind() conn = op.get_bind()
inspector = sa.inspect(conn) inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns('sheets')] columns = [col["name"] for col in inspector.get_columns("sheets")]
if 'last_url_archived_at' in columns: if "last_url_archived_at" in columns:
op.alter_column('sheets', 'last_url_archived_at', new_column_name='last_archived_at') op.alter_column(
"sheets", "last_url_archived_at", new_column_name="last_archived_at"
)

View File

@@ -5,13 +5,14 @@ Revises: 02b2f6d17ed0
Create Date: 2025-02-11 21:53:23.293274 Create Date: 2025-02-11 21:53:23.293274
""" """
from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '63ac79df4ad0' revision = "63ac79df4ad0"
down_revision = '02b2f6d17ed0' down_revision = "02b2f6d17ed0"
branch_labels = None branch_labels = None
depends_on = None depends_on = None
@@ -22,15 +23,17 @@ TABLE = "groups"
def upgrade() -> None: def upgrade() -> None:
conn = op.get_bind() conn = op.get_bind()
inspector = sa.inspect(conn) inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns(TABLE)] columns = [col["name"] for col in inspector.get_columns(TABLE)]
if NEW_COL not in columns: if NEW_COL not in columns:
op.add_column(TABLE, sa.Column(NEW_COL, sa.String, nullable=True, default=None)) op.add_column(
TABLE, sa.Column(NEW_COL, sa.String, nullable=True, default=None)
)
def downgrade() -> None: def downgrade() -> None:
conn = op.get_bind() conn = op.get_bind()
inspector = sa.inspect(conn) inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns(TABLE)] columns = [col["name"] for col in inspector.get_columns(TABLE)]
if NEW_COL in columns: if NEW_COL in columns:
op.drop_column(TABLE, NEW_COL) op.drop_column(TABLE, NEW_COL)

View File

@@ -5,14 +5,14 @@ Revises: fa012ec405b8
Create Date: 2024-11-04 11:12:30.237299 Create Date: 2024-11-04 11:12:30.237299
""" """
from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.engine.reflection import Inspector from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '89121d2c96d8' revision = "89121d2c96d8"
down_revision = 'fa012ec405b8' down_revision = "fa012ec405b8"
branch_labels = None branch_labels = None
depends_on = None depends_on = None
@@ -20,23 +20,27 @@ depends_on = None
def upgrade() -> None: def upgrade() -> None:
conn = op.get_bind() conn = op.get_bind()
inspector = sa.inspect(conn) inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns('archives')] columns = [col["name"] for col in inspector.get_columns("archives")]
if 'sheet_id' not in columns: if "sheet_id" not in columns:
with op.batch_alter_table('archives') as batch_op: with op.batch_alter_table("archives") as batch_op:
batch_op.add_column(sa.Column('sheet_id', sa.String(), nullable=True, default=None)) batch_op.add_column(
batch_op.create_foreign_key('fk_sheet_id', 'sheets', ['sheet_id'], ['id']) sa.Column("sheet_id", sa.String(), nullable=True, default=None)
)
batch_op.create_foreign_key(
"fk_sheet_id", "sheets", ["sheet_id"], ["id"]
)
def downgrade() -> None: def downgrade() -> None:
conn = op.get_bind() conn = op.get_bind()
inspector = sa.inspect(conn) inspector = sa.inspect(conn)
foreign_keys = [fk['name'] for fk in inspector.get_foreign_keys('archives')] foreign_keys = [fk["name"] for fk in inspector.get_foreign_keys("archives")]
columns = [col['name'] for col in inspector.get_columns('archives')] columns = [col["name"] for col in inspector.get_columns("archives")]
with op.batch_alter_table('archives') as batch_op: with op.batch_alter_table("archives") as batch_op:
if 'fk_sheet_id' in foreign_keys: if "fk_sheet_id" in foreign_keys:
batch_op.drop_constraint('fk_sheet_id', type_='foreignkey') batch_op.drop_constraint("fk_sheet_id", type_="foreignkey")
if 'sheet_id' in columns: if "sheet_id" in columns:
batch_op.drop_column('sheet_id') batch_op.drop_column("sheet_id")

View File

@@ -5,10 +5,12 @@ Revises:
Create Date: 2023-12-20 17:24:59.320691 Create Date: 2023-12-20 17:24:59.320691
""" """
from alembic import op from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '9369a264945b' revision = "9369a264945b"
down_revision = None down_revision = None
branch_labels = None branch_labels = None
depends_on = None depends_on = None
@@ -18,10 +20,11 @@ def upgrade() -> None:
# since the primary key constraint is not named, we have to recreate it first # since the primary key constraint is not named, we have to recreate it first
with op.batch_alter_table("archive_urls") as batch_op: with op.batch_alter_table("archive_urls") as batch_op:
batch_op.create_primary_key("pk_url", ["url"]) batch_op.create_primary_key("pk_url", ["url"])
batch_op.drop_constraint("pk_url", type_='primary') batch_op.drop_constraint("pk_url", type_="primary")
batch_op.create_primary_key("pk_url_archive_id", ["url", "archive_id"]) batch_op.create_primary_key("pk_url_archive_id", ["url", "archive_id"])
def downgrade() -> None: def downgrade() -> None:
with op.batch_alter_table("archive_urls") as batch_op: with op.batch_alter_table("archive_urls") as batch_op:
batch_op.drop_constraint("pk_url_archive_id", type_='primary') batch_op.drop_constraint("pk_url_archive_id", type_="primary")
batch_op.create_primary_key("url", ["url"]) batch_op.create_primary_key("url", ["url"])

View File

@@ -5,12 +5,13 @@ Revises: 9369a264945b
Create Date: 2023-12-20 18:33:27.132566 Create Date: 2023-12-20 18:33:27.132566
""" """
from alembic import op from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '93a611e4c066' revision = "93a611e4c066"
down_revision = '9369a264945b' down_revision = "9369a264945b"
branch_labels = None branch_labels = None
depends_on = None depends_on = None
@@ -20,7 +21,9 @@ def upgrade() -> None:
with op.get_context().autocommit_block(): with op.get_context().autocommit_block():
op.execute("VACUUM") op.execute("VACUUM")
except Exception as e: except Exception as e:
print("Unable to run vacuum, maybe there's not enough disk space. it should be 2x the size of the database") print(
"Unable to run vacuum, maybe there's not enough disk space. it should be 2x the size of the database"
)
print(e) print(e)

View File

@@ -5,13 +5,14 @@ Revises: 89121d2c96d8
Create Date: 2025-02-04 12:19:20.753570 Create Date: 2025-02-04 12:19:20.753570
""" """
from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = 'a23aaf3ae930' revision = "a23aaf3ae930"
down_revision = '89121d2c96d8' down_revision = "89121d2c96d8"
branch_labels = None branch_labels = None
depends_on = None depends_on = None
@@ -19,16 +20,24 @@ depends_on = None
def upgrade() -> None: def upgrade() -> None:
conn = op.get_bind() conn = op.get_bind()
inspector = sa.inspect(conn) inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns('users')] columns = [col["name"] for col in inspector.get_columns("users")]
if 'is_active' in columns: if "is_active" in columns:
op.drop_column('users', 'is_active') op.drop_column("users", "is_active")
def downgrade() -> None: def downgrade() -> None:
conn = op.get_bind() conn = op.get_bind()
inspector = sa.inspect(conn) inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns('users')] columns = [col["name"] for col in inspector.get_columns("users")]
if 'is_active' not in columns: if "is_active" not in columns:
op.add_column('users', sa.Column('is_active', sa.Boolean(), nullable=False, server_default=sa.false())) op.add_column(
"users",
sa.Column(
"is_active",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
)

View File

@@ -5,14 +5,14 @@ Revises: 93a611e4c066
Create Date: 2024-10-31 09:36:50.360710 Create Date: 2024-10-31 09:36:50.360710
""" """
from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.engine.reflection import Inspector from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = 'fa012ec405b8' revision = "fa012ec405b8"
down_revision = '93a611e4c066' down_revision = "93a611e4c066"
branch_labels = None branch_labels = None
depends_on = None depends_on = None
@@ -20,26 +20,41 @@ depends_on = None
def upgrade() -> None: def upgrade() -> None:
conn = op.get_bind() conn = op.get_bind()
inspector = sa.inspect(conn) inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns('groups')] columns = [col["name"] for col in inspector.get_columns("groups")]
if 'description' not in columns: if "description" not in columns:
op.add_column('groups', sa.Column('description', sa.String(), nullable=True)) op.add_column(
if 'orchestrator' not in columns: "groups", sa.Column("description", sa.String(), nullable=True)
op.add_column('groups', sa.Column('orchestrator', sa.String(), nullable=True)) )
if 'orchestrator_sheet' not in columns: if "orchestrator" not in columns:
op.add_column('groups', sa.Column('orchestrator_sheet', sa.String(), nullable=True)) op.add_column(
if 'permissions' not in columns: "groups", sa.Column("orchestrator", sa.String(), nullable=True)
op.add_column('groups', sa.Column('permissions', sa.JSON(), nullable=True)) )
if 'domains' not in columns: if "orchestrator_sheet" not in columns:
op.add_column('groups', sa.Column('domains', sa.JSON(), nullable=True)) op.add_column(
"groups",
sa.Column("orchestrator_sheet", sa.String(), nullable=True),
)
if "permissions" not in columns:
op.add_column(
"groups", sa.Column("permissions", sa.JSON(), nullable=True)
)
if "domains" not in columns:
op.add_column("groups", sa.Column("domains", sa.JSON(), nullable=True))
def downgrade() -> None: def downgrade() -> None:
conn = op.get_bind() conn = op.get_bind()
inspector = sa.inspect(conn) inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns('groups')] columns = [col["name"] for col in inspector.get_columns("groups")]
column_names = ['description', 'orchestrator', 'orchestrator_sheet', 'permissions', 'domains'] column_names = [
"description",
"orchestrator",
"orchestrator_sheet",
"permissions",
"domains",
]
for column_name in column_names: for column_name in column_names:
if column_name in columns: if column_name in columns:
op.drop_column('groups', column_name) op.drop_column("groups", column_name)

View File

@@ -1,32 +0,0 @@
# TODO: code in this file should eventually be moved to the auto-archiver code base
from typing import List
from loguru import logger
from auto_archiver.core import Media, Metadata
from app.shared.db import models
def get_all_urls(result: Metadata) -> List[models.ArchiveUrl]:
db_urls = []
for m in result.media:
for i, url in enumerate(m.urls): db_urls.append(models.ArchiveUrl(url=url, key=m.get("id", f"media_{i}")))
for k, prop in m.properties.items():
if prop_converted := convert_if_media(prop):
for i, url in enumerate(prop_converted.urls): db_urls.append(models.ArchiveUrl(url=url, key=prop_converted.get("id", f"{k}_{i}")))
if isinstance(prop, list):
for i, prop_media in enumerate(prop):
if prop_media := convert_if_media(prop_media):
for j, url in enumerate(prop_media.urls):
db_urls.append(models.ArchiveUrl(url=url, key=prop_media.get("id", f"{k}{prop_media.key}_{i}.{j}")))
return db_urls
def convert_if_media(media):
if isinstance(media, Media): return media
elif isinstance(media, dict):
try: return Media.from_dict(media)
except Exception as e:
logger.debug(f"error parsing {media} : {e}")
return False

View File

@@ -1,25 +1,35 @@
# TODO: temporary file for this code, maybe other code belongs here, maybe not. do decide
import datetime import datetime
from typing import Union
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.shared.db import worker_crud from app.shared.db import worker_crud
def get_store_archive_until(db: Session, group_id: str) -> datetime.datetime: # TODO: temporary file for this code, maybe other code belongs here, maybe not. do
# decide
def get_store_archive_until(
db: Session, group_id: str
) -> Union[datetime.datetime, None]:
group = worker_crud.get_group(db, group_id) group = worker_crud.get_group(db, group_id)
assert group, f"Group {group_id} not found." assert group, f"Group {group_id} not found."
assert group.permissions and type(group.permissions) == dict, f"Group {group_id} has no permissions." assert group.permissions and isinstance(group.permissions, dict), (
f"Group {group_id} has no permissions."
)
max_lifespan = group.permissions.get("max_archive_lifespan_months", -1) max_lifespan = group.permissions.get("max_archive_lifespan_months", -1)
if max_lifespan == -1: return None if max_lifespan == -1:
return None
return datetime.datetime.now() + datetime.timedelta(days=30 * max_lifespan) return datetime.datetime.now() + datetime.timedelta(days=30 * max_lifespan)
def get_store_archive_until_or_never(db: Session, group_id: str) -> datetime.datetime: def get_store_archive_until_or_never(
db: Session, group_id: str
) -> Union[datetime.datetime, None]:
try: try:
return get_store_archive_until(db, group_id) return get_store_archive_until(db, group_id)
except AssertionError as e: except AssertionError:
return None return None

7
app/shared/constants.py Normal file
View File

@@ -0,0 +1,7 @@
# Statuses
STATUS_FAILURE = "FAILURE"
STATUS_PENDING = "PENDING"
STATUS_SUCCESS = "SUCCESS"
# AA CLI CONFIGS
SHEET_ID = "--gsheet_feeder_db.sheet_id"

View File

@@ -1,8 +1,14 @@
from functools import lru_cache
from sqlalchemy import Engine, create_engine, event, text
from sqlalchemy.orm import sessionmaker
from contextlib import asynccontextmanager, contextmanager from contextlib import asynccontextmanager, contextmanager
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, AsyncEngine, async_sessionmaker from functools import lru_cache
from sqlalchemy import Engine, create_engine, event, text
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.orm import sessionmaker
from app.shared.settings import get_settings from app.shared.settings import get_settings
@@ -12,9 +18,9 @@ def make_engine(database_url: str):
engine = create_engine( engine = create_engine(
database_url, database_url,
connect_args={"check_same_thread": False}, connect_args={"check_same_thread": False},
pool_size=15, # Increase pool size pool_size=15, # Increase pool size
max_overflow=20, # Allow more temporary connections max_overflow=20, # Allow more temporary connections
pool_recycle=1800 # Recycle connections every 30 minutes pool_recycle=1800, # Recycle connections every 30 minutes
) )
@event.listens_for(engine, "connect") @event.listens_for(engine, "connect")
@@ -34,8 +40,10 @@ def make_session_local(engine: Engine):
@contextmanager @contextmanager
def get_db(): def get_db():
session = make_session_local(make_engine(get_settings().DATABASE_PATH))() session = make_session_local(make_engine(get_settings().DATABASE_PATH))()
try: yield session try:
finally: session.close() yield session
finally:
session.close()
def get_db_dependency(): def get_db_dependency():
@@ -53,22 +61,32 @@ def wal_checkpoint():
# ASYNC connections # ASYNC connections
async def make_async_engine(database_url: str) -> AsyncEngine: async def make_async_engine(database_url: str) -> AsyncEngine:
engine = create_async_engine(database_url, connect_args={"check_same_thread": False}) engine = create_async_engine(
database_url, connect_args={"check_same_thread": False}
)
async with engine.begin() as conn: async with engine.begin() as conn:
await conn.run_sync(lambda sync_conn: sync_conn.execute(text("PRAGMA journal_mode=WAL;"))) await conn.run_sync(
lambda sync_conn: sync_conn.execute(
text("PRAGMA journal_mode=WAL;")
)
)
return engine return engine
async def make_async_session_local(engine: AsyncEngine) -> AsyncSession: async def make_async_session_local(engine: AsyncEngine) -> AsyncSession:
return async_sessionmaker(engine, expire_on_commit=False, autoflush=False, autocommit=False) return async_sessionmaker(
engine, expire_on_commit=False, autoflush=False, autocommit=False
)
@asynccontextmanager @asynccontextmanager
async def get_db_async(): async def get_db_async():
engine = await make_async_engine(get_settings().ASYNC_DATABASE_PATH) engine = await make_async_engine(get_settings().async_database_path)
async_session = await make_async_session_local(engine) async_session = await make_async_session_local(engine)
async with async_session() as session: async with async_session() as session:
try: yield session try:
finally: await engine.dispose() yield session
finally:
await engine.dispose()

View File

@@ -1,8 +1,17 @@
from sqlalchemy import Column, String, JSON, DateTime, Boolean, Table, ForeignKey
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship, declarative_base
import uuid import uuid
from sqlalchemy import (
JSON,
Boolean,
Column,
DateTime,
ForeignKey,
String,
Table,
)
from sqlalchemy.orm import declarative_base, relationship
from sqlalchemy.sql import func
Base = declarative_base() Base = declarative_base()
@@ -11,7 +20,7 @@ def generate_uuid():
return str(uuid.uuid4()) return str(uuid.uuid4())
# many to many association tables # many-to-many association tables
association_table_archive_tags = Table( association_table_archive_tags = Table(
"mtm_archives_tags", "mtm_archives_tags",
Base.metadata, Base.metadata,
@@ -33,7 +42,9 @@ class Archive(Base):
id = Column(String, primary_key=True, index=True) id = Column(String, primary_key=True, index=True)
url = Column(String, index=True) url = Column(String, index=True)
result = Column(JSON, default=None) result = Column(JSON, default=None)
public = Column(Boolean, default=True) # if public=false, access by group and author public = Column(
Boolean, default=True
) # if public=false, access by group and author
deleted = Column(Boolean, default=False) deleted = Column(Boolean, default=False)
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now())
@@ -43,7 +54,11 @@ class Archive(Base):
author_id = Column(String, ForeignKey("users.email")) author_id = Column(String, ForeignKey("users.email"))
sheet_id = Column(String, ForeignKey("sheets.id"), default=None) sheet_id = Column(String, ForeignKey("sheets.id"), default=None)
tags = relationship("Tag", back_populates="archives", secondary=association_table_archive_tags) tags = relationship(
"Tag",
back_populates="archives",
secondary=association_table_archive_tags,
)
group = relationship("Group", back_populates="archives") group = relationship("Group", back_populates="archives")
author = relationship("User", back_populates="archives") author = relationship("User", back_populates="archives")
urls = relationship("ArchiveUrl", back_populates="archive") urls = relationship("ArchiveUrl", back_populates="archive")
@@ -66,7 +81,11 @@ class Tag(Base):
id = Column(String, primary_key=True, index=True) id = Column(String, primary_key=True, index=True)
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
archives = relationship("Archive", back_populates="tags", secondary=association_table_archive_tags) archives = relationship(
"Archive",
back_populates="tags",
secondary=association_table_archive_tags,
)
class User(Base): class User(Base):
@@ -76,7 +95,9 @@ class User(Base):
archives = relationship("Archive", back_populates="author") archives = relationship("Archive", back_populates="author")
sheets = relationship("Sheet", back_populates="author") sheets = relationship("Sheet", back_populates="author")
groups = relationship("Group", back_populates="users", secondary=association_table_user_groups) groups = relationship(
"Group", back_populates="users", secondary=association_table_user_groups
)
class Group(Base): class Group(Base):
@@ -92,7 +113,9 @@ class Group(Base):
archives = relationship("Archive", back_populates="group") archives = relationship("Archive", back_populates="group")
sheets = relationship("Sheet", back_populates="group") sheets = relationship("Sheet", back_populates="group")
users = relationship("User", back_populates="groups", secondary=association_table_user_groups) users = relationship(
"User", back_populates="groups", secondary=association_table_user_groups
)
class Sheet(Base): class Sheet(Base):
@@ -101,11 +124,27 @@ class Sheet(Base):
id = Column(String, primary_key=True, index=True, doc="Google Sheet ID") id = Column(String, primary_key=True, index=True, doc="Google Sheet ID")
name = Column(String, default=None) name = Column(String, default=None)
author_id = Column(String, ForeignKey("users.email")) author_id = Column(String, ForeignKey("users.email"))
group_id = Column(String, ForeignKey("groups.id"), doc="Group ID, user must be in a group to create a sheet.") group_id = Column(
frequency = Column(String, default="daily", doc="Frequency of archiving: hourly, daily, weekly.") String,
ForeignKey("groups.id"),
doc="Group ID, user must be in a group to create a sheet.",
)
frequency = Column(
String,
default="daily",
doc="Frequency of archiving: hourly, daily, weekly.",
)
# TODO: stats is not being used, consider removing # TODO: stats is not being used, consider removing
stats = Column(JSON, default={}, doc="Sheet statistics like total links, total rows, ...") stats = Column(
last_url_archived_at = Column(DateTime(timezone=True), server_default=func.now(), doc="Last time a new link was archived.") JSON,
default={},
doc="Sheet statistics like total links, total rows, ...",
)
last_url_archived_at = Column(
DateTime(timezone=True),
server_default=func.now(),
doc="Last time a new link was archived.",
)
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now())

View File

@@ -1,13 +1,17 @@
from sqlalchemy.orm import Session
from datetime import datetime from datetime import datetime
from app.shared.db import models from sqlalchemy.orm import Session
from app.shared import schemas from app.shared import schemas
from app.shared.db import models
# TODO: isolate database operations away from worker and into WEB # TODO: isolate database operations away from worker and into WEB
# ONLY WORKER # ONLY WORKER
def update_sheet_last_url_archived_at(db: Session, sheet_id: str): def update_sheet_last_url_archived_at(db: Session, sheet_id: str):
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id).first() db_sheet = (
db.query(models.Sheet).filter(models.Sheet.id == sheet_id).first()
)
if db_sheet: if db_sheet:
db_sheet.last_url_archived_at = datetime.now() db_sheet.last_url_archived_at = datetime.now()
db.commit() db.commit()
@@ -17,12 +21,17 @@ def update_sheet_last_url_archived_at(db: Session, sheet_id: str):
# ONLY WORKER and INTEROP # ONLY WORKER and INTEROP
def get_group(db: Session, group_name: str) -> models.Group: def get_group(db: Session, group_name: str) -> models.Group:
return db.query(models.Group).filter(models.Group.id == group_name).first() return db.query(models.Group).filter(models.Group.id == group_name).first()
def create_or_get_user(db: Session, author_id: str) -> models.User: def create_or_get_user(db: Session, author_id: str) -> models.User:
if type(author_id) == str: author_id = author_id.lower() if isinstance(author_id, str):
db_user = db.query(models.User).filter(models.User.email == author_id).first() author_id = author_id.lower()
db_user = (
db.query(models.User).filter(models.User.email == author_id).first()
)
if not db_user: if not db_user:
db_user = models.User(email=author_id) db_user = models.User(email=author_id)
db.add(db_user) db.add(db_user)
@@ -41,8 +50,22 @@ def create_tag(db: Session, tag: str) -> models.Tag:
return db_tag return db_tag
def create_archive(db: Session, archive: schemas.ArchiveCreate, tags: list[models.Tag], urls: list[models.ArchiveUrl]) -> models.Archive: def create_archive(
db_archive = models.Archive(id=archive.id, url=archive.url, result=archive.result, public=archive.public, author_id=archive.author_id, group_id=archive.group_id, sheet_id=archive.sheet_id, store_until=archive.store_until) db: Session,
archive: schemas.ArchiveCreate,
tags: list[models.Tag],
urls: list[models.ArchiveUrl],
) -> models.Archive:
db_archive = models.Archive(
id=archive.id,
url=archive.url,
result=archive.result,
public=archive.public,
author_id=archive.author_id,
group_id=archive.group_id,
sheet_id=archive.sheet_id,
store_until=archive.store_until,
)
db_archive.tags = tags db_archive.tags = tags
db_archive.urls = urls db_archive.urls = urls
db.add(db_archive) db.add(db_archive)
@@ -51,10 +74,14 @@ def create_archive(db: Session, archive: schemas.ArchiveCreate, tags: list[model
return db_archive return db_archive
def store_archived_url(db: Session, archive: schemas.ArchiveCreate) -> models.Archive: def store_archived_url(
db: Session, archive: schemas.ArchiveCreate
) -> models.Archive:
# create and load user, tags, if needed # create and load user, tags, if needed
create_or_get_user(db, archive.author_id) create_or_get_user(db, archive.author_id)
db_tags = [create_tag(db, tag) for tag in (archive.tags or [])] db_tags = [create_tag(db, tag) for tag in (archive.tags or [])]
# insert everything # insert everything
db_archive = create_archive(db, archive=archive, tags=db_tags, urls=archive.urls) db_archive = create_archive(
db, archive=archive, tags=db_tags, urls=archive.urls
)
return db_archive return db_archive

View File

@@ -1,4 +1,5 @@
import traceback import traceback
from loguru import logger from loguru import logger
@@ -7,7 +8,9 @@ logger.add("logs/api_logs.log", retention="30 days")
logger.add("logs/error_logs.log", retention="30 days", level="ERROR") logger.add("logs/error_logs.log", retention="30 days", level="ERROR")
def log_error(e: Exception, traceback_str: str = None, extra:str = ""): def log_error(e: Exception, traceback_str: str = None, extra: str = ""):
if not traceback_str: traceback_str = traceback.format_exc() if not traceback_str:
if extra: extra = f"{extra}\n" traceback_str = traceback.format_exc()
if extra:
extra = f"{extra}\n"
logger.error(f"{extra}{e.__class__.__name__}: {e}\n{traceback_str}") logger.error(f"{extra}{e.__class__.__name__}: {e}\n{traceback_str}")

View File

@@ -1,7 +1,8 @@
from datetime import datetime
from typing import Annotated from typing import Annotated
from annotated_types import Len from annotated_types import Len
from pydantic import BaseModel from pydantic import BaseModel
from datetime import datetime
class SubmitSheet(BaseModel): class SubmitSheet(BaseModel):
@@ -10,6 +11,7 @@ class SubmitSheet(BaseModel):
group_id: str = "default" group_id: str = "default"
tags: set[str] | None = set() tags: set[str] | None = set()
class ArchiveUrl(BaseModel): class ArchiveUrl(BaseModel):
url: str url: str
public: bool = False public: bool = False
@@ -17,6 +19,7 @@ class ArchiveUrl(BaseModel):
group_id: str | None group_id: str | None
tags: set[str] | None = set() tags: set[str] | None = set()
class ArchiveResult(BaseModel): class ArchiveResult(BaseModel):
id: str id: str
url: str url: str

View File

@@ -1,31 +1,38 @@
from functools import lru_cache
import os import os
from functools import lru_cache
from typing import Annotated, Set
from annotated_types import Len
from fastapi_mail import ConnectionConfig from fastapi_mail import ConnectionConfig
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import Annotated, Set
from annotated_types import Len
class Settings(BaseSettings): class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_file=os.environ.get("ENVIRONMENT_FILE"),
env_file_encoding="utf-8",
extra="ignore",
str_strip_whitespace=True,
)
model_config = SettingsConfigDict(env_file=os.environ.get("ENVIRONMENT_FILE") , env_file_encoding='utf-8', extra='ignore', str_strip_whitespace=True) # general
# general
SERVE_LOCAL_ARCHIVE: str | None = None SERVE_LOCAL_ARCHIVE: str | None = None
USER_GROUPS_FILENAME: str = "app/user-groups.yaml" USER_GROUPS_FILENAME: str = "app/user-groups.yaml"
# database # database
DATABASE_PATH: str DATABASE_PATH: str
DATABASE_QUERY_LIMIT: int = 100 DATABASE_QUERY_LIMIT: int = 100
@property @property
def ASYNC_DATABASE_PATH(self) -> str: def async_database_path(self) -> str:
return self.DATABASE_PATH.replace("sqlite://", "sqlite+aiosqlite://") return self.DATABASE_PATH.replace("sqlite://", "sqlite+aiosqlite://")
# security # security
API_BEARER_TOKEN: Annotated[str, Len(min_length=20)] API_BEARER_TOKEN: Annotated[str, Len(min_length=20)]
ALLOWED_ORIGINS: Annotated[Set[str], Len(min_length=1)] ALLOWED_ORIGINS: Annotated[Set[str], Len(min_length=1)]
CHROME_APP_IDS: Annotated[Set[Annotated[str, Len(min_length=10)]], Len(min_length=1)] CHROME_APP_IDS: Annotated[
Set[Annotated[str, Len(min_length=10)]], Len(min_length=1)
]
BLOCKED_EMAILS: Annotated[Set[str], Len(min_length=0)] = set() BLOCKED_EMAILS: Annotated[Set[str], Len(min_length=0)] = set()
# if not provided only OAUTH access_tokens are allowed # if not provided only OAUTH access_tokens are allowed
FIREBASE_SERVICE_ACCOUNT_JSON: str = "" FIREBASE_SERVICE_ACCOUNT_JSON: str = ""
@@ -34,8 +41,9 @@ class Settings(BaseSettings):
REDIS_PASSWORD: str = "" REDIS_PASSWORD: str = ""
REDIS_HOSTNAME: str = "localhost" REDIS_HOSTNAME: str = "localhost"
REDIS_EXCEPTIONS_CHANNEL: str = "exceptions-channel" REDIS_EXCEPTIONS_CHANNEL: str = "exceptions-channel"
@property @property
def CELERY_BROKER_URL(self)-> str: def celery_broker_url(self) -> str:
if self.REDIS_PASSWORD: if self.REDIS_PASSWORD:
return f"redis://:{self.REDIS_PASSWORD}@{self.REDIS_HOSTNAME}:6379" return f"redis://:{self.REDIS_PASSWORD}@{self.REDIS_HOSTNAME}:6379"
return f"redis://{self.REDIS_HOSTNAME}:6379" return f"redis://{self.REDIS_HOSTNAME}:6379"
@@ -47,7 +55,7 @@ class Settings(BaseSettings):
CRON_DELETE_SCHEDULED_ARCHIVES: bool = False CRON_DELETE_SCHEDULED_ARCHIVES: bool = False
DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS: int = 7 DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS: int = 7
# observability # observability
REPEAT_COUNT_METRICS_SECONDS: int = 30 REPEAT_COUNT_METRICS_SECONDS: int = 30
# email configuration, if needed # email configuration, if needed
@@ -59,8 +67,9 @@ class Settings(BaseSettings):
MAIL_PORT: int = 587 MAIL_PORT: int = 587
MAIL_STARTTLS: bool = False MAIL_STARTTLS: bool = False
MAIL_SSL_TLS: bool = True MAIL_SSL_TLS: bool = True
@property @property
def MAIL_CONFIG(self) -> str: def mail_config(self) -> ConnectionConfig:
return ConnectionConfig( return ConnectionConfig(
MAIL_FROM=self.MAIL_FROM, MAIL_FROM=self.MAIL_FROM,
MAIL_FROM_NAME=self.MAIL_FROM_NAME, MAIL_FROM_NAME=self.MAIL_FROM_NAME,

View File

@@ -1,8 +1,8 @@
from functools import lru_cache from functools import lru_cache
from celery import Celery
import redis
from celery import Celery
import redis
from app.shared.settings import get_settings from app.shared.settings import get_settings
@@ -10,14 +10,14 @@ from app.shared.settings import get_settings
def get_celery(name: str = "") -> Celery: def get_celery(name: str = "") -> Celery:
return Celery( return Celery(
name, name,
broker_url=get_settings().CELERY_BROKER_URL, broker_url=get_settings().celery_broker_url,
result_backend=get_settings().CELERY_BROKER_URL, result_backend=get_settings().celery_broker_url,
broker_connection_retry_on_startup=False, broker_connection_retry_on_startup=False,
broker_transport_options={ broker_transport_options={
'queue_order_strategy': 'priority', "queue_order_strategy": "priority",
} },
) )
def get_redis() -> redis.Redis: def get_redis() -> redis.Redis:
return redis.Redis.from_url(get_settings().CELERY_BROKER_URL) return redis.Redis.from_url(get_settings().celery_broker_url)

View File

@@ -1,9 +1,16 @@
import json import json
import os import os
from typing import Dict, List, Set
import yaml import yaml
from loguru import logger from loguru import logger
from pydantic import BaseModel, computed_field, field_validator, Field, model_validator from pydantic import (
from typing import Dict, List, Set BaseModel,
Field,
computed_field,
field_validator,
model_validator,
)
from typing_extensions import Self from typing_extensions import Self
@@ -12,13 +19,16 @@ class UserGroups:
user_groups = self.read_yaml(filename) user_groups = self.read_yaml(filename)
self.validate_and_load(user_groups) self.validate_and_load(user_groups)
def read_yaml(self, user_groups_filename): @staticmethod
def read_yaml(user_groups_filename):
# read yaml safely # read yaml safely
with open(user_groups_filename) as inf: with open(user_groups_filename) as inf:
try: try:
return yaml.safe_load(inf) return yaml.safe_load(inf)
except yaml.YAMLError as e: except yaml.YAMLError as e:
logger.error(f"could not open user groups filename {user_groups_filename}: {e}") logger.error(
f"could not open user groups filename {user_groups_filename}: {e}"
)
raise e raise e
def validate_and_load(self, user_groups): def validate_and_load(self, user_groups):
@@ -45,22 +55,36 @@ class GroupPermissions(BaseModel):
max_monthly_mbs: int = 0 max_monthly_mbs: int = 0
priority: str = "low" priority: str = "low"
@field_validator('max_sheets', 'max_archive_lifespan_months', 'max_monthly_urls', 'max_monthly_mbs', mode='before') @classmethod
@field_validator(
"max_sheets",
"max_archive_lifespan_months",
"max_monthly_urls",
"max_monthly_mbs",
mode="before",
)
def validate_max_values(cls, v): def validate_max_values(cls, v):
if v < -1: if v < -1:
raise ValueError("max_* values should be positive integers or -1 (for no limit).") raise ValueError(
"max_* values should be positive integers or -1 (for no limit)."
)
return v return v
@field_validator('sheet_frequency', mode='before') @classmethod
@field_validator("sheet_frequency", mode="before")
def validate_sheet_frequency(cls, v): def validate_sheet_frequency(cls, v):
if not v: return [] if not v:
return []
allowed = ["daily", "hourly"] allowed = ["daily", "hourly"]
for k in v: for k in v:
if k not in allowed: if k not in allowed:
raise ValueError(f"Invalid sheet frequency: '{k}', expected one of {allowed}") raise ValueError(
f"Invalid sheet frequency: '{k}', expected one of {allowed}"
)
return v return v
@field_validator('priority', mode='before') @classmethod
@field_validator("priority", mode="before")
def validate_priority(cls, v): def validate_priority(cls, v):
v = v.lower() v = v.lower()
if v not in ["low", "high"]: if v not in ["low", "high"]:
@@ -70,19 +94,31 @@ class GroupPermissions(BaseModel):
class GroupModel(BaseModel): class GroupModel(BaseModel):
description: str description: str
orchestrator: str orchestrator: str | None = None
orchestrator_sheet: str orchestrator_sheet: str | None = None
permissions: GroupPermissions permissions: GroupPermissions
@field_validator('orchestrator', 'orchestrator_sheet', mode='before') @classmethod
@field_validator("orchestrator", mode="before")
def validate_orchestrator(cls, v): def validate_orchestrator(cls, v):
if not os.path.exists(v): # orchestrator is only needed if the group has archive_url permission
if cls.permissions.archive_url and not os.path.exists(v):
raise ValueError(f"Orchestrator file not found with this path: {v}")
return v
@classmethod
@field_validator("orchestrator_sheet", mode="before")
def validate_orchestrator_sheet(cls, v):
# orchestrator_sheet is only needed if the group has archive_sheet permission
if cls.permissions.archive_sheet and not os.path.exists(v):
raise ValueError(f"Orchestrator file not found with this path: {v}") raise ValueError(f"Orchestrator file not found with this path: {v}")
return v return v
@computed_field @computed_field
@property @property
def service_account_email(self) -> str: def service_account_email(self) -> str:
if self.orchestrator_sheet is None:
return ""
if hasattr(self, "_service_account_email"): if hasattr(self, "_service_account_email"):
return self._service_account_email return self._service_account_email
orch = yaml.safe_load(open(self.orchestrator_sheet)) orch = yaml.safe_load(open(self.orchestrator_sheet))
@@ -98,13 +134,17 @@ class GroupModel(BaseModel):
service_account_json = find_service_account_email(orch) service_account_json = find_service_account_email(orch)
if not service_account_json: if not service_account_json:
raise ValueError(f"service_account key not found in orchestrator sheet file: {self.orchestrator_sheet}.") raise ValueError(
f"service_account key not found in orchestrator sheet file: {self.orchestrator_sheet}."
)
with open(service_account_json) as f: with open(service_account_json) as f:
self._service_account_email = json.load(f).get("client_email") self._service_account_email = json.load(f).get("client_email")
if not self._service_account_email: if not self._service_account_email:
raise ValueError(f"Service account email not found in {service_account_json}.") raise ValueError(
f"Service account email not found in {service_account_json}."
)
return self._service_account_email return self._service_account_email
@@ -114,29 +154,45 @@ class UserGroupModel(BaseModel):
domains: Dict[str, List[str]] = Field(default_factory=dict) domains: Dict[str, List[str]] = Field(default_factory=dict)
groups: Dict[str, GroupModel] = Field(default_factory=dict) groups: Dict[str, GroupModel] = Field(default_factory=dict)
@field_validator('users', mode='before')
@classmethod @classmethod
@field_validator("users", mode="before")
def validate_emails(cls, v): def validate_emails(cls, v):
for email in v.keys(): for email in v.keys():
if '@' not in email: if "@" not in email:
raise ValueError(f"Invalid user, it should be an address: {email}") raise ValueError(
f"Invalid user, it should be an address: {email}"
)
if not v[email]: if not v[email]:
raise ValueError(f"User {email} has no explicitly listed groups, only include them here if they should be in a group.") raise ValueError(
f"User {email} has no explicitly listed groups, only include them here if they should be in a group."
)
# all users belong to the default group # all users belong to the default group
return {k.lower().strip(): list(set(["default"] + [g.lower().strip() for g in v])) for k, v in v.items()} return {
k.lower().strip(): list(
set(["default"] + [g.lower().strip() for g in v])
)
for k, v in v.items()
}
@field_validator('domains', mode='before')
@classmethod @classmethod
@field_validator("domains", mode="before")
def validate_domains(cls, v): def validate_domains(cls, v):
for domain, members in v.items(): for domain, members in v.items():
if '.' not in domain: if "." not in domain:
raise ValueError(f"Invalid domain, it should contain a dot: {domain}") raise ValueError(
f"Invalid domain, it should contain a dot: {domain}"
)
if not members: if not members:
raise ValueError(f"Domain {domain} should have at least one member.") raise ValueError(
return {k.lower().strip(): list(set([g.lower().strip() for g in v])) for k, v in v.items()} f"Domain {domain} should have at least one member."
)
return {
k.lower().strip(): list({[g.lower().strip() for g in v]})
for k, v in v.items()
}
@field_validator('groups', mode='before')
@classmethod @classmethod
@field_validator("groups", mode="before")
def validate_groups(cls, v): def validate_groups(cls, v):
if "default" not in v.keys(): if "default" not in v.keys():
raise ValueError("Please include a 'default' group.") raise ValueError("Please include a 'default' group.")
@@ -147,20 +203,28 @@ class UserGroupModel(BaseModel):
raise ValueError(f"Group names should be lowercase: {group}") raise ValueError(f"Group names should be lowercase: {group}")
return v return v
@model_validator(mode='after') @model_validator(mode="after")
def check_groups_consistency(self) -> Self: def check_groups_consistency(self) -> Self:
groups_in_domains = set([g for domain in self.domains for g in self.domains[domain]]) groups_in_domains = {
groups_in_users = set([g for user in self.users for g in self.users[user]]) g for domain in self.domains for g in self.domains[domain]
}
groups_in_users = {g for user in self.users for g in self.users[user]}
configured_groups = set(self.groups.keys()) configured_groups = set(self.groups.keys())
# groups mentioned in domains and users should be defined, but this is not a ValueError since historical DB data may require it # groups mentioned in domains and users should be defined, but this is
# not a ValueError since historical DB data may require it
if groups_in_domains - configured_groups: if groups_in_domains - configured_groups:
logger.warning(f"These groups are associated to DOMAINS but not defined in the GROUPS section, the domains settings may not work as expected: {groups_in_domains - configured_groups}") logger.warning(
f"These groups are associated to DOMAINS but not defined in the GROUPS section, the domains settings may not work as expected: {groups_in_domains - configured_groups}"
)
if groups_in_users - configured_groups: if groups_in_users - configured_groups:
logger.warning(f"These groups are associated to USERS but not defined in the GROUPS section, the users settings may not work as expected: {groups_in_users - configured_groups}") logger.warning(
f"These groups are associated to USERS but not defined in the GROUPS section, the users settings may not work as expected: {groups_in_users - configured_groups}"
)
return self return self
# for the API return values # for the API return values

View File

@@ -1,10 +1,14 @@
def fnv1a_hash_mod(s: str, modulo: int) -> int:
def fnv1a_hash_mod(s: str, modulo:int) -> int: # receives a string and returns a number in [0:modulo-1], ensures an even
# receives a string and returns a number in [0:modulo-1], ensures an even distribution over the modulo range # distribution over the modulo range
hash = 0x811c9dc5 # FNV offset basis offset_basis_hash = 0x811C9DC5 # FNV offset basis
fnv_prime = 0x01000193 # FNV prime fnv_prime = 0x01000193 # FNV prime
for char in s: for char in s:
hash ^= ord(char) offset_basis_hash ^= ord(char)
hash *= fnv_prime offset_basis_hash *= fnv_prime
hash &= 0xFFFFFFFF # Keep it 32-bit offset_basis_hash &= 0xFFFFFFFF # Keep it 32-bit
return (hash if hash < 0x80000000 else hash - 0x100000000) % modulo return (
offset_basis_hash
if offset_basis_hash < 0x80000000
else offset_basis_hash - 0x100000000
) % modulo

View File

@@ -1,19 +1,39 @@
import os import os
from datetime import datetime
from http import HTTPStatus
from typing import AsyncGenerator from typing import AsyncGenerator
from fastapi.testclient import TestClient
import pytest
from unittest.mock import patch from unittest.mock import patch
import pytest
import pytest_asyncio import pytest_asyncio
from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine from fastapi.testclient import TestClient
from app.web.config import ALLOW_ANY_EMAIL from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from app.shared.db import models
from app.shared.db.database import (
make_async_engine,
make_async_session_local,
make_engine,
make_session_local,
)
from app.shared.settings import Settings from app.shared.settings import Settings
from app.web.config import ALLOW_ANY_EMAIL
from app.web.db import crud
from app.web.db.crud import get_user_group_names
from app.web.db.user_state import UserState from app.web.db.user_state import UserState
from app.web.main import app_factory
from app.web.security import (
get_token_or_user_auth,
get_user_auth,
get_user_state,
token_api_key_auth,
)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_logger_add(): def mock_logger_add():
"""Fixture to mock loguru.logger.add for all tests.""" """Fixture to mock loguru.logger.add for all tests."""
with patch('loguru.logger.add') as mock_add: with patch("loguru.logger.add") as mock_add:
yield mock_add # This makes the mock available to tests yield mock_add # This makes the mock available to tests
@@ -24,23 +44,22 @@ def get_settings():
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_settings(): def mock_settings():
with patch('app.shared.settings.Settings', return_value=Settings(_env_file=".env.test")) as mock_settings: with patch(
"app.shared.settings.Settings",
return_value=Settings(_env_file=".env.test"),
) as mock_settings:
yield mock_settings yield mock_settings
@pytest.fixture() @pytest.fixture()
def test_db(get_settings: Settings): def test_db(get_settings: Settings):
from app.shared.db import models
from app.shared.db.database import make_engine
from app.web.db.crud import get_user_group_names
get_user_group_names.cache_clear() get_user_group_names.cache_clear()
make_engine.cache_clear() make_engine.cache_clear()
engine = make_engine(get_settings.DATABASE_PATH) engine = make_engine(get_settings.DATABASE_PATH)
fs = get_settings.DATABASE_PATH.replace("sqlite:///", "") fs = get_settings.DATABASE_PATH.replace("sqlite:///", "")
if not os.path.exists(fs): if not os.path.exists(fs):
open(fs, 'w').close() open(fs, "w").close()
models.Base.metadata.create_all(engine) models.Base.metadata.create_all(engine)
@@ -57,7 +76,6 @@ def test_db(get_settings: Settings):
@pytest.fixture() @pytest.fixture()
def db_session(test_db): def db_session(test_db):
from app.shared.db.database import make_session_local
session_local = make_session_local(test_db) session_local = make_session_local(test_db)
with session_local() as session: with session_local() as session:
yield session yield session
@@ -65,17 +83,12 @@ def db_session(test_db):
@pytest_asyncio.fixture() @pytest_asyncio.fixture()
async def async_test_db(get_settings: Settings): async def async_test_db(get_settings: Settings):
from app.shared.db import models
from app.shared.db.database import make_async_engine
from app.web.db.crud import get_user_group_names
import asyncio
get_user_group_names.cache_clear() get_user_group_names.cache_clear()
engine = await make_async_engine(get_settings.ASYNC_DATABASE_PATH) engine = await make_async_engine(get_settings.async_database_path)
fs = get_settings.ASYNC_DATABASE_PATH.replace("sqlite+aiosqlite:///", "") fs = get_settings.async_database_path.replace("sqlite+aiosqlite:///", "")
if not os.path.exists(fs): if not os.path.exists(fs):
open(fs, 'w').close() open(fs, "w").close()
async def create_all(): async def create_all():
async with engine.begin() as conn: async with engine.begin() as conn:
@@ -99,8 +112,9 @@ async def async_test_db(get_settings: Settings):
@pytest_asyncio.fixture() @pytest_asyncio.fixture()
async def async_db_session(async_test_db: AsyncEngine) -> AsyncGenerator[AsyncSession, None]: async def async_db_session(
from app.shared.db.database import make_async_session_local async_test_db: AsyncEngine,
) -> AsyncGenerator[AsyncSession, None]:
session_local = await make_async_session_local(async_test_db) session_local = await make_async_session_local(async_test_db)
async with session_local() as session: async with session_local() as session:
yield session yield session
@@ -108,8 +122,6 @@ async def async_db_session(async_test_db: AsyncEngine) -> AsyncGenerator[AsyncSe
@pytest.fixture() @pytest.fixture()
def app(db_session): def app(db_session):
from app.web.main import app_factory
from app.web.db import crud
app = app_factory() app = app_factory()
crud.upsert_user_groups(db_session) crud.upsert_user_groups(db_session)
return app return app
@@ -123,10 +135,13 @@ def client(app):
@pytest.fixture() @pytest.fixture()
def app_with_auth(app, db_session): def app_with_auth(app, db_session):
from app.web.security import get_token_or_user_auth, get_user_auth, get_user_state app.dependency_overrides[get_token_or_user_auth] = (
app.dependency_overrides[get_token_or_user_auth] = lambda: "rick@example.com" lambda: "rick@example.com"
)
app.dependency_overrides[get_user_auth] = lambda: "morty@example.com" app.dependency_overrides[get_user_auth] = lambda: "morty@example.com"
app.dependency_overrides[get_user_state] = lambda: UserState(db_session, "MORTY@example.com") app.dependency_overrides[get_user_state] = lambda: UserState(
db_session, "MORTY@example.com"
)
return app return app
@@ -138,7 +153,6 @@ def client_with_auth(app_with_auth):
@pytest.fixture() @pytest.fixture()
def app_with_token(app): def app_with_token(app):
from app.web.security import token_api_key_auth, get_token_or_user_auth
app.dependency_overrides[token_api_key_auth] = lambda: ALLOW_ANY_EMAIL app.dependency_overrides[token_api_key_auth] = lambda: ALLOW_ANY_EMAIL
app.dependency_overrides[get_token_or_user_auth] = lambda: ALLOW_ANY_EMAIL app.dependency_overrides[get_token_or_user_auth] = lambda: ALLOW_ANY_EMAIL
return app return app
@@ -155,6 +169,93 @@ def test_no_auth():
# reusable code to ensure a method/endpoint combination is unauthorized # reusable code to ensure a method/endpoint combination is unauthorized
def no_auth(http_method, endpoint): def no_auth(http_method, endpoint):
response = http_method(endpoint) response = http_method(endpoint)
assert response.status_code == 403 assert response.status_code == HTTPStatus.FORBIDDEN
assert response.json() == {"detail": "Not authenticated"} assert response.json() == {"detail": "Not authenticated"}
return no_auth return no_auth
@pytest.fixture()
def test_data(db_session):
author_emails = [
"rick@example.com",
"morty@example.com",
"jerry@example.com",
]
# creates 3 users
for email in author_emails:
db_session.add(models.User(email=email))
db_session.commit()
assert db_session.query(models.User).count() == 3
# creates 100 archives for 3 users over 2 months with repeating URLs
for i in range(100):
author = author_emails[i % 3]
archive = models.Archive(
id=f"archive-id-456-{i}",
url=f"https://example-{i % 3}.com",
result={},
public=author == "jerry@example.com",
author_id=author,
group_id="spaceship"
if author == "morty@example.com" and i % 2 == 0
else None,
created_at=datetime(2021, (i % 2) + 1, (i % 25) + 1),
)
if i % 5 == 0:
archive.tags.append(models.Tag(id=f"tag-{i}"))
if i % 10 == 0:
archive.tags.append(models.Tag(id=f"tag-second-{i}"))
if i % 4 == 0:
archive.tags.append(models.Tag(id=f"tag-third-{i}"))
for j in range(10):
archive.urls.append(
models.ArchiveUrl(
url=f"https://example-{i}.com/{j}", key=f"media_{j}"
)
)
db_session.add(archive)
# creates a sheet for each user
for i, email in enumerate(
["rick@example.com", "morty@example.com", "jerry@example.com"]
):
db_session.add(
models.Sheet(
id=f"sheet-{i}",
name=f"sheet-{i}",
author_id=email,
group_id=None,
frequency="daily",
)
)
if email == "rick@example.com":
db_session.add(
models.Sheet(
id=f"sheet-{i}-2",
name=f"sheet-{i}-2",
author_id=email,
group_id="spaceship",
frequency="hourly",
)
)
db_session.commit()
assert db_session.query(models.Archive).count() == 100
assert db_session.query(models.Tag).count() == 20 + 10 + 25
assert db_session.query(models.ArchiveUrl).count() == 1000
assert (
db_session.query(models.ArchiveUrl)
.filter(models.ArchiveUrl.archive_id == "archive-id-456-0")
.count()
== 10
)
# setup groups
assert db_session.query(models.Group).count() == 0
crud.upsert_user_groups(db_session)
assert db_session.query(models.Group).count() == 4
assert db_session.query(models.User).count() == 3

View File

@@ -1,3 +1,3 @@
{ {
"client_email": "fake_service_account@fake_service_account.iam.gserviceaccount.com" "client_email": "fake_service_account@fake_service_account.iam.gserviceaccount.com"
} }

View File

@@ -1,7 +1,8 @@
steps: steps:
feeder: cli_feeder feeders:
- cli_feeder
archivers: # order matters archivers: # order matters
- youtubedl_archiver - generic_extractor
enrichers: enrichers:
- hash_enricher - hash_enricher
@@ -12,7 +13,7 @@ steps:
- console_db - console_db
configurations: configurations:
gsheet_feeder: gsheet_feeder_db:
service_account: "app/tests/fake_service_account.json" service_account: "app/tests/fake_service_account.json"
cli_feeder: cli_feeder:
urls: urls:

View File

@@ -1,6 +1,7 @@
def test_generate_uuid(): from app.shared.db.models import generate_uuid
from app.shared.db.models import generate_uuid
assert generate_uuid() != generate_uuid()
assert len(generate_uuid()) == 36 def test_generate_uuid():
assert generate_uuid().count("-") == 4 assert generate_uuid() != generate_uuid()
assert len(generate_uuid()) == 36
assert generate_uuid().count("-") == 4

View File

@@ -1,12 +1,10 @@
from app.shared.db import models
from app.shared.db import worker_crud, models
from datetime import datetime from datetime import datetime
from app.shared import schemas
from app.shared.db import models, worker_crud
from app.tests.web.db.test_crud import test_data
def test_update_sheet_last_url_archived_at(db_session): def test_update_sheet_last_url_archived_at(db_session):
# Create test sheet # Create test sheet
test_sheet = models.Sheet(id="sheet-123") test_sheet = models.Sheet(id="sheet-123")
db_session.add(test_sheet) db_session.add(test_sheet)
@@ -15,17 +13,24 @@ def test_update_sheet_last_url_archived_at(db_session):
# Test updating existing sheet # Test updating existing sheet
assert isinstance(test_sheet.last_url_archived_at, datetime) assert isinstance(test_sheet.last_url_archived_at, datetime)
before = test_sheet.last_url_archived_at before = test_sheet.last_url_archived_at
assert worker_crud.update_sheet_last_url_archived_at(db_session, "sheet-123") is True assert (
worker_crud.update_sheet_last_url_archived_at(db_session, "sheet-123")
is True
)
db_session.refresh(test_sheet) db_session.refresh(test_sheet)
assert isinstance(test_sheet.last_url_archived_at, datetime) assert isinstance(test_sheet.last_url_archived_at, datetime)
assert test_sheet.last_url_archived_at > before assert test_sheet.last_url_archived_at > before
# Test non-existent sheet # Test non-existent sheet
assert worker_crud.update_sheet_last_url_archived_at(db_session, "non-existent-sheet") is False assert (
worker_crud.update_sheet_last_url_archived_at(
db_session, "non-existent-sheet"
)
is False
)
def test_get_group(test_data, db_session): def test_get_group(test_data, db_session):
from app.shared.db import worker_crud
assert worker_crud.get_group(db_session, "spaceship") is not None assert worker_crud.get_group(db_session, "spaceship") is not None
assert worker_crud.get_group(db_session, "interdimensional") is not None assert worker_crud.get_group(db_session, "interdimensional") is not None
assert worker_crud.get_group(db_session, "animated-characters") is not None assert worker_crud.get_group(db_session, "animated-characters") is not None
@@ -33,24 +38,24 @@ def test_get_group(test_data, db_session):
def test_create_or_get_user(test_data, db_session): def test_create_or_get_user(test_data, db_session):
from app.shared.db import worker_crud
assert db_session.query(models.User).count() == 3 assert db_session.query(models.User).count() == 3
# already exists # already exists
assert (u1 := worker_crud.create_or_get_user(db_session, "rick@example.com")) is not None assert (
u1 := worker_crud.create_or_get_user(db_session, "rick@example.com")
) is not None
assert u1.email == "rick@example.com" assert u1.email == "rick@example.com"
# new user # new user
assert (u2 := worker_crud.create_or_get_user(db_session, "beth@example.com")) is not None assert (
u2 := worker_crud.create_or_get_user(db_session, "beth@example.com")
) is not None
assert u2.email == "beth@example.com" assert u2.email == "beth@example.com"
assert db_session.query(models.User).count() == 4 assert db_session.query(models.User).count() == 4
def test_create_tag(db_session): def test_create_tag(db_session):
from app.shared.db import worker_crud
assert db_session.query(models.Tag).count() == 0 assert db_session.query(models.Tag).count() == 0
# create first # create first
@@ -58,7 +63,10 @@ def test_create_tag(db_session):
assert create_tag is not None assert create_tag is not None
assert create_tag.id == "tag-101" assert create_tag.id == "tag-101"
assert db_session.query(models.Tag).count() == 1 assert db_session.query(models.Tag).count() == 1
assert db_session.query(models.Tag).filter(models.Tag.id == "tag-101").first() == create_tag assert (
db_session.query(models.Tag).filter(models.Tag.id == "tag-101").first()
== create_tag
)
# same id does not add new db entry # same id does not add new db entry
existing_tag = worker_crud.create_tag(db_session, "tag-101") existing_tag = worker_crud.create_tag(db_session, "tag-101")
@@ -73,9 +81,6 @@ def test_create_tag(db_session):
def test_create_task(db_session): def test_create_task(db_session):
from app.shared.db import worker_crud
from app.shared import schemas
task = schemas.ArchiveCreate( task = schemas.ArchiveCreate(
id="archive-id-456-101", id="archive-id-456-101",
url="https://example-0.com", url="https://example-0.com",
@@ -84,17 +89,22 @@ def test_create_task(db_session):
author_id="rick@example.com", author_id="rick@example.com",
group_id="spaceship", group_id="spaceship",
tags=[], tags=[],
urls=[] urls=[],
) )
# with tags and urls # with tags and urls
nt = worker_crud.create_archive(db_session, task, [models.Tag(id="tag-101")], [models.ArchiveUrl(url="https://example-0.com/0", key="media_0")]) nt = worker_crud.create_archive(
db_session,
task,
[models.Tag(id="tag-101")],
[models.ArchiveUrl(url="https://example-0.com/0", key="media_0")],
)
assert nt is not None assert nt is not None
assert nt.id == "archive-id-456-101" assert nt.id == "archive-id-456-101"
assert nt.url == "https://example-0.com" assert nt.url == "https://example-0.com"
assert nt.author_id == "rick@example.com" assert nt.author_id == "rick@example.com"
assert nt.public == False assert nt.public is False
assert nt.group_id == "spaceship" assert nt.group_id == "spaceship"
assert len(nt.tags) == 1 assert len(nt.tags) == 1
assert nt.tags[0].id == "tag-101" assert nt.tags[0].id == "tag-101"
@@ -110,7 +120,7 @@ def test_create_task(db_session):
assert nt.id == "archive-id-456-102" assert nt.id == "archive-id-456-102"
assert nt.url == "https://example-0.com" assert nt.url == "https://example-0.com"
assert nt.author_id == "rick@example.com" assert nt.author_id == "rick@example.com"
assert nt.public == False assert nt.public is False
assert nt.group_id == "spaceship" assert nt.group_id == "spaceship"
assert len(nt.tags) == 0 assert len(nt.tags) == 0
assert len(nt.urls) == 0 assert len(nt.urls) == 0

View File

@@ -1,10 +1,15 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
from app.shared.business_logic import get_store_archive_until, get_store_archive_until_or_never
from app.shared.business_logic import (
get_store_archive_until,
get_store_archive_until_or_never,
)
class Test_get_store_archive_until: class TestGetStoreArchiveUntil:
GROUP_ID = "test-group" GROUP_ID = "test-group"
def test_group_not_found(self, db_session): def test_group_not_found(self, db_session):
@@ -12,7 +17,10 @@ class Test_get_store_archive_until:
get_store_archive_until(db_session, self.GROUP_ID) get_store_archive_until(db_session, self.GROUP_ID)
assert str(exc.value) == f"Group {self.GROUP_ID} not found." assert str(exc.value) == f"Group {self.GROUP_ID} not found."
@patch("app.shared.db.worker_crud.get_group", return_value=MagicMock(permissions=None)) @patch(
"app.shared.db.worker_crud.get_group",
return_value=MagicMock(permissions=None),
)
def test_group_no_permissions(self, db_session): def test_group_no_permissions(self, db_session):
with pytest.raises(AssertionError) as exc: with pytest.raises(AssertionError) as exc:
get_store_archive_until(db_session, self.GROUP_ID) get_store_archive_until(db_session, self.GROUP_ID)
@@ -43,14 +51,17 @@ class Test_get_store_archive_until:
mock_get_group.assert_called_once_with(db_session, self.GROUP_ID) mock_get_group.assert_called_once_with(db_session, self.GROUP_ID)
class Test_get_store_archive_until_or_never: class TestGetStoreArchiveUntilOrNever:
GROUP_ID = "test-group" GROUP_ID = "test-group"
def test_group_not_found(self, db_session): def test_group_not_found(self, db_session):
result = get_store_archive_until_or_never(db_session, self.GROUP_ID) result = get_store_archive_until_or_never(db_session, self.GROUP_ID)
assert result is None assert result is None
@patch("app.shared.db.worker_crud.get_group", return_value=MagicMock(permissions=None)) @patch(
"app.shared.db.worker_crud.get_group",
return_value=MagicMock(permissions=None),
)
def test_group_no_permissions(self, db_session): def test_group_no_permissions(self, db_session):
result = get_store_archive_until_or_never(db_session, self.GROUP_ID) result = get_store_archive_until_or_never(db_session, self.GROUP_ID)
assert result is None assert result is None

View File

@@ -2,123 +2,379 @@ from datetime import datetime, timedelta
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
import sqlalchemy
import yaml import yaml
from sqlalchemy import false, true
from sqlalchemy.sql import select
from app.shared.db import models from app.shared.db import models
from app.shared.settings import Settings from app.shared.settings import Settings
from app.web.config import ALLOW_ANY_EMAIL
from app.web.db import crud from app.web.db import crud
authors = ["rick@example.com", "morty@example.com", "jerry@example.com"]
@pytest.fixture()
def test_data(db_session):
# creates 3 users
for email in authors:
db_session.add(models.User(email=email))
db_session.commit()
assert db_session.query(models.User).count() == 3
# creates 100 archives for 3 users over 2 months with repeating URLs
for i in range(100):
author = authors[i % 3]
archive = models.Archive(
id=f"archive-id-456-{i}",
url=f"https://example-{i%3}.com",
result={},
public=author == "jerry@example.com",
author_id=author,
group_id="spaceship" if author == "morty@example.com" and i % 2 == 0 else None,
created_at=datetime(2021, (i % 2) + 1, (i % 25) + 1)
)
if i % 5 == 0:
archive.tags.append(models.Tag(id=f"tag-{i}"))
if i % 10 == 0:
archive.tags.append(models.Tag(id=f"tag-second-{i}"))
if i % 4 == 0:
archive.tags.append(models.Tag(id=f"tag-third-{i}"))
for j in range(10):
archive.urls.append(models.ArchiveUrl(url=f"https://example-{i}.com/{j}", key=f"media_{j}"))
db_session.add(archive)
# creates a sheet for each user
for i, email in enumerate(authors):
db_session.add(models.Sheet(id=f"sheet-{i}", name=f"sheet-{i}", author_id=email, group_id=None, frequency="daily"))
if email == "rick@example.com":
db_session.add(models.Sheet(id=f"sheet-{i}-2", name=f"sheet-{i}-2", author_id=email, group_id="spaceship", frequency="hourly"))
db_session.commit()
assert db_session.query(models.Archive).count() == 100
assert db_session.query(models.Tag).count() == 20 + 10 + 25
assert db_session.query(models.ArchiveUrl).count() == 1000
assert db_session.query(models.ArchiveUrl).filter(models.ArchiveUrl.archive_id == "archive-id-456-0").count() == 10
# setup groups
assert db_session.query(models.Group).count() == 0
from app.web.db import crud
crud.upsert_user_groups(db_session)
assert db_session.query(models.Group).count() == 4
assert db_session.query(models.User).count() == 3
def test_search_archives_by_url(test_data, db_session): def test_search_archives_by_url(test_data, db_session):
from app.web.config import ALLOW_ANY_EMAIL # Rick's archives are private
assert (
# rick's archives are private len(
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "rick@example.com", True, False)) == 34 crud.search_archives_by_url(
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "rick@example.com", [], False)) == 34 db_session,
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "rick@example.com", [], True)) == 34 "https://example-0.com",
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", ALLOW_ANY_EMAIL, [], False)) == 34 "rick@example.com",
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", ALLOW_ANY_EMAIL, True, False)) == 34 True,
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "morty@example.com", [], False)) == 0 False,
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "morty@example.com", [], True)) == 0 )
)
== 34
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example-0.com",
"rick@example.com",
[],
False,
)
)
== 34
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example-0.com",
"rick@example.com",
[],
True,
)
)
== 34
)
assert (
len(
crud.search_archives_by_url(
db_session, "https://example-0.com", ALLOW_ANY_EMAIL, [], False
)
)
== 34
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example-0.com",
ALLOW_ANY_EMAIL,
True,
False,
)
)
== 34
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example-0.com",
"morty@example.com",
[],
False,
)
)
== 0
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example-0.com",
"morty@example.com",
[],
True,
)
)
== 0
)
# morty's archives are public but half are in spaceship group # morty's archives are public but half are in spaceship group
assert len(crud.search_archives_by_url(db_session, "https://example-1.com", "rick@example.com", ["spaceship"], False)) == 16 assert (
assert len(crud.search_archives_by_url(db_session, "https://example-1.com", "rick@example.com", True, False)) == 16 len(
assert len(crud.search_archives_by_url(db_session, "https://example-1.com", "jerry@example.com", True, True)) == 16 crud.search_archives_by_url(
db_session,
"https://example-1.com",
"rick@example.com",
["spaceship"],
False,
)
)
== 16
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example-1.com",
"rick@example.com",
True,
False,
)
)
== 16
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example-1.com",
"jerry@example.com",
True,
True,
)
)
== 16
)
# jerry's archives are public # Jerry's archives are public
assert len(crud.search_archives_by_url(db_session, "https://example-2.com", "jerry@example.com", [], True)) == 33 assert (
assert len(crud.search_archives_by_url(db_session, "https://example-2.com", "rick@example.com", [], True)) == 33 len(
crud.search_archives_by_url(
db_session,
"https://example-2.com",
"jerry@example.com",
[],
True,
)
)
== 33
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example-2.com",
"rick@example.com",
[],
True,
)
)
== 33
)
# fuzzy search # fuzzy search
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False)) == 100 assert (
assert len(crud.search_archives_by_url(db_session, "https://EXAMPLE", ALLOW_ANY_EMAIL, False, False)) == 100 len(
assert len(crud.search_archives_by_url(db_session, "2.com", ALLOW_ANY_EMAIL, False, False)) == 33 crud.search_archives_by_url(
db_session, "https://example", ALLOW_ANY_EMAIL, False, False
)
)
== 100
)
assert (
len(
crud.search_archives_by_url(
db_session, "https://EXAMPLE", ALLOW_ANY_EMAIL, False, False
)
)
== 100
)
assert (
len(
crud.search_archives_by_url(
db_session, "2.com", ALLOW_ANY_EMAIL, False, False
)
)
== 33
)
# absolute search # absolute search
assert len(crud.search_archives_by_url(db_session, "example-2.com", ALLOW_ANY_EMAIL, [], False, absolute_search=True)) == 0 assert (
assert len(crud.search_archives_by_url(db_session, "https://example-2.com", ALLOW_ANY_EMAIL, [], False, absolute_search=True)) == 33 len(
crud.search_archives_by_url(
db_session,
"example-2.com",
ALLOW_ANY_EMAIL,
[],
False,
absolute_search=True,
)
)
== 0
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example-2.com",
ALLOW_ANY_EMAIL,
[],
False,
absolute_search=True,
)
)
== 33
)
# archived_after # archived_after
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, True, True, archived_after=datetime(2010, 1, 1))) == 100 assert (
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_after=datetime(2021, 1, 15))) == 70 len(
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_after=datetime(2031, 1, 1))) == 0 crud.search_archives_by_url(
db_session,
"https://example",
ALLOW_ANY_EMAIL,
True,
True,
archived_after=datetime(2010, 1, 1),
)
)
== 100
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example",
ALLOW_ANY_EMAIL,
False,
False,
archived_after=datetime(2021, 1, 15),
)
)
== 70
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example",
ALLOW_ANY_EMAIL,
False,
False,
archived_after=datetime(2031, 1, 1),
)
)
== 0
)
# archived before # archived before
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_before=datetime(2010, 1, 1))) == 0 assert (
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_before=datetime(2021, 1, 15))) == 28 len(
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_before=datetime(2031, 1, 1))) == 100 crud.search_archives_by_url(
db_session,
"https://example",
ALLOW_ANY_EMAIL,
False,
False,
archived_before=datetime(2010, 1, 1),
)
)
== 0
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example",
ALLOW_ANY_EMAIL,
False,
False,
archived_before=datetime(2021, 1, 15),
)
)
== 28
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example",
ALLOW_ANY_EMAIL,
False,
False,
archived_before=datetime(2031, 1, 1),
)
)
== 100
)
# archived before and after # archived before and after
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_after=datetime(2001, 1, 1), archived_before=datetime(2031, 1, 11))) == 100 assert (
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_after=datetime(2021, 1, 14), archived_before=datetime(2021, 1, 16))) == 2 len(
crud.search_archives_by_url(
db_session,
"https://example",
ALLOW_ANY_EMAIL,
False,
False,
archived_after=datetime(2001, 1, 1),
archived_before=datetime(2031, 1, 11),
)
)
== 100
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example",
ALLOW_ANY_EMAIL,
False,
False,
archived_after=datetime(2021, 1, 14),
archived_before=datetime(2021, 1, 16),
)
)
== 2
)
# limit # limit
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, limit=10)) == 10 assert (
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, limit=-1)) == 1 len(
crud.search_archives_by_url(
db_session,
"https://example",
ALLOW_ANY_EMAIL,
False,
False,
limit=10,
)
)
== 10
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example",
ALLOW_ANY_EMAIL,
False,
False,
limit=-1,
)
)
== 1
)
# skip # skip
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, skip=10)) == 90 assert (
len(
crud.search_archives_by_url(
db_session,
"https://example",
ALLOW_ANY_EMAIL,
False,
False,
skip=10,
)
)
== 90
)
def test_search_archives_by_email(test_data, db_session): def test_search_archives_by_email(test_data, db_session):
from app.web.config import ALLOW_ANY_EMAIL
# lower/upper case # lower/upper case
assert len(crud.search_archives_by_email(db_session, "rick@example.com")) == 34 assert (
len(crud.search_archives_by_email(db_session, "rick@example.com")) == 34
)
# ALLOW_ANY_EMAIL is not a user # ALLOW_ANY_EMAIL is not a user
assert len(crud.search_archives_by_email(db_session, ALLOW_ANY_EMAIL)) == 0 assert len(crud.search_archives_by_email(db_session, ALLOW_ANY_EMAIL)) == 0
@@ -136,45 +392,108 @@ def test_search_archives_by_email(test_data, db_session):
@patch("app.web.db.crud.DATABASE_QUERY_LIMIT", new=25) @patch("app.web.db.crud.DATABASE_QUERY_LIMIT", new=25)
def test_max_query_limit(test_data, db_session): def test_max_query_limit(test_data, db_session):
from app.web.config import ALLOW_ANY_EMAIL assert (
len(
crud.search_archives_by_url(
db_session, "https://example", ALLOW_ANY_EMAIL, [], False
)
)
== 25
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example",
ALLOW_ANY_EMAIL,
True,
True,
limit=1000,
)
)
== 25
)
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, [], False)) == 25 assert (
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, True, True, limit=1000)) == 25 len(crud.search_archives_by_email(db_session, "rick@example.com")) == 25
)
assert len(crud.search_archives_by_email(db_session, "rick@example.com")) == 25 assert (
assert len(crud.search_archives_by_email(db_session, "rick@example.com", limit=1000)) == 25 len(
crud.search_archives_by_email(
db_session, "rick@example.com", limit=1000
)
)
== 25
)
def test_soft_delete(test_data, db_session): def test_soft_delete(test_data, db_session):
# none deleted yet # none deleted yet
db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").first() is not None assert (
assert db_session.query(models.Archive).filter(models.Archive.deleted == True).count() == 0 db_session.query(models.Archive)
.filter(models.Archive.id == "archive-id-456-0")
.first()
is not None
)
assert (
db_session.query(models.Archive)
.filter(models.Archive.deleted.is_(true()))
.count()
== 0
)
# delete # delete
assert crud.soft_delete_archive(db_session, "archive-id-456-0", "rick@example.com") == True assert (
crud.soft_delete_archive(
db_session, "archive-id-456-0", "rick@example.com"
)
is True
)
# ensure soft delete # ensure soft delete
assert db_session.query(models.Archive).filter(models.Archive.deleted == True).count() == 1 assert (
db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").first() is None db_session.query(models.Archive)
.filter(models.Archive.deleted.is_(true()))
.count()
== 1
)
assert (
db_session.query(models.Archive)
.filter(models.Archive.id == "archive-id-456-0")
.filter(models.Archive.deleted.is_(false()))
.first()
is None
)
# already deleted # already deleted
assert crud.soft_delete_archive(db_session, "archive-id-456-0", "rick@example.com") == False assert (
crud.soft_delete_archive(
db_session, "archive-id-456-0", "rick@example.com"
)
is False
)
def test_count_archives(test_data, db_session): def test_count_archives(test_data, db_session):
assert crud.count_archives(db_session) == 100 assert crud.count_archives(db_session) == 100
db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").delete() db_session.query(models.Archive).filter(
models.Archive.id == "archive-id-456-0"
).delete()
db_session.commit() db_session.commit()
assert crud.count_archives(db_session) == 99 assert crud.count_archives(db_session) == 99
def test_count_archive_urls(test_data, db_session): def test_count_archive_urls(test_data, db_session):
assert crud.count_archive_urls(db_session) == 1000 assert crud.count_archive_urls(db_session) == 1000
db_session.query(models.ArchiveUrl).filter(models.ArchiveUrl.url == "https://example-0.com/0").delete() db_session.query(models.ArchiveUrl).filter(
models.ArchiveUrl.url == "https://example-0.com/0"
).delete()
db_session.commit() db_session.commit()
assert crud.count_archive_urls(db_session) == 999 assert crud.count_archive_urls(db_session) == 999
db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").delete() db_session.query(models.Archive).filter(
models.Archive.id == "archive-id-456-0"
).delete()
db_session.commit() db_session.commit()
# no Cascade is enabled # no Cascade is enabled
assert crud.count_archives(db_session) == 99 assert crud.count_archives(db_session) == 99
@@ -183,16 +502,23 @@ def test_count_archive_urls(test_data, db_session):
def test_count_users(test_data, db_session): def test_count_users(test_data, db_session):
assert crud.count_users(db_session) == 3 assert crud.count_users(db_session) == 3
db_session.query(models.User).filter(models.User.email == "rick@example.com").delete() db_session.query(models.User).filter(
models.User.email == "rick@example.com"
).delete()
db_session.commit() db_session.commit()
assert crud.count_users(db_session) == 2 assert crud.count_users(db_session) == 2
def test_count_by_users_since(test_data, db_session): def test_count_by_users_since(test_data, db_session):
from app.web.db import crud
# 100y window # 100y window
assert len(cu := crud.count_by_user_since(db_session, 60 * 60 * 24 * 31 * 12 * 100)) == 3 assert (
len(
cu := crud.count_by_user_since(
db_session, 60 * 60 * 24 * 31 * 12 * 100
)
)
== 3
)
assert cu[0].total == 34 assert cu[0].total == 34
assert cu[1].total == 33 assert cu[1].total == 33
assert cu[2].total == 33 assert cu[2].total == 33
@@ -201,9 +527,18 @@ def test_count_by_users_since(test_data, db_session):
def test_upsert_group(test_data, db_session): def test_upsert_group(test_data, db_session):
assert db_session.query(models.Group).count() == 4 assert db_session.query(models.Group).count() == 4
repeatable_params = ["desc 1", "orch.yaml", "sheet.yaml", "service_account_email@example.com", {"read": ["all"]}, ["example.com"]] repeatable_params = [
"desc 1",
"orch.yaml",
"sheet.yaml",
"service_account_email@example.com",
{"read": ["all"]},
["example.com"],
]
assert (g1 := crud.upsert_group(db_session, "spaceship", *repeatable_params)) is not None assert (
g1 := crud.upsert_group(db_session, "spaceship", *repeatable_params)
) is not None
assert g1.id == "spaceship" assert g1.id == "spaceship"
assert g1.description == "desc 1" assert g1.description == "desc 1"
assert g1.orchestrator == "orch.yaml" assert g1.orchestrator == "orch.yaml"
@@ -212,14 +547,25 @@ def test_upsert_group(test_data, db_session):
assert g1.permissions == {"read": ["all"]} assert g1.permissions == {"read": ["all"]}
assert g1.domains == ["example.com"] assert g1.domains == ["example.com"]
assert len(g1.users) == 2 assert len(g1.users) == 2
assert [u.email for u in g1.users] == ["rick@example.com", "morty@example.com"] assert [u.email for u in g1.users] == [
"rick@example.com",
"morty@example.com",
]
assert (g2 := crud.upsert_group(db_session, "interdimensional", *repeatable_params)) is not None assert (
g2 := crud.upsert_group(
db_session, "interdimensional", *repeatable_params
)
) is not None
assert g2.id == "interdimensional" assert g2.id == "interdimensional"
assert len(g2.users) == 1 assert len(g2.users) == 1
assert [u.email for u in g2.users] == ["rick@example.com"] assert [u.email for u in g2.users] == ["rick@example.com"]
assert (g3 := crud.upsert_group(db_session, "this-is-a-new-group", *repeatable_params)) is not None assert (
g3 := crud.upsert_group(
db_session, "this-is-a-new-group", *repeatable_params
)
) is not None
assert g3.id == "this-is-a-new-group" assert g3.id == "this-is-a-new-group"
assert len(g3.users) == 0 assert len(g3.users) == 0
@@ -227,29 +573,38 @@ def test_upsert_group(test_data, db_session):
def test_upsert_user_groups(db_session): def test_upsert_user_groups(db_session):
@patch('app.web.db.crud.get_settings', new=lambda: bad_setings) @patch("app.web.db.crud.get_settings", new=lambda: bad_settings)
def test_missing_yaml(db_session): def test_missing_yaml(db_session):
with pytest.raises(FileNotFoundError): with pytest.raises(FileNotFoundError):
crud.upsert_user_groups(db_session) crud.upsert_user_groups(db_session)
@patch('app.web.db.crud.get_settings', new=lambda: bad_setings) @patch("app.web.db.crud.get_settings", new=lambda: bad_settings)
def test_broken_yaml(db_session): def test_broken_yaml(db_session):
with pytest.raises(yaml.YAMLError): with pytest.raises(yaml.YAMLError):
crud.upsert_user_groups(db_session) crud.upsert_user_groups(db_session)
bad_setings = Settings(_env_file=".env.test") bad_settings = Settings(_env_file=".env.test")
bad_setings.USER_GROUPS_FILENAME = "app/tests/user-groups.test.missing.yaml" bad_settings.USER_GROUPS_FILENAME = (
"app/tests/user-groups.test.missing.yaml"
)
test_missing_yaml(db_session) test_missing_yaml(db_session)
bad_setings.USER_GROUPS_FILENAME = "app/tests/user-groups.test.broken.yaml" bad_settings.USER_GROUPS_FILENAME = "app/tests/user-groups.test.broken.yaml"
test_broken_yaml(db_session) test_broken_yaml(db_session)
def test_create_sheet(db_session): def test_create_sheet(db_session):
assert db_session.query(models.Sheet).count() == 0 assert db_session.query(models.Sheet).count() == 0
s = crud.create_sheet(db_session, "sheet-id-123", "sheet name", "email@example.com", "group-id", "hourly") s = crud.create_sheet(
db_session,
"sheet-id-123",
"sheet name",
"email@example.com",
"group-id",
"hourly",
)
assert s is not None assert s is not None
assert s.id == "sheet-id-123" assert s.id == "sheet-id-123"
assert s.name == "sheet name" assert s.name == "sheet name"
@@ -259,19 +614,35 @@ def test_create_sheet(db_session):
assert db_session.query(models.Sheet).count() == 1 assert db_session.query(models.Sheet).count() == 1
# duplicate id
import sqlalchemy
with pytest.raises(sqlalchemy.exc.IntegrityError): with pytest.raises(sqlalchemy.exc.IntegrityError):
crud.create_sheet(db_session, "sheet-id-123", "I thought this was another sheet", "email", "group-id", "hourly") crud.create_sheet(
db_session,
"sheet-id-123",
"I thought this was another sheet",
"email",
"group-id",
"hourly",
)
def test_get_user_sheet(test_data, db_session): def test_get_user_sheet(test_data, db_session):
assert crud.get_user_sheet(db_session, "", "sheet-0") is None assert crud.get_user_sheet(db_session, "", "sheet-0") is None
assert crud.get_user_sheet(db_session, "morty@example.com", "sheet-0") is None assert (
crud.get_user_sheet(db_session, "morty@example.com", "sheet-0") is None
)
assert crud.get_user_sheet(db_session, "rick@example.com", "sheet-0") is not None assert (
assert crud.get_user_sheet(db_session, "rick@example.com", "sheet-0-2") is not None crud.get_user_sheet(db_session, "rick@example.com", "sheet-0")
assert crud.get_user_sheet(db_session, "morty@example.com", "sheet-1") is not None is not None
)
assert (
crud.get_user_sheet(db_session, "rick@example.com", "sheet-0-2")
is not None
)
assert (
crud.get_user_sheet(db_session, "morty@example.com", "sheet-1")
is not None
)
def test_get_user_sheets(test_data, db_session): def test_get_user_sheets(test_data, db_session):
@@ -283,9 +654,9 @@ def test_get_user_sheets(test_data, db_session):
def test_delete_sheet(test_data, db_session): def test_delete_sheet(test_data, db_session):
assert crud.delete_sheet(db_session, "sheet-0", "") == False assert crud.delete_sheet(db_session, "sheet-0", "") is False
assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") == True assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") is True
assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") == False assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") is False
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -297,21 +668,21 @@ async def test_find_by_store_until(async_db_session):
url="https://example-expired-1.com", url="https://example-expired-1.com",
result={}, result={},
author_id="rick@example.com", author_id="rick@example.com",
store_until=now - timedelta(days=1) store_until=now - timedelta(days=1),
) )
archive2 = models.Archive( archive2 = models.Archive(
id="archive-expired-2", id="archive-expired-2",
url="https://example-expired-2.com", url="https://example-expired-2.com",
result={}, result={},
author_id="rick@example.com", author_id="rick@example.com",
store_until=now - timedelta(hours=1) store_until=now - timedelta(hours=1),
) )
archive3 = models.Archive( archive3 = models.Archive(
id="archive-active", id="archive-active",
url="https://example-active.com", url="https://example-active.com",
result={}, result={},
author_id="rick@example.com", author_id="rick@example.com",
store_until=now + timedelta(days=1) store_until=now + timedelta(days=1),
) )
async_db_session.add_all([archive1, archive2, archive3]) async_db_session.add_all([archive1, archive2, archive3])
await async_db_session.commit() await async_db_session.commit()
@@ -321,11 +692,15 @@ async def test_find_by_store_until(async_db_session):
assert len(list(expired)) == 2 assert len(list(expired)) == 2
# Should find 1 archive expired before 2 hours ago # Should find 1 archive expired before 2 hours ago
expired = await crud.find_by_store_until(async_db_session, now - timedelta(hours=2)) expired = await crud.find_by_store_until(
async_db_session, now - timedelta(hours=2)
)
assert len(list(expired)) == 1 assert len(list(expired)) == 1
# Should find no archives expired before 2 days ago # Should find no archives expired before 2 days ago
expired = await crud.find_by_store_until(async_db_session, now - timedelta(days=2)) expired = await crud.find_by_store_until(
async_db_session, now - timedelta(days=2)
)
assert len(list(expired)) == 0 assert len(list(expired)) == 0
# Should not find deleted archives # Should not find deleted archives
@@ -337,44 +712,82 @@ async def test_find_by_store_until(async_db_session):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_sheets_by_id_hash(async_db_session): async def test_get_sheets_by_id_hash(async_db_session):
author_emails = [
"rick@example.com",
"morty@example.com",
"jerry@example.com",
]
# Add test data # Add test data
authors = ["rick@example.com", "morty@example.com", "jerry@example.com"]
sheets = [ sheets = [
models.Sheet(id="sheet-0", name="sheet-0", author_id=authors[0], group_id=None, frequency="daily"), models.Sheet(
models.Sheet(id="sheet-0-2", name="sheet-0-2", author_id=authors[0], group_id="spaceship", frequency="hourly"), id="sheet-0",
models.Sheet(id="sheet-1", name="sheet-1", author_id=authors[1], group_id=None, frequency="daily"), name="sheet-0",
models.Sheet(id="sheet-2", name="sheet-2", author_id=authors[2], group_id=None, frequency="daily") author_id=author_emails[0],
group_id=None,
frequency="daily",
),
models.Sheet(
id="sheet-0-2",
name="sheet-0-2",
author_id=author_emails[0],
group_id="spaceship",
frequency="hourly",
),
models.Sheet(
id="sheet-1",
name="sheet-1",
author_id=author_emails[1],
group_id=None,
frequency="daily",
),
models.Sheet(
id="sheet-2",
name="sheet-2",
author_id=author_emails[2],
group_id=None,
frequency="daily",
),
] ]
async_db_session.add_all(sheets) async_db_session.add_all(sheets)
await async_db_session.commit() await async_db_session.commit()
with patch("app.web.db.crud.fnv1a_hash_mod", return_value=1): with patch("app.web.db.crud.fnv1a_hash_mod", return_value=1):
# Test retrieving hourly sheets # Test retrieving hourly sheets
hourly_sheets = await crud.get_sheets_by_id_hash(async_db_session, "hourly", 4, 1) hourly_sheets = await crud.get_sheets_by_id_hash(
async_db_session, "hourly", 4, 1
)
assert len(hourly_sheets) == 1 assert len(hourly_sheets) == 1
assert hourly_sheets[0].id == "sheet-0-2" assert hourly_sheets[0].id == "sheet-0-2"
assert hourly_sheets[0].frequency == "hourly" assert hourly_sheets[0].frequency == "hourly"
# Test retrieving daily sheets # Test retrieving daily sheets
daily_sheets = await crud.get_sheets_by_id_hash(async_db_session, "daily", 4, 1) daily_sheets = await crud.get_sheets_by_id_hash(
async_db_session, "daily", 4, 1
)
assert len(daily_sheets) == 3 assert len(daily_sheets) == 3
assert all(sheet.frequency == "daily" for sheet in daily_sheets) assert all(sheet.frequency == "daily" for sheet in daily_sheets)
assert {sheet.id for sheet in daily_sheets} == {"sheet-0", "sheet-1", "sheet-2"} assert {sheet.id for sheet in daily_sheets} == {
"sheet-0",
"sheet-1",
"sheet-2",
}
# Test with non-matching hash # Test with non-matching hash
no_sheets = await crud.get_sheets_by_id_hash(async_db_session, "daily", 4, 3) no_sheets = await crud.get_sheets_by_id_hash(
async_db_session, "daily", 4, 3
)
assert len(no_sheets) == 0 assert len(no_sheets) == 0
# Test with non-existent frequency # Test with non-existent frequency
weekly_sheets = await crud.get_sheets_by_id_hash(async_db_session, "weekly", 4, 1) weekly_sheets = await crud.get_sheets_by_id_hash(
async_db_session, "weekly", 4, 1
)
assert len(weekly_sheets) == 0 assert len(weekly_sheets) == 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_stale_sheets(async_db_session): async def test_delete_stale_sheets(async_db_session):
from datetime import datetime, timedelta
from sqlalchemy.sql import select
now = datetime.now() now = datetime.now()
active_date = now - timedelta(days=5) active_date = now - timedelta(days=5)
stale_date = now - timedelta(days=15) stale_date = now - timedelta(days=15)
@@ -386,29 +799,29 @@ async def test_delete_stale_sheets(async_db_session):
name="Active Sheet 1", name="Active Sheet 1",
author_id="rick@example.com", author_id="rick@example.com",
frequency="daily", frequency="daily",
last_url_archived_at=active_date last_url_archived_at=active_date,
), ),
models.Sheet( models.Sheet(
id="sheet-active-2", id="sheet-active-2",
name="Active Sheet 2", name="Active Sheet 2",
author_id="morty@example.com", author_id="morty@example.com",
frequency="hourly", frequency="hourly",
last_url_archived_at=active_date last_url_archived_at=active_date,
), ),
models.Sheet( models.Sheet(
id="sheet-stale-1", id="sheet-stale-1",
name="Stale Sheet 1", name="Stale Sheet 1",
author_id="rick@example.com", author_id="rick@example.com",
frequency="daily", frequency="daily",
last_url_archived_at=stale_date last_url_archived_at=stale_date,
), ),
models.Sheet( models.Sheet(
id="sheet-stale-2", id="sheet-stale-2",
name="Stale Sheet 2", name="Stale Sheet 2",
author_id="morty@example.com", author_id="morty@example.com",
frequency="daily", frequency="daily",
last_url_archived_at=stale_date last_url_archived_at=stale_date,
) ),
] ]
async_db_session.add_all(sheets) async_db_session.add_all(sheets)
await async_db_session.commit() await async_db_session.commit()

View File

@@ -1,10 +1,11 @@
from unittest.mock import MagicMock, PropertyMock, patch from unittest.mock import MagicMock, PropertyMock, patch
import pytest import pytest
from app.shared.db import models from app.shared.db import models
from app.shared.user_groups import GroupInfo, GroupPermissions from app.shared.user_groups import GroupInfo, GroupPermissions
from app.web.db.user_state import UserState from app.web.db.user_state import UserState
from app.web.utils.misc import convert_priority_to_queue_dict
def fresh_user_state(): def fresh_user_state():
@@ -20,39 +21,73 @@ def user_state():
def user_state_with_groups(user_state): def user_state_with_groups(user_state):
user_groups = [ user_groups = [
models.Group(id="no-permissions", permissions={}), models.Group(id="no-permissions", permissions={}),
models.Group(id="group1", description="this is g1", service_account_email="sa1@example.com", permissions={"read": ["group1", "no-permissions"], "read_public": True, "archive_url": True, "archive_sheet": True, "max_archive_lifespan_months": 24, "max_monthly_urls": 100, "max_monthly_mbs": 1000, "priority": "high"}), models.Group(
models.Group(id="group2", description="this is g2", service_account_email="sa2@example.com", permissions={"read": ["all"], "read_public": True, "archive_url": False, "archive_sheet": False, "max_archive_lifespan_months": -1, "max_monthly_urls": -1, "max_monthly_mbs": -1, "priority": "low", "sheet_frequency": {"daily"}}), id="group1",
description="this is g1",
service_account_email="sa1@example.com",
permissions={
"read": ["group1", "no-permissions"],
"read_public": True,
"archive_url": True,
"archive_sheet": True,
"max_archive_lifespan_months": 24,
"max_monthly_urls": 100,
"max_monthly_mbs": 1000,
"priority": "high",
},
),
models.Group(
id="group2",
description="this is g2",
service_account_email="sa2@example.com",
permissions={
"read": ["all"],
"read_public": True,
"archive_url": False,
"archive_sheet": False,
"max_archive_lifespan_months": -1,
"max_monthly_urls": -1,
"max_monthly_mbs": -1,
"priority": "low",
"sheet_frequency": {"daily"},
},
),
] ]
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=user_groups): with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=user_groups,
):
yield user_state yield user_state
def test_permissions(user_state_with_groups): def test_permissions(user_state_with_groups):
permissions = user_state_with_groups.permissions permissions = user_state_with_groups.permissions
assert permissions["all"].read == True assert permissions["all"].read is True
assert permissions["all"].read_public == True assert permissions["all"].read_public is True
assert permissions["all"].archive_url == True assert permissions["all"].archive_url is True
assert permissions["all"].archive_sheet == True assert permissions["all"].archive_sheet is True
assert permissions["all"].max_archive_lifespan_months == -1 assert permissions["all"].max_archive_lifespan_months == -1
assert permissions["all"].max_monthly_urls == -1 assert permissions["all"].max_monthly_urls == -1
assert permissions["all"].max_monthly_mbs == -1 assert permissions["all"].max_monthly_mbs == -1
assert permissions["all"].priority == "high" assert permissions["all"].priority == "high"
assert permissions["group1"].read == set(["group1", "no-permissions"]) assert permissions["group1"].read == {"group1", "no-permissions"}
assert permissions["group1"].read_public == True assert permissions["group1"].read_public is True
assert permissions["group1"].archive_url == True assert permissions["group1"].archive_url is True
assert permissions["group1"].archive_sheet == True assert permissions["group1"].archive_sheet is True
assert permissions["group1"].max_archive_lifespan_months == 24 assert permissions["group1"].max_archive_lifespan_months == 24
assert permissions["group1"].max_monthly_urls == 100 assert permissions["group1"].max_monthly_urls == 100
assert permissions["group1"].max_monthly_mbs == 1000 assert permissions["group1"].max_monthly_mbs == 1000
assert permissions["group1"].priority == "high" assert permissions["group1"].priority == "high"
assert permissions["group2"].read == set(["all"]) assert permissions["group2"].read == {"all"}
assert permissions["group2"].read_public == True assert permissions["group2"].read_public is True
assert permissions["group2"].archive_url == False assert permissions["group2"].archive_url is False
assert permissions["group2"].archive_sheet == False assert permissions["group2"].archive_sheet is False
assert permissions["group2"].max_archive_lifespan_months == -1 assert permissions["group2"].max_archive_lifespan_months == -1
assert permissions["group2"].max_monthly_urls == -1 assert permissions["group2"].max_monthly_urls == -1
assert permissions["group2"].max_monthly_mbs == -1 assert permissions["group2"].max_monthly_mbs == -1
@@ -62,13 +97,19 @@ def test_permissions(user_state_with_groups):
def test_user_groups_names(user_state): def test_user_groups_names(user_state):
with patch('app.web.db.crud.get_user_group_names', return_value=["group1", "group2"]) as mock: with patch(
"app.web.db.crud.get_user_group_names",
return_value=["group1", "group2"],
) as mock:
assert user_state.user_groups_names == ["group1", "group2", "default"] assert user_state.user_groups_names == ["group1", "group2", "default"]
mock.assert_called_once_with(None, "test@example.com") mock.assert_called_once_with(None, "test@example.com")
def test_user_groups(user_state): def test_user_groups(user_state):
with patch('app.web.db.crud.get_user_groups_by_name', return_value=[MagicMock(), MagicMock()]) as mock: with patch(
"app.web.db.crud.get_user_groups_by_name",
return_value=[MagicMock(), MagicMock()],
) as mock:
user_state._user_groups_names = ["group1", "group2"] user_state._user_groups_names = ["group1", "group2"]
assert len(user_state.user_groups) == 2 assert len(user_state.user_groups) == 2
mock.assert_called_once_with(None, ["group1", "group2"]) mock.assert_called_once_with(None, ["group1", "group2"])
@@ -77,85 +118,166 @@ def test_user_groups(user_state):
def test_read(): def test_read():
us = fresh_user_state() us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock: with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[models.Group(id="no-permissions", permissions={})],
) as mock:
assert not hasattr(us, "_read") assert not hasattr(us, "_read")
assert us.read == set() assert us.read == set()
assert us._read == set() assert us._read == set()
mock.assert_called_once() mock.assert_called_once()
us = fresh_user_state() us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"read": ["group1", "no-permissions"]})]): with patch.object(
assert us.read == set(["group1", "no-permissions"]) UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(
id="group1", permissions={"read": ["group1", "no-permissions"]}
)
],
):
assert us.read == {"group1", "no-permissions"}
us = fresh_user_state() us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"read": ["all"]})]): with patch.object(
assert us.read == True UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[models.Group(id="group1", permissions={"read": ["all"]})],
):
assert us.read is True
def test_read_public(): def test_read_public():
us = fresh_user_state() us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock: with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[models.Group(id="no-permissions", permissions={})],
) as mock:
assert not hasattr(us, "_read_public") assert not hasattr(us, "_read_public")
assert us.read_public == False assert us.read_public is False
assert us._read_public == False assert us._read_public is False
mock.assert_called_once() mock.assert_called_once()
# no new calls # no new calls
assert us.read_public == False assert us.read_public is False
mock.assert_called_once() mock.assert_called_once()
us = fresh_user_state() us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"read_public": True})]): with patch.object(
assert us.read_public == True UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(id="group1", permissions={"read_public": True})
],
):
assert us.read_public is True
us = fresh_user_state() us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"read_public": False})]): with patch.object(
assert us.read_public == False UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(id="group1", permissions={"read_public": False})
],
):
assert us.read_public is False
def test_archive_url(): def test_archive_url():
us = fresh_user_state() us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock: with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[models.Group(id="no-permissions", permissions={})],
) as mock:
assert not hasattr(us, "_archive_url") assert not hasattr(us, "_archive_url")
assert us.archive_url == False assert us.archive_url is False
assert us._archive_url == False assert us._archive_url is False
mock.assert_called_once() mock.assert_called_once()
# no new calls # no new calls
assert us.archive_url == False assert us.archive_url is False
mock.assert_called_once() mock.assert_called_once()
us = fresh_user_state() us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"archive_url": False})]): with patch.object(
assert us.archive_url == False UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(id="group1", permissions={"archive_url": False})
],
):
assert us.archive_url is False
us = fresh_user_state() us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"archive_url": True})]): with patch.object(
assert us.archive_url == True UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(id="group1", permissions={"archive_url": True})
],
):
assert us.archive_url is True
def test_archive_sheet(): def test_archive_sheet():
us = fresh_user_state() us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock: with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[models.Group(id="no-permissions", permissions={})],
) as mock:
assert not hasattr(us, "_archive_sheet") assert not hasattr(us, "_archive_sheet")
assert us.archive_sheet == False assert us.archive_sheet is False
assert us._archive_sheet == False assert us._archive_sheet is False
mock.assert_called_once() mock.assert_called_once()
# no new calls # no new calls
assert us.archive_sheet == False assert us.archive_sheet is False
mock.assert_called_once() mock.assert_called_once()
us = fresh_user_state() us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"archive_sheet": False})]): with patch.object(
assert us.archive_sheet == False UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(id="group1", permissions={"archive_sheet": False})
],
):
assert us.archive_sheet is False
us = fresh_user_state() us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"archive_sheet": True})]): with patch.object(
assert us.archive_sheet == True UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(id="group1", permissions={"archive_sheet": True})
],
):
assert us.archive_sheet is True
def test_sheet_frequency(): def test_sheet_frequency():
us = fresh_user_state() us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock: with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[models.Group(id="no-permissions", permissions={})],
) as mock:
assert not hasattr(us, "_sheet_frequency") assert not hasattr(us, "_sheet_frequency")
assert us.sheet_frequency == set() assert us.sheet_frequency == set()
assert us._sheet_frequency == set() assert us._sheet_frequency == set()
@@ -165,18 +287,42 @@ def test_sheet_frequency():
mock.assert_called_once() mock.assert_called_once()
us = fresh_user_state() us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"sheet_frequency": ["daily", "hourly"]})]): with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(
id="group1",
permissions={"sheet_frequency": ["daily", "hourly"]},
)
],
):
assert us.sheet_frequency == {"daily", "hourly"} assert us.sheet_frequency == {"daily", "hourly"}
us = fresh_user_state() us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"sheet_frequency": []})]): with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(id="group1", permissions={"sheet_frequency": []})
],
):
assert us.sheet_frequency == set() assert us.sheet_frequency == set()
def test_max_archive_lifespan_months(): def test_max_archive_lifespan_months():
us = fresh_user_state() us = fresh_user_state()
default = GroupPermissions.model_fields["max_archive_lifespan_months"].default default = GroupPermissions.model_fields[
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock: "max_archive_lifespan_months"
].default
with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[models.Group(id="no-permissions", permissions={})],
) as mock:
assert not hasattr(us, "_max_archive_lifespan_months") assert not hasattr(us, "_max_archive_lifespan_months")
assert us.max_archive_lifespan_months == default assert us.max_archive_lifespan_months == default
assert us._max_archive_lifespan_months == default assert us._max_archive_lifespan_months == default
@@ -186,18 +332,44 @@ def test_max_archive_lifespan_months():
mock.assert_called_once() mock.assert_called_once()
us = fresh_user_state() us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_archive_lifespan_months": 24})]): with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(
id="group1", permissions={"max_archive_lifespan_months": 24}
)
],
):
assert us.max_archive_lifespan_months == 24 assert us.max_archive_lifespan_months == 24
us = fresh_user_state() us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_archive_lifespan_months": 150}), models.Group(id="group2", permissions={"max_archive_lifespan_months": -1})]): with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(
id="group1", permissions={"max_archive_lifespan_months": 150}
),
models.Group(
id="group2", permissions={"max_archive_lifespan_months": -1}
),
],
):
assert us.max_archive_lifespan_months == -1 assert us.max_archive_lifespan_months == -1
def test_max_monthly_urls(): def test_max_monthly_urls():
us = fresh_user_state() us = fresh_user_state()
default = GroupPermissions.model_fields["max_monthly_urls"].default default = GroupPermissions.model_fields["max_monthly_urls"].default
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock: with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[models.Group(id="no-permissions", permissions={})],
) as mock:
assert not hasattr(us, "_max_monthly_urls") assert not hasattr(us, "_max_monthly_urls")
assert us.max_monthly_urls == default assert us.max_monthly_urls == default
assert us._max_monthly_urls == default assert us._max_monthly_urls == default
@@ -207,18 +379,38 @@ def test_max_monthly_urls():
mock.assert_called_once() mock.assert_called_once()
us = fresh_user_state() us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_monthly_urls": 100})]): with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(id="group1", permissions={"max_monthly_urls": 100})
],
):
assert us.max_monthly_urls == 100 assert us.max_monthly_urls == 100
us = fresh_user_state() us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_monthly_urls": 150}), models.Group(id="group2", permissions={"max_monthly_urls": -1})]): with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(id="group1", permissions={"max_monthly_urls": 150}),
models.Group(id="group2", permissions={"max_monthly_urls": -1}),
],
):
assert us.max_monthly_urls == -1 assert us.max_monthly_urls == -1
def test_max_monthly_mbs(): def test_max_monthly_mbs():
us = fresh_user_state() us = fresh_user_state()
default = GroupPermissions.model_fields["max_monthly_mbs"].default default = GroupPermissions.model_fields["max_monthly_mbs"].default
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock: with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[models.Group(id="no-permissions", permissions={})],
) as mock:
assert not hasattr(us, "_max_monthly_mbs") assert not hasattr(us, "_max_monthly_mbs")
assert us.max_monthly_mbs == default assert us.max_monthly_mbs == default
assert us._max_monthly_mbs == default assert us._max_monthly_mbs == default
@@ -228,17 +420,37 @@ def test_max_monthly_mbs():
mock.assert_called_once() mock.assert_called_once()
us = fresh_user_state() us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_monthly_mbs": 1000})]): with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(id="group1", permissions={"max_monthly_mbs": 1000})
],
):
assert us.max_monthly_mbs == 1000 assert us.max_monthly_mbs == 1000
us = fresh_user_state() us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_monthly_mbs": 1500}), models.Group(id="group2", permissions={"max_monthly_mbs": -1})]): with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(id="group1", permissions={"max_monthly_mbs": 1500}),
models.Group(id="group2", permissions={"max_monthly_mbs": -1}),
],
):
assert us.max_monthly_mbs == -1 assert us.max_monthly_mbs == -1
def test_priority(user_state): def test_priority(user_state):
default = GroupPermissions.model_fields["priority"].default default = GroupPermissions.model_fields["priority"].default
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock: with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[models.Group(id="no-permissions", permissions={})],
) as mock:
assert not hasattr(user_state, "_priority") assert not hasattr(user_state, "_priority")
assert user_state.priority == default assert user_state.priority == default
assert user_state._priority == default assert user_state._priority == default
@@ -248,11 +460,26 @@ def test_priority(user_state):
mock.assert_called_once() mock.assert_called_once()
us = fresh_user_state() us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"priority": "high"})]): with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(id="group1", permissions={"priority": "high"})
],
):
assert us.priority == "high" assert us.priority == "high"
us = fresh_user_state() us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"priority": "low"}), models.Group(id="group2", permissions={"priority": "medium"})]): with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(id="group1", permissions={"priority": "low"}),
models.Group(id="group2", permissions={"priority": "medium"}),
],
):
assert us.priority == "low" assert us.priority == "low"
@@ -262,21 +489,45 @@ def test_active():
(True, False, False, False, True), (True, False, False, False, True),
(False, True, False, False, True), (False, True, False, False, True),
(False, False, True, False, True), (False, False, True, False, True),
(False, False, False, True, True) (False, False, False, True, True),
]: ]:
us = fresh_user_state() us = fresh_user_state()
with patch.object(UserState, 'read', new_callable=PropertyMock, return_value=read), \ with (
patch.object(UserState, 'read_public', new_callable=PropertyMock, return_value=read_public), \ patch.object(
patch.object(UserState, 'archive_url', new_callable=PropertyMock, return_value=archive_url), \ UserState, "read", new_callable=PropertyMock, return_value=read
patch.object(UserState, 'archive_sheet', new_callable=PropertyMock, return_value=archive_sheet): ),
patch.object(
UserState,
"read_public",
new_callable=PropertyMock,
return_value=read_public,
),
patch.object(
UserState,
"archive_url",
new_callable=PropertyMock,
return_value=archive_url,
),
patch.object(
UserState,
"archive_sheet",
new_callable=PropertyMock,
return_value=archive_sheet,
),
):
assert us.active == is_active assert us.active == is_active
def test_in_group(user_state): def test_in_group(user_state):
with patch.object(UserState, 'user_groups_names', new_callable=PropertyMock, return_value=["group1", "group2"]): with patch.object(
assert user_state.in_group("group1") == True UserState,
assert user_state.in_group("group2") == True "user_groups_names",
assert user_state.in_group("group3") == False new_callable=PropertyMock,
return_value=["group1", "group2"],
):
assert user_state.in_group("group1") is True
assert user_state.in_group("group2") is True
assert user_state.in_group("group3") is False
def test_usage(db_session): def test_usage(db_session):
@@ -294,10 +545,34 @@ def test_usage(db_session):
] ]
megabytes = int(sum(bytes) / 1024 / 1024) megabytes = int(sum(bytes) / 1024 / 1024)
with patch.object(db_session, 'query', side_effect=[ with patch.object(
MagicMock(filter=MagicMock(return_value=MagicMock(group_by=MagicMock(return_value=MagicMock(all=MagicMock(return_value=user_sheets)))))), db_session,
MagicMock(filter=MagicMock(return_value=MagicMock(group_by=MagicMock(return_value=MagicMock(all=MagicMock(return_value=urls_by_group)))))) "query",
]): side_effect=[
MagicMock(
filter=MagicMock(
return_value=MagicMock(
group_by=MagicMock(
return_value=MagicMock(
all=MagicMock(return_value=user_sheets)
)
)
)
)
),
MagicMock(
filter=MagicMock(
return_value=MagicMock(
group_by=MagicMock(
return_value=MagicMock(
all=MagicMock(return_value=urls_by_group)
)
)
)
)
),
],
):
usage_response = user_state.usage() usage_response = user_state.usage()
assert usage_response.monthly_urls == 155 assert usage_response.monthly_urls == 155
@@ -305,11 +580,15 @@ def test_usage(db_session):
assert usage_response.total_sheets == 115 assert usage_response.total_sheets == 115
assert usage_response.groups["group1"].monthly_urls == 50 assert usage_response.groups["group1"].monthly_urls == 50
assert usage_response.groups["group1"].monthly_mbs == int(bytes[0] / 1024 / 1024) assert usage_response.groups["group1"].monthly_mbs == int(
bytes[0] / 1024 / 1024
)
assert usage_response.groups["group1"].total_sheets == 5 assert usage_response.groups["group1"].total_sheets == 5
assert usage_response.groups["group2"].monthly_urls == 100 assert usage_response.groups["group2"].monthly_urls == 100
assert usage_response.groups["group2"].monthly_mbs == int(bytes[1] / 1024 / 1024) assert usage_response.groups["group2"].monthly_mbs == int(
bytes[1] / 1024 / 1024
)
assert usage_response.groups["group2"].total_sheets == 10 assert usage_response.groups["group2"].total_sheets == 10
assert usage_response.groups["group3"].monthly_urls == 0 assert usage_response.groups["group3"].monthly_urls == 0
@@ -317,7 +596,9 @@ def test_usage(db_session):
assert usage_response.groups["group3"].total_sheets == 100 assert usage_response.groups["group3"].total_sheets == 100
assert usage_response.groups["group4"].monthly_urls == 5 assert usage_response.groups["group4"].monthly_urls == 5
assert usage_response.groups["group4"].monthly_mbs == int(bytes[2] / 1024 / 1024) assert usage_response.groups["group4"].monthly_mbs == int(
bytes[2] / 1024 / 1024
)
assert usage_response.groups["group4"].total_sheets == 0 assert usage_response.groups["group4"].total_sheets == 0
@@ -333,8 +614,23 @@ def test_has_quota_monthly_sheets(db_session):
] ]
for permissions, count, expected in test_cases: for permissions, count, expected in test_cases:
with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions): with patch.object(
with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(count=MagicMock(return_value=count))))): UserState,
"permissions",
new_callable=PropertyMock,
return_value=permissions,
):
with patch.object(
us.db,
"query",
return_value=MagicMock(
filter=MagicMock(
return_value=MagicMock(
count=MagicMock(return_value=count)
)
)
),
):
assert us.has_quota_monthly_sheets("group1") == expected assert us.has_quota_monthly_sheets("group1") == expected
@@ -349,8 +645,23 @@ def test_has_quota_max_monthly_urls(db_session):
] ]
for permissions, count, expected in test_cases: for permissions, count, expected in test_cases:
with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions): with patch.object(
with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(count=MagicMock(return_value=count))))): UserState,
"permissions",
new_callable=PropertyMock,
return_value=permissions,
):
with patch.object(
us.db,
"query",
return_value=MagicMock(
filter=MagicMock(
return_value=MagicMock(
count=MagicMock(return_value=count)
)
)
),
):
assert us.has_quota_max_monthly_urls("group1") == expected assert us.has_quota_max_monthly_urls("group1") == expected
test_cases = [ test_cases = [
(-1, 1000, True), (-1, 1000, True),
@@ -360,8 +671,23 @@ def test_has_quota_max_monthly_urls(db_session):
] ]
for max_urls, count, expected in test_cases: for max_urls, count, expected in test_cases:
with patch.object(UserState, 'max_monthly_urls', new_callable=PropertyMock, return_value=max_urls): with patch.object(
with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(count=MagicMock(return_value=count))))): UserState,
"max_monthly_urls",
new_callable=PropertyMock,
return_value=max_urls,
):
with patch.object(
us.db,
"query",
return_value=MagicMock(
filter=MagicMock(
return_value=MagicMock(
count=MagicMock(return_value=count)
)
)
),
):
assert us.has_quota_max_monthly_urls("") == expected assert us.has_quota_max_monthly_urls("") == expected
@@ -376,8 +702,29 @@ def test_has_quota_max_monthly_mbs(db_session):
] ]
for permissions, mbs, expected in test_cases: for permissions, mbs, expected in test_cases:
with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions): with patch.object(
with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(with_entities=MagicMock(return_value=MagicMock(scalar=MagicMock(return_value=mbs * 1024 * 1024))))))): UserState,
"permissions",
new_callable=PropertyMock,
return_value=permissions,
):
with patch.object(
us.db,
"query",
return_value=MagicMock(
filter=MagicMock(
return_value=MagicMock(
with_entities=MagicMock(
return_value=MagicMock(
scalar=MagicMock(
return_value=mbs * 1024 * 1024
)
)
)
)
)
),
):
assert us.has_quota_max_monthly_mbs("group1") == expected assert us.has_quota_max_monthly_mbs("group1") == expected
test_cases = [ test_cases = [
@@ -388,8 +735,29 @@ def test_has_quota_max_monthly_mbs(db_session):
] ]
for max_mbs, mbs, expected in test_cases: for max_mbs, mbs, expected in test_cases:
with patch.object(UserState, 'max_monthly_mbs', new_callable=PropertyMock, return_value=max_mbs): with patch.object(
with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(with_entities=MagicMock(return_value=MagicMock(scalar=MagicMock(return_value=mbs * 1024 * 1024))))))): UserState,
"max_monthly_mbs",
new_callable=PropertyMock,
return_value=max_mbs,
):
with patch.object(
us.db,
"query",
return_value=MagicMock(
filter=MagicMock(
return_value=MagicMock(
with_entities=MagicMock(
return_value=MagicMock(
scalar=MagicMock(
return_value=mbs * 1024 * 1024
)
)
)
)
)
),
):
assert us.has_quota_max_monthly_mbs("") == expected assert us.has_quota_max_monthly_mbs("") == expected
@@ -399,10 +767,15 @@ def test_can_manually_trigger(user_state):
"group2": GroupInfo(manually_trigger_sheet=False), "group2": GroupInfo(manually_trigger_sheet=False),
} }
with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions): with patch.object(
assert user_state.can_manually_trigger("group1") == True UserState,
assert user_state.can_manually_trigger("group2") == False "permissions",
assert user_state.can_manually_trigger("group3") == False new_callable=PropertyMock,
return_value=permissions,
):
assert user_state.can_manually_trigger("group1") is True
assert user_state.can_manually_trigger("group2") is False
assert user_state.can_manually_trigger("group3") is False
def test_is_sheet_frequency_allowed(user_state): def test_is_sheet_frequency_allowed(user_state):
@@ -411,23 +784,44 @@ def test_is_sheet_frequency_allowed(user_state):
"group2": GroupInfo(sheet_frequency={"daily"}), "group2": GroupInfo(sheet_frequency={"daily"}),
} }
with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions): with patch.object(
assert user_state.is_sheet_frequency_allowed("group1", "daily") == True UserState,
assert user_state.is_sheet_frequency_allowed("group1", "hourly") == True "permissions",
assert user_state.is_sheet_frequency_allowed("group1", "weekly") == False new_callable=PropertyMock,
assert user_state.is_sheet_frequency_allowed("group2", "hourly") == False return_value=permissions,
assert user_state.is_sheet_frequency_allowed("group2", "daily") == True ):
assert user_state.is_sheet_frequency_allowed("group3", "daily") == False assert user_state.is_sheet_frequency_allowed("group1", "daily") is True
assert user_state.is_sheet_frequency_allowed("group1", "hourly") is True
assert (
user_state.is_sheet_frequency_allowed("group1", "weekly") is False
)
assert (
user_state.is_sheet_frequency_allowed("group2", "hourly") is False
)
assert user_state.is_sheet_frequency_allowed("group2", "daily") is True
assert user_state.is_sheet_frequency_allowed("group3", "daily") is False
def test_priority_group(user_state): def test_priority_group(user_state):
from app.web.utils.misc import convert_priority_to_queue_dict with patch.object(
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[ UserState,
models.Group(id="group1", permissions={"priority": "high"}), "user_groups",
models.Group(id="group2", permissions={"priority": "medium"}), new_callable=PropertyMock,
models.Group(id="group3", permissions={"priority": "low"}), return_value=[
]): models.Group(id="group1", permissions={"priority": "high"}),
assert user_state.priority_group("group1") == convert_priority_to_queue_dict("high") models.Group(id="group2", permissions={"priority": "medium"}),
assert user_state.priority_group("group2") == convert_priority_to_queue_dict("medium") models.Group(id="group3", permissions={"priority": "low"}),
assert user_state.priority_group("group3") == convert_priority_to_queue_dict("low") ],
assert user_state.priority_group("group4") == convert_priority_to_queue_dict("low") ):
assert user_state.priority_group(
"group1"
) == convert_priority_to_queue_dict("high")
assert user_state.priority_group(
"group2"
) == convert_priority_to_queue_dict("medium")
assert user_state.priority_group(
"group3"
) == convert_priority_to_queue_dict("low")
assert user_state.priority_group(
"group4"
) == convert_priority_to_queue_dict("low")

View File

@@ -1,60 +0,0 @@
from datetime import datetime
import json
from unittest.mock import MagicMock, patch
from app.shared.db import models
from app.web.config import ALLOW_ANY_EMAIL
from app.web.db import crud
def test_submit_manual_archive_unauthenticated(client, test_no_auth):
test_no_auth(client.post, "/interop/submit-archive")
def test_submit_manual_archive_not_user_auth(client_with_auth, test_no_auth):
test_no_auth(client_with_auth.post, "/interop/submit-archive")
@patch("app.web.endpoints.interoperability.business_logic", return_value=MagicMock(get_store_archive_until=MagicMock(return_value=datetime)))
def test_submit_manual_archive(m1, client_with_token, db_session):
# normal workflow
aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.s3.com"]}]})
r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": True, "author_id": "jerry@gmail.com", "group_id": "spaceship", "tags": ["test"], "url": "http://example.com"})
assert r.status_code == 201
assert "id" in r.json()
inserted = db_session.query(models.Archive).filter(models.Archive.id == r.json()["id"]).first()
assert inserted.url == "http://example.com"
assert inserted.group_id == "spaceship"
assert inserted.author_id == "jerry@gmail.com"
assert sorted([t.id for t in inserted.tags]) == sorted(["test", "manual"])
assert inserted.public
assert type(inserted.result) == dict
assert [u.url for u in inserted.urls] == ["http://example.s3.com"]
assert type(inserted.store_until) == datetime
# cannot have the same URL twice
aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.com", "http://example.com"]}]})
r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": False, "author_id": "jerry@gmail.com", "tags": ["test"], "url": "http://example.com"})
assert r.status_code == 422
assert r.json() == {"detail": "Cannot insert into DB due to integrity error, likely duplicate urls."}
# test with invalid JSON
def test_submit_manual_archive_invalid_json(client_with_token):
r = client_with_token.post("/interop/submit-archive", json={"result": "invalid json", "public": False, "author_id": "jer", "tags": ["test"], "url": "http://example.com"})
assert r.status_code == 422
assert r.json() == {"detail": "Invalid JSON in result field."}
@patch("app.web.endpoints.interoperability.business_logic.get_store_archive_until", side_effect=AssertionError("AssertionError"))
def test_submit_manual_archive_no_store_until(m_sau, client_with_token, db_session):
aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.s3.com"]}]})
r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": True, "author_id": "jerry@gmail.com", "group_id": "spaceship", "tags": ["test"], "url": "http://example.com"})
assert r.status_code == 201
assert len(r.json()["id"]) == 36
res = db_session.query(models.Archive).filter(models.Archive.id == r.json()["id"]).first()
assert res.store_until is None
# testing that store_until = None is not comparable with datetime, and will always return False
res = db_session.query(models.Archive).filter(models.Archive.id == r.json()["id"], models.Archive.store_until < datetime.now()).first()
assert res is None

View File

@@ -1,193 +0,0 @@
from datetime import datetime
import json
from unittest.mock import MagicMock, patch
from fastapi.testclient import TestClient
from app.shared.schemas import TaskResult
def test_endpoints_no_auth(client, test_no_auth):
test_no_auth(client.post, "/sheet/create")
test_no_auth(client.get, "/sheet/mine")
test_no_auth(client.delete, "/sheet/123-sheet-id")
test_no_auth(client.post, "/sheet/123-sheet-id/archive")
def test_create_sheet_endpoint(app_with_auth, db_session):
client_with_auth = TestClient(app_with_auth)
good_data = {
"id": "123-sheet-id",
"name": "Test Sheet",
"group_id": "spaceship",
"frequency": "daily"
}
# with good data
response = client_with_auth.post("/sheet/create", json=good_data)
assert response.status_code == 201
j = response.json()
assert datetime.fromisoformat(j.pop("created_at"))
assert datetime.fromisoformat(j.pop("last_url_archived_at"))
assert j.pop("author_id") == 'morty@example.com'
assert j == good_data
# already exists
response = client_with_auth.post("/sheet/create", json=good_data)
assert response.status_code == 400
assert response.json() == {"detail": "Sheet with this ID is already being archived."}
# bad group
bad_data = good_data.copy()
bad_data["group_id"] = "not a group"
response = client_with_auth.post("/sheet/create", json=bad_data)
assert response.status_code == 403
assert response.json() == {"detail": "User does not have access to this group."}
# switch to jerry who's got less quota/permissions
from app.web.security import get_user_state
from app.web.db.user_state import UserState
app_with_auth.dependency_overrides[get_user_state] = lambda: UserState(db_session, "jerry@example.com")
client_jerry = TestClient(app_with_auth)
# frequency not allowed
jerry_data = good_data.copy()
jerry_data["group_id"] = "animated-characters"
jerry_data["frequency"] = "hourly"
jerry_data["id"] = "jerry-sheet-id"
response = client_jerry.post("/sheet/create", json=jerry_data)
assert response.status_code == 422
assert response.json() == {"detail": "Invalid frequency selected for this group."}
jerry_data["frequency"] = "daily"
# success for the first sheet, bad quota on second
response = client_jerry.post("/sheet/create", json=jerry_data)
assert response.status_code == 201
response = client_jerry.post("/sheet/create", json=jerry_data)
assert response.status_code == 429
assert response.json() == {"detail": "User has reached their sheet quota for this group."}
def test_get_user_sheets_endpoint(client_with_auth, db_session):
# no data
response = client_with_auth.get("/sheet/mine")
assert response.status_code == 200
assert response.json() == []
# with data
from app.shared.db import models
db_session.add(
models.Sheet(id="123", name="Test Sheet 1", author_id="morty@example.com", group_id="spaceship", frequency="hourly")
)
db_session.commit()
db_session.add_all([
models.Sheet(id="456", name="Test Sheet 2", author_id="morty@example.com", group_id="interdimensional", frequency="daily"),
models.Sheet(id="789", name="Test Sheet 3", author_id="rick@example.com", group_id="interdimensional", frequency="hourly"),
])
db_session.commit()
response = client_with_auth.get("/sheet/mine")
assert response.status_code == 200
r = response.json()
assert isinstance(r, list)
assert len(r) == 2
assert datetime.fromisoformat(r[0].pop("created_at"))
assert datetime.fromisoformat(r[0].pop("last_url_archived_at"))
assert datetime.fromisoformat(r[1].pop("created_at"))
assert datetime.fromisoformat(r[1].pop("last_url_archived_at"))
assert r[0] == {
'id': '123',
'author_id': 'morty@example.com',
'frequency': 'hourly',
'group_id': 'spaceship',
'name': 'Test Sheet 1',
}
assert r[1] == {
'id': '456',
'author_id': 'morty@example.com',
'frequency': 'daily',
'group_id': 'interdimensional',
'name': 'Test Sheet 2',
}
def test_delete_sheet_endpoint(client_with_auth, db_session):
# missing sheet
response = client_with_auth.delete("/sheet/123-sheet-id")
assert response.status_code == 200
assert response.json() == {
"id": "123-sheet-id",
"deleted": False
}
# add sheets for deletion
from app.shared.db import models
db_session.add_all([
models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="interdimensional", frequency="daily"),
models.Sheet(id="456-sheet-id", name="Test Sheet 2", author_id="rick@example.com", group_id="spaceship", frequency="hourly"),
])
db_session.commit()
# morty can delete his
response = client_with_auth.delete("/sheet/123-sheet-id")
assert response.status_code == 200
assert response.json() == {"id": "123-sheet-id", "deleted": True}
# but only once
response = client_with_auth.delete("/sheet/123-sheet-id")
assert response.status_code == 200
assert response.json() == {"id": "123-sheet-id", "deleted": False}
# and not rick's
response = client_with_auth.delete("/sheet/456-sheet-id")
assert response.status_code == 200
assert response.json() == {"id": "456-sheet-id", "deleted": False}
class TestArchiveUserSheetEndpoint:
@patch("app.web.endpoints.sheet.celery", return_value=MagicMock())
def test_normal_flow(self, m_celery, client_with_auth, db_session):
from app.shared.db import models
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="spaceship", frequency="hourly"))
db_session.commit()
m_signature = MagicMock()
m_signature.apply_async.return_value = TaskResult(id="123-taskid", status="PENDING", result="")
m_celery.signature.return_value = m_signature
r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == 201
assert r.json() == {"id": "123-taskid"}
m_celery.signature.assert_called_once()
m_signature.apply_async.assert_called_once()
def test_token_auth(self, client_with_token, test_no_auth):
test_no_auth(client_with_token.post, "/sheet/123-sheet-id/archive")
def test_missing_data(self, client_with_auth):
r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == 403
assert r.json() == {"detail": "No access to this sheet."}
def test_no_access(self, client_with_auth, db_session):
from app.shared.db import models
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="rick@example.com", group_id="spaceship", frequency="hourly"))
db_session.commit()
r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == 403
assert r.json() == {"detail": "No access to this sheet."}
def test_user_not_in_group(self, client_with_auth, db_session):
from app.shared.db import models
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="interdimensional", frequency="hourly"))
db_session.commit()
r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == 403
assert r.json() == {"detail": "User does not have access to this group."}
def test_user_cannot_manually_trigger(self, client_with_auth, db_session):
from app.shared.db import models
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="default", frequency="hourly"))
db_session.commit()
r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == 429
assert r.json() == {"detail": "User cannot manually trigger sheet archiving in this group."}

View File

@@ -1,202 +0,0 @@
import json
from unittest.mock import MagicMock, patch
from app.shared.schemas import ArchiveCreate, TaskResult
from app.web.config import ALLOW_ANY_EMAIL
def test_archive_url_unauthenticated(client, test_no_auth):
test_no_auth(client.post, "/url/archive")
@patch("app.web.endpoints.url.UserState")
@patch("app.web.endpoints.url.celery", return_value=MagicMock())
def test_archive_url(m_celery, m2, client_with_auth):
m_signature = MagicMock()
m_signature.apply_async.return_value = TaskResult(id="123-456-789", status="PENDING", result="")
m_celery.signature.return_value = m_signature
m_user_state = MagicMock()
m2.return_value = m_user_state
# url is too short
response = client_with_auth.post("/url/archive", json={"url": "bad"})
assert response.status_code == 422
assert response.json()["detail"][0]["msg"] == 'String should have at least 5 characters'
m_celery.signature.assert_not_called()
# url is invalid
response = client_with_auth.post("/url/archive", json={"url": "example.com"})
assert response.status_code == 400
assert response.json()["detail"] == "Invalid URL received."
# valid request
m_user_state.has_quota_max_monthly_urls.return_value = True
m_user_state.has_quota_max_monthly_mbs.return_value = True
response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
assert response.status_code == 201
assert response.json() == {'id': '123-456-789'}
m_celery.signature.assert_called_once()
m_signature.apply_async.assert_called_once()
called_val = m_celery.signature.call_args
assert called_val[0][0] == "create_archive_task"
assert json.loads(called_val[1]['args'][0]) == {"id": None, "url": "https://example.com", "result": None, "public": False, "author_id": "rick@example.com", "group_id": "default", "tags": None, "sheet_id": None, "store_until": None, "urls": None}
m_user_state.has_quota_max_monthly_urls.assert_called_once()
m_user_state.has_quota_max_monthly_mbs.assert_called_once()
m_user_state.in_group.assert_called_once_with("default")
# user is not in group
m_user_state.in_group.return_value = False
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "new-group"})
assert response.status_code == 403
assert response.json()["detail"] == "User does not have access to this group."
m_user_state.in_group.assert_called_with("new-group")
# user is in group
m_user_state.in_group.return_value = True
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spaceship"})
assert response.status_code == 201
assert response.json() == {'id': '123-456-789'}
assert m_celery.signature.call_count == 2
assert m_signature.apply_async.call_count == 2
called_val = m_celery.signature.call_args
assert json.loads(called_val[1]['args'][0])["group_id"] == "spaceship"
m_user_state.in_group.assert_called_with("spaceship")
# user is over monthly URL quota
m_user_state.has_quota_max_monthly_urls.return_value = False
m_user_state.has_quota_max_monthly_mbs.return_value = True
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spaceship"})
assert response.status_code == 429
assert response.json()["detail"] == "User has reached their monthly URL quota."
m_user_state.has_quota_max_monthly_urls.assert_called_with("spaceship")
# user is over monthly MB quota
m_user_state.has_quota_max_monthly_urls.return_value = True
m_user_state.has_quota_max_monthly_mbs.return_value = False
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spacesuit"})
assert response.status_code == 429
assert response.json()["detail"] == "User has reached their monthly MB quota."
m_user_state.has_quota_max_monthly_mbs.assert_called_with("spacesuit")
assert m_celery.signature.call_count == 2
assert m_signature.apply_async.call_count == 2
@patch("app.web.endpoints.url.UserState")
def test_archive_url_quotas(m1, client_with_auth):
m_user_state = MagicMock()
m1.return_value = m_user_state
# misses on monthly URLs quota
m_user_state.has_quota_max_monthly_urls.return_value = False
response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
assert response.status_code == 429
assert response.json()["detail"] == "User has reached their monthly URL quota."
m_user_state.has_quota_max_monthly_urls.assert_called_once()
# misses on monthly MBs quota
m_user_state.has_quota_max_monthly_urls.return_value = True
m_user_state.has_quota_max_monthly_mbs.return_value = False
response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
assert response.status_code == 429
assert response.json()["detail"] == "User has reached their monthly MB quota."
m_user_state.has_quota_max_monthly_mbs.assert_called_once()
@patch("app.web.endpoints.url.celery", return_value=MagicMock())
def test_archive_url_with_api_token(m_celery, client_with_token):
m_signature = MagicMock()
m_signature.apply_async.return_value = TaskResult(id="123-456-789", status="PENDING", result="")
m_celery.signature.return_value = m_signature
response = client_with_token.post("/url/archive", json={"url": "https://example.com", "author_id": "someone@example.com"})
assert response.status_code == 201
assert response.json() == {'id': '123-456-789'}
m_celery.signature.assert_called_once()
m_signature.apply_async.assert_called_once()
called_val = m_celery.signature.call_args
assert called_val[0][0] == "create_archive_task"
assert json.loads(called_val[1]['args'][0]) == {"id": None, "url": "https://example.com", "result": None, "public": False, "author_id": "someone@example.com", "group_id": "default", "tags": None, "sheet_id": None, "store_until": None, "urls": None}
# missing id should use ALLOW_ANY_EMAIL
response = client_with_token.post("/url/archive", json={"url": "https://example.com", "author_id": None})
assert response.status_code == 201
called_val = m_celery.signature.call_args
assert called_val[0][0] == "create_archive_task"
assert json.loads(called_val[1]['args'][0]) == {"id": None, "url": "https://example.com", "result": None, "public": False, "author_id": ALLOW_ANY_EMAIL, "group_id": "default", "tags": None, "sheet_id": None, "store_until": None, "urls": None}
def test_search_by_url_unauthenticated(client, test_no_auth):
test_no_auth(client.get, "/url/search")
def test_search_by_url(client_with_auth, client_with_token, db_session):
# tests the search endpoint, including through some db data for the endpoint params
response = client_with_auth.get("/url/search")
assert response.status_code == 422
assert response.json()["detail"][0]["msg"] == "Field required"
response = client_with_auth.get("/url/search?url=https://example.com")
assert response.status_code == 200
assert response.json() == []
from app.shared import schemas
from app.shared.db import worker_crud
for i in range(11):
worker_crud.create_archive(db_session, ArchiveCreate(id=f"url-456-{i}", url="https://example.com" if i < 10 else "https://something-else.com", result={}, public=True, author_id="rick@example.com"), [], [])
# NB: this insertion is too fast for the ordering to be correct as they are within the same second
response = client_with_auth.get("/url/search?url=https://example.com")
assert response.status_code == 200
assert len(j := response.json()) == 10
assert "url-456-0" in [i["id"] for i in j]
assert "url-456-9" in [i["id"] for i in j]
assert "url-456-10" not in [i["id"] for i in j]
assert j[0].keys() == schemas.ArchiveResult.model_fields.keys()
response = client_with_auth.get("/url/search?url=https://example.com&limit=5")
assert response.status_code == 200
assert len(response.json()) == 5
response = client_with_auth.get("/url/search?url=https://example.com&skip=5&limit=2")
assert response.status_code == 200
assert len(response.json()) == 2
response = client_with_auth.get("/url/search?url=https://example.com&archived_before=2010-01-01")
assert response.status_code == 200
assert len(response.json()) == 0
response = client_with_auth.get("/url/search?url=https://example.com&archived_after=2010-01-01")
assert response.status_code == 200
assert len(response.json()) == 10
# API token will also work
response = client_with_token.get("/url/search?url=https://example.com&archived_after=2010-01-01")
assert response.status_code == 200
assert len(response.json()) == 10
@patch("app.web.endpoints.url.UserState")
def test_search_no_read_access(mock_user_state, client_with_auth):
mock_user_state.return_value.read = False
mock_user_state.return_value.read_public = False
response = client_with_auth.get("/url/search?url=https://example.com")
assert response.status_code == 403
assert response.json() == {"detail": "User does not have read access."}
def test_delete_task_unauthenticated(client, test_no_auth):
test_no_auth(client.delete, "/url/123-456-789")
def test_delete_task(client_with_auth, db_session):
response = client_with_auth.delete("/url/delete-123-456-789")
assert response.status_code == 200
assert response.json() == {"id": "delete-123-456-789", "deleted": False}
from app.shared.db import worker_crud
worker_crud.create_archive(db_session, ArchiveCreate(id="delete-123-456-789", url="https://example.com", result={}, public=True, author_id="morty@example.com"), [], [])
response = client_with_auth.delete("/url/delete-123-456-789")
assert response.status_code == 200
assert response.json() == {"id": "delete-123-456-789", "deleted": True}

View File

@@ -1,15 +1,20 @@
from http import HTTPStatus
from unittest.mock import MagicMock from unittest.mock import MagicMock
from fastapi.testclient import TestClient
import pytest import pytest
from fastapi.testclient import TestClient
from loguru import logger
from app.shared.schemas import Usage, UsageResponse from app.shared.schemas import Usage, UsageResponse
from app.shared.user_groups import GroupInfo from app.shared.user_groups import GroupInfo
from app.web.config import VERSION from app.web.config import VERSION
from app.tests.web.db.test_crud import test_data from app.web.security import get_user_state
from app.web.utils.metrics import measure_regular_metrics
def test_endpoint_home(client_with_auth): def test_endpoint_home(client_with_auth):
r = client_with_auth.get("/") r = client_with_auth.get("/")
assert r.status_code == 200 assert r.status_code == HTTPStatus.OK
j = r.json() j = r.json()
assert "version" in j and j["version"] == VERSION assert "version" in j and j["version"] == VERSION
assert "breakingChanges" in j assert "breakingChanges" in j
@@ -18,7 +23,7 @@ def test_endpoint_home(client_with_auth):
def test_endpoint_health(client_with_auth): def test_endpoint_health(client_with_auth):
r = client_with_auth.get("/health") r = client_with_auth.get("/health")
assert r.status_code == 200 assert r.status_code == HTTPStatus.OK
assert r.json() == {"status": "ok"} assert r.json() == {"status": "ok"}
@@ -29,32 +34,31 @@ def test_endpoint_active_no_auth(client, test_no_auth):
def test_endpoint_active(app): def test_endpoint_active(app):
m_user_state = MagicMock() m_user_state = MagicMock()
from app.web.security import get_user_state
app.dependency_overrides[get_user_state] = lambda: m_user_state app.dependency_overrides[get_user_state] = lambda: m_user_state
# inactive user # inactive user
m_user_state.active = False m_user_state.active = False
client = TestClient(app) client = TestClient(app)
r = client.get("/user/active") r = client.get("/user/active")
assert r.status_code == 200 assert r.status_code == HTTPStatus.OK
assert r.json() == {"active": False} assert r.json() == {"active": False}
# active user # active user
m_user_state.active = True m_user_state.active = True
client = TestClient(app) client = TestClient(app)
r = client.get("/user/active") r = client.get("/user/active")
assert r.status_code == 200 assert r.status_code == HTTPStatus.OK
assert r.json() == {"active": True} assert r.json() == {"active": True}
def test_no_serve_local_archive_by_default(client_with_auth): def test_no_serve_local_archive_by_default(client_with_auth):
r = client_with_auth.get("/app/local_archive_test/temp.txt") r = client_with_auth.get("/app/local_archive_test/temp.txt")
assert r.status_code == 404 assert r.status_code == HTTPStatus.NOT_FOUND
def test_favicon(client_with_auth): def test_favicon(client_with_auth):
r = client_with_auth.get("/favicon.ico") r = client_with_auth.get("/favicon.ico")
assert r.status_code == 200 assert r.status_code == HTTPStatus.OK
assert r.headers["content-type"] == "image/vnd.microsoft.icon" assert r.headers["content-type"] == "image/vnd.microsoft.icon"
@@ -70,8 +74,10 @@ def test_endpoint_test_prometheus_no_user_auth(client_with_auth, test_no_auth):
async def test_prometheus_metrics(test_data, client_with_token, get_settings): async def test_prometheus_metrics(test_data, client_with_token, get_settings):
# before metrics calculation # before metrics calculation
r = client_with_token.get("/metrics") r = client_with_token.get("/metrics")
assert r.status_code == 200 assert r.status_code == HTTPStatus.OK
assert r.headers["content-type"] == "text/plain; version=0.0.4; charset=utf-8" assert (
r.headers["content-type"] == "text/plain; version=0.0.4; charset=utf-8"
)
assert "disk_utilization" in r.text assert "disk_utilization" in r.text
assert "database_metrics" in r.text assert "database_metrics" in r.text
assert "exceptions" in r.text assert "exceptions" in r.text
@@ -79,8 +85,9 @@ async def test_prometheus_metrics(test_data, client_with_token, get_settings):
assert 'disk_utilization{type="used"}' not in r.text assert 'disk_utilization{type="used"}' not in r.text
# after metrics calculation # after metrics calculation
from app.web.utils.metrics import measure_regular_metrics await measure_regular_metrics(
await measure_regular_metrics(get_settings.DATABASE_PATH, 60 * 60 * 24 * 31 * 12 * 100) get_settings.DATABASE_PATH, 60 * 60 * 24 * 31 * 12 * 100
)
r2 = client_with_token.get("/metrics") r2 = client_with_token.get("/metrics")
assert 'disk_utilization{type="used"}' in r2.text assert 'disk_utilization{type="used"}' in r2.text
assert 'disk_utilization{type="free"}' in r2.text assert 'disk_utilization{type="free"}' in r2.text
@@ -88,20 +95,37 @@ async def test_prometheus_metrics(test_data, client_with_token, get_settings):
assert 'database_metrics{query="count_archives"} 100.0' in r2.text assert 'database_metrics{query="count_archives"} 100.0' in r2.text
assert 'database_metrics{query="count_archive_urls"} 1000.0' in r2.text assert 'database_metrics{query="count_archive_urls"} 1000.0' in r2.text
assert 'database_metrics{query="count_users"} 3.0' in r2.text assert 'database_metrics{query="count_users"} 3.0' in r2.text
assert 'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0' in r2.text assert (
assert 'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0' in r2.text 'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0'
assert 'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0' in r2.text in r2.text
)
assert (
'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0'
in r2.text
)
assert (
'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0'
in r2.text
)
# 30s window, should not change the gauges nor the total in the counters # 30s window, should not change the gauges nor the total in the counters
from app.web.utils.metrics import measure_regular_metrics
await measure_regular_metrics(get_settings.DATABASE_PATH, 30) await measure_regular_metrics(get_settings.DATABASE_PATH, 30)
r3 = client_with_token.get("/metrics") r3 = client_with_token.get("/metrics")
assert 'database_metrics{query="count_archives"} 100.0' in r3.text assert 'database_metrics{query="count_archives"} 100.0' in r3.text
assert 'database_metrics{query="count_archive_urls"} 1000.0' in r3.text assert 'database_metrics{query="count_archive_urls"} 1000.0' in r3.text
assert 'database_metrics{query="count_users"} 3.0' in r3.text assert 'database_metrics{query="count_users"} 3.0' in r3.text
assert 'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0' in r3.text assert (
assert 'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0' in r3.text 'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0'
assert 'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0' in r3.text in r3.text
)
assert (
'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0'
in r3.text
)
assert (
'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0'
in r3.text
)
def test_endpoint_get_user_permissions_no_user_auth(client, test_no_auth): def test_endpoint_get_user_permissions_no_user_auth(client, test_no_auth):
@@ -109,14 +133,12 @@ def test_endpoint_get_user_permissions_no_user_auth(client, test_no_auth):
def test_endpoint_get_user_permissions(app): def test_endpoint_get_user_permissions(app):
from app.web.security import get_user_state
m_user_state = MagicMock() m_user_state = MagicMock()
rv = { rv = {
"all": GroupInfo(read=True), "all": GroupInfo(read=True),
"group1": GroupInfo(archive_url=True), "group1": GroupInfo(archive_url=True),
} }
from loguru import logger
logger.info(rv) logger.info(rv)
m_user_state.permissions = rv m_user_state.permissions = rv
@@ -124,13 +146,13 @@ def test_endpoint_get_user_permissions(app):
client = TestClient(app) client = TestClient(app)
r = client.get("/user/permissions") r = client.get("/user/permissions")
assert r.status_code == 200 assert r.status_code == HTTPStatus.OK
response = r.json() response = r.json()
assert response.keys() == {"all", "group1"} assert response.keys() == {"all", "group1"}
assert response["all"]["read"] assert response["all"]["read"]
assert response["group1"]["read"] == [] assert response["group1"]["read"] == []
assert response["group1"]["archive_url"] assert response["group1"]["archive_url"]
assert response["all"]["archive_url"] == False assert response["all"]["archive_url"] is False
def test_endpoint_get_user_usage_no_user_auth(client, test_no_auth): def test_endpoint_get_user_usage_no_user_auth(client, test_no_auth):
@@ -138,8 +160,6 @@ def test_endpoint_get_user_usage_no_user_auth(client, test_no_auth):
def test_endpoint_get_user_usage_inactive(app): def test_endpoint_get_user_usage_inactive(app):
from app.web.security import get_user_state
m_user_state = MagicMock() m_user_state = MagicMock()
m_user_state.active = False m_user_state.active = False
@@ -147,13 +167,11 @@ def test_endpoint_get_user_usage_inactive(app):
client = TestClient(app) client = TestClient(app)
r = client.get("/user/usage") r = client.get("/user/usage")
assert r.status_code == 403 assert r.status_code == HTTPStatus.FORBIDDEN
assert r.json() == {"detail": "User is not active."} assert r.json() == {"detail": "User is not active."}
def test_endpoint_get_user_usage_active(app): def test_endpoint_get_user_usage_active(app):
from app.web.security import get_user_state
m_user_state = MagicMock() m_user_state = MagicMock()
m_user_state.active = True m_user_state.active = True
mock_usage = UsageResponse( mock_usage = UsageResponse(
@@ -162,8 +180,8 @@ def test_endpoint_get_user_usage_active(app):
total_sheets=3, total_sheets=3,
groups={ groups={
"group1": Usage(monthly_urls=4, monthly_mbs=5, total_sheets=6), "group1": Usage(monthly_urls=4, monthly_mbs=5, total_sheets=6),
"group2": Usage(monthly_urls=7, monthly_mbs=8, total_sheets=9) "group2": Usage(monthly_urls=7, monthly_mbs=8, total_sheets=9),
} },
) )
m_user_state.usage.return_value = mock_usage m_user_state.usage.return_value = mock_usage
@@ -171,5 +189,5 @@ def test_endpoint_get_user_usage_active(app):
client = TestClient(app) client = TestClient(app)
r = client.get("/user/usage") r = client.get("/user/usage")
assert r.status_code == 200 assert r.status_code == HTTPStatus.OK
assert UsageResponse(**r.json()) == mock_usage assert UsageResponse(**r.json()) == mock_usage

View File

@@ -0,0 +1,147 @@
import json
from datetime import datetime
from http import HTTPStatus
from unittest.mock import MagicMock, patch
from app.shared.db import models
def test_submit_manual_archive_unauthenticated(client, test_no_auth):
test_no_auth(client.post, "/interop/submit-archive")
def test_submit_manual_archive_not_user_auth(client_with_auth, test_no_auth):
test_no_auth(client_with_auth.post, "/interop/submit-archive")
@patch(
"app.web.routers.interoperability.business_logic",
return_value=MagicMock(
get_store_archive_until=MagicMock(return_value=datetime)
),
)
def test_submit_manual_archive(m1, client_with_token, db_session):
# normal workflow
aa_metadata = json.dumps(
{
"status": "test: success",
"metadata": {"url": "http://example.com"},
"media": [{"filename": "fn1", "urls": ["http://example.s3.com"]}],
}
)
r = client_with_token.post(
"/interop/submit-archive",
json={
"result": aa_metadata,
"public": True,
"author_id": "jerry@gmail.com",
"group_id": "spaceship",
"tags": ["test"],
"url": "http://example.com",
},
)
assert r.status_code == HTTPStatus.CREATED
assert "id" in r.json()
inserted = (
db_session.query(models.Archive)
.filter(models.Archive.id == r.json()["id"])
.first()
)
assert inserted.url == "http://example.com"
assert inserted.group_id == "spaceship"
assert inserted.author_id == "jerry@gmail.com"
assert sorted([t.id for t in inserted.tags]) == sorted(["test", "manual"])
assert inserted.public
assert isinstance(inserted.result, dict)
assert [u.url for u in inserted.urls] == ["http://example.s3.com"]
assert isinstance(inserted.store_until, datetime)
# cannot have the same URL twice
aa_metadata = json.dumps(
{
"status": "test: success",
"metadata": {"url": "http://example.com"},
"media": [
{
"filename": "fn1",
"urls": ["http://example.com", "http://example.com"],
}
],
}
)
r = client_with_token.post(
"/interop/submit-archive",
json={
"result": aa_metadata,
"public": False,
"author_id": "jerry@gmail.com",
"tags": ["test"],
"url": "http://example.com",
},
)
assert r.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
assert r.json() == {
"detail": "Cannot insert into DB due to integrity error, likely duplicate urls."
}
# test with invalid JSON
def test_submit_manual_archive_invalid_json(client_with_token):
r = client_with_token.post(
"/interop/submit-archive",
json={
"result": "invalid json",
"public": False,
"author_id": "jer",
"tags": ["test"],
"url": "http://example.com",
},
)
assert r.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
assert r.json() == {"detail": "Invalid JSON in result field."}
@patch(
"app.web.routers.interoperability.business_logic.get_store_archive_until",
side_effect=AssertionError("AssertionError"),
)
def test_submit_manual_archive_no_store_until(
m_sau, client_with_token, db_session
):
aa_metadata = json.dumps(
{
"status": "test: success",
"metadata": {"url": "http://example.com"},
"media": [{"filename": "fn1", "urls": ["http://example.s3.com"]}],
}
)
r = client_with_token.post(
"/interop/submit-archive",
json={
"result": aa_metadata,
"public": True,
"author_id": "jerry@gmail.com",
"group_id": "spaceship",
"tags": ["test"],
"url": "http://example.com",
},
)
assert r.status_code == HTTPStatus.CREATED
assert len(r.json()["id"]) == 36
res = (
db_session.query(models.Archive)
.filter(models.Archive.id == r.json()["id"])
.first()
)
assert res.store_until is None
# testing that store_until = None is not comparable with datetime, and will always return False
res = (
db_session.query(models.Archive)
.filter(
models.Archive.id == r.json()["id"],
models.Archive.store_until < datetime.now(),
)
.first()
)
assert res is None

View File

@@ -0,0 +1,268 @@
from datetime import datetime
from http import HTTPStatus
from unittest.mock import MagicMock, patch
from fastapi.testclient import TestClient
from app.shared.constants import STATUS_PENDING
from app.shared.db import models
from app.shared.schemas import TaskResult
from app.web.db.user_state import UserState
from app.web.security import get_user_state
def test_endpoints_no_auth(client, test_no_auth):
test_no_auth(client.post, "/sheet/create")
test_no_auth(client.get, "/sheet/mine")
test_no_auth(client.delete, "/sheet/123-sheet-id")
test_no_auth(client.post, "/sheet/123-sheet-id/archive")
def test_create_sheet_endpoint(app_with_auth, db_session):
client_with_auth = TestClient(app_with_auth)
good_data = {
"id": "123-sheet-id",
"name": "Test Sheet",
"group_id": "spaceship",
"frequency": "daily",
}
# with good data
response = client_with_auth.post("/sheet/create", json=good_data)
assert response.status_code == HTTPStatus.CREATED
j = response.json()
assert datetime.fromisoformat(j.pop("created_at"))
assert datetime.fromisoformat(j.pop("last_url_archived_at"))
assert j.pop("author_id") == "morty@example.com"
assert j == good_data
# already exists
response = client_with_auth.post("/sheet/create", json=good_data)
assert response.status_code == HTTPStatus.BAD_REQUEST
assert response.json() == {
"detail": "Sheet with this ID is already being archived."
}
# bad group
bad_data = good_data.copy()
bad_data["group_id"] = "not a group"
response = client_with_auth.post("/sheet/create", json=bad_data)
assert response.status_code == HTTPStatus.FORBIDDEN
assert response.json() == {
"detail": "User does not have access to this group."
}
# switch to jerry who's got less quota/permissions
app_with_auth.dependency_overrides[get_user_state] = lambda: UserState(
db_session, "jerry@example.com"
)
client_jerry = TestClient(app_with_auth)
# frequency not allowed
jerry_data = good_data.copy()
jerry_data["group_id"] = "animated-characters"
jerry_data["frequency"] = "hourly"
jerry_data["id"] = "jerry-sheet-id"
response = client_jerry.post("/sheet/create", json=jerry_data)
assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
assert response.json() == {
"detail": "Invalid frequency selected for this group."
}
jerry_data["frequency"] = "daily"
# success for the first sheet, bad quota on second
response = client_jerry.post("/sheet/create", json=jerry_data)
assert response.status_code == HTTPStatus.CREATED
response = client_jerry.post("/sheet/create", json=jerry_data)
assert response.status_code == HTTPStatus.TOO_MANY_REQUESTS
assert response.json() == {
"detail": "User has reached their sheet quota for this group."
}
def test_get_user_sheets_endpoint(client_with_auth, db_session):
# no data
response = client_with_auth.get("/sheet/mine")
assert response.status_code == HTTPStatus.OK
assert response.json() == []
# with data
db_session.add(
models.Sheet(
id="123",
name="Test Sheet 1",
author_id="morty@example.com",
group_id="spaceship",
frequency="hourly",
)
)
db_session.commit()
db_session.add_all(
[
models.Sheet(
id="456",
name="Test Sheet 2",
author_id="morty@example.com",
group_id="interdimensional",
frequency="daily",
),
models.Sheet(
id="789",
name="Test Sheet 3",
author_id="rick@example.com",
group_id="interdimensional",
frequency="hourly",
),
]
)
db_session.commit()
response = client_with_auth.get("/sheet/mine")
assert response.status_code == HTTPStatus.OK
r = response.json()
assert isinstance(r, list)
assert len(r) == 2
assert datetime.fromisoformat(r[0].pop("created_at"))
assert datetime.fromisoformat(r[0].pop("last_url_archived_at"))
assert datetime.fromisoformat(r[1].pop("created_at"))
assert datetime.fromisoformat(r[1].pop("last_url_archived_at"))
assert r[0] == {
"id": "123",
"author_id": "morty@example.com",
"frequency": "hourly",
"group_id": "spaceship",
"name": "Test Sheet 1",
}
assert r[1] == {
"id": "456",
"author_id": "morty@example.com",
"frequency": "daily",
"group_id": "interdimensional",
"name": "Test Sheet 2",
}
def test_delete_sheet_endpoint(client_with_auth, db_session):
# missing sheet
response = client_with_auth.delete("/sheet/123-sheet-id")
assert response.status_code == HTTPStatus.OK
assert response.json() == {"id": "123-sheet-id", "deleted": False}
# add sheets for deletion
db_session.add_all(
[
models.Sheet(
id="123-sheet-id",
name="Test Sheet 1",
author_id="morty@example.com",
group_id="interdimensional",
frequency="daily",
),
models.Sheet(
id="456-sheet-id",
name="Test Sheet 2",
author_id="rick@example.com",
group_id="spaceship",
frequency="hourly",
),
]
)
db_session.commit()
# morty can delete his
response = client_with_auth.delete("/sheet/123-sheet-id")
assert response.status_code == HTTPStatus.OK
assert response.json() == {"id": "123-sheet-id", "deleted": True}
# but only once
response = client_with_auth.delete("/sheet/123-sheet-id")
assert response.status_code == HTTPStatus.OK
assert response.json() == {"id": "123-sheet-id", "deleted": False}
# and not Rick's
response = client_with_auth.delete("/sheet/456-sheet-id")
assert response.status_code == HTTPStatus.OK
assert response.json() == {"id": "456-sheet-id", "deleted": False}
class TestArchiveUserSheetEndpoint:
@patch("app.web.routers.sheet.celery", return_value=MagicMock())
def test_normal_flow(self, m_celery, client_with_auth, db_session):
db_session.add(
models.Sheet(
id="123-sheet-id",
name="Test Sheet 1",
author_id="morty@example.com",
group_id="spaceship",
frequency="hourly",
)
)
db_session.commit()
m_signature = MagicMock()
m_signature.apply_async.return_value = TaskResult(
id="123-taskid", status=STATUS_PENDING, result=""
)
m_celery.signature.return_value = m_signature
r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == HTTPStatus.CREATED
assert r.json() == {"id": "123-taskid"}
m_celery.signature.assert_called_once()
m_signature.apply_async.assert_called_once()
def test_token_auth(self, client_with_token, test_no_auth):
test_no_auth(client_with_token.post, "/sheet/123-sheet-id/archive")
def test_missing_data(self, client_with_auth):
r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == HTTPStatus.FORBIDDEN
assert r.json() == {"detail": "No access to this sheet."}
def test_no_access(self, client_with_auth, db_session):
db_session.add(
models.Sheet(
id="123-sheet-id",
name="Test Sheet 1",
author_id="rick@example.com",
group_id="spaceship",
frequency="hourly",
)
)
db_session.commit()
r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == HTTPStatus.FORBIDDEN
assert r.json() == {"detail": "No access to this sheet."}
def test_user_not_in_group(self, client_with_auth, db_session):
db_session.add(
models.Sheet(
id="123-sheet-id",
name="Test Sheet 1",
author_id="morty@example.com",
group_id="interdimensional",
frequency="hourly",
)
)
db_session.commit()
r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == HTTPStatus.FORBIDDEN
assert r.json() == {
"detail": "User does not have access to this group."
}
def test_user_cannot_manually_trigger(self, client_with_auth, db_session):
db_session.add(
models.Sheet(
id="123-sheet-id",
name="Test Sheet 1",
author_id="morty@example.com",
group_id="default",
frequency="hourly",
)
)
db_session.commit()
r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == HTTPStatus.TOO_MANY_REQUESTS
assert r.json() == {
"detail": "User cannot manually trigger sheet archiving in this group."
}

View File

@@ -1,51 +1,53 @@
from http import HTTPStatus
from unittest.mock import patch from unittest.mock import patch
from app.shared.constants import STATUS_FAILURE, STATUS_PENDING, STATUS_SUCCESS
def test_endpoint_task_status_no_auth(client, test_no_auth): def test_endpoint_task_status_no_auth(client, test_no_auth):
test_no_auth(client.get, "/task/test-task-id") test_no_auth(client.get, "/task/test-task-id")
@patch("app.web.endpoints.task.AsyncResult") @patch("app.web.routers.task.AsyncResult")
def test_get_status_success(mock_async_result, client_with_auth): def test_get_status_success(mock_async_result, client_with_auth):
mock_async_result.return_value.status = "SUCCESS" mock_async_result.return_value.status = STATUS_SUCCESS
mock_async_result.return_value.result = {"data": "some result"} mock_async_result.return_value.result = {"data": "some result"}
response = client_with_auth.get("/task/test-task-id") response = client_with_auth.get("/task/test-task-id")
assert response.status_code == 200 assert response.status_code == HTTPStatus.OK
assert response.json() == { assert response.json() == {
"id": "test-task-id", "id": "test-task-id",
"status": "SUCCESS", "status": STATUS_SUCCESS,
"result": {"data": "some result"} "result": {"data": "some result"},
} }
@patch("app.web.endpoints.task.AsyncResult") @patch("app.web.routers.task.AsyncResult")
def test_get_status_failure(mock_async_result, client_with_auth): def test_get_status_failure(mock_async_result, client_with_auth):
mock_async_result.return_value.status = STATUS_FAILURE
mock_async_result.return_value.status = "FAILURE"
mock_async_result.return_value.result = Exception("Some error") mock_async_result.return_value.result = Exception("Some error")
response = client_with_auth.get("/task/test-task-id") response = client_with_auth.get("/task/test-task-id")
assert response.status_code == 200 assert response.status_code == HTTPStatus.OK
assert response.json() == { assert response.json() == {
"id": "test-task-id", "id": "test-task-id",
"status": "FAILURE", "status": STATUS_FAILURE,
"result": {"error": "Some error"} "result": {"error": "Some error"},
} }
@patch("app.web.endpoints.task.AsyncResult") @patch("app.web.routers.task.AsyncResult")
def test_get_status_pending(mock_async_result, client_with_auth): def test_get_status_pending(mock_async_result, client_with_auth):
mock_async_result.return_value.status = "PENDING" mock_async_result.return_value.status = STATUS_PENDING
mock_async_result.return_value.result = None mock_async_result.return_value.result = None
response = client_with_auth.get("/task/test-task-id") response = client_with_auth.get("/task/test-task-id")
assert response.status_code == 200 assert response.status_code == HTTPStatus.OK
assert response.json() == { assert response.json() == {
"id": "test-task-id", "id": "test-task-id",
"status": "PENDING", "status": STATUS_PENDING,
"result": None "result": None,
} }

View File

@@ -0,0 +1,312 @@
import json
from http import HTTPStatus
from unittest.mock import MagicMock, patch
from app.shared import schemas
from app.shared.constants import STATUS_PENDING
from app.shared.db import worker_crud
from app.shared.schemas import ArchiveCreate, TaskResult
from app.web.config import ALLOW_ANY_EMAIL
def test_archive_url_unauthenticated(client, test_no_auth):
test_no_auth(client.post, "/url/archive")
@patch("app.web.routers.url.UserState")
@patch("app.web.routers.url.celery", return_value=MagicMock())
def test_archive_url(m_celery, m2, client_with_auth):
m_signature = MagicMock()
m_signature.apply_async.return_value = TaskResult(
id="123-456-789", status=STATUS_PENDING, result=""
)
m_celery.signature.return_value = m_signature
m_user_state = MagicMock()
m2.return_value = m_user_state
# url is too short
response = client_with_auth.post("/url/archive", json={"url": "bad"})
assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
assert (
response.json()["detail"][0]["msg"]
== "String should have at least 5 characters"
)
m_celery.signature.assert_not_called()
# url is invalid
response = client_with_auth.post(
"/url/archive", json={"url": "example.com"}
)
assert response.status_code == HTTPStatus.BAD_REQUEST
assert response.json()["detail"] == "Invalid URL received."
# valid request
m_user_state.has_quota_max_monthly_urls.return_value = True
m_user_state.has_quota_max_monthly_mbs.return_value = True
response = client_with_auth.post(
"/url/archive", json={"url": "https://example.com"}
)
assert response.status_code == HTTPStatus.CREATED
assert response.json() == {"id": "123-456-789"}
m_celery.signature.assert_called_once()
m_signature.apply_async.assert_called_once()
called_val = m_celery.signature.call_args
assert called_val[0][0] == "create_archive_task"
assert json.loads(called_val[1]["args"][0]) == {
"id": None,
"url": "https://example.com",
"result": None,
"public": False,
"author_id": "rick@example.com",
"group_id": "default",
"tags": None,
"sheet_id": None,
"store_until": None,
"urls": None,
}
m_user_state.has_quota_max_monthly_urls.assert_called_once()
m_user_state.has_quota_max_monthly_mbs.assert_called_once()
m_user_state.in_group.assert_called_once_with("default")
# user is not in group
m_user_state.in_group.return_value = False
response = client_with_auth.post(
"/url/archive",
json={"url": "https://example.com", "group_id": "new-group"},
)
assert response.status_code == HTTPStatus.FORBIDDEN
assert (
response.json()["detail"] == "User does not have access to this group."
)
m_user_state.in_group.assert_called_with("new-group")
# user is in group
m_user_state.in_group.return_value = True
response = client_with_auth.post(
"/url/archive",
json={"url": "https://example.com", "group_id": "spaceship"},
)
assert response.status_code == HTTPStatus.CREATED
assert response.json() == {"id": "123-456-789"}
assert m_celery.signature.call_count == 2
assert m_signature.apply_async.call_count == 2
called_val = m_celery.signature.call_args
assert json.loads(called_val[1]["args"][0])["group_id"] == "spaceship"
m_user_state.in_group.assert_called_with("spaceship")
# user is over monthly URL quota
m_user_state.has_quota_max_monthly_urls.return_value = False
m_user_state.has_quota_max_monthly_mbs.return_value = True
response = client_with_auth.post(
"/url/archive",
json={"url": "https://example.com", "group_id": "spaceship"},
)
assert response.status_code == HTTPStatus.TOO_MANY_REQUESTS
assert (
response.json()["detail"] == "User has reached their monthly URL quota."
)
m_user_state.has_quota_max_monthly_urls.assert_called_with("spaceship")
# user is over monthly MB quota
m_user_state.has_quota_max_monthly_urls.return_value = True
m_user_state.has_quota_max_monthly_mbs.return_value = False
response = client_with_auth.post(
"/url/archive",
json={"url": "https://example.com", "group_id": "spacesuit"},
)
assert response.status_code == HTTPStatus.TOO_MANY_REQUESTS
assert (
response.json()["detail"] == "User has reached their monthly MB quota."
)
m_user_state.has_quota_max_monthly_mbs.assert_called_with("spacesuit")
assert m_celery.signature.call_count == 2
assert m_signature.apply_async.call_count == 2
@patch("app.web.routers.url.UserState")
def test_archive_url_quotas(m1, client_with_auth):
m_user_state = MagicMock()
m1.return_value = m_user_state
# misses on monthly URLs quota
m_user_state.has_quota_max_monthly_urls.return_value = False
response = client_with_auth.post(
"/url/archive", json={"url": "https://example.com"}
)
assert response.status_code == HTTPStatus.TOO_MANY_REQUESTS
assert (
response.json()["detail"] == "User has reached their monthly URL quota."
)
m_user_state.has_quota_max_monthly_urls.assert_called_once()
# misses on monthly MBs quota
m_user_state.has_quota_max_monthly_urls.return_value = True
m_user_state.has_quota_max_monthly_mbs.return_value = False
response = client_with_auth.post(
"/url/archive", json={"url": "https://example.com"}
)
assert response.status_code == HTTPStatus.TOO_MANY_REQUESTS
assert (
response.json()["detail"] == "User has reached their monthly MB quota."
)
m_user_state.has_quota_max_monthly_mbs.assert_called_once()
@patch("app.web.routers.url.celery", return_value=MagicMock())
def test_archive_url_with_api_token(m_celery, client_with_token):
m_signature = MagicMock()
m_signature.apply_async.return_value = TaskResult(
id="123-456-789", status=STATUS_PENDING, result=""
)
m_celery.signature.return_value = m_signature
response = client_with_token.post(
"/url/archive",
json={"url": "https://example.com", "author_id": "someone@example.com"},
)
assert response.status_code == HTTPStatus.CREATED
assert response.json() == {"id": "123-456-789"}
m_celery.signature.assert_called_once()
m_signature.apply_async.assert_called_once()
called_val = m_celery.signature.call_args
assert called_val[0][0] == "create_archive_task"
assert json.loads(called_val[1]["args"][0]) == {
"id": None,
"url": "https://example.com",
"result": None,
"public": False,
"author_id": "someone@example.com",
"group_id": "default",
"tags": None,
"sheet_id": None,
"store_until": None,
"urls": None,
}
# missing id should use ALLOW_ANY_EMAIL
response = client_with_token.post(
"/url/archive", json={"url": "https://example.com", "author_id": None}
)
assert response.status_code == HTTPStatus.CREATED
called_val = m_celery.signature.call_args
assert called_val[0][0] == "create_archive_task"
assert json.loads(called_val[1]["args"][0]) == {
"id": None,
"url": "https://example.com",
"result": None,
"public": False,
"author_id": ALLOW_ANY_EMAIL,
"group_id": "default",
"tags": None,
"sheet_id": None,
"store_until": None,
"urls": None,
}
def test_search_by_url_unauthenticated(client, test_no_auth):
test_no_auth(client.get, "/url/search")
def test_search_by_url(client_with_auth, client_with_token, db_session):
# tests the search endpoint, including through some db data for the endpoint params
response = client_with_auth.get("/url/search")
assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
assert response.json()["detail"][0]["msg"] == "Field required"
response = client_with_auth.get("/url/search?url=https://example.com")
assert response.status_code == HTTPStatus.OK
assert response.json() == []
for i in range(11):
worker_crud.create_archive(
db_session,
ArchiveCreate(
id=f"url-456-{i}",
url="https://example.com"
if i < 10
else "https://something-else.com",
result={},
public=True,
author_id="rick@example.com",
),
[],
[],
)
# NB: this insertion is too fast for the ordering to be correct as they are within the same second
response = client_with_auth.get("/url/search?url=https://example.com")
assert response.status_code == HTTPStatus.OK
assert len(j := response.json()) == 10
assert "url-456-0" in [i["id"] for i in j]
assert "url-456-9" in [i["id"] for i in j]
assert "url-456-10" not in [i["id"] for i in j]
assert j[0].keys() == schemas.ArchiveResult.model_fields.keys()
response = client_with_auth.get(
"/url/search?url=https://example.com&limit=5"
)
assert response.status_code == HTTPStatus.OK
assert len(response.json()) == 5
response = client_with_auth.get(
"/url/search?url=https://example.com&skip=5&limit=2"
)
assert response.status_code == HTTPStatus.OK
assert len(response.json()) == 2
response = client_with_auth.get(
"/url/search?url=https://example.com&archived_before=2010-01-01"
)
assert response.status_code == HTTPStatus.OK
assert len(response.json()) == 0
response = client_with_auth.get(
"/url/search?url=https://example.com&archived_after=2010-01-01"
)
assert response.status_code == HTTPStatus.OK
assert len(response.json()) == 10
# API token will also work
response = client_with_token.get(
"/url/search?url=https://example.com&archived_after=2010-01-01"
)
assert response.status_code == HTTPStatus.OK
assert len(response.json()) == 10
@patch("app.web.routers.url.UserState")
def test_search_no_read_access(mock_user_state, client_with_auth):
mock_user_state.return_value.read = False
mock_user_state.return_value.read_public = False
response = client_with_auth.get("/url/search?url=https://example.com")
assert response.status_code == HTTPStatus.FORBIDDEN
assert response.json() == {"detail": "User does not have read access."}
def test_delete_task_unauthenticated(client, test_no_auth):
test_no_auth(client.delete, "/url/123-456-789")
def test_delete_task(client_with_auth, db_session):
response = client_with_auth.delete("/url/delete-123-456-789")
assert response.status_code == HTTPStatus.OK
assert response.json() == {"id": "delete-123-456-789", "deleted": False}
worker_crud.create_archive(
db_session,
ArchiveCreate(
id="delete-123-456-789",
url="https://example.com",
result={},
public=True,
author_id="morty@example.com",
),
[],
[],
)
response = client_with_auth.delete("/url/delete-123-456-789")
assert response.status_code == HTTPStatus.OK
assert response.json() == {"id": "delete-123-456-789", "deleted": True}

View File

@@ -1,25 +1,49 @@
import os import os
import shutil
from http import HTTPStatus
from unittest.mock import patch from unittest.mock import patch
import alembic.config
import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
import shutil from app.web.main import app_factory
from app.web.utils.metrics import EXCEPTION_COUNTER
import pytest
def test_lifespan(app): def test_lifespan(app):
with TestClient(app) as client: with TestClient(app) as client:
r = client.get("/health") r = client.get("/health")
assert r.status_code == 200 assert r.status_code == HTTPStatus.OK
assert r.json() == {"status": "ok"} assert r.json() == {"status": "ok"}
def test_alembic(db_session):
import alembic.config
alembic.config.main(argv=['--raiseerr', 'upgrade', 'head'])
alembic.config.main(argv=['--raiseerr', 'downgrade', 'base'])
@patch("app.web.endpoints.url.crud.soft_delete_archive", side_effect=Exception('mocked error')) def test_alembic(db_session):
alembic.config.main(
argv=[
"-c",
"./app/migrations/alembic.ini",
"--raiseerr",
"upgrade",
"head",
]
)
alembic.config.main(
argv=[
"-c",
"./app/migrations/alembic.ini",
"--raiseerr",
"downgrade",
"base",
]
)
@patch(
"app.web.routers.url.crud.soft_delete_archive",
side_effect=Exception("mocked error"),
)
def test_logging_middleware(m1, client_with_auth): def test_logging_middleware(m1, client_with_auth):
from app.web.utils.metrics import EXCEPTION_COUNTER
assert len(EXCEPTION_COUNTER.collect()[0].samples) == 0 assert len(EXCEPTION_COUNTER.collect()[0].samples) == 0
with pytest.raises(Exception, match="mocked error"): with pytest.raises(Exception, match="mocked error"):
client_with_auth.delete("/url/123") client_with_auth.delete("/url/123")
@@ -36,13 +60,13 @@ def test_serve_local_archive_logic(get_settings):
try: try:
# modify the settings # modify the settings
get_settings.SERVE_LOCAL_ARCHIVE = "/app/local_archive_test" get_settings.SERVE_LOCAL_ARCHIVE = "/app/local_archive_test"
from app.web.main import app_factory
app = app_factory(get_settings) app = app_factory(get_settings)
# test # test
client = TestClient(app) client = TestClient(app)
r = client.get("/app/local_archive_test/temp.txt") r = client.get("/app/local_archive_test/temp.txt")
assert r.status_code == 200 assert r.status_code == HTTPStatus.OK
assert r.text == "test" assert r.text == "test"
finally: finally:
# cleanup # cleanup

View File

@@ -1,101 +1,168 @@
from http import HTTPStatus
from unittest import mock from unittest import mock
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest
from fastapi import HTTPException from fastapi import HTTPException
from fastapi.security import HTTPAuthorizationCredentials from fastapi.security import HTTPAuthorizationCredentials
import pytest
from app.web.config import ALLOW_ANY_EMAIL from app.web.config import ALLOW_ANY_EMAIL
from app.web.db.user_state import UserState
from app.web.security import (
authenticate_user,
get_token_or_user_auth,
get_user_auth,
get_user_state,
secure_compare,
token_api_key_auth,
)
def test_secure_compare(): def test_secure_compare():
from app.web.security import secure_compare
assert secure_compare("test", "test") assert secure_compare("test", "test")
assert not secure_compare("test", "test2") assert not secure_compare("test", "test2")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_token_or_user_auth_with_api(): async def test_get_token_or_user_auth_with_api():
from app.web.security import get_token_or_user_auth mock_api = HTTPAuthorizationCredentials(
mock_api = HTTPAuthorizationCredentials(scheme="lorem", credentials="this_is_the_test_api_token") scheme="lorem", credentials="this_is_the_test_api_token"
)
assert await get_token_or_user_auth(mock_api) == ALLOW_ANY_EMAIL assert await get_token_or_user_auth(mock_api) == ALLOW_ANY_EMAIL
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_token_or_user_auth_with_user(): async def test_get_token_or_user_auth_with_user():
from app.web.security import get_token_or_user_auth bad_user = HTTPAuthorizationCredentials(
bad_user = HTTPAuthorizationCredentials(scheme="ipsum", credentials="invalid") scheme="ipsum", credentials="invalid"
e: pytest.ExceptionInfo = None )
with pytest.raises(HTTPException) as e: with pytest.raises(HTTPException) as e:
await get_token_or_user_auth(bad_user) await get_token_or_user_auth(bad_user)
assert e.value.status_code == 401 assert e.value.status_code == HTTPStatus.UNAUTHORIZED
assert e.value.detail == "invalid access_token" assert e.value.detail == "invalid access_token"
@patch("app.web.security.authenticate_user", return_value=(True, "summer@example.com")) @patch(
"app.web.security.authenticate_user",
return_value=(True, "summer@example.com"),
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_user_auth(m1): async def test_get_user_auth(m1):
from app.web.security import get_user_auth good_user = HTTPAuthorizationCredentials(
good_user = HTTPAuthorizationCredentials(scheme="ipsum", credentials="valid-and-good") scheme="ipsum", credentials="valid-and-good"
)
assert await get_user_auth(good_user) == "summer@example.com" assert await get_user_auth(good_user) == "summer@example.com"
@patch("app.web.security.secure_compare", return_value=False) @patch("app.web.security.secure_compare", return_value=False)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_token_api_key_auth_exception(m1): async def test_token_api_key_auth_exception(m1):
from app.web.security import token_api_key_auth
e: pytest.ExceptionInfo = None
with pytest.raises(HTTPException) as e: with pytest.raises(HTTPException) as e:
await token_api_key_auth(HTTPAuthorizationCredentials(scheme="ipsum", credentials="does-not-matter"), auto_error=True) await token_api_key_auth(
assert e.value.status_code == 401 HTTPAuthorizationCredentials(
scheme="ipsum", credentials="does-not-matter"
),
auto_error=True,
)
assert e.value.status_code == HTTPStatus.UNAUTHORIZED
assert e.value.detail == "Wrong auth credentials" assert e.value.detail == "Wrong auth credentials"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_authenticate_user(): async def test_authenticate_user():
from app.web.security import authenticate_user
assert authenticate_user("test") == (False, "invalid access_token") assert authenticate_user("test") == (False, "invalid access_token")
assert authenticate_user(123) == (False, "invalid access_token") assert authenticate_user(123) == (False, "invalid access_token")
with patch("app.web.security.requests.get") as mock_get: with patch("app.web.security.requests.get") as mock_get:
# bad response from oauth2 # bad response from oauth2
mock_get.return_value.status_code = 403 mock_get.return_value.status_code = HTTPStatus.FORBIDDEN
assert authenticate_user("this-will-call-requests") == (False, "invalid token") assert authenticate_user("this-will-call-requests") == (
False,
"invalid token",
)
assert mock_get.call_count == 1 assert mock_get.call_count == 1
# 200 but invalid json # 200 but invalid json
mock_get.return_value.status_code = 200 mock_get.return_value.status_code = HTTPStatus.OK
assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID") assert authenticate_user("this-will-call-requests") == (
False,
"token does not belong to valid APP_ID",
)
assert mock_get.call_count == 2 assert mock_get.call_count == 2
# 200 but invalid azp and aud # 200 but invalid azp and aud
mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "not_an_app"} mock_get.return_value.json.return_value = {
assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID") "email": "summer@example.com",
"azp": "not_an_app",
}
assert authenticate_user("this-will-call-requests") == (
False,
"token does not belong to valid APP_ID",
)
mock_get.return_value.json.return_value = {"email": "summer@example.com", "aud": "not_an_app"} mock_get.return_value.json.return_value = {
assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID") "email": "summer@example.com",
"aud": "not_an_app",
}
assert authenticate_user("this-will-call-requests") == (
False,
"token does not belong to valid APP_ID",
)
mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "not_an_app", "aud": "not_an_app"} mock_get.return_value.json.return_value = {
assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID") "email": "summer@example.com",
"azp": "not_an_app",
"aud": "not_an_app",
}
assert authenticate_user("this-will-call-requests") == (
False,
"token does not belong to valid APP_ID",
)
# blocked email # blocked email
mock_get.return_value.json.return_value = {"email": "blocked@example.com", "azp": "test_app_id_1", "aud": "not_an_app"} mock_get.return_value.json.return_value = {
assert authenticate_user("this-will-call-requests") == (False, "email 'blocked@example.com' not allowed") "email": "blocked@example.com",
"azp": "test_app_id_1",
"aud": "not_an_app",
}
assert authenticate_user("this-will-call-requests") == (
False,
"email 'blocked@example.com' not allowed",
)
# not verified # not verified
mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "not_an_app", "aud": "test_app_id_1"} mock_get.return_value.json.return_value = {
assert authenticate_user("this-will-call-requests") == (False, "email 'summer@example.com' not verified") "email": "summer@example.com",
"azp": "not_an_app",
"aud": "test_app_id_1",
}
assert authenticate_user("this-will-call-requests") == (
False,
"email 'summer@example.com' not verified",
)
# token expired # token expired
mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "test_app_id_2", "email_verified": "true"} mock_get.return_value.json.return_value = {
assert authenticate_user("this-will-call-requests") == (False, "Token expired") "email": "summer@example.com",
"azp": "test_app_id_2",
"email_verified": "true",
}
assert authenticate_user("this-will-call-requests") == (
False,
"Token expired",
)
# 200 and valid azp and aup and verified # 200 and valid azp and aup and verified
mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "test_app_id_2", "email_verified": "true", "expires_in": 100} mock_get.return_value.json.return_value = {
assert authenticate_user("this-will-call-requests") == (True, "summer@example.com") "email": "summer@example.com",
"azp": "test_app_id_2",
"email_verified": "true",
"expires_in": 100,
}
assert authenticate_user("this-will-call-requests") == (
True,
"summer@example.com",
)
assert mock_get.call_count == 9 assert mock_get.call_count == 9
@@ -104,6 +171,7 @@ async def test_authenticate_user():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_authenticate_user_with_id_token(m_init): async def test_authenticate_user_with_id_token(m_init):
from firebase_admin import exceptions from firebase_admin import exceptions
from app.web.security import authenticate_user from app.web.security import authenticate_user
with pytest.raises(ValueError) as e: with pytest.raises(ValueError) as e:
@@ -113,12 +181,20 @@ async def test_authenticate_user_with_id_token(m_init):
with patch("app.web.security.auth.verify_id_token") as mock_verify: with patch("app.web.security.auth.verify_id_token") as mock_verify:
# missing email # missing email
mock_verify.return_value = {"email": None} mock_verify.return_value = {"email": None}
assert authenticate_user("fake_token") == (False, "email not found in token") assert authenticate_user("fake_token") == (
False,
"email not found in token",
)
assert mock_verify.call_count == 1 assert mock_verify.call_count == 1
# blocked email # blocked email
mock_verify.return_value = {"email": "blocked@example.com", } mock_verify.return_value = {
assert authenticate_user("fake_token") == (False, "email 'blocked@example.com' not allowed") "email": "blocked@example.com",
}
assert authenticate_user("fake_token") == (
False,
"email 'blocked@example.com' not allowed",
)
assert mock_verify.call_count == 2 assert mock_verify.call_count == 2
# valid email # valid email
@@ -132,17 +208,16 @@ async def test_authenticate_user_with_id_token(m_init):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_authenticate_user_exception(): async def test_authenticate_user_exception():
from app.web.security import authenticate_user
with patch("app.web.security.requests.get") as mock_get: with patch("app.web.security.requests.get") as mock_get:
mock_get.return_value.status_code = 200 mock_get.return_value.status_code = HTTPStatus.OK
mock_get.return_value.json.side_effect = Exception("mocked error") mock_get.return_value.json.side_effect = Exception("mocked error")
assert authenticate_user("this-will-call-requests") == (False, "exception occurred") assert authenticate_user("this-will-call-requests") == (
False,
"exception occurred",
)
def test_get_user_state(): def test_get_user_state():
from app.web.security import get_user_state
from app.web.db.user_state import UserState
mock_session = Mock() mock_session = Mock()
test_email = "test@example.com" test_email = "test@example.com"

View File

@@ -1,29 +1,47 @@
from datetime import datetime from datetime import datetime
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from app.shared.db import models
from app.shared import schemas
from auto_archiver.core import Media, Metadata from auto_archiver.core import Media, Metadata
from app.shared import constants, schemas
from app.shared.db import models
from app.web.utils.misc import get_all_urls
from app.worker.main import create_archive_task, create_sheet_task
class Test_create_archive_task():
class TestCreateArchiveTask:
URL = "https://example-live.com" URL = "https://example-live.com"
archive = schemas.ArchiveCreate(url=URL, tags=["tag-celery"], public=True, author_id="rick@example.com", group_id="interstellar") archive = schemas.ArchiveCreate(
url=URL,
tags=["tag-celery"],
public=True,
author_id="rick@example.com",
group_id="interstellar",
)
@patch("app.worker.main.ArchivingOrchestrator") @patch("app.worker.main.ArchivingOrchestrator")
@patch("app.worker.main.get_all_urls", return_value=[]) @patch("app.worker.main.get_all_urls", return_value=[])
@patch("app.worker.main.insert_result_into_db") @patch("app.worker.main.insert_result_into_db")
@patch("app.worker.main.get_store_until", return_value=datetime.now()) @patch("app.worker.main.get_store_until", return_value=datetime.now())
@patch("app.worker.main.get_orchestrator_args", return_value=["arg1", "arg2"]) @patch(
"app.worker.main.get_orchestrator_args", return_value=["arg1", "arg2"]
)
@patch("celery.app.task.Task.request") @patch("celery.app.task.Task.request")
def test_success(self, m_req, m_args, m_store, m_insert, m_urls, m_orchestrator, db_session): def test_success(
from app.worker.main import create_archive_task self,
m_req,
m_args,
m_store,
m_insert,
m_urls,
m_orchestrator,
db_session,
):
m_req.id = "this-just-in" m_req.id = "this-just-in"
m_orchestrator.return_value.feed.return_value = iter([Metadata().set_url(self.URL).success()]) m_orchestrator.return_value.feed.return_value = iter(
[Metadata().set_url(self.URL).success()]
)
task = create_archive_task(self.archive.model_dump_json()) task = create_archive_task(self.archive.model_dump_json())
@@ -39,15 +57,15 @@ class Test_create_archive_task():
assert len(task["media"]) == 0 assert len(task["media"]) == 0
def test_raise_invalid(self): def test_raise_invalid(self):
from app.worker.main import create_archive_task with pytest.raises(Exception) as _:
with pytest.raises(Exception):
create_archive_task(self.archive.model_dump_json()) create_archive_task(self.archive.model_dump_json())
@patch("app.worker.main.ArchivingOrchestrator") @patch("app.worker.main.ArchivingOrchestrator")
@patch("app.worker.main.get_orchestrator_args") @patch("app.worker.main.get_orchestrator_args")
def test_raise_db_error(self, m_args, m_orchestrator): def test_raise_db_error(self, m_args, m_orchestrator):
from app.worker.main import create_archive_task m_orchestrator.return_value.feed.side_effect = Exception(
m_orchestrator.return_value.feed.side_effect = Exception("Orchestrator failed") "Orchestrator failed"
)
with pytest.raises(Exception) as e: with pytest.raises(Exception) as e:
create_archive_task(self.archive.model_dump_json()) create_archive_task(self.archive.model_dump_json())
@@ -59,7 +77,6 @@ class Test_create_archive_task():
@patch("app.worker.main.insert_result_into_db", return_value=None) @patch("app.worker.main.insert_result_into_db", return_value=None)
@patch("app.worker.main.get_orchestrator_args") @patch("app.worker.main.get_orchestrator_args")
def test_raise_empty_result(self, m_args, m_insert, m_orchestrator): def test_raise_empty_result(self, m_args, m_insert, m_orchestrator):
from app.worker.main import create_archive_task
m_orchestrator.return_value.feed.return_value = iter([None]) m_orchestrator.return_value.feed.return_value = iter([None])
with pytest.raises(Exception) as e: with pytest.raises(Exception) as e:
@@ -68,61 +85,83 @@ class Test_create_archive_task():
m_orchestrator.return_value.feed.assert_called_once() m_orchestrator.return_value.feed.assert_called_once()
class Test_create_sheet_task(): class TestCreateSheetTask:
URL = "https://example-live.com" URL = "https://example-live.com"
sheet = schemas.SubmitSheet(sheet_id="123", author_id="rick@example.com", group_id="interstellar", tags=["spaceship"]) sheet = schemas.SubmitSheet(
sheet_id="123",
author_id="rick@example.com",
group_id="interstellar",
tags=["spaceship"],
)
@patch("app.worker.main.get_all_urls", return_value=[]) @patch("app.worker.main.get_all_urls", return_value=[])
@patch("app.worker.main.ArchivingOrchestrator") @patch("app.worker.main.ArchivingOrchestrator")
@patch("app.worker.main.models.generate_uuid", return_value="constant-uuid") @patch("app.worker.main.models.generate_uuid", return_value="constant-uuid")
@patch("app.worker.main.get_store_until", return_value=datetime.now()) @patch("app.worker.main.get_store_until", return_value=datetime.now())
@patch("app.worker.main.get_orchestrator_args") @patch("app.worker.main.get_orchestrator_args")
def test_success(self, m_args, m_store, m_uuid, m_orchestrator, m_urls, db_session): def test_success(
from app.worker.main import create_sheet_task self, m_args, m_store, m_uuid, m_orchestrator, m_urls, db_session
):
assert db_session.query(models.Archive).filter(models.Archive.url == self.URL).count() == 0 assert (
db_session.query(models.Archive)
.filter(models.Archive.url == self.URL)
.count()
== 0
)
mock_metadata = Metadata().set_url(self.URL).success() mock_metadata = Metadata().set_url(self.URL).success()
mock_metadata.add_media(Media("fn1.txt", urls=["outcome1.com"])) mock_metadata.add_media(Media("fn1.txt", urls=["outcome1.com"]))
m_orchestrator.return_value.feed.return_value = iter([False, mock_metadata, mock_metadata]) m_orchestrator.return_value.feed.return_value = iter(
[False, mock_metadata, mock_metadata]
)
res = create_sheet_task(self.sheet.model_dump_json()) res = create_sheet_task(self.sheet.model_dump_json())
m_args.assert_called_once_with("interstellar", True, ["--gsheet_feeder.sheet_id", "123"]) m_args.assert_called_once_with(
"interstellar", True, [constants.SHEET_ID, "123"]
)
m_orchestrator.return_value.setup.assert_called_once() m_orchestrator.return_value.setup.assert_called_once()
m_orchestrator.return_value.feed.assert_called_once() m_orchestrator.return_value.feed.assert_called_once()
m_store.assert_called_with("interstellar") m_store.assert_called_with("interstellar")
m_store.call_count == 2 assert m_store.call_count == 2
m_uuid.call_count == 2 assert m_uuid.call_count == 2
assert type(res) == dict assert isinstance(res, dict)
assert res["stats"]["archived"] == 1 assert res["stats"]["archived"] == 1
assert res["stats"]["failed"] == 1 assert res["stats"]["failed"] == 1
assert len(res["stats"]["errors"]) == 1 assert len(res["stats"]["errors"]) == 1
assert res["sheet_id"] == "123" assert res["sheet_id"] == "123"
assert res["success"] assert res["success"]
assert type(res["time"]) == datetime assert isinstance(res["time"], datetime)
# query created archive entry # query created archive entry
inserted = db_session.query(models.Archive).filter(models.Archive.url == self.URL).one() inserted = (
db_session.query(models.Archive)
.filter(models.Archive.url == self.URL)
.one()
)
assert inserted is not None assert inserted is not None
assert inserted.url == self.URL assert inserted.url == self.URL
assert len(inserted.tags) == 1 assert len(inserted.tags) == 1
assert inserted.tags[0].id == "spaceship" assert inserted.tags[0].id == "spaceship"
assert inserted.group_id == "interstellar" assert inserted.group_id == "interstellar"
assert inserted.author_id == "rick@example.com" assert inserted.author_id == "rick@example.com"
assert inserted.public == False assert inserted.public is False
def test_get_all_urls(db_session): def test_get_all_urls(db_session):
from app.worker.main import get_all_urls
meta = Metadata().set_url("https://example.com") meta = Metadata().set_url("https://example.com")
m1 = meta.add_media(Media("fn1.txt", urls=["outcome1.com"])) m1 = meta.add_media(Media("fn1.txt", urls=["outcome1.com"]))
m2 = meta.add_media(Media("fn2.txt", urls=["outcome2.com"])) m2 = meta.add_media(Media("fn2.txt", urls=["outcome2.com"]))
m3 = meta.add_media(Media("fn3.txt", urls=["outcome3.com"])) m3 = meta.add_media(Media("fn3.txt", urls=["outcome3.com"]))
m1.set("screenshot", Media("screenshot.png", urls=["screenshot.com"])) m1.set("screenshot", Media("screenshot.png", urls=["screenshot.com"]))
m2.set("thumbnails", [Media("thumb1.png", urls=["thumb1.com"]), Media("thumb2.png", urls=["thumb2.com"])]) m2.set(
"thumbnails",
[
Media("thumb1.png", urls=["thumb1.com"]),
Media("thumb2.png", urls=["thumb2.com"]),
],
)
m3.set("ssl_data", Media("ssl_data.txt", urls=["ssl_data.com"]).to_dict()) m3.set("ssl_data", Media("ssl_data.txt", urls=["ssl_data.com"]).to_dict())
m3.set("bad_data", {"bad": "dict is ignored"}) m3.set("bad_data", {"bad": "dict is ignored"})

View File

@@ -1,3 +1,4 @@
from app.web.main import app_factory from app.web.main import app_factory
app = app_factory app = app_factory

View File

@@ -1,4 +1,4 @@
VERSION = "0.9.4" VERSION = "0.10.0"
API_DESCRIPTION = """ API_DESCRIPTION = """
#### API for the Auto-Archiver project, a tool to archive web pages and Google Sheets. #### API for the Auto-Archiver project, a tool to archive web pages and Google Sheets.
@@ -8,7 +8,10 @@ API_DESCRIPTION = """
- You can use this API to archive single URLs or entire Google Sheets. - You can use this API to archive single URLs or entire Google Sheets.
- Once you submit a URL or Sheet for archiving, the API will return a task_id that you can use to check the status of the archiving process. It works asynchronously. - Once you submit a URL or Sheet for archiving, the API will return a task_id that you can use to check the status of the archiving process. It works asynchronously.
""" """
BREAKING_CHANGES = {"minVersion": "0.4.0", "message": "The latest update has breaking changes, please update the extension to the most recent version."} BREAKING_CHANGES = {
"minVersion": "0.4.0",
"message": "The latest update has breaking changes, please update the extension to the most recent version.",
}
# changing this will corrupt the database logic # changing this will corrupt the database logic
ALLOW_ANY_EMAIL = "*" ALLOW_ANY_EMAIL = "*"

View File

@@ -1,18 +1,30 @@
from collections import defaultdict from collections import defaultdict
from functools import lru_cache
from sqlalchemy.orm import Session, load_only
from sqlalchemy import Column, or_, func, select
from loguru import logger
from datetime import datetime, timedelta from datetime import datetime, timedelta
from sqlalchemy.ext.asyncio import AsyncSession from typing import Any, Type
from cachetools import LRUCache, cached from cachetools import LRUCache, cached
from cachetools.keys import hashkey from cachetools.keys import hashkey
from loguru import logger
from sqlalchemy import (
Column,
ColumnElement,
ScalarResult,
false,
func,
not_,
or_,
select,
true,
)
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session, load_only
from app.web.config import ALLOW_ANY_EMAIL
from app.shared.db import models from app.shared.db import models
from app.shared.db.models import Archive, Group
from app.shared.settings import get_settings from app.shared.settings import get_settings
from app.shared.user_groups import UserGroups from app.shared.user_groups import UserGroups
from app.shared.utils.misc import fnv1a_hash_mod from app.shared.utils.misc import fnv1a_hash_mod
from app.web.config import ALLOW_ANY_EMAIL
from app.web.utils.misc import convert_priority_to_queue_dict from app.web.utils.misc import convert_priority_to_queue_dict
@@ -22,24 +34,48 @@ DATABASE_QUERY_LIMIT = get_settings().DATABASE_QUERY_LIMIT
def get_limit(user_limit: int): def get_limit(user_limit: int):
return max(1, min(user_limit, DATABASE_QUERY_LIMIT)) return max(1, min(user_limit, DATABASE_QUERY_LIMIT))
# --------------- TASK = Archive # --------------- TASK = Archive
def base_query(db: Session): def base_query(db: Session):
# NOTE: load_only is for optimization and not obfuscation, use .with_entities() if needed # NOTE: load_only is for optimization and not obfuscation, use
return db.query(models.Archive)\ # .with_entities() if needed
.filter(models.Archive.deleted == False)\ return (
.options(load_only(models.Archive.id, models.Archive.created_at, models.Archive.url, models.Archive.result, models.Archive.store_until)) db.query(models.Archive)
.filter(not_(models.Archive.deleted))
.options(
load_only(
models.Archive.id,
models.Archive.created_at,
models.Archive.url,
models.Archive.result,
models.Archive.store_until,
)
)
)
def search_archives_by_url(db: Session, url: str, email: str, read_groups: bool | set[str], read_public: bool, skip: int = 0, limit: int = 100, archived_after: datetime = None, archived_before: datetime = None, absolute_search: bool = False) -> list[models.Archive]: def search_archives_by_url(
# searches for partial URLs, if email is * no ownership (or read/read_public) filtering happens db: Session,
url: str,
email: str,
read_groups: bool | set[str],
read_public: bool,
skip: int = 0,
limit: int = 100,
archived_after: datetime = None,
archived_before: datetime = None,
absolute_search: bool = False,
) -> list[Type[Archive]]:
# searches for partial URLs, if email is * no ownership
# (or read/read_public) filtering happens
query = base_query(db) query = base_query(db)
if email != ALLOW_ANY_EMAIL: if email != ALLOW_ANY_EMAIL:
or_filters = [models.Archive.author_id == email] or_filters = [models.Archive.author_id == email]
if read_public: if read_public:
or_filters.append(models.Archive.public == True) or_filters.append(models.Archive.public.is_(true()))
if read_groups == True: if read_groups is True:
or_filters.append(models.Archive.group_id.isnot(None)) or_filters.append(models.Archive.group_id.isnot(None))
else: else:
or_filters.append(models.Archive.group_id.in_(read_groups)) or_filters.append(models.Archive.group_id.in_(read_groups))
@@ -47,21 +83,43 @@ def search_archives_by_url(db: Session, url: str, email: str, read_groups: bool
if absolute_search: if absolute_search:
query = query.filter(models.Archive.url == url) query = query.filter(models.Archive.url == url)
else: else:
query = query.filter(models.Archive.url.like(f'%{url}%')) query = query.filter(models.Archive.url.like(f"%{url}%"))
if archived_after: if archived_after:
query = query.filter(models.Archive.created_at > archived_after) query = query.filter(models.Archive.created_at > archived_after)
if archived_before: if archived_before:
query = query.filter(models.Archive.created_at < archived_before) query = query.filter(models.Archive.created_at < archived_before)
return query.order_by(models.Archive.created_at.desc()).offset(skip).limit(get_limit(limit)).all() return (
query.order_by(models.Archive.created_at.desc())
.offset(skip)
.limit(get_limit(limit))
.all()
)
def search_archives_by_email(db: Session, email: str, skip: int = 0, limit: int = 100): def search_archives_by_email(
return base_query(db).filter(models.Archive.author_id == email).order_by(models.Archive.created_at.desc()).offset(skip).limit(get_limit(limit)).all() db: Session, email: str, skip: int = 0, limit: int = 100
):
return (
base_query(db)
.filter(models.Archive.author_id == email)
.order_by(models.Archive.created_at.desc())
.offset(skip)
.limit(get_limit(limit))
.all()
)
def soft_delete_archive(db: Session, id: str, email: str) -> bool: def soft_delete_archive(db: Session, id: str, email: str) -> bool:
# TODO: implement hard-delete with cronjob that deletes from S3 # TODO: implement hard-delete with cronjob that deletes from S3
db_archive = db.query(models.Archive).filter(models.Archive.id == id, models.Archive.author_id == email, models.Archive.deleted == False).first() db_archive = (
db.query(models.Archive)
.filter(
models.Archive.id == id,
models.Archive.author_id == email,
models.Archive.deleted.is_(false()),
)
.first()
)
if db_archive: if db_archive:
db_archive.deleted = True db_archive.deleted = True
db.commit() db.commit()
@@ -82,22 +140,29 @@ def count_users(db: Session):
def count_by_user_since(db: Session, seconds_delta: int = 15): def count_by_user_since(db: Session, seconds_delta: int = 15):
time_threshold = datetime.now() - timedelta(seconds=seconds_delta) time_threshold = datetime.now() - timedelta(seconds=seconds_delta)
return db.query(models.Archive.author_id, func.count().label('total'))\ return (
.filter(models.Archive.created_at >= time_threshold)\ db.query(models.Archive.author_id, func.count().label("total"))
.group_by(models.Archive.author_id)\ .filter(models.Archive.created_at >= time_threshold)
.order_by(func.count().desc())\ .group_by(models.Archive.author_id)
.limit(500).all() .order_by(func.count().desc())
.limit(500)
.all()
)
async def find_by_store_until(db: AsyncSession, store_until_is_before: datetime) -> list[models.Archive]: async def find_by_store_until(
db: AsyncSession, store_until_is_before: datetime
) -> ScalarResult[Archive]:
res = await db.execute( res = await db.execute(
select(models.Archive) select(models.Archive).filter(
.filter(models.Archive.deleted == False, models.Archive.store_until < store_until_is_before) models.Archive.deleted.is_(false()),
models.Archive.store_until < store_until_is_before,
)
) )
return res.scalars() return res.scalars()
async def soft_delete_expired_archives(db: AsyncSession) -> dict: async def soft_delete_expired_archives(db: AsyncSession) -> int:
to_delete = await find_by_store_until(db, datetime.now()) to_delete = await find_by_store_until(db, datetime.now())
counter = 0 counter = 0
for archive in to_delete: for archive in to_delete:
@@ -105,47 +170,86 @@ async def soft_delete_expired_archives(db: AsyncSession) -> dict:
counter += 1 counter += 1
await db.commit() await db.commit()
return counter return counter
# --------------- TAG # --------------- TAG
async def get_group_priority_async(db: AsyncSession, group_id: str) -> dict: async def get_group_priority_async(db: AsyncSession, group_id: str) -> dict:
db_group = await db.get(models.Group, group_id) db_group = await db.get(models.Group, group_id)
priority = db_group.permissions.get("priority", "low") if db_group else "low" priority = (
db_group.permissions.get("priority", "low") if db_group else "low"
)
return convert_priority_to_queue_dict(priority) return convert_priority_to_queue_dict(priority)
@cached(cache=LRUCache(maxsize=128), key=lambda db, email: hashkey(email)) @cached(cache=LRUCache(maxsize=128), key=lambda db, email: hashkey(email))
def get_user_group_names(db: Session, email: str) -> list[str]: def get_user_group_names(
db: Session, email: str
) -> list[Any] | list[ColumnElement[Any]]:
""" """
given an email retrieves the user groups from the DB and then the email-domain groups from a global variable, the email does not need to belong to an existing user. given an email retrieves the user groups from the DB and then the
email-domain groups from a global variable, the email does not need to
belong to an existing user.
""" """
# TODO: the read: [group1, group2] permissions don't currently work # TODO: the read: [group1, group2] permissions don't currently work
if not email or not len(email) or "@" not in email: return [] if not email or not len(email) or "@" not in email:
return []
# get user groups # get user groups
user_groups = db.query(models.association_table_user_groups).filter_by(user_id=email).with_entities(Column("group_id")).all() user_groups = (
db.query(models.association_table_user_groups)
.filter_by(user_id=email)
.with_entities(Column("group_id"))
.all()
)
user_level_groups_names = [g[0] for g in user_groups] user_level_groups_names = [g[0] for g in user_groups]
# get domain groups # get domain groups
domain = email.split('@')[1] domain = email.split("@")[1]
domain_level_groups = db.query(models.Group.id).filter(models.Group.domains.contains(domain)).with_entities(Column("id")).all() domain_level_groups = (
db.query(models.Group.id)
.filter(models.Group.domains.contains(domain))
.with_entities(Column("id"))
.all()
)
domain_level_groups_names = [g[0] for g in domain_level_groups] domain_level_groups_names = [g[0] for g in domain_level_groups]
return list(set(user_level_groups_names + domain_level_groups_names)) return list(set(user_level_groups_names + domain_level_groups_names))
def get_user_groups_by_name(db: Session, groups: list[str]) -> list[models.Group]: def get_user_groups_by_name(
return db.query(models.Group).filter( db: Session, groups: list[str]
models.Group.id.in_(groups) ) -> list[Type[Group]]:
).all() return db.query(models.Group).filter(models.Group.id.in_(groups)).all()
# --------------- INIT User-Groups # --------------- INIT User-Groups
def upsert_group(db: Session, group_name: str, description: str, orchestrator: str, orchestrator_sheet: str, service_account_email: str, permissions: dict, domains: list) -> models.Group: def upsert_group(
db_group = db.query(models.Group).filter(models.Group.id == group_name).first() db: Session,
group_name: str,
description: str,
orchestrator: str,
orchestrator_sheet: str,
service_account_email: str,
permissions: dict,
domains: list,
) -> models.Group:
db_group = (
db.query(models.Group).filter(models.Group.id == group_name).first()
)
if db_group is None: if db_group is None:
db_group = models.Group(id=group_name, description=description, orchestrator=orchestrator, orchestrator_sheet=orchestrator_sheet, service_account_email=service_account_email, permissions=permissions, domains=domains) db_group = models.Group(
id=group_name,
description=description,
orchestrator=orchestrator,
orchestrator_sheet=orchestrator_sheet,
service_account_email=service_account_email,
permissions=permissions,
domains=domains,
)
db.add(db_group) db.add(db_group)
else: else:
db_group.description = description db_group.description = description
@@ -172,6 +276,7 @@ def upsert_user(db: Session, email: str):
def upsert_user_groups(db: Session): def upsert_user_groups(db: Session):
def display_email_pii(email: str): def display_email_pii(email: str):
return f"'{email[0:3]}...@{email.split('@')[1]}'" return f"'{email[0:3]}...@{email.split('@')[1]}'"
""" """
reads the user_groups yaml file and inserts any new users, groups, reads the user_groups yaml file and inserts any new users, groups,
along with new participation of users in groups along with new participation of users in groups
@@ -192,18 +297,33 @@ def upsert_user_groups(db: Session):
for group in explicit_groups: for group in explicit_groups:
group_domains[group].add(domain) group_domains[group].add(domain)
import json import json
# upsert groups and save a map of groupid -> dbobject # upsert groups and save a map of groupid -> dbobject
for group_id, g in ug.groups.items(): for group_id, g in ug.groups.items():
upsert_group(db, group_id, g.description, g.orchestrator, g.orchestrator_sheet, g.service_account_email, json.loads(g.permissions.model_dump_json()), list(group_domains.get(group_id, []))) upsert_group(
db_groups: dict[str, models.Group] = {g.id: g for g in db.query(models.Group).all()} db,
group_id,
g.description,
g.orchestrator,
g.orchestrator_sheet,
g.service_account_email,
json.loads(g.permissions.model_dump_json()),
list(group_domains.get(group_id, [])),
)
db_groups: dict[str, models.Group] = {
g.id: g for g in db.query(models.Group).all()
}
# integrity checks # integrity checks
for group_in_domains in group_domains: for group_in_domains in group_domains:
if group_in_domains not in db_groups: if group_in_domains not in db_groups:
logger.warning(f"[CONFIG] Group '{group_in_domains}' does not exist in the database: domains setting will not work.") logger.warning(
f"[CONFIG] Group '{group_in_domains}' does not exist in the database: domains setting will not work."
)
# reinsert users in their EXPLICITLY DEFINED groups # reinsert users in their EXPLICITLY DEFINED groups
# domain groups are check live, as there may be new users that are not explicitly registered but belong to a domain # domain groups are check live, as there may be new users that are not
# explicitly registered but belong to a domain
for email, explicit_groups in ug.users.items(): for email, explicit_groups in ug.users.items():
explicit_groups = explicit_groups or [] explicit_groups = explicit_groups or []
logger.info(f"EXPLICIT {display_email_pii(email)} => {explicit_groups}") logger.info(f"EXPLICIT {display_email_pii(email)} => {explicit_groups}")
@@ -213,7 +333,9 @@ def upsert_user_groups(db: Session):
# connect users to groups # connect users to groups
for group_id in explicit_groups: for group_id in explicit_groups:
if group_id not in db_groups: if group_id not in db_groups:
logger.warning(f"[CONFIG] Group {group_id} does not exist in config file, skipping for email={display_email_pii(email)}.") logger.warning(
f"[CONFIG] Group {group_id} does not exist in config file, skipping for email={display_email_pii(email)}."
)
continue continue
db_groups[group_id].users.append(db_user) db_groups[group_id].users.append(db_user)
@@ -221,12 +343,27 @@ def upsert_user_groups(db: Session):
count_user_groups = db.query(models.association_table_user_groups).count() count_user_groups = db.query(models.association_table_user_groups).count()
count_groups = db.query(func.count(models.Group.id)).scalar() count_groups = db.query(func.count(models.Group.id)).scalar()
logger.success(f"[CONFIG] DONE: [users={count_users(db)}, groups={count_groups}, explicit user groups={count_user_groups}].") logger.success(
f"[CONFIG] DONE: [users={count_users(db)}, groups={count_groups}, explicit user groups={count_user_groups}]."
)
# --------------- SHEET # --------------- SHEET
def create_sheet(db: Session, sheet_id: str, name: str, email: str, group_id: str, frequency: str): def create_sheet(
db_sheet = models.Sheet(id=sheet_id, name=name, author_id=email, group_id=group_id, frequency=frequency) db: Session,
sheet_id: str,
name: str,
email: str,
group_id: str,
frequency: str,
):
db_sheet = models.Sheet(
id=sheet_id,
name=name,
author_id=email,
group_id=group_id,
frequency=frequency,
)
db.add(db_sheet) db.add(db_sheet)
db.commit() db.commit()
db.refresh(db_sheet) db.refresh(db_sheet)
@@ -234,20 +371,31 @@ def create_sheet(db: Session, sheet_id: str, name: str, email: str, group_id: st
def get_user_sheet(db: Session, email: str, sheet_id: str) -> models.Sheet: def get_user_sheet(db: Session, email: str, sheet_id: str) -> models.Sheet:
return db.query(models.Sheet).filter(models.Sheet.author_id == email, models.Sheet.id == sheet_id).first() return (
db.query(models.Sheet)
.filter(models.Sheet.author_id == email, models.Sheet.id == sheet_id)
.first()
)
def get_user_sheets(db: Session, email: str) -> list[models.Sheet]: def get_user_sheets(db: Session, email: str) -> list[models.Sheet]:
return db.query(models.Sheet).filter(models.Sheet.author_id == email).order_by(models.Sheet.last_url_archived_at.desc()).all() return (
db.query(models.Sheet)
.filter(models.Sheet.author_id == email)
.order_by(models.Sheet.last_url_archived_at.desc())
.all()
)
async def get_sheets_by_id_hash(db: AsyncSession, frequency: str, modulo: str, id_hash: int) -> list[models.Sheet]: async def get_sheets_by_id_hash(
db: AsyncSession, frequency: str, modulo: str, id_hash: int
) -> list[models.Sheet]:
result = await db.execute( result = await db.execute(
select(models.Sheet).filter(models.Sheet.frequency == frequency) select(models.Sheet).filter(models.Sheet.frequency == frequency)
) )
filtered = [] filtered = []
for sheet in result.scalars(): for sheet in result.scalars():
if fnv1a_hash_mod(sheet.id, modulo) == id_hash: if fnv1a_hash_mod(sheet.id, int(modulo)) == id_hash:
filtered.append(sheet) filtered.append(sheet)
return filtered return filtered
@@ -255,7 +403,9 @@ async def get_sheets_by_id_hash(db: AsyncSession, frequency: str, modulo: str, i
async def delete_stale_sheets(db: AsyncSession, inactivity_days: int) -> dict: async def delete_stale_sheets(db: AsyncSession, inactivity_days: int) -> dict:
time_threshold = datetime.now() - timedelta(days=inactivity_days) time_threshold = datetime.now() - timedelta(days=inactivity_days)
result = await db.execute( result = await db.execute(
select(models.Sheet).filter(models.Sheet.last_url_archived_at < time_threshold) select(models.Sheet).filter(
models.Sheet.last_url_archived_at < time_threshold
)
) )
deleted = defaultdict(list) deleted = defaultdict(list)
for sheet in result.scalars(): for sheet in result.scalars():
@@ -266,7 +416,11 @@ async def delete_stale_sheets(db: AsyncSession, inactivity_days: int) -> dict:
def delete_sheet(db: Session, sheet_id: str, email: str) -> bool: def delete_sheet(db: Session, sheet_id: str, email: str) -> bool:
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id, models.Sheet.author_id == email).first() db_sheet = (
db.query(models.Sheet)
.filter(models.Sheet.id == sheet_id, models.Sheet.author_id == email)
.first()
)
if db_sheet: if db_sheet:
db.delete(db_sheet) db.delete(db_sheet)
db.commit() db.commit()

View File

@@ -1,13 +1,13 @@
from typing import Dict, Set
import sqlalchemy
from sqlalchemy.orm import Session
from sqlalchemy import func
from datetime import datetime from datetime import datetime
from typing import Dict, Set
import sqlalchemy
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.shared.db import models from app.shared.db import models
from app.shared.user_groups import GroupInfo, GroupPermissions
from app.shared.schemas import Usage, UsageResponse from app.shared.schemas import Usage, UsageResponse
from app.shared.user_groups import GroupInfo, GroupPermissions
from app.web.db import crud from app.web.db import crud
from app.web.utils.misc import convert_priority_to_queue_dict from app.web.utils.misc import convert_priority_to_queue_dict
@@ -20,14 +20,15 @@ class UserState:
def __init__(self, db: Session, email: str): def __init__(self, db: Session, email: str):
self.db = db self.db = db
self.email = email.lower() self.email = email.lower()
self._permissions = {}
@property @property
def permissions(self) -> Dict[str, GroupInfo]: def permissions(self) -> Dict[str, GroupInfo]:
""" """
Returns a dict of all group permissions and a special {"all": read/archive_url/archive_sheet} key Returns a dict of all group permissions and a special
{"all": read/archive_url/archive_sheet} key
""" """
if not hasattr(self, '_permissions'): if not self._permissions:
self._permissions = {}
self._permissions["all"] = GroupInfo( self._permissions["all"] = GroupInfo(
read=self.read, read=self.read,
read_public=self.read_public, read_public=self.read_public,
@@ -37,23 +38,33 @@ class UserState:
max_archive_lifespan_months=self.max_archive_lifespan_months, max_archive_lifespan_months=self.max_archive_lifespan_months,
max_monthly_urls=self.max_monthly_urls, max_monthly_urls=self.max_monthly_urls,
max_monthly_mbs=self.max_monthly_mbs, max_monthly_mbs=self.max_monthly_mbs,
priority=self.priority priority=self.priority,
) )
for group in self.user_groups: for group in self.user_groups:
if not group.permissions: continue if not group.permissions:
self._permissions[group.id] = GroupInfo(**group.permissions, description=group.description, service_account_email=group.service_account_email) continue
self._permissions[group.id] = GroupInfo(
**group.permissions,
description=group.description,
service_account_email=group.service_account_email,
)
return self._permissions return self._permissions
@property @property
def user_groups_names(self): def user_groups_names(self):
if not hasattr(self, '_user_groups_names'): if not hasattr(self, "_user_groups_names"):
self._user_groups_names = crud.get_user_group_names(self.db, self.email) + ["default"] # TODO: Define hidden properties in __init__ method
self._user_groups_names = crud.get_user_group_names(
self.db, self.email
) + ["default"]
return self._user_groups_names return self._user_groups_names
@property @property
def user_groups(self): def user_groups(self):
if not hasattr(self, '_user_groups'): if not hasattr(self, "_user_groups"):
self._user_groups = crud.get_user_groups_by_name(self.db, self.user_groups_names) self._user_groups = crud.get_user_groups_by_name(
self.db, self.user_groups_names
)
return self._user_groups return self._user_groups
@property @property
@@ -61,10 +72,11 @@ class UserState:
""" """
Read can be a list of group names or True, if all can be read. Read can be a list of group names or True, if all can be read.
""" """
if not hasattr(self, '_read'): if not hasattr(self, "_read"):
self._read = set() self._read = set()
for group in self.user_groups: for group in self.user_groups:
if not group.permissions: continue if not group.permissions:
continue
group_read_permissions = group.permissions.get("read", []) group_read_permissions = group.permissions.get("read", [])
if "all" in group_read_permissions: if "all" in group_read_permissions:
self._read = True self._read = True
@@ -78,10 +90,11 @@ class UserState:
""" """
Read public permission Read public permission
""" """
if not hasattr(self, '_read_public'): if not hasattr(self, "_read_public"):
self._read_public = False self._read_public = False
for group in self.user_groups: for group in self.user_groups:
if not group.permissions: continue if not group.permissions:
continue
if group.permissions.get("read_public", False): if group.permissions.get("read_public", False):
self._read_public = True self._read_public = True
return self._read_public return self._read_public
@@ -92,10 +105,11 @@ class UserState:
""" """
Archive URL permission Archive URL permission
""" """
if not hasattr(self, '_archive_url'): if not hasattr(self, "_archive_url"):
self._archive_url = False self._archive_url = False
for group in self.user_groups: for group in self.user_groups:
if not group.permissions: continue if not group.permissions:
continue
if group.permissions.get("archive_url", False): if group.permissions.get("archive_url", False):
self._archive_url = True self._archive_url = True
return self._archive_url return self._archive_url
@@ -106,10 +120,11 @@ class UserState:
""" """
Archive sheet permission Archive sheet permission
""" """
if not hasattr(self, '_archive_sheet'): if not hasattr(self, "_archive_sheet"):
self._archive_sheet = False self._archive_sheet = False
for group in self.user_groups: for group in self.user_groups:
if not group.permissions: continue if not group.permissions:
continue
if group.permissions.get("archive_sheet", False): if group.permissions.get("archive_sheet", False):
self._archive_sheet = True self._archive_sheet = True
return self._archive_sheet return self._archive_sheet
@@ -117,37 +132,53 @@ class UserState:
@property @property
def sheet_frequency(self): def sheet_frequency(self):
if not hasattr(self, '_sheet_frequency'): if not hasattr(self, "_sheet_frequency"):
self._sheet_frequency = set() self._sheet_frequency = set()
for group in self.user_groups: for group in self.user_groups:
if not group.permissions: continue if not group.permissions:
self._sheet_frequency.update(group.permissions.get("sheet_frequency", None)) continue
self._sheet_frequency.update(
group.permissions.get("sheet_frequency", None)
)
return self._sheet_frequency return self._sheet_frequency
@property @property
def max_archive_lifespan_months(self) -> int: def max_archive_lifespan_months(self) -> int:
if not hasattr(self, '_max_archive_lifespan_months'): if not hasattr(self, "_max_archive_lifespan_months"):
self._max_archive_lifespan_months = self._helper_for_grouping_max_numerical_permissions("max_archive_lifespan_months") self._max_archive_lifespan_months = (
self._helper_for_grouping_max_numerical_permissions(
"max_archive_lifespan_months"
)
)
return self._max_archive_lifespan_months return self._max_archive_lifespan_months
@property @property
def max_monthly_urls(self) -> int: def max_monthly_urls(self) -> int:
if not hasattr(self, '_max_monthly_urls'): if not hasattr(self, "_max_monthly_urls"):
self._max_monthly_urls = self._helper_for_grouping_max_numerical_permissions("max_monthly_urls") self._max_monthly_urls = (
self._helper_for_grouping_max_numerical_permissions(
"max_monthly_urls"
)
)
return self._max_monthly_urls return self._max_monthly_urls
@property @property
def max_monthly_mbs(self) -> int: def max_monthly_mbs(self) -> int:
if not hasattr(self, '_max_monthly_mbs'): if not hasattr(self, "_max_monthly_mbs"):
self._max_monthly_mbs = self._helper_for_grouping_max_numerical_permissions("max_monthly_mbs") self._max_monthly_mbs = (
self._helper_for_grouping_max_numerical_permissions(
"max_monthly_mbs"
)
)
return self._max_monthly_mbs return self._max_monthly_mbs
@property @property
def priority(self) -> str: def priority(self) -> str:
if not hasattr(self, '_priority'): if not hasattr(self, "_priority"):
self._priority = "low" self._priority = "low"
for group in self.user_groups: for group in self.user_groups:
if not group.permissions: continue if not group.permissions:
continue
if group.permissions.get("priority", self._priority) == "high": if group.permissions.get("priority", self._priority) == "high":
self._priority = "high" self._priority = "high"
break break
@@ -158,18 +189,28 @@ class UserState:
""" """
A user is active if they can read/archive anything A user is active if they can read/archive anything
""" """
if not hasattr(self, '_active'): if not hasattr(self, "_active"):
self._active = bool(self.read or self.read_public or self.archive_url or self.archive_sheet) self._active = bool(
self.read
or self.read_public
or self.archive_url
or self.archive_sheet
)
return self._active return self._active
def _helper_for_grouping_max_numerical_permissions(self, permission_name: str) -> int: def _helper_for_grouping_max_numerical_permissions(
self, permission_name: str
) -> int:
""" """
Iterates one of the numerical permissions where -1 means no restrictions and returns either -1 or the maximum value, defaults according to GroupPermissions Iterates one of the numerical permissions where -1 means no restrictions
and returns either -1 or the maximum value, defaults according to
GroupPermissions
""" """
default = GroupPermissions.model_fields[permission_name].default default = GroupPermissions.model_fields[permission_name].default
max_value = default max_value = default
for group in self.user_groups: for group in self.user_groups:
if not group.permissions: continue if not group.permissions:
continue
group_value = group.permissions.get(permission_name, default) group_value = group.permissions.get(permission_name, default)
if group_value == -1: if group_value == -1:
max_value = -1 max_value = -1
@@ -180,43 +221,65 @@ class UserState:
def in_group(self, group_id: str) -> bool: def in_group(self, group_id: str) -> bool:
return group_id in self.user_groups_names return group_id in self.user_groups_names
def usage(self) -> Dict: def usage(self) -> UsageResponse:
""" """
returns the monthly quotas for the URLs/MBs and the totals for Sheets Returns the monthly quotas for the URLs/MBs and the totals for Sheets
""" """
current_month = datetime.now().month current_month = datetime.now().month
current_year = datetime.now().year current_year = datetime.now().year
# find and sum all user sheets over this month # find and sum all user sheets over this month
user_sheets = self.db.query( user_sheets = (
models.Sheet.group_id, self.db.query(
func.count(models.Sheet.id).label('sheet_count') models.Sheet.group_id,
).filter(models.Sheet.author_id == self.email).group_by(models.Sheet.group_id).all() func.count(models.Sheet.id).label("sheet_count"),
)
.filter(models.Sheet.author_id == self.email)
.group_by(models.Sheet.group_id)
.all()
)
sheets_by_group = {sheet.group_id: sheet.sheet_count for sheet in user_sheets} sheets_by_group = {
sheet.group_id: sheet.sheet_count for sheet in user_sheets
}
# find and sum all user urls over this month # find and sum all user urls over this month
urls_by_group = self.db.query( urls_by_group = (
models.Archive.group_id, self.db.query(
func.count(models.Archive.id).label('url_count'), models.Archive.group_id,
func.coalesce(func.sum( func.count(models.Archive.id).label("url_count"),
func.coalesce( func.coalesce(
func.cast( func.sum(
func.json_extract(models.Archive.result, '$.metadata.total_bytes'), func.coalesce(
sqlalchemy.Integer func.cast(
), 0 func.json_extract(
) models.Archive.result,
), 0).label('total_bytes') "$.metadata.total_bytes",
).filter( ),
models.Archive.author_id == self.email, sqlalchemy.Integer,
func.extract('month', models.Archive.created_at) == current_month, ),
func.extract('year', models.Archive.created_at) == current_year 0,
).group_by(models.Archive.group_id).all() )
),
0,
).label("total_bytes"),
)
.filter(
models.Archive.author_id == self.email,
func.extract("month", models.Archive.created_at)
== current_month,
func.extract("year", models.Archive.created_at) == current_year,
)
.group_by(models.Archive.group_id)
.all()
)
# merge the two queries # merge the two queries
usage_by_group: Dict[str, Usage] = { usage_by_group: Dict[str, Usage] = {
(url.group_id or ""): (url.group_id or ""): Usage(
Usage(monthly_urls=url.url_count, monthly_mbs=int(url.total_bytes / 1024 / 1024)) monthly_urls=url.url_count,
monthly_mbs=int(url.total_bytes / 1024 / 1024),
)
for url in urls_by_group for url in urls_by_group
} }
for group_id, sheet_count in sheets_by_group.items(): for group_id, sheet_count in sheets_by_group.items():
@@ -235,7 +298,7 @@ class UserState:
monthly_urls=total_urls, monthly_urls=total_urls,
monthly_mbs=int(total_bytes / 1024 / 1024), monthly_mbs=int(total_bytes / 1024 / 1024),
total_sheets=total_sheets, total_sheets=total_sheets,
groups=usage_by_group groups=usage_by_group,
) )
def has_quota_monthly_sheets(self, group_id: str) -> bool: def has_quota_monthly_sheets(self, group_id: str) -> bool:
@@ -245,7 +308,14 @@ class UserState:
if group_id not in self.permissions: if group_id not in self.permissions:
return False return False
user_sheets = self.db.query(models.Sheet).filter(models.Sheet.author_id == self.email, models.Sheet.group_id == group_id).count() user_sheets = (
self.db.query(models.Sheet)
.filter(
models.Sheet.author_id == self.email,
models.Sheet.group_id == group_id,
)
.count()
)
sheet_quota = self.permissions[group_id].max_sheets sheet_quota = self.permissions[group_id].max_sheets
if sheet_quota == -1: if sheet_quota == -1:
@@ -254,13 +324,15 @@ class UserState:
def has_quota_max_monthly_urls(self, group_id: str) -> bool: def has_quota_max_monthly_urls(self, group_id: str) -> bool:
""" """
checks if a user has reached their monthly url quota for a group, if global then group should be empty string Checks if a user has reached their monthly url quota for a group, if
global then group should be empty string
""" """
quota = 0 quota = 0
if not group_id: if not group_id:
quota = self.max_monthly_urls quota = self.max_monthly_urls
else: else:
if group_id not in self.permissions: return False if group_id not in self.permissions:
return False
quota = self.permissions[group_id].max_monthly_urls quota = self.permissions[group_id].max_monthly_urls
if quota == -1: if quota == -1:
@@ -268,24 +340,31 @@ class UserState:
current_month = datetime.now().month current_month = datetime.now().month
current_year = datetime.now().year current_year = datetime.now().year
user_urls = self.db.query(models.Archive).filter( user_urls = (
models.Archive.author_id == self.email, self.db.query(models.Archive)
models.Archive.group_id == group_id, .filter(
func.extract('month', models.Archive.created_at) == current_month, models.Archive.author_id == self.email,
func.extract('year', models.Archive.created_at) == current_year models.Archive.group_id == group_id,
).count() func.extract("month", models.Archive.created_at)
== current_month,
func.extract("year", models.Archive.created_at) == current_year,
)
.count()
)
return user_urls < quota return user_urls < quota
def has_quota_max_monthly_mbs(self, group_id: str) -> bool: def has_quota_max_monthly_mbs(self, group_id: str) -> bool:
""" """
checks if a user has reached their monthly MBs quota for a group, if global then group should be empty string Checks if a user has reached their monthly MBs quota for a group, if
global then group should be empty string
""" """
quota = 0 quota = 0
if not group_id: if not group_id:
quota = self.max_monthly_mbs quota = self.max_monthly_mbs
else: else:
if group_id not in self.permissions: return False if group_id not in self.permissions:
return False
quota = self.permissions[group_id].max_monthly_mbs quota = self.permissions[group_id].max_monthly_mbs
if quota == -1: if quota == -1:
@@ -295,19 +374,34 @@ class UserState:
current_year = datetime.now().year current_year = datetime.now().year
# find and sum all user bytes over this month # find and sum all user bytes over this month
user_bytes = self.db.query(models.Archive).filter( user_bytes = (
models.Archive.author_id == self.email, self.db.query(models.Archive)
models.Archive.group_id == group_id, .filter(
func.extract('month', models.Archive.created_at) == current_month, models.Archive.author_id == self.email,
func.extract('year', models.Archive.created_at) == current_year models.Archive.group_id == group_id,
).with_entities(func.coalesce(func.sum( func.extract("month", models.Archive.created_at)
func.coalesce( == current_month,
func.cast( func.extract("year", models.Archive.created_at) == current_year,
func.json_extract(models.Archive.result, '$.metadata.total_bytes'),
sqlalchemy.Integer
), 0
) )
), 0).label('total')).scalar() .with_entities(
func.coalesce(
func.sum(
func.coalesce(
func.cast(
func.json_extract(
models.Archive.result,
"$.metadata.total_bytes",
),
sqlalchemy.Integer,
),
0,
)
),
0,
).label("total")
)
.scalar()
)
# convert bytes to mb # convert bytes to mb
user_mbs = int(user_bytes / 1024 / 1024) user_mbs = int(user_bytes / 1024 / 1024)
@@ -315,7 +409,7 @@ class UserState:
def can_manually_trigger(self, group_id: str) -> bool: def can_manually_trigger(self, group_id: str) -> bool:
""" """
checks if a user is allowed to manually trigger a sheet Checks if a user is allowed to manually trigger a sheet
""" """
if group_id not in self.permissions: if group_id not in self.permissions:
return False return False
@@ -324,18 +418,21 @@ class UserState:
def is_sheet_frequency_allowed(self, group_id: str, frequency: str) -> bool: def is_sheet_frequency_allowed(self, group_id: str, frequency: str) -> bool:
""" """
checks if a user is allowed to create a sheet with this frequency for this group Checks if a user is allowed to create a sheet with this frequency for
this group
""" """
if group_id not in self.permissions: if group_id not in self.permissions:
return False return False
return frequency in self.permissions[group_id].sheet_frequency return frequency in self.permissions[group_id].sheet_frequency
def priority_group(self, group_id: str) -> str: def priority_group(self, group_id: str) -> dict:
priority = "low" priority = "low"
for group in self.user_groups: for group in self.user_groups:
if group.id != group_id: continue if group.id != group_id:
if not group.permissions: continue continue
if not group.permissions:
continue
priority = group.permissions.get("priority", priority) priority = group.permissions.get("priority", priority)
break break
return convert_priority_to_queue_dict(priority) return convert_priority_to_queue_dict(priority)

View File

@@ -1,50 +0,0 @@
from typing import Dict
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import FileResponse, JSONResponse
from app.web.config import VERSION, BREAKING_CHANGES
from app.shared.schemas import ActiveUser, UsageResponse
from app.web.db.user_state import UserState
from app.web.security import get_user_state
from app.shared.user_groups import GroupInfo
default_router = APIRouter()
@default_router.get("/")
async def home():
return JSONResponse({"version": VERSION, "breakingChanges": BREAKING_CHANGES})
@default_router.get("/health")
async def health():
return JSONResponse({"status": "ok"})
@default_router.get("/user/active", summary="Check if the user is active and can use the tool.")
async def active(
user: UserState = Depends(get_user_state),
) -> ActiveUser:
return {"active": user.active}
@default_router.get("/user/permissions", summary="Get the user's global 'all' permissions and the permissions for each group they belong to.")
def get_user_permissions(
user: UserState = Depends(get_user_state),
) -> Dict[str, GroupInfo]:
return user.permissions
@default_router.get("/user/usage", summary="Get the user's monthly URLs/MBs usage along with the total active sheets, breakdown by group.")
def get_user_usage(
user: UserState = Depends(get_user_state),
) -> UsageResponse:
if not user.active:
raise HTTPException(status_code=403, detail="User is not active.")
return user.usage()
@default_router.get('/favicon.ico', include_in_schema=False)
async def favicon() -> FileResponse:
return FileResponse("app/web/static/favicon.ico")

View File

@@ -1,81 +0,0 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse
from sqlalchemy import exc
from sqlalchemy.orm import Session
from app.web.db.user_state import UserState
from app.shared import schemas
from app.shared.task_messaging import get_celery
from app.web.security import get_user_state
from app.web.db import crud
from app.shared.db.database import get_db_dependency
sheet_router = APIRouter(prefix="/sheet", tags=["Google Spreadsheet operations"])
celery = get_celery()
@sheet_router.post("/create", status_code=201, summary="Store a new Google Sheet for regular archiving.")
def create_sheet(
sheet: schemas.SheetAdd,
user: UserState = Depends(get_user_state),
db: Session = Depends(get_db_dependency),
) -> schemas.SheetResponse:
if not user.in_group(sheet.group_id):
raise HTTPException(status_code=403, detail="User does not have access to this group.")
if not user.has_quota_monthly_sheets(sheet.group_id):
raise HTTPException(status_code=429, detail="User has reached their sheet quota for this group.")
if not user.is_sheet_frequency_allowed(sheet.group_id, sheet.frequency):
raise HTTPException(status_code=422, detail="Invalid frequency selected for this group.")
try:
return crud.create_sheet(db, sheet.id, sheet.name, user.email, sheet.group_id, sheet.frequency)
except exc.IntegrityError as e:
raise HTTPException(status_code=400, detail="Sheet with this ID is already being archived.") from e
@sheet_router.get("/mine", status_code=200, summary="Get the authenticated user's Google Sheets.")
def get_user_sheets(
user: UserState = Depends(get_user_state),
db: Session = Depends(get_db_dependency)
) -> list[schemas.SheetResponse]:
return crud.get_user_sheets(db, user.email)
@sheet_router.delete("/{id}", summary="Delete a Google Sheet by ID.")
def delete_sheet(
id: str,
user: UserState = Depends(get_user_state),
db: Session = Depends(get_db_dependency),
) -> schemas.DeleteResponse:
return JSONResponse({
"id": id,
"deleted": crud.delete_sheet(db, id, user.email)
})
@sheet_router.post("/{id}/archive", status_code=201, summary="Trigger an archiving task for a GSheet you own.", response_description="task_id for the archiving task.")
def archive_user_sheet(
id: str,
user: UserState = Depends(get_user_state),
db: Session = Depends(get_db_dependency),
) -> schemas.Task:
sheet = crud.get_user_sheet(db, user.email, sheet_id=id)
if not sheet:
raise HTTPException(status_code=403, detail="No access to this sheet.")
if not user.in_group(sheet.group_id):
raise HTTPException(status_code=403, detail="User does not have access to this group.")
if not user.can_manually_trigger(sheet.group_id):
raise HTTPException(status_code=429, detail="User cannot manually trigger sheet archiving in this group.")
group_queue = user.priority_group(sheet.group_id)
task = celery.signature("create_sheet_task", args=[schemas.SubmitSheet(sheet_id=id, author_id=user.email, group_id=sheet.group_id).model_dump_json()]).apply_async(**group_queue)
return JSONResponse({"id": task.id}, status_code=201)

View File

@@ -1,40 +0,0 @@
from celery.result import AsyncResult
from fastapi import APIRouter, Depends
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from app.shared.task_messaging import get_celery
from app.web.security import get_token_or_user_auth
from app.shared import schemas
from app.shared.log import log_error
from app.web.utils.misc import custom_jsonable_encoder
task_router = APIRouter(prefix="/task", tags=["Async task operations"])
celery = get_celery()
@task_router.get("/{task_id}", summary="Check the status of an async task by its id, works for URLs and Sheet tasks.")
def get_status(task_id, email=Depends(get_token_or_user_auth)) -> schemas.TaskResult:
task = AsyncResult(task_id, app=celery)
try:
if task.status == "FAILURE":
# *FAILURE* The task raised an exception, or has exceeded the retry limit.
# The :attr:`result` attribute then contains the exception raised by the task.
# https://docs.celeryq.dev/en/stable/_modules/celery/result.html#AsyncResult
raise task.result
response = {
"id": task_id,
"status": task.status,
"result": task.result
}
return JSONResponse(jsonable_encoder(response, exclude_unset=True, custom_encoder={bytes: custom_jsonable_encoder}))
except Exception as e:
log_error(e)
return JSONResponse({
"id": task_id,
"status": "FAILURE",
"result": {"error": str(e)}
})

View File

@@ -1,85 +0,0 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse
from datetime import datetime
from loguru import logger
from sqlalchemy.orm import Session
from app.web.config import ALLOW_ANY_EMAIL
from app.shared import schemas
from app.shared.task_messaging import get_celery
from app.web.security import get_token_or_user_auth, get_user_state
from app.web.db import crud
from app.web.db.user_state import UserState
from app.shared.db.database import get_db_dependency
from urllib.parse import urlparse
from app.web.utils.misc import convert_priority_to_queue_dict
url_router = APIRouter(prefix="/url", tags=["Single URL operations"])
celery = get_celery()
@url_router.post("/archive", status_code=201, summary="Submit a single URL archive request, starts an archiving task.", response_description="task_id for the archiving task, will match the archive id.")
def archive_url(
archive: schemas.ArchiveTrigger,
email=Depends(get_token_or_user_auth),
db: Session = Depends(get_db_dependency)
) -> schemas.Task:
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {archive.url}")
parsed_url = urlparse(archive.url)
if not all([parsed_url.scheme, parsed_url.netloc]):
raise HTTPException(status_code=400, detail="Invalid URL received.")
archive_create = schemas.ArchiveCreate(**archive.model_dump())
if email != ALLOW_ANY_EMAIL:
archive_create.author_id = email
user = UserState(db, email)
if archive.group_id and not user.in_group(archive.group_id):
raise HTTPException(status_code=403, detail="User does not have access to this group.")
if not user.has_quota_max_monthly_urls(archive.group_id):
raise HTTPException(status_code=429, detail="User has reached their monthly URL quota.")
if not user.has_quota_max_monthly_mbs(archive.group_id):
raise HTTPException(status_code=429, detail="User has reached their monthly MB quota.")
group_queue = user.priority_group(archive_create.group_id)
else:
archive_create.author_id = archive.author_id or email
group_queue = convert_priority_to_queue_dict("high")
task = celery.signature("create_archive_task", args=[archive_create.model_dump_json()]).apply_async(**group_queue)
task_response = schemas.Task(id=task.id)
return JSONResponse(task_response.model_dump(), status_code=201)
@url_router.get("/search", summary="Search for archive entries by URL.")
def search_by_url(
url: str, skip: int = 0, limit: int = 25,
archived_after: datetime = None, archived_before: datetime = None,
db: Session = Depends(get_db_dependency),
email: str = Depends(get_token_or_user_auth)
) -> list[schemas.ArchiveResult]:
read_groups, read_public = False, False
if email != ALLOW_ANY_EMAIL:
user = UserState(db, email)
if not user.read and not user.read_public:
raise HTTPException(status_code=403, detail="User does not have read access.")
read_groups = user.read
read_public = user.read_public
return crud.search_archives_by_url(db, url.strip(), email, read_groups, read_public, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before)
@url_router.delete("/{id}", summary="Delete a single URL archive by id.")
def delete_archive(
id:str,
user: UserState = Depends(get_user_state),
db: Session = Depends(get_db_dependency)
) -> schemas.DeleteResponse:
logger.info(f"deleting url archive task {id} request by {user.email}")
return JSONResponse({
"id": id,
"deleted": crud.soft_delete_archive(db, id, user.email)
})

View File

@@ -1,22 +1,32 @@
import asyncio import asyncio
from collections import defaultdict
import datetime import datetime
import logging import logging
from collections import defaultdict
from contextlib import asynccontextmanager
import alembic.config import alembic.config
from fastapi import FastAPI from fastapi import FastAPI
from contextlib import asynccontextmanager from fastapi_mail import FastMail, MessageSchema, MessageType
from fastapi_utils.tasks import repeat_every from fastapi_utils.tasks import repeat_every
from loguru import logger from loguru import logger
from fastapi_mail import FastMail, MessageSchema, MessageType
from app.shared.db import models
from app.shared.db.database import get_db, get_db_async, make_engine, wal_checkpoint
from app.shared import schemas from app.shared import schemas
from app.shared.db import models
from app.shared.db.database import (
get_db,
get_db_async,
make_engine,
wal_checkpoint,
)
from app.shared.settings import get_settings from app.shared.settings import get_settings
from app.shared.task_messaging import get_celery from app.shared.task_messaging import get_celery
from app.web.db import crud from app.web.db import crud
from app.web.middleware import increase_exceptions_counter from app.web.middleware import increase_exceptions_counter
from app.web.utils.metrics import measure_regular_metrics, redis_subscribe_worker_exceptions from app.web.utils.metrics import (
measure_regular_metrics,
redis_subscribe_worker_exceptions,
)
celery = get_celery() celery = get_celery()
@@ -28,9 +38,22 @@ async def lifespan(app: FastAPI):
# STARTUP # STARTUP
engine = make_engine(get_settings().DATABASE_PATH) engine = make_engine(get_settings().DATABASE_PATH)
models.Base.metadata.create_all(bind=engine) models.Base.metadata.create_all(bind=engine)
alembic.config.main(prog="alembic", argv=['--raiseerr', 'upgrade', 'head']) alembic.config.main(
prog="alembic",
argv=[
"-c",
"./app/migrations/alembic.ini",
"--raiseerr",
"upgrade",
"head",
],
)
logging.getLogger("uvicorn.access").disabled = True # loguru logging.getLogger("uvicorn.access").disabled = True # loguru
asyncio.create_task(redis_subscribe_worker_exceptions(get_settings().REDIS_EXCEPTIONS_CHANNEL)) asyncio.create_task(
redis_subscribe_worker_exceptions(
get_settings().REDIS_EXCEPTIONS_CHANNEL
)
)
asyncio.create_task(repeat_measure_regular_metrics()) asyncio.create_task(repeat_measure_regular_metrics())
with get_db() as db: with get_db() as db:
crud.upsert_user_groups(db) crud.upsert_user_groups(db)
@@ -61,41 +84,74 @@ async def lifespan(app: FastAPI):
# CRON JOBS # CRON JOBS
@repeat_every(seconds=get_settings().REPEAT_COUNT_METRICS_SECONDS, on_exception=increase_exceptions_counter) @repeat_every(
seconds=get_settings().REPEAT_COUNT_METRICS_SECONDS,
on_exception=increase_exceptions_counter,
)
async def repeat_measure_regular_metrics(): async def repeat_measure_regular_metrics():
await measure_regular_metrics(get_settings().DATABASE_PATH, get_settings().REPEAT_COUNT_METRICS_SECONDS) await measure_regular_metrics(
get_settings().DATABASE_PATH,
get_settings().REPEAT_COUNT_METRICS_SECONDS,
)
@repeat_every(seconds=60, wait_first=120, on_exception=increase_exceptions_counter) @repeat_every(
seconds=60, wait_first=120, on_exception=increase_exceptions_counter
)
async def archive_hourly_sheets_cronjob(): async def archive_hourly_sheets_cronjob():
await archive_sheets_cronjob("hourly", 60, datetime.datetime.now().minute) await archive_sheets_cronjob("hourly", 60, datetime.datetime.now().minute)
@repeat_every(seconds=3600, wait_first=120, on_exception=increase_exceptions_counter) @repeat_every(
seconds=3600, wait_first=120, on_exception=increase_exceptions_counter
)
async def archive_daily_sheets_cronjob(): async def archive_daily_sheets_cronjob():
await archive_sheets_cronjob("daily", 24, datetime.datetime.now().hour) await archive_sheets_cronjob("daily", 24, datetime.datetime.now().hour)
async def archive_sheets_cronjob(frequency: str, interval: int, current_time_unit: int): async def archive_sheets_cronjob(
frequency: str, interval: int, current_time_unit: int
):
triggered_jobs = [] triggered_jobs = []
async with get_db_async() as db: async with get_db_async() as db:
sheets = await crud.get_sheets_by_id_hash(db, frequency, interval, current_time_unit) sheets = await crud.get_sheets_by_id_hash(
db, frequency, str(interval), current_time_unit
)
for s in sheets: for s in sheets:
group_queue = await crud.get_group_priority_async(db, s.group_id) group_queue = await crud.get_group_priority_async(db, s.group_id)
task = celery.signature("create_sheet_task", args=[schemas.SubmitSheet(sheet_id=s.id, author_id=s.author_id, group_id=s.group_id).model_dump_json()]).apply_async(**group_queue) task = celery.signature(
"create_sheet_task",
args=[
schemas.SubmitSheet(
sheet_id=s.id,
author_id=s.author_id,
group_id=s.group_id,
).model_dump_json()
],
).apply_async(**group_queue)
triggered_jobs.append({"sheet_id": s.id, "task_id": task.id}) triggered_jobs.append({"sheet_id": s.id, "task_id": task.id})
logger.debug(f"[CRON {frequency.upper()}:{current_time_unit}] Triggered {len(triggered_jobs)} sheet tasks: {triggered_jobs}") logger.debug(
f"[CRON {frequency.upper()}:{current_time_unit}] Triggered {len(triggered_jobs)} sheet tasks: {triggered_jobs}"
)
# TODO: on exception should logerror but also prometheus counter # TODO: on exception should logerror but also prometheus counter
DELETE_WINDOW = get_settings().DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS * 24 * 60 * 60 DELETE_WINDOW = (
get_settings().DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS * 24 * 60 * 60
)
@repeat_every(seconds=DELETE_WINDOW, wait_first=180, on_exception=increase_exceptions_counter) @repeat_every(
seconds=DELETE_WINDOW,
wait_first=180,
on_exception=increase_exceptions_counter,
)
async def notify_about_expired_archives(): async def notify_about_expired_archives():
notify_from = datetime.datetime.now() + datetime.timedelta(days=get_settings().DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS) notify_from = datetime.datetime.now() + datetime.timedelta(
days=get_settings().DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS
)
async with get_db_async() as db: async with get_db_async() as db:
scheduled_deletions = await crud.find_by_store_until(db, notify_from) scheduled_deletions = await crud.find_by_store_until(db, notify_from)
@@ -104,10 +160,15 @@ async def notify_about_expired_archives():
user_archives[archive.author_id].append(archive) user_archives[archive.author_id].append(archive)
if user_archives: if user_archives:
fastmail = FastMail(get_settings().MAIL_CONFIG) fastmail = FastMail(get_settings().mail_config)
# notify users # notify users
for email in user_archives: for email in user_archives:
list_of_archives = "\n".join([f'{a.url}, {a.id}, {a.store_until.isoformat()}<br/>' for a in user_archives[email]]) list_of_archives = "\n".join(
[
f"{a.url}, {a.id}, {a.store_until.isoformat()}<br/>"
for a in user_archives[email]
]
)
# TODO: how can users download them in bulk? # TODO: how can users download them in bulk?
message = MessageSchema( message = MessageSchema(
subject="Auto Archiver: Archives Scheduled for Deletion", subject="Auto Archiver: Archives Scheduled for Deletion",
@@ -127,16 +188,23 @@ async def notify_about_expired_archives():
</body> </body>
</html> </html>
""", """,
subtype=MessageType.html subtype=MessageType.html,
) )
await fastmail.send_message(message) await fastmail.send_message(message)
logger.debug(f"[CRON] Email sent to {email} about {len(user_archives[email])} scheduled archives deletion.") logger.debug(
f"[CRON] Email sent to {email} about {len(user_archives[email])} scheduled archives deletion."
)
# now schedule the deletion event # now schedule the deletion event
asyncio.create_task(delete_expired_archives()) asyncio.create_task(delete_expired_archives())
@repeat_every(max_repetitions=1, wait_first=10, seconds=0, on_exception=increase_exceptions_counter) @repeat_every(
max_repetitions=1,
wait_first=10,
seconds=0,
on_exception=increase_exceptions_counter,
)
async def delete_expired_archives(): async def delete_expired_archives():
async with get_db_async() as db: async with get_db_async() as db:
count_deleted = await crud.soft_delete_expired_archives(db) count_deleted = await crud.soft_delete_expired_archives(db)
@@ -144,19 +212,27 @@ async def delete_expired_archives():
logger.debug(f"[CRON] Deleted {count_deleted} archives.") logger.debug(f"[CRON] Deleted {count_deleted} archives.")
@repeat_every(seconds=86400, wait_first=150, on_exception=increase_exceptions_counter) @repeat_every(
seconds=86400, wait_first=150, on_exception=increase_exceptions_counter
)
async def delete_stale_sheets(): async def delete_stale_sheets():
STALE_DAYS = get_settings().DELETE_STALE_SHEETS_DAYS STALE_DAYS = get_settings().DELETE_STALE_SHEETS_DAYS
logger.debug(f"[CRON] Deleting stale sheets older than {STALE_DAYS} days.") logger.debug(f"[CRON] Deleting stale sheets older than {STALE_DAYS} days.")
async with get_db_async() as db: async with get_db_async() as db:
user_sheets = await crud.delete_stale_sheets(db, STALE_DAYS) user_sheets = await crud.delete_stale_sheets(db, STALE_DAYS)
if not user_sheets: return if not user_sheets:
return
fastmail = FastMail(get_settings().MAIL_CONFIG) fastmail = FastMail(get_settings().mail_config)
# notify users # notify users
for email in user_sheets: for email in user_sheets:
list_of_sheets = "\n".join([f'<li><a href="https://docs.google.com/spreadsheets/d/{s.id}">{s.name}</a></li>' for s in user_sheets[email]]) list_of_sheets = "\n".join(
[
f'<li><a href="https://docs.google.com/spreadsheets/d/{s.id}">{s.name}</a></li>'
for s in user_sheets[email]
]
)
message = MessageSchema( message = MessageSchema(
subject="Auto Archiver: Stale Sheets Removed", subject="Auto Archiver: Stale Sheets Removed",
recipients=[email], recipients=[email],
@@ -173,14 +249,16 @@ async def delete_stale_sheets():
</body> </body>
</html> </html>
""", """,
subtype=MessageType.html subtype=MessageType.html,
) )
await fastmail.send_message(message) await fastmail.send_message(message)
logger.debug(f"[CRON] Email sent to {email} about stale sheets deletion.") logger.debug(
f"[CRON] Email sent to {email} about stale sheets deletion."
)
# @repeat_at # @repeat_at
async def generate_users_export_csv(): async def generate_users_export_csv():
#TODO: implement a cronjob that regularly requested user data to a CSV file # TODO: implement a cronjob that regularly requested user data to a CSV file
# see https://colab.research.google.com/drive/1QDbo3QXHPBdiTuANlA1AWVvN-rqxuCPa?authuser=0#scrollTo=4nPXeSdK8RBT # see https://colab.research.google.com/drive/1QDbo3QXHPBdiTuANlA1AWVvN-rqxuCPa?authuser=0#scrollTo=4nPXeSdK8RBT
pass pass

View File

@@ -1,34 +1,42 @@
import os import os
from fastapi import FastAPI, Depends
from fastapi.staticfiles import StaticFiles from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from prometheus_fastapi_instrumentator import Instrumentator from fastapi.staticfiles import StaticFiles
from loguru import logger from loguru import logger
from prometheus_fastapi_instrumentator import Instrumentator
from app.web.middleware import logging_middleware from app.shared.settings import Settings, get_settings
from app.shared.task_messaging import get_celery from app.shared.task_messaging import get_celery
from app.web.config import API_DESCRIPTION, VERSION
from app.web.security import token_api_key_auth
from app.web.config import VERSION, API_DESCRIPTION
from app.web.events import lifespan from app.web.events import lifespan
from app.shared.settings import get_settings from app.web.middleware import logging_middleware
from app.web.routers.default import router as default_router
from app.web.routers.interoperability import router as interoperability_router
from app.web.routers.sheet import router as sheet_router
from app.web.routers.task import router as task_router
from app.web.routers.url import router as url_router
from app.web.security import token_api_key_auth
from app.web.endpoints.default import default_router
from app.web.endpoints.url import url_router
from app.web.endpoints.sheet import sheet_router
from app.web.endpoints.task import task_router
from app.web.endpoints.interoperability import interoperability_router
celery = get_celery() celery = get_celery()
def app_factory(settings = get_settings()):
def app_factory(settings: Settings = None):
# TODO: Create dev, test, and prod versions of settings that do not have
# TODO: to be passed in as a parameter
if settings is None:
settings = get_settings()
app = FastAPI( app = FastAPI(
title="Auto-Archiver API", title="Auto-Archiver API",
description=API_DESCRIPTION, description=API_DESCRIPTION,
version=VERSION, version=VERSION,
contact={"name": "GitHub", "url": "https://github.com/bellingcat/auto-archiver-api"}, contact={
lifespan=lifespan "name": "GitHub",
"url": "https://github.com/bellingcat/auto-archiver-api",
},
lifespan=lifespan,
) )
app.add_middleware( app.add_middleware(
@@ -47,14 +55,30 @@ def app_factory(settings = get_settings()):
app.include_router(interoperability_router) app.include_router(interoperability_router)
# prometheus exposed in /metrics with authentication # prometheus exposed in /metrics with authentication
Instrumentator(should_group_status_codes=False, excluded_handlers=["/metrics", "/health", "/openapi.json", "/favicon.ico"]).instrument(app).expose(app, dependencies=[Depends(token_api_key_auth)]) Instrumentator(
should_group_status_codes=False,
excluded_handlers=[
"/metrics",
"/health",
"/openapi.json",
"/favicon.ico",
],
).instrument(app).expose(app, dependencies=[Depends(token_api_key_auth)])
if settings.SERVE_LOCAL_ARCHIVE: if settings.SERVE_LOCAL_ARCHIVE:
local_dir = settings.SERVE_LOCAL_ARCHIVE local_dir = settings.SERVE_LOCAL_ARCHIVE
if not os.path.isdir(local_dir) and os.path.isdir(local_dir.replace("/app", ".")): if not os.path.isdir(local_dir) and os.path.isdir(
local_dir.replace("/app", ".")
):
local_dir = local_dir.replace("/app", ".") local_dir = local_dir.replace("/app", ".")
if len(settings.SERVE_LOCAL_ARCHIVE) > 1 and os.path.isdir(local_dir): if len(settings.SERVE_LOCAL_ARCHIVE) > 1 and os.path.isdir(local_dir):
logger.warning(f"MOUNTing local archive, use this in development only {settings.SERVE_LOCAL_ARCHIVE}") logger.warning(
app.mount(settings.SERVE_LOCAL_ARCHIVE, StaticFiles(directory=local_dir), name=settings.SERVE_LOCAL_ARCHIVE) f"MOUNTing local archive, use this in development only {settings.SERVE_LOCAL_ARCHIVE}"
)
app.mount(
settings.SERVE_LOCAL_ARCHIVE,
StaticFiles(directory=local_dir),
name=settings.SERVE_LOCAL_ARCHIVE,
)
return app return app

View File

@@ -1,7 +1,8 @@
import traceback import traceback
from loguru import logger
from fastapi import Request from fastapi import Request
from loguru import logger
from app.shared.log import log_error from app.shared.log import log_error
from app.web.utils.metrics import EXCEPTION_COUNTER from app.web.utils.metrics import EXCEPTION_COUNTER
@@ -9,23 +10,33 @@ from app.web.utils.metrics import EXCEPTION_COUNTER
async def logging_middleware(request: Request, call_next): async def logging_middleware(request: Request, call_next):
try: try:
response = await call_next(request) response = await call_next(request)
#TODO: use Origin to have summary prometheus metrics on where requests come from # TODO: use Origin to have summary prometheus metrics on where
# requests come from
# origin = request.headers.get("origin") # origin = request.headers.get("origin")
logger.info(f"{request.client.host}:{request.client.port} {request.method} {request.url._url} - HTTP {response.status_code}") logger.info(
f"{request.client.host}:{request.client.port} {request.method} {request.url._url} - HTTP {response.status_code}"
)
return response return response
except Exception as e: except Exception as e:
location = f"{request.method} {request.url._url}" location = f"{request.method} {request.url._url}"
await increase_exceptions_counter(e, location) await increase_exceptions_counter(e, location)
logger.info(f"{request.client.host}:{request.client.port} {location} - {e.__class__.__name__} {e}") logger.info(
f"{request.client.host}:{request.client.port} {location} - {e.__class__.__name__} {e}"
)
raise e raise e
async def increase_exceptions_counter(e: Exception, location:str="cronjob"):
async def increase_exceptions_counter(
e: Exception, location: str = "cronjob"
) -> None:
if location == "cronjob": if location == "cronjob":
try: try:
last_trace = traceback.extract_tb(e.__traceback__)[-1] last_trace = traceback.extract_tb(e.__traceback__)[-1]
_file, _line, func_name, _text = last_trace _file, _line, func_name, _text = last_trace
location = func_name location = func_name
except Exception as e: except Exception as e:
logger.error(f"Unable to get function name from cronjob exception traceback: {e}") logger.error(
f"Unable to get function name from cronjob exception traceback: {e}"
)
EXCEPTION_COUNTER.labels(type=e.__class__.__name__, location=location).inc() EXCEPTION_COUNTER.labels(type=e.__class__.__name__, location=location).inc()
log_error(e) log_error(e)

View File

View File

@@ -0,0 +1,64 @@
from http import HTTPStatus
from typing import Dict
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import FileResponse, JSONResponse
from app.shared.schemas import ActiveUser, UsageResponse
from app.shared.user_groups import GroupInfo
from app.web.config import BREAKING_CHANGES, VERSION
from app.web.db.user_state import UserState
from app.web.security import get_user_state
router = APIRouter()
@router.get("/")
async def home() -> JSONResponse:
return JSONResponse(
{"version": VERSION, "breakingChanges": BREAKING_CHANGES}
)
@router.get("/health")
async def health() -> JSONResponse:
return JSONResponse({"status": "ok"})
@router.get(
"/user/active", summary="Check if the user is active and can use the tool."
)
async def active(
user: UserState = Depends(get_user_state),
) -> ActiveUser:
return ActiveUser(active=user.active)
@router.get(
"/user/permissions",
summary="Get the user's global 'all' permissions and the permissions for each group they belong to.",
)
def get_user_permissions(
user: UserState = Depends(get_user_state),
) -> Dict[str, GroupInfo]:
return user.permissions
@router.get(
"/user/usage",
summary="Get the user's monthly URLs/MBs usage along with the total active sheets, breakdown by group.",
)
def get_user_usage(
user: UserState = Depends(get_user_state),
) -> UsageResponse:
if not user.active:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN, detail="User is not active."
)
return user.usage()
@router.get("/favicon.ico", include_in_schema=False)
async def favicon() -> FileResponse:
return FileResponse("app/web/static/favicon.ico")

View File

@@ -1,41 +1,53 @@
import json import json
from http import HTTPStatus
import sqlalchemy
from auto_archiver.core import Metadata
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from loguru import logger from loguru import logger
import sqlalchemy
from auto_archiver.core import Metadata
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.shared.aa_utils import get_all_urls
from app.web.config import ALLOW_ANY_EMAIL
from app.shared import business_logic, schemas from app.shared import business_logic, schemas
from app.shared.db import worker_crud from app.shared.db import models, worker_crud
from app.shared.db.database import get_db_dependency from app.shared.db.database import get_db_dependency
from app.web.security import token_api_key_auth
from app.shared.db import models
from app.shared.log import log_error from app.shared.log import log_error
from app.web.config import ALLOW_ANY_EMAIL
from app.web.security import token_api_key_auth
from app.web.utils.misc import get_all_urls
interoperability_router = APIRouter(prefix="/interop", tags=["Interoperability endpoints."]) router = APIRouter(prefix="/interop", tags=["Interoperability endpoints."])
# ----- endpoint to submit data archived elsewhere # ----- endpoint to submit data archived elsewhere
@interoperability_router.post("/submit-archive", status_code=201, summary="Submit a manual archive entry, for data that was archived elsewhere.") @router.post(
"/submit-archive",
status_code=HTTPStatus.CREATED,
summary="Submit a manual archive entry, for data that was archived elsewhere.",
)
def submit_manual_archive( def submit_manual_archive(
manual: schemas.SubmitManualArchive, manual: schemas.SubmitManualArchive,
auth=Depends(token_api_key_auth), auth=Depends(token_api_key_auth),
db: Session = Depends(get_db_dependency) db: Session = Depends(get_db_dependency),
): ):
try: try:
result: Metadata = Metadata.from_json(manual.result) result: Metadata = Metadata.from_json(manual.result)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
log_error(e) log_error(e)
raise HTTPException(status_code=422, detail="Invalid JSON in result field.") raise HTTPException(
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
detail="Invalid JSON in result field.",
) from e
manual.author_id = manual.author_id or ALLOW_ANY_EMAIL manual.author_id = manual.author_id or ALLOW_ANY_EMAIL
manual.tags.add("manual") manual.tags.add("manual")
store_until = business_logic.get_store_archive_until_or_never(db, manual.group_id) store_until = business_logic.get_store_archive_until_or_never(
logger.debug(f"[MANUAL ARCHIVE] {manual.author_id} {manual.url} {store_until}") db, manual.group_id
)
logger.debug(
f"[MANUAL ARCHIVE] {manual.author_id} {manual.url} {store_until}"
)
try: try:
archive = schemas.ArchiveCreate( archive = schemas.ArchiveCreate(
@@ -51,8 +63,15 @@ def submit_manual_archive(
) )
db_archive = worker_crud.store_archived_url(db, archive) db_archive = worker_crud.store_archived_url(db, archive)
logger.debug(f"[MANUAL ARCHIVE STORED] {db_archive.author_id} {db_archive.url}") logger.debug(
return JSONResponse({"id": db_archive.id}, status_code=201) f"[MANUAL ARCHIVE STORED] {db_archive.author_id} {db_archive.url}"
)
return JSONResponse(
{"id": db_archive.id}, status_code=HTTPStatus.CREATED
)
except sqlalchemy.exc.IntegrityError as e: except sqlalchemy.exc.IntegrityError as e:
log_error(e) log_error(e)
raise HTTPException(status_code=422, detail=f"Cannot insert into DB due to integrity error, likely duplicate urls.") raise HTTPException(
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
detail="Cannot insert into DB due to integrity error, likely duplicate urls.",
) from e

132
app/web/routers/sheet.py Normal file
View File

@@ -0,0 +1,132 @@
from http import HTTPStatus
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse
from sqlalchemy import exc
from sqlalchemy.orm import Session
from app.shared.db.database import get_db_dependency
from app.shared.schemas import (
DeleteResponse,
SheetAdd,
SheetResponse,
SubmitSheet,
)
from app.shared.task_messaging import get_celery
from app.web.db import crud
from app.web.db.user_state import UserState
from app.web.security import get_user_state
router = APIRouter(prefix="/sheet", tags=["Google Spreadsheet operations"])
celery = get_celery()
@router.post(
"/create",
status_code=HTTPStatus.CREATED,
summary="Store a new Google Sheet for regular archiving.",
)
def create_sheet(
sheet: SheetAdd,
user: UserState = Depends(get_user_state),
db: Session = Depends(get_db_dependency),
) -> SheetResponse:
if not user.in_group(sheet.group_id):
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="User does not have access to this group.",
)
if not user.has_quota_monthly_sheets(sheet.group_id):
raise HTTPException(
status_code=HTTPStatus.TOO_MANY_REQUESTS,
detail="User has reached their sheet quota for this group.",
)
if not user.is_sheet_frequency_allowed(sheet.group_id, sheet.frequency):
raise HTTPException(
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
detail="Invalid frequency selected for this group.",
)
try:
return crud.create_sheet(
db,
sheet.id,
sheet.name,
user.email,
sheet.group_id,
sheet.frequency,
)
except exc.IntegrityError as e:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="Sheet with this ID is already being archived.",
) from e
@router.get(
"/mine",
status_code=HTTPStatus.OK,
summary="Get the authenticated user's Google Sheets.",
)
def get_user_sheets(
user: UserState = Depends(get_user_state),
db: Session = Depends(get_db_dependency),
) -> list[SheetResponse]:
return crud.get_user_sheets(db, user.email)
@router.delete("/{sheet_id}", summary="Delete a Google Sheet by ID.")
def delete_sheet(
sheet_id: str,
user: UserState = Depends(get_user_state),
db: Session = Depends(get_db_dependency),
) -> DeleteResponse:
return DeleteResponse(
id=sheet_id, deleted=crud.delete_sheet(db, sheet_id, user.email)
)
@router.post(
"/{sheet_id}/archive",
status_code=HTTPStatus.CREATED,
summary="Trigger an archiving task for a GSheet you own.",
response_description="task_id for the archiving task.",
)
def archive_user_sheet(
sheet_id: str,
user: UserState = Depends(get_user_state),
db: Session = Depends(get_db_dependency),
) -> JSONResponse:
sheet = crud.get_user_sheet(db, user.email, sheet_id=sheet_id)
if not sheet:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN, detail="No access to this sheet."
)
if not user.in_group(sheet.group_id):
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="User does not have access to this group.",
)
if not user.can_manually_trigger(sheet.group_id):
raise HTTPException(
status_code=HTTPStatus.TOO_MANY_REQUESTS,
detail="User cannot manually trigger sheet archiving in this group.",
)
group_queue = user.priority_group(sheet.group_id)
task = celery.signature(
"create_sheet_task",
args=[
SubmitSheet(
sheet_id=sheet_id, author_id=user.email, group_id=sheet.group_id
).model_dump_json()
],
).apply_async(**group_queue)
return JSONResponse({"id": task.id}, status_code=HTTPStatus.CREATED)

52
app/web/routers/task.py Normal file
View File

@@ -0,0 +1,52 @@
from celery.result import AsyncResult
from fastapi import APIRouter, Depends
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from app.shared import schemas
from app.shared.constants import STATUS_FAILURE
from app.shared.log import log_error
from app.shared.task_messaging import get_celery
from app.web.security import get_token_or_user_auth
from app.web.utils.misc import custom_jsonable_encoder
router = APIRouter(prefix="/task", tags=["Async task operations"])
celery = get_celery()
@router.get(
"/{task_id}",
summary="Check the status of an async task by its id, works for URLs and Sheet tasks.",
)
def get_status(
task_id, email=Depends(get_token_or_user_auth)
) -> schemas.TaskResult:
task = AsyncResult(task_id, app=celery)
try:
if task.status == STATUS_FAILURE:
# *FAILURE* The task raised an exception, or has exceeded the retry limit.
# The :attr:`result` attribute then contains the exception raised by
# the task.
# https://docs.celeryq.dev/en/stable/_modules/celery/result.html#AsyncResult
raise task.result
response = {"id": task_id, "status": task.status, "result": task.result}
return JSONResponse(
jsonable_encoder(
response,
exclude_unset=True,
custom_encoder={bytes: custom_jsonable_encoder},
)
)
except Exception as e:
log_error(e)
return JSONResponse(
{
"id": task_id,
"status": STATUS_FAILURE,
"result": {"error": str(e)},
}
)

125
app/web/routers/url.py Normal file
View File

@@ -0,0 +1,125 @@
from datetime import datetime
from http import HTTPStatus
from urllib.parse import urlparse
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse
from loguru import logger
from sqlalchemy.orm import Session
from app.shared import schemas
from app.shared.db.database import get_db_dependency
from app.shared.schemas import DeleteResponse
from app.shared.task_messaging import get_celery
from app.web.config import ALLOW_ANY_EMAIL
from app.web.db import crud
from app.web.db.user_state import UserState
from app.web.security import get_token_or_user_auth, get_user_state
from app.web.utils.misc import convert_priority_to_queue_dict
router = APIRouter(prefix="/url", tags=["Single URL operations"])
celery = get_celery()
@router.post(
"/archive",
status_code=HTTPStatus.CREATED,
summary="Submit a single URL archive request, starts an archiving task.",
response_description="task_id for the archiving task, will match the archive id.",
)
def archive_url(
archive: schemas.ArchiveTrigger,
email=Depends(get_token_or_user_auth),
db: Session = Depends(get_db_dependency),
) -> JSONResponse:
logger.info(
f"new {archive.public=} task for {email=} and {archive.group_id=}: {archive.url}"
)
parsed_url = urlparse(archive.url)
if not all([parsed_url.scheme, parsed_url.netloc]):
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, detail="Invalid URL received."
)
archive_create = schemas.ArchiveCreate(**archive.model_dump())
if email != ALLOW_ANY_EMAIL:
archive_create.author_id = email
user = UserState(db, email)
if archive.group_id and not user.in_group(archive.group_id):
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="User does not have access to this group.",
)
if not user.has_quota_max_monthly_urls(archive.group_id):
raise HTTPException(
status_code=HTTPStatus.TOO_MANY_REQUESTS,
detail="User has reached their monthly URL quota.",
)
if not user.has_quota_max_monthly_mbs(archive.group_id):
raise HTTPException(
status_code=HTTPStatus.TOO_MANY_REQUESTS,
detail="User has reached their monthly MB quota.",
)
group_queue = user.priority_group(archive_create.group_id)
else:
archive_create.author_id = archive.author_id or email
group_queue = convert_priority_to_queue_dict("high")
task = celery.signature(
"create_archive_task", args=[archive_create.model_dump_json()]
).apply_async(**group_queue)
task_response = schemas.Task(id=task.id)
return JSONResponse(
task_response.model_dump(), status_code=HTTPStatus.CREATED
)
@router.get("/search", summary="Search for archive entries by URL.")
def search_by_url(
url: str,
skip: int = 0,
limit: int = 25,
archived_after: datetime = None,
archived_before: datetime = None,
db: Session = Depends(get_db_dependency),
email: str = Depends(get_token_or_user_auth),
) -> list[schemas.ArchiveResult]:
read_groups, read_public = False, False
if email != ALLOW_ANY_EMAIL:
user = UserState(db, email)
if not user.read and not user.read_public:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="User does not have read access.",
)
read_groups = user.read
read_public = user.read_public
return crud.search_archives_by_url(
db,
url.strip(),
email,
read_groups,
read_public,
skip=skip,
limit=limit,
archived_after=archived_after,
archived_before=archived_before,
)
@router.delete("/{archive_id}", summary="Delete a single URL archive by id.")
def delete_archive(
archive_id: str,
user: UserState = Depends(get_user_state),
db: Session = Depends(get_db_dependency),
) -> DeleteResponse:
logger.info(
f"deleting url archive task {archive_id} request by {user.email}"
)
return DeleteResponse(
id=archive_id,
deleted=crud.soft_delete_archive(db, archive_id, user.email),
)

View File

@@ -1,27 +1,32 @@
from loguru import logger import secrets
import requests, secrets from http import HTTPStatus
from fastapi import HTTPException, status, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.orm import Session
import firebase_admin import firebase_admin
from firebase_admin import credentials, auth, exceptions import requests
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from firebase_admin import auth, credentials, exceptions
from loguru import logger
from sqlalchemy.orm import Session
from app.web.config import ALLOW_ANY_EMAIL
from app.shared.settings import get_settings
from app.shared.db.database import get_db_dependency from app.shared.db.database import get_db_dependency
from app.shared.settings import get_settings
from app.web.config import ALLOW_ANY_EMAIL
from app.web.db.user_state import UserState from app.web.db.user_state import UserState
settings = get_settings() settings = get_settings()
bearer_security = HTTPBearer() bearer_security = HTTPBearer()
FIREBASE_OAUTH_ENABLED = settings.FIREBASE_SERVICE_ACCOUNT_JSON != "" FIREBASE_OAUTH_ENABLED = settings.FIREBASE_SERVICE_ACCOUNT_JSON != ""
if FIREBASE_OAUTH_ENABLED: if FIREBASE_OAUTH_ENABLED:
logger.debug("Firebase OAUTH enabled, initializing...") logger.debug("Firebase OAUTH enabled, initializing...")
firebase_admin.initialize_app(credentials.Certificate(settings.FIREBASE_SERVICE_ACCOUNT_JSON)) firebase_admin.initialize_app(
credentials.Certificate(settings.FIREBASE_SERVICE_ACCOUNT_JSON)
)
def secure_compare(token, api_key): def secure_compare(token, api_key) -> bool:
return secrets.compare_digest(token.encode("utf8"), api_key.encode("utf8")) return secrets.compare_digest(token.encode("utf8"), api_key.encode("utf8"))
@@ -29,9 +34,13 @@ def secure_compare(token, api_key):
def api_key_auth(api_key): def api_key_auth(api_key):
assert len(api_key) >= 20, "Invalid API key, must be at least 20 chars" assert len(api_key) >= 20, "Invalid API key, must be at least 20 chars"
async def auth(bearer: HTTPAuthorizationCredentials = Depends(bearer_security), auto_error=True): async def auth(
bearer: HTTPAuthorizationCredentials = Depends(bearer_security),
auto_error=True,
):
is_correct = secure_compare(bearer.credentials, api_key) is_correct = secure_compare(bearer.credentials, api_key)
if is_correct: return True if is_correct:
return True
if auto_error: if auto_error:
raise HTTPException( raise HTTPException(
@@ -43,18 +52,23 @@ def api_key_auth(api_key):
return auth return auth
# --------------------- Token Auth for AA itself to query the API, AA setup tool and Prometheus # --- Token Auth for AA itself to query the API, AA setup tool and Prometheus
token_api_key_auth = api_key_auth(settings.API_BEARER_TOKEN) token_api_key_auth = api_key_auth(settings.API_BEARER_TOKEN)
async def get_token_or_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)): async def get_token_or_user_auth(
credentials: HTTPAuthorizationCredentials = Depends(bearer_security),
):
# tries to use the static API_KEY and defaults to google JWT auth # tries to use the static API_KEY and defaults to google JWT auth
if await token_api_key_auth(credentials, auto_error=False): return ALLOW_ANY_EMAIL if await token_api_key_auth(credentials, auto_error=False):
return ALLOW_ANY_EMAIL
return await get_user_auth(credentials) return await get_user_auth(credentials)
async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)): async def get_user_auth(
# validates the Bearer token in the case that it requires it credentials: HTTPAuthorizationCredentials = Depends(bearer_security),
):
# Validates the Bearer token in the case that it requires it
valid_user, info = authenticate_user(credentials.credentials) valid_user, info = authenticate_user(credentials.credentials)
if valid_user: if valid_user:
return info.lower() return info.lower()
@@ -66,39 +80,55 @@ async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bear
) )
def authenticate_user(access_token): def authenticate_user(access_token) -> (bool, str):
if FIREBASE_OAUTH_ENABLED: if FIREBASE_OAUTH_ENABLED:
try: try:
j = auth.verify_id_token(access_token) return firebase_login_attempt(access_token)
email = j.get('email', None) except exceptions.FirebaseError:
logger.debug(f"Successfully verified the ID token for {email}") # used a non-Firebase token, fallback to Google OAuth
if email is None: pass
return False, "email not found in token"
if email in settings.BLOCKED_EMAILS:
return False, f"email '{email}' not allowed"
return True, email
except exceptions.FirebaseError as e:
logger.warning(f"Error verifying ID token: {str(e)[:80]}...")
# https://cloud.google.com/docs/authentication/token-types#access # https://cloud.google.com/docs/authentication/token-types#access
if type(access_token) != str or len(access_token) < 10: return False, "invalid access_token" if not isinstance(access_token, str) or len(access_token) < 10:
r = requests.get("https://oauth2.googleapis.com/tokeninfo", {"access_token": access_token}) return False, "invalid access_token"
if r.status_code != 200: return False, "invalid token" r = requests.get(
"https://oauth2.googleapis.com/tokeninfo",
{"access_token": access_token},
)
if r.status_code != HTTPStatus.OK:
return False, "invalid token"
try: try:
j = r.json() j = r.json()
if j.get("azp") not in settings.CHROME_APP_IDS and j.get("aud") not in settings.CHROME_APP_IDS: if (
return False, f"token does not belong to valid APP_ID" j.get("azp") not in settings.CHROME_APP_IDS
and j.get("aud") not in settings.CHROME_APP_IDS
):
return False, "token does not belong to valid APP_ID"
if j.get("email") in settings.BLOCKED_EMAILS: if j.get("email") in settings.BLOCKED_EMAILS:
return False, f"email '{j.get('email')}' not allowed" return False, f"email '{j.get('email')}' not allowed"
if j.get("email_verified") != "true": if j.get("email_verified") != "true":
return False, f"email '{j.get('email')}' not verified" return False, f"email '{j.get('email')}' not verified"
if int(j.get("expires_in", -1)) <= 0: if int(j.get("expires_in", -1)) <= 0:
return False, "Token expired" return False, "Token expired"
return True, j.get('email').lower() return True, j.get("email").lower()
except Exception as e: except Exception as e:
logger.warning(f"AUTH EXCEPTION occurred: {e}") logger.warning(f"AUTH EXCEPTION occurred: {e}")
return False, "exception occurred" return False, "exception occurred"
def get_user_state(email: str = Depends(get_user_auth), db: Session = Depends(get_db_dependency)): def firebase_login_attempt(access_token) -> (bool, str):
j = auth.verify_id_token(access_token)
email = j.get("email", None)
logger.debug(f"Successfully verified the ID token for {email}")
if email is None:
return False, "email not found in token"
if email in settings.BLOCKED_EMAILS:
return False, f"email '{email}' not allowed"
return True, email
def get_user_state(
email: str = Depends(get_user_auth),
db: Session = Depends(get_db_dependency),
) -> UserState:
return UserState(db, email) return UserState(db, email)

View File

@@ -2,52 +2,57 @@ import asyncio
import json import json
import os import os
import shutil import shutil
from prometheus_client import Counter, Gauge from prometheus_client import Counter, Gauge
from app.web.db import crud
from app.shared.db.database import get_db from app.shared.db.database import get_db
from app.shared.log import log_error from app.shared.log import log_error
from app.shared.task_messaging import get_redis from app.shared.task_messaging import get_redis
from app.web.db import crud
# Custom metrics # Custom metrics
EXCEPTION_COUNTER = Counter( EXCEPTION_COUNTER = Counter(
"exceptions", "exceptions",
"Number of times a certain exception has occurred.", "Number of times a certain exception has occurred.",
labelnames=["type", "location"] labelnames=["type", "location"],
) )
WORKER_EXCEPTION = Counter( WORKER_EXCEPTION = Counter(
"worker_exceptions_total", "worker_exceptions_total",
"Number of times a certain exception has occurred on the worker.", "Number of times a certain exception has occurred on the worker.",
labelnames=["type", "exception", "task", "traceback"] labelnames=["type", "exception", "task", "traceback"],
) )
DISK_UTILIZATION = Gauge( DISK_UTILIZATION = Gauge(
"disk_utilization", "disk_utilization", "Disk utilization in GB", labelnames=["type"]
"Disk utilization in GB",
labelnames=["type"]
) )
DATABASE_METRICS = Gauge( DATABASE_METRICS = Gauge(
"database_metrics", "database_metrics",
"Database metric readings at a certain point in time", "Database metric readings at a certain point in time",
labelnames=["query"] labelnames=["query"],
) )
DATABASE_METRICS_COUNTER = Counter( DATABASE_METRICS_COUNTER = Counter(
"database_metrics_counter", "database_metrics_counter",
"Database metrics that increase over time", "Database metrics that increase over time",
labelnames=["query", "user"] labelnames=["query", "user"],
) )
async def redis_subscribe_worker_exceptions(REDIS_EXCEPTIONS_CHANNEL: str): async def redis_subscribe_worker_exceptions(redis_exceptions_channel: str):
# Subscribe to Redis channel and increment the counter for each exception with info on the exception and task # Subscribe to Redis channel and increment the counter for each exception
# with info on the exception and task
Redis = get_redis() Redis = get_redis()
PubSubExceptions = Redis.pubsub() PubSubExceptions = Redis.pubsub()
PubSubExceptions.subscribe(REDIS_EXCEPTIONS_CHANNEL) PubSubExceptions.subscribe(redis_exceptions_channel)
while True: while True:
message = PubSubExceptions.get_message() message = PubSubExceptions.get_message()
if message and message["type"] == "message": if message and message["type"] == "message":
data = json.loads(message["data"].decode("utf-8")) data = json.loads(message["data"].decode("utf-8"))
WORKER_EXCEPTION.labels(type=data["type"], exception=data["exception"], task=data["task"], traceback=data["traceback"]).inc() WORKER_EXCEPTION.labels(
type=data["type"],
exception=data["exception"],
task=data["task"],
traceback=data["traceback"],
).inc()
await asyncio.sleep(1) await asyncio.sleep(1)
@@ -58,12 +63,19 @@ async def measure_regular_metrics(sqlite_db_url: str, repeat_in_seconds: int):
try: try:
fs = os.stat(sqlite_db_url.replace("sqlite:///", "")) fs = os.stat(sqlite_db_url.replace("sqlite:///", ""))
DISK_UTILIZATION.labels(type="database").set(fs.st_size / (2**30)) DISK_UTILIZATION.labels(type="database").set(fs.st_size / (2**30))
except Exception as e: log_error(e) except Exception as e:
log_error(e)
with get_db() as db: with get_db() as db:
DATABASE_METRICS.labels(query="count_archives").set(crud.count_archives(db)) DATABASE_METRICS.labels(query="count_archives").set(
DATABASE_METRICS.labels(query="count_archive_urls").set(crud.count_archive_urls(db)) crud.count_archives(db)
)
DATABASE_METRICS.labels(query="count_archive_urls").set(
crud.count_archive_urls(db)
)
DATABASE_METRICS.labels(query="count_users").set(crud.count_users(db)) DATABASE_METRICS.labels(query="count_users").set(crud.count_users(db))
for user in crud.count_by_user_since(db, repeat_in_seconds): for user in crud.count_by_user_since(db, repeat_in_seconds):
DATABASE_METRICS_COUNTER.labels(query="count_by_user", user=user.author_id).inc(user.total) DATABASE_METRICS_COUNTER.labels(
query="count_by_user", user=user.author_id
).inc(user.total)

View File

@@ -1,15 +1,62 @@
import base64 import base64
from typing import List
from auto_archiver.core import Media, Metadata
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from loguru import logger
from app.shared.db import models
def custom_jsonable_encoder(obj): def custom_jsonable_encoder(obj):
if isinstance(obj, bytes): if isinstance(obj, bytes):
return base64.b64encode(obj).decode('utf-8') return base64.b64encode(obj).decode("utf-8")
return jsonable_encoder(obj) return jsonable_encoder(obj)
def convert_priority_to_queue_dict(priority: str) -> dict: def convert_priority_to_queue_dict(priority: str) -> dict:
return { return {
"priority": 0 if priority == "high" else 10, "priority": 0 if priority == "high" else 10,
"queue": f"{priority}_priority" "queue": f"{priority}_priority",
} }
def convert_if_media(media):
if isinstance(media, Media):
return media
elif isinstance(media, dict):
try:
return Media.from_dict(media)
except Exception as e:
logger.debug(f"error parsing {media} : {e}")
return False
def get_all_urls(result: Metadata) -> List[models.ArchiveUrl]:
db_urls = []
for m in result.media:
for i, url in enumerate(m.urls):
db_urls.append(
models.ArchiveUrl(url=url, key=m.get("id", f"media_{i}"))
)
for k, prop in m.properties.items():
if prop_converted := convert_if_media(prop):
for i, url in enumerate(prop_converted.urls):
db_urls.append(
models.ArchiveUrl(
url=url, key=prop_converted.get("id", f"{k}_{i}")
)
)
if isinstance(prop, list):
for i, prop_media in enumerate(prop):
if prop_media := convert_if_media(prop_media):
for j, url in enumerate(prop_media.urls):
db_urls.append(
models.ArchiveUrl(
url=url,
key=prop_media.get(
"id", f"{k}{prop_media.key}_{i}.{j}"
),
)
)
return db_urls

View File

@@ -1,21 +1,22 @@
import datetime
import json import json
import traceback
import traceback, datetime from auto_archiver.core.orchestrator import ArchivingOrchestrator
from celery.signals import task_failure from celery.signals import task_failure
from loguru import logger from loguru import logger
from sqlalchemy import exc from sqlalchemy import exc
from auto_archiver.core.orchestrator import ArchivingOrchestrator
from app.shared.db import models from app.shared import business_logic, constants, schemas
from app.shared.db import models, worker_crud
from app.shared.db.database import get_db from app.shared.db.database import get_db
from app.shared import business_logic, schemas
from app.shared.task_messaging import get_celery, get_redis
from app.shared.settings import get_settings
from app.shared.log import log_error from app.shared.log import log_error
from app.shared.aa_utils import get_all_urls from app.shared.settings import get_settings
from app.shared.db import worker_crud from app.shared.task_messaging import get_celery, get_redis
from app.web.utils.misc import get_all_urls
from app.worker.worker_log import setup_celery_logger from app.worker.worker_log import setup_celery_logger
settings = get_settings() settings = get_settings()
celery = get_celery("worker") celery = get_celery("worker")
@@ -24,26 +25,36 @@ Redis = get_redis()
USER_GROUPS_FILENAME = settings.USER_GROUPS_FILENAME USER_GROUPS_FILENAME = settings.USER_GROUPS_FILENAME
setup_celery_logger(celery) setup_celery_logger(celery)
AA_LOGGER_ID = None
# TODO: these are temporary PATCHES for new aa's functionality
# logger.add("app/worker/worker_log.log", level="DEBUG")
logger.remove = lambda x: print(f"logger.remove({x})")
# TODO: after release, as it requires updating past entries with sheet_id where tag is used, drop tags # TODO: after release, as it requires updating past entries with sheet_id where tag
@celery.task(name="create_archive_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 1}) # is used, drop tags
@celery.task(
name="create_archive_task",
bind=True,
autoretry_for=(Exception,),
retry_backoff=True,
retry_kwargs={"max_retries": 1},
)
def create_archive_task(self, archive_json: str): def create_archive_task(self, archive_json: str):
global AA_LOGGER_ID
archive = schemas.ArchiveCreate.model_validate_json(archive_json) archive = schemas.ArchiveCreate.model_validate_json(archive_json)
# call auto-archiver # call auto-archiver
args = get_orchestrator_args(archive.group_id, False, [archive.url]) args = get_orchestrator_args(archive.group_id, False, [archive.url])
result = None
try: try:
orchestrator = ArchivingOrchestrator() orchestrator = ArchivingOrchestrator()
orchestrator.logger_id = AA_LOGGER_ID # ensure single logger
orchestrator.setup(args) orchestrator.setup(args)
result = next(orchestrator.feed()) AA_LOGGER_ID = orchestrator.logger_id
for orch_res in orchestrator.feed():
result = orch_res
except SystemExit as e: except SystemExit as e:
log_error(e, f"create_archive_task: SystemExit from AA") log_error(e, "create_archive_task: SystemExit from AA")
except Exception as e: except Exception as e:
log_error(e, f"create_archive_task") log_error(e, "create_archive_task")
raise e raise e
assert result, f"UNABLE TO archive: {archive.url}" assert result, f"UNABLE TO archive: {archive.url}"
@@ -59,13 +70,20 @@ def create_archive_task(self, archive_json: str):
@celery.task(name="create_sheet_task", bind=True) @celery.task(name="create_sheet_task", bind=True)
def create_sheet_task(self, sheet_json: str): def create_sheet_task(self, sheet_json: str):
global AA_LOGGER_ID
sheet = schemas.SubmitSheet.model_validate_json(sheet_json) sheet = schemas.SubmitSheet.model_validate_json(sheet_json)
queue_name = (create_sheet_task.request.delivery_info or {}).get('routing_key', 'unknown') queue_name = (create_sheet_task.request.delivery_info or {}).get(
"routing_key", "unknown"
)
logger.info(f"[queue={queue_name}] SHEET START {sheet=}") logger.info(f"[queue={queue_name}] SHEET START {sheet=}")
args = get_orchestrator_args(sheet.group_id, True, ["--gsheet_feeder.sheet_id", sheet.sheet_id]) args = get_orchestrator_args(
sheet.group_id, True, [constants.SHEET_ID, sheet.sheet_id]
)
orchestrator = ArchivingOrchestrator() orchestrator = ArchivingOrchestrator()
orchestrator.logger_id = AA_LOGGER_ID # ensure single logger
orchestrator.setup(args) orchestrator.setup(args)
AA_LOGGER_ID = orchestrator.logger_id
stats = {"archived": 0, "failed": 0, "errors": []} stats = {"archived": 0, "failed": 0, "errors": []}
try: try:
@@ -81,7 +99,7 @@ def create_sheet_task(self, sheet_json: str):
result=json.loads(result.to_json()), result=json.loads(result.to_json()),
sheet_id=sheet.sheet_id, sheet_id=sheet.sheet_id,
urls=get_all_urls(result), urls=get_all_urls(result),
store_until=get_store_until(sheet.group_id) store_until=get_store_until(sheet.group_id),
) )
insert_result_into_db(archive) insert_result_into_db(archive)
stats["archived"] += 1 stats["archived"] += 1
@@ -94,25 +112,38 @@ def create_sheet_task(self, sheet_json: str):
stats["errors"].append(str(e)) stats["errors"].append(str(e))
except SystemExit as e: except SystemExit as e:
log_error(e, f"create_sheet_task: SystemExit from AA") log_error(e, "create_sheet_task: SystemExit from AA")
if stats["archived"] > 0: if stats["archived"] > 0:
with get_db() as session: with get_db() as session:
worker_crud.update_sheet_last_url_archived_at(session, sheet.sheet_id) worker_crud.update_sheet_last_url_archived_at(
session, sheet.sheet_id
)
logger.info(f"SHEET DONE {sheet=}") logger.info(f"SHEET DONE {sheet=}")
# TODO: is this used anywhere? maybe drop it # TODO: is this used anywhere? maybe drop it
return schemas.CelerySheetTask(success=True, sheet_id=sheet.sheet_id, time=datetime.datetime.now().isoformat(), stats=stats).model_dump() return schemas.CelerySheetTask(
success=True,
sheet_id=sheet.sheet_id,
time=datetime.datetime.now().isoformat(),
stats=stats,
).model_dump()
def get_orchestrator_args(group_id: str, orchestrator_for_sheet: bool, cli_args: list = []) -> list: def get_orchestrator_args(
group_id: str, orchestrator_for_sheet: bool, cli_args: list = None
) -> list:
cli_args.append("--logging.enabled=false")
aa_configs = [] aa_configs = []
with get_db() as session: with get_db() as session:
group = worker_crud.get_group(session, group_id) group = worker_crud.get_group(session, group_id)
if orchestrator_for_sheet: if orchestrator_for_sheet:
orchestrator_fn = group.orchestrator_sheet orchestrator_fn = group.orchestrator_sheet
else: else:
orchestrator_fn = worker_crud.get_group(session, group_id).orchestrator orchestrator_fn = worker_crud.get_group(
session, group_id
).orchestrator
assert orchestrator_fn, f"no orchestrator found for {group_id}" assert orchestrator_fn, f"no orchestrator found for {group_id}"
aa_configs.extend(["--config", orchestrator_fn]) aa_configs.extend(["--config", orchestrator_fn])
aa_configs.extend(cli_args) aa_configs.extend(cli_args)
@@ -122,7 +153,9 @@ def get_orchestrator_args(group_id: str, orchestrator_for_sheet: bool, cli_args:
def insert_result_into_db(archive: schemas.ArchiveCreate) -> str: def insert_result_into_db(archive: schemas.ArchiveCreate) -> str:
with get_db() as session: with get_db() as session:
db_archive = worker_crud.store_archived_url(session, archive) db_archive = worker_crud.store_archived_url(session, archive)
logger.debug(f"[ARCHIVE STORED] {db_archive.author_id} {db_archive.url}") logger.debug(
f"[ARCHIVE STORED] {db_archive.author_id} {db_archive.url}"
)
return db_archive.id return db_archive.id
@@ -131,13 +164,22 @@ def get_store_until(group_id: str) -> datetime.datetime:
return business_logic.get_store_archive_until(session, group_id) return business_logic.get_store_archive_until(session, group_id)
def redis_publish_exception(exception, task_name, traceback: str = ""): def redis_publish_exception(exception, task_name, trace_back: str = ""):
REDIS_EXCEPTIONS_CHANNEL = settings.REDIS_EXCEPTIONS_CHANNEL REDIS_EXCEPTIONS_CHANNEL = settings.REDIS_EXCEPTIONS_CHANNEL
try: try:
exception_data = {"task": task_name, "type": exception.__class__.__name__, "exception": exception, "traceback": traceback} exception_data = {
Redis.publish(REDIS_EXCEPTIONS_CHANNEL, json.dumps(exception_data, default=str)) "task": task_name,
"type": exception.__class__.__name__,
"exception": exception,
"traceback": trace_back,
}
Redis.publish(
REDIS_EXCEPTIONS_CHANNEL, json.dumps(exception_data, default=str)
)
except Exception as e: except Exception as e:
log_error(e, f"[CRITICAL] Could not publish to {REDIS_EXCEPTIONS_CHANNEL}") log_error(
e, f"[CRITICAL] Could not publish to {REDIS_EXCEPTIONS_CHANNEL}"
)
@task_failure.connect(sender=create_sheet_task) @task_failure.connect(sender=create_sheet_task)
@@ -145,6 +187,10 @@ def redis_publish_exception(exception, task_name, traceback: str = ""):
def task_failure_notifier(sender, **kwargs): def task_failure_notifier(sender, **kwargs):
# automatically capture exceptions in the worker tasks # automatically capture exceptions in the worker tasks
logger.warning(f"⚠️ worker task failed: {sender.name}") logger.warning(f"⚠️ worker task failed: {sender.name}")
traceback_msg = "\n".join(traceback.format_list(traceback.extract_tb(kwargs['traceback']))) traceback_msg = "\n".join(
log_error(kwargs['exception'], traceback_msg, f"task_failure: {sender.name}") traceback.format_list(traceback.extract_tb(kwargs["traceback"]))
redis_publish_exception(kwargs['exception'], sender.name, traceback_msg) )
log_error(
kwargs["exception"], traceback_msg, f"task_failure: {sender.name}"
)
redis_publish_exception(kwargs["exception"], sender.name, traceback_msg)

View File

@@ -1,29 +1,38 @@
from loguru import logger
from celery import Celery
import sys import sys
from loguru import logger
from app.shared.task_messaging import get_celery from app.shared.task_messaging import get_celery
celery = get_celery("worker") celery = get_celery("worker")
def setup_celery_logger(celery):
# Remove Celery's default handlers to prevent duplicate logs
celery_logger = celery.log.get_default_logger()
for handler in celery_logger.handlers[:]:
celery_logger.removeHandler(handler)
# Set up Loguru logging def setup_celery_logger(c):
logger.add("logs/celery_logs.log", retention="30 days", level="DEBUG") # Remove Celery's default handlers to prevent duplicate logs
logger.add("logs/celery_error_logs.log", retention="30 days", level="ERROR") celery_logger = c.log.get_default_logger()
for handler in celery_logger.handlers[:]:
celery_logger.removeHandler(handler)
# Redirect Celery logs to Loguru # Set up Loguru logging
class InterceptHandler: logger.add("logs/celery_logs.log", retention="30 days", level="DEBUG")
def write(self, message): logger.add("logs/celery_error_logs.log", retention="30 days", level="ERROR")
if message.strip():
logger.info(message.strip())
# Required to prevent issues with buffered output
def flush(self): pass
def isatty(self): return False
sys.stdout = InterceptHandler() # Redirect Celery logs to Loguru
sys.stderr = InterceptHandler() class InterceptHandler:
@staticmethod
def write(message):
if message.strip():
logger.info(message.strip())
# Required to prevent issues with buffered output
@staticmethod
def flush():
pass
@staticmethod
def isatty():
return False
sys.stdout = InterceptHandler()
sys.stderr = InterceptHandler()

View File

@@ -7,7 +7,7 @@ services:
web: web:
build: build:
context: . context: .
dockerfile: web.Dockerfile dockerfile: docker/web/Dockerfile
restart: always restart: always
env_file: .env.prod env_file: .env.prod
environment: environment:
@@ -31,7 +31,7 @@ services:
worker: worker:
build: build:
context: . context: .
dockerfile: worker.Dockerfile dockerfile: docker/worker/Dockerfile
restart: always restart: always
env_file: .env.prod env_file: .env.prod
command: celery --app=app.worker.main.celery worker -Q high_priority,low_priority --concurrency=${CONCURRENCY} --max-tasks-per-child=100 -O fair command: celery --app=app.worker.main.celery worker -Q high_priority,low_priority --concurrency=${CONCURRENCY} --max-tasks-per-child=100 -O fair

View File

@@ -10,13 +10,12 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
RUN pip install --no-cache-dir poetry RUN pip install --no-cache-dir poetry
COPY pyproject.toml poetry.lock README.md . COPY ../../pyproject.toml ../../poetry.lock ../../README.md ./
RUN poetry install --with web --no-interaction --no-ansi --no-cache RUN poetry install --with web --no-interaction --no-ansi --no-cache
# Copy the application code and configurations # Copy the application code and configurations
COPY alembic.ini ./ COPY ../../app ./app/
COPY ./app/ ./app/ COPY ../../user-groups.* ./app/
COPY user-groups.* ./app/
# Run the FastAPI app with Uvicorn # Run the FastAPI app with Uvicorn
ENTRYPOINT ["poetry", "run"] ENTRYPOINT ["poetry", "run"]

View File

@@ -20,14 +20,13 @@ RUN apt update -y && \
python3 -m venv ./poetry-venv && \ python3 -m venv ./poetry-venv && \
./poetry-venv/bin/python -m pip install --upgrade pip && \ ./poetry-venv/bin/python -m pip install --upgrade pip && \
./poetry-venv/bin/python -m pip install "poetry>=2.0.0,<3.0.0" ./poetry-venv/bin/python -m pip install "poetry>=2.0.0,<3.0.0"
COPY pyproject.toml poetry.lock ./ COPY ../../pyproject.toml ../../poetry.lock ./
RUN ./poetry-venv/bin/poetry install --without dev --no-root --no-cache RUN ./poetry-venv/bin/poetry install --without dev --no-root --no-cache
# install dependencies # install dependencies
# copy source code and .env files over # copy source code and .env files over
COPY alembic.ini ./ COPY ../../app ./app/
COPY ./app/ ./app/ COPY ../../user-groups.* ./app/
COPY user-groups.* ./app/
ENTRYPOINT ["./poetry-venv/bin/poetry", "run"] ENTRYPOINT ["./poetry-venv/bin/poetry", "run"]

1669
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -22,7 +22,6 @@ requires-python = ">=3.10,<3.13"
dependencies = [ dependencies = [
"auto-archiver (>=0.13.1)", "auto-archiver (>=0.13.1)",
"oscrypto @ git+https://github.com/wbond/oscrypto.git@d5f3437ed24257895ae1edd9e503cfb352e635a8",
"celery (>=5.0)", "celery (>=5.0)",
"redis (==3.5.3)", "redis (==3.5.3)",
"loguru (>=0.7.3,<0.8.0)", "loguru (>=0.7.3,<0.8.0)",
@@ -31,6 +30,16 @@ dependencies = [
"requests (>=2.25.1)", "requests (>=2.25.1)",
"pyopenssl (>=23.3.0)", "pyopenssl (>=23.3.0)",
] ]
[tool.pytest.ini_options]
pythonpath = "."
[tool.coverage.run]
omit = ["app/migrations/*"]
[tool.ruff.lint.flake8-bugbear]
extend-immutable-calls = ["fastapi.Depends", "fastapi.Query"]
[tool.poetry.group.worker.dependencies] [tool.poetry.group.worker.dependencies]
watchdog = ">=6.0.0,<7.0.0" watchdog = ">=6.0.0,<7.0.0"
setuptools = "^75.8.0" setuptools = "^75.8.0"
@@ -53,4 +62,4 @@ pytest = ">=8.3.4,<9.0.0"
httpx = ">=0.28.1,<0.29.0" httpx = ">=0.28.1,<0.29.0"
coverage = ">=7.6.11,<8.0.0" coverage = ">=7.6.11,<8.0.0"
pytest-asyncio = ">=0.25.3,<0.26.0" pytest-asyncio = ">=0.25.3,<0.26.0"
pre-commit = "^4.1.0"

View File

@@ -59,4 +59,3 @@ groups:
permissions: permissions:
read: ["default"] read: ["default"]
read_public: true read_public: true