Merge pull request #72 from bellingcat/dev

Community contributions, code standardization, and AA v1.0.0
This commit is contained in:
Miguel Sozinho Ramalho
2025-04-03 19:51:56 +01:00
committed by GitHub
88 changed files with 5245 additions and 2550 deletions

View File

@@ -1,3 +0,0 @@
[run]
omit =
app/migrations/*

View File

@@ -2,4 +2,4 @@ CHROME_APP_IDS='["1234567890"]'
ALLOWED_ORIGINS='["allowed"]'
BLOCKED_EMAILS='[]'
DATABASE_PATH="sqlite:///./database/auto-archiver.db"
API_BEARER_TOKEN=THIS_API_TOKEN_SHOULD_NEVER_BE_USED
API_BEARER_TOKEN=THIS_API_TOKEN_SHOULD_NEVER_BE_USED

View File

@@ -35,4 +35,4 @@ MAIL_SSL_TLS=True
# celery workers config
CONCURRENCY=2
CONCURRENCY=2

View File

@@ -5,4 +5,4 @@ BLOCKED_EMAILS='["blocked@example.com"]'
DATABASE_PATH="sqlite:///auto-archiver.test.db"
API_BEARER_TOKEN=this_is_the_test_api_token
USER_GROUPS_FILENAME=app/tests/user-groups.test.yaml
USER_GROUPS_FILENAME=app/tests/user-groups.test.yaml

23
.github/pull_request_template.md vendored Normal file
View File

@@ -0,0 +1,23 @@
<!---
Please write your PR name in the present imperative tense. Examples of that tense are:
"Fix issue in the dispatcher where…", "Improve our handling of…", etc.
For more information on Pull Requests, you can reference here:
https://success.vanillaforums.com/kb/articles/228-using-pull-requests-to-contribute
-->
## Describe your changes
## Non-obvious technical information
## Checklist before requesting a review
<!---
These are suggested things you could add, but what you add will be dependent on
your repository's standards.
--->
- [ ] The code runs successfully.
```commandline
HERE IS SOME COMMAND LINE OUTPUT
```

View File

@@ -1,16 +1,12 @@
name: CI
on:
push:
branches:
- main
- dev
branches: [ main, dev ]
pull_request:
branches:
- main
- dev
branches: [ main, dev ]
jobs:
test:
test-with-coverage:
runs-on: ubuntu-latest
services:
@@ -41,4 +37,4 @@ jobs:
run: poetry run coverage run -m pytest -v -ra --color=yes app/tests/
- name: Report coverage
run: poetry run coverage report
run: poetry run coverage report

16
.github/workflows/format-and-fail.yml vendored Normal file
View File

@@ -0,0 +1,16 @@
name: Format and Fail
on:
push:
branches: [ main, dev ]
pull_request:
branches: [ main, dev ]
jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: pre-commit/action@v3.0.0

40
.github/workflows/test.yml vendored Normal file
View File

@@ -0,0 +1,40 @@
name: Run Tests
on:
push:
branches: [ main, dev ]
pull_request:
branches: [ main, dev ]
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.10', '3.11', '3.12']
services:
redis:
image: redis:6-alpine
ports:
- 6379:6379
steps:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Checkout code
uses: actions/checkout@v4
- name: Install Poetry
run: pipx install poetry
- name: Install dependencies
run: poetry install --no-interaction --with dev
- name: Set dev environment variable
run: echo "ENVIRONMENT_FILE=.env.test" >> $GITHUB_ENV
- name: Run tests
run: poetry run pytest app/tests

141
.gitignore vendored
View File

@@ -1,16 +1,10 @@
# Misc.
user-groups.dev.yaml
user-groups.yaml
orchestration.yaml
my-archives
*.pyc
.DS_Store
secrets/*
*.log
__pycache__
.pytest_cache
.env
.env.dev
.env.prod
*.db
redis/data/*
.ipynb_checkpoints*
@@ -18,8 +12,6 @@ app/user-groups.yaml
app/user-groups.dev.yaml
wit*
app/crawls
.coverage
.pytest_cache/
htmlcov
local_archive
local_archive_test
@@ -27,6 +19,133 @@ local_archive_test
*db-shm
copy-files.sh
temp/
.python-version
orchestration2.yaml
database
database
# IDE files
.idea
.vscode
**/.DS_Store
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env*
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/

78
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,78 @@
repos:
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.9.1
hooks:
- id: nbqa-ruff
args:
- --fix
- --target-version=py310
- --ignore=E721,E722
- --line-length=80
- id: nbqa-black
args:
- --line-length=80
- id: nbqa-isort
args:
- --float-to-top
- --profile=black
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: check-docstring-first
- id: check-executables-have-shebangs
- id: check-json
- id: check-case-conflict
- id: check-toml
- id: check-merge-conflict
- id: check-xml
- id: check-yaml
exclude: app/tests/user-groups.test.broken.yaml
- id: end-of-file-fixer
- id: check-symlinks
- id: mixed-line-ending
- id: sort-simple-yaml
- id: fix-encoding-pragma
args:
- --remove
- id: pretty-format-json
args:
- --autofix
- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.10.0
hooks:
- id: python-check-blanket-noqa
- id: python-check-mock-methods
- id: python-no-eval
- id: python-no-log-warn
- repo: https://github.com/PyCQA/isort
rev: 6.0.1
hooks:
- id: isort
name: Run isort to sort imports
files: \.py$
# To keep consistent with the global isort skip config defined in setup.cfg
exclude: ^build/.*$|^.tox/.*$|^venv/.*$
args:
- --lines-after-imports=2
- --profile=black
- --line-length=80
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.7
hooks:
- id: ruff
types_or: [python,pyi]
args:
- --fix
- --select=B,C,E,F,W,B9
- --line-length=80
- --ignore=E203,E402,E501,E261
- id: ruff-format
types_or: [ python,pyi]
args:
- --target-version=py310
- --line-length=80

1
CODEOWNERS Normal file
View File

@@ -0,0 +1 @@
* @msramalho

View File

@@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
SOFTWARE.

View File

@@ -1,19 +1,33 @@
.PHONY: lint
lint:
poetry run pre-commit run --all-files
.PHONY: test
test:
export ENVIRONMENT_FILE=.env.test
poetry run coverage run -m pytest -v --disable-warnings --color=yes app/tests/
poetry run coverage report
.PHONY: clean-dev
clean-dev:
@echo -n "Are you sure? [yes/N] (this will delete volumes) " && read ans && [ $${ans:-N} = yes ]
docker compose -f docker-compose.yml -f docker-compose.dev.yml down --volumes --remove-orphans
.PHONY: dev
dev:
docker compose --env-file .env.dev -f docker-compose.yml -f docker-compose.dev.yml build
docker compose --env-file .env.dev -f docker-compose.yml -f docker-compose.dev.yml up --remove-orphans
.PHONY: dev-redis-only
dev-redis-only:
docker compose --env-file .env.dev -f docker-compose.yml -f docker-compose.dev.yml build redis
docker compose --env-file .env.dev -f docker-compose.yml -f docker-compose.dev.yml up --remove-orphans redis
.PHONY: stop-dev
stop-dev:
docker compose -f docker-compose.yml -f docker-compose.dev.yml down --volumes
.PHONY: prod
prod:
docker compose --env-file .env.prod build
docker compose --env-file .env.prod up -d --remove-orphans
@@ -21,5 +35,6 @@ prod:
docker image prune -f
docker system df
.PHONY: stop-prod
stop-prod:
docker compose down
docker compose down

View File

@@ -12,9 +12,9 @@ To properly set up the API you need to install `docker` and to have these files,
2. a `user-groups.yaml` to manage user permissions
1. note that all local files referenced in `user-groups.yaml` and any orchestration.yaml files should be relative to the home directory so if your service account is in `secrets/orchestration.yaml` use that path and not just `orchestration.yaml`.
2. go through the example file and configure it according to your needs.
3. you will need to create and reference at least one `secrets/orchestration.yaml` file, you can do so by following the instructions in the [auto-archiver](https://github.com/bellingcat/auto-archiver#installation) that automatically generates one for you. If you use the archive sheets feature you will need to create a `orchestrationsheets-sheets.yaml` file as well that should have the `gsheet_feeder` and `gsheet_db` enabled and configured, the auto-archiver has [extensive documentation](https://auto-archiver.readthedocs.io/en/latest/) on how to set this up.
3. you will need to create and reference at least one `secrets/orchestration.yaml` file, you can do so by following the instructions in the [auto-archiver](https://github.com/bellingcat/auto-archiver#installation) that automatically generates one for you. If you use the archive sheets feature you will need to create a `orchestrationsheets-sheets.yaml` file as well that should have the `gsheet_feeder_db` feeder and database enabled and configured, the auto-archiver has [extensive documentation](https://auto-archiver.readthedocs.io/en/latest/) on how to set this up.
Do not commit those files, they are .gitignored by default.
Do not commit those files, they are .gitignored by default.
We also advise you to keep any sensitive files in the `secrets/` folder which is pinned and gitignored.
We have examples for both of those files (`.env.example` and `user-groups.example.yaml`), and here's how to set them up whether you're in development or production:
@@ -108,6 +108,27 @@ Make sure environment and user-groups files are up to date.
Then `make prod`.
## Development
```bash
# make sure all development dependencies are installed
poetry install --with dev
# this project uses pre-commit to enforce code style and formatting, set that up locally
poetry run pre-commit install
# you can test pre-commit with
poetry run pre-commit run --all-files
# this means pre-commit will always run with git commit, to skip it use
git commit --no-verify
# see the Makefile for more commands, but linting and formatting can be done with
make lint
# run all tests
make test
```
### Testing
```bash
# set the testing environment variables

View File

@@ -2,7 +2,7 @@
[alembic]
# path to migration scripts
script_location = app/migrations
script_location = ./app/migrations
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
@@ -10,10 +10,6 @@ script_location = app/migrations
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python-dateutil library that can be

View File

@@ -1,19 +1,20 @@
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
from sqlalchemy import engine_from_config, pool
from app.shared.settings import get_settings
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
config.set_main_option('sqlalchemy.url', get_settings().DATABASE_PATH)
config.set_main_option("sqlalchemy.url", get_settings().DATABASE_PATH)
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name, disable_existing_loggers=False) # disable_existing_loggers prevents loguru disabling
# disable_existing_loggers prevents loguru disabling
fileConfig(config.config_file_name, disable_existing_loggers=False)
# add your model's MetaData object here
# for 'autogenerate' support

View File

@@ -5,13 +5,14 @@ Revises: 1636724ec4b1
Create Date: 2025-02-08 15:22:20.392522
"""
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = '02b2f6d17ed0'
down_revision = '1636724ec4b1'
revision = "02b2f6d17ed0"
down_revision = "1636724ec4b1"
branch_labels = None
depends_on = None
STORE_UNTIL_COL = "store_until"
@@ -20,15 +21,20 @@ STORE_UNTIL_COL = "store_until"
def upgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns('archives')]
columns = [col["name"] for col in inspector.get_columns("archives")]
if STORE_UNTIL_COL not in columns:
op.add_column('archives', sa.Column(STORE_UNTIL_COL, sa.DateTime(), nullable=True, default=None))
op.add_column(
"archives",
sa.Column(
STORE_UNTIL_COL, sa.DateTime(), nullable=True, default=None
),
)
def downgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns('archives')]
columns = [col["name"] for col in inspector.get_columns("archives")]
if STORE_UNTIL_COL in columns:
op.drop_column('archives', STORE_UNTIL_COL)
op.drop_column("archives", STORE_UNTIL_COL)

View File

@@ -5,13 +5,14 @@ Revises: a23aaf3ae930
Create Date: 2025-02-05 19:19:01.984396
"""
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = '1636724ec4b1'
down_revision = 'a23aaf3ae930'
revision = "1636724ec4b1"
down_revision = "a23aaf3ae930"
branch_labels = None
depends_on = None
@@ -19,14 +20,18 @@ depends_on = None
def upgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns('sheets')]
if 'last_archived_at' in columns:
op.alter_column('sheets', 'last_archived_at', new_column_name='last_url_archived_at')
columns = [col["name"] for col in inspector.get_columns("sheets")]
if "last_archived_at" in columns:
op.alter_column(
"sheets", "last_archived_at", new_column_name="last_url_archived_at"
)
def downgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns('sheets')]
if 'last_url_archived_at' in columns:
op.alter_column('sheets', 'last_url_archived_at', new_column_name='last_archived_at')
columns = [col["name"] for col in inspector.get_columns("sheets")]
if "last_url_archived_at" in columns:
op.alter_column(
"sheets", "last_url_archived_at", new_column_name="last_archived_at"
)

View File

@@ -5,13 +5,14 @@ Revises: 02b2f6d17ed0
Create Date: 2025-02-11 21:53:23.293274
"""
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = '63ac79df4ad0'
down_revision = '02b2f6d17ed0'
revision = "63ac79df4ad0"
down_revision = "02b2f6d17ed0"
branch_labels = None
depends_on = None
@@ -22,15 +23,17 @@ TABLE = "groups"
def upgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns(TABLE)]
columns = [col["name"] for col in inspector.get_columns(TABLE)]
if NEW_COL not in columns:
op.add_column(TABLE, sa.Column(NEW_COL, sa.String, nullable=True, default=None))
op.add_column(
TABLE, sa.Column(NEW_COL, sa.String, nullable=True, default=None)
)
def downgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns(TABLE)]
columns = [col["name"] for col in inspector.get_columns(TABLE)]
if NEW_COL in columns:
op.drop_column(TABLE, NEW_COL)

View File

@@ -5,14 +5,14 @@ Revises: fa012ec405b8
Create Date: 2024-11-04 11:12:30.237299
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.engine.reflection import Inspector
from alembic import op
# revision identifiers, used by Alembic.
revision = '89121d2c96d8'
down_revision = 'fa012ec405b8'
revision = "89121d2c96d8"
down_revision = "fa012ec405b8"
branch_labels = None
depends_on = None
@@ -20,23 +20,27 @@ depends_on = None
def upgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns('archives')]
columns = [col["name"] for col in inspector.get_columns("archives")]
if 'sheet_id' not in columns:
with op.batch_alter_table('archives') as batch_op:
batch_op.add_column(sa.Column('sheet_id', sa.String(), nullable=True, default=None))
batch_op.create_foreign_key('fk_sheet_id', 'sheets', ['sheet_id'], ['id'])
if "sheet_id" not in columns:
with op.batch_alter_table("archives") as batch_op:
batch_op.add_column(
sa.Column("sheet_id", sa.String(), nullable=True, default=None)
)
batch_op.create_foreign_key(
"fk_sheet_id", "sheets", ["sheet_id"], ["id"]
)
def downgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
foreign_keys = [fk['name'] for fk in inspector.get_foreign_keys('archives')]
columns = [col['name'] for col in inspector.get_columns('archives')]
foreign_keys = [fk["name"] for fk in inspector.get_foreign_keys("archives")]
columns = [col["name"] for col in inspector.get_columns("archives")]
with op.batch_alter_table('archives') as batch_op:
if 'fk_sheet_id' in foreign_keys:
batch_op.drop_constraint('fk_sheet_id', type_='foreignkey')
with op.batch_alter_table("archives") as batch_op:
if "fk_sheet_id" in foreign_keys:
batch_op.drop_constraint("fk_sheet_id", type_="foreignkey")
if 'sheet_id' in columns:
batch_op.drop_column('sheet_id')
if "sheet_id" in columns:
batch_op.drop_column("sheet_id")

View File

@@ -1,14 +1,16 @@
"""modify archive url to have uuid id instead of url unique constraint
Revision ID: 9369a264945b
Revises:
Revises:
Create Date: 2023-12-20 17:24:59.320691
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = '9369a264945b'
revision = "9369a264945b"
down_revision = None
branch_labels = None
depends_on = None
@@ -18,10 +20,11 @@ def upgrade() -> None:
# since the primary key constraint is not named, we have to recreate it first
with op.batch_alter_table("archive_urls") as batch_op:
batch_op.create_primary_key("pk_url", ["url"])
batch_op.drop_constraint("pk_url", type_='primary')
batch_op.drop_constraint("pk_url", type_="primary")
batch_op.create_primary_key("pk_url_archive_id", ["url", "archive_id"])
def downgrade() -> None:
with op.batch_alter_table("archive_urls") as batch_op:
batch_op.drop_constraint("pk_url_archive_id", type_='primary')
batch_op.drop_constraint("pk_url_archive_id", type_="primary")
batch_op.create_primary_key("url", ["url"])

View File

@@ -5,12 +5,13 @@ Revises: 9369a264945b
Create Date: 2023-12-20 18:33:27.132566
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = '93a611e4c066'
down_revision = '9369a264945b'
revision = "93a611e4c066"
down_revision = "9369a264945b"
branch_labels = None
depends_on = None
@@ -20,7 +21,9 @@ def upgrade() -> None:
with op.get_context().autocommit_block():
op.execute("VACUUM")
except Exception as e:
print("Unable to run vacuum, maybe there's not enough disk space. it should be 2x the size of the database")
print(
"Unable to run vacuum, maybe there's not enough disk space. it should be 2x the size of the database"
)
print(e)

View File

@@ -5,13 +5,14 @@ Revises: 89121d2c96d8
Create Date: 2025-02-04 12:19:20.753570
"""
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = 'a23aaf3ae930'
down_revision = '89121d2c96d8'
revision = "a23aaf3ae930"
down_revision = "89121d2c96d8"
branch_labels = None
depends_on = None
@@ -19,16 +20,24 @@ depends_on = None
def upgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns('users')]
columns = [col["name"] for col in inspector.get_columns("users")]
if 'is_active' in columns:
op.drop_column('users', 'is_active')
if "is_active" in columns:
op.drop_column("users", "is_active")
def downgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns('users')]
columns = [col["name"] for col in inspector.get_columns("users")]
if 'is_active' not in columns:
op.add_column('users', sa.Column('is_active', sa.Boolean(), nullable=False, server_default=sa.false()))
if "is_active" not in columns:
op.add_column(
"users",
sa.Column(
"is_active",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
)

View File

@@ -5,14 +5,14 @@ Revises: 93a611e4c066
Create Date: 2024-10-31 09:36:50.360710
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.engine.reflection import Inspector
from alembic import op
# revision identifiers, used by Alembic.
revision = 'fa012ec405b8'
down_revision = '93a611e4c066'
revision = "fa012ec405b8"
down_revision = "93a611e4c066"
branch_labels = None
depends_on = None
@@ -20,26 +20,41 @@ depends_on = None
def upgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns('groups')]
columns = [col["name"] for col in inspector.get_columns("groups")]
if 'description' not in columns:
op.add_column('groups', sa.Column('description', sa.String(), nullable=True))
if 'orchestrator' not in columns:
op.add_column('groups', sa.Column('orchestrator', sa.String(), nullable=True))
if 'orchestrator_sheet' not in columns:
op.add_column('groups', sa.Column('orchestrator_sheet', sa.String(), nullable=True))
if 'permissions' not in columns:
op.add_column('groups', sa.Column('permissions', sa.JSON(), nullable=True))
if 'domains' not in columns:
op.add_column('groups', sa.Column('domains', sa.JSON(), nullable=True))
if "description" not in columns:
op.add_column(
"groups", sa.Column("description", sa.String(), nullable=True)
)
if "orchestrator" not in columns:
op.add_column(
"groups", sa.Column("orchestrator", sa.String(), nullable=True)
)
if "orchestrator_sheet" not in columns:
op.add_column(
"groups",
sa.Column("orchestrator_sheet", sa.String(), nullable=True),
)
if "permissions" not in columns:
op.add_column(
"groups", sa.Column("permissions", sa.JSON(), nullable=True)
)
if "domains" not in columns:
op.add_column("groups", sa.Column("domains", sa.JSON(), nullable=True))
def downgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns('groups')]
columns = [col["name"] for col in inspector.get_columns("groups")]
column_names = ['description', 'orchestrator', 'orchestrator_sheet', 'permissions', 'domains']
column_names = [
"description",
"orchestrator",
"orchestrator_sheet",
"permissions",
"domains",
]
for column_name in column_names:
if column_name in columns:
op.drop_column('groups', column_name)
op.drop_column("groups", column_name)

View File

@@ -1,32 +0,0 @@
# TODO: code in this file should eventually be moved to the auto-archiver code base
from typing import List
from loguru import logger
from auto_archiver.core import Media, Metadata
from app.shared.db import models
def get_all_urls(result: Metadata) -> List[models.ArchiveUrl]:
db_urls = []
for m in result.media:
for i, url in enumerate(m.urls): db_urls.append(models.ArchiveUrl(url=url, key=m.get("id", f"media_{i}")))
for k, prop in m.properties.items():
if prop_converted := convert_if_media(prop):
for i, url in enumerate(prop_converted.urls): db_urls.append(models.ArchiveUrl(url=url, key=prop_converted.get("id", f"{k}_{i}")))
if isinstance(prop, list):
for i, prop_media in enumerate(prop):
if prop_media := convert_if_media(prop_media):
for j, url in enumerate(prop_media.urls):
db_urls.append(models.ArchiveUrl(url=url, key=prop_media.get("id", f"{k}{prop_media.key}_{i}.{j}")))
return db_urls
def convert_if_media(media):
if isinstance(media, Media): return media
elif isinstance(media, dict):
try: return Media.from_dict(media)
except Exception as e:
logger.debug(f"error parsing {media} : {e}")
return False

View File

@@ -1,25 +1,35 @@
# TODO: temporary file for this code, maybe other code belongs here, maybe not. do decide
import datetime
from typing import Union
from sqlalchemy.orm import Session
from app.shared.db import worker_crud
def get_store_archive_until(db: Session, group_id: str) -> datetime.datetime:
# TODO: temporary file for this code, maybe other code belongs here, maybe not. do
# decide
def get_store_archive_until(
db: Session, group_id: str
) -> Union[datetime.datetime, None]:
group = worker_crud.get_group(db, group_id)
assert group, f"Group {group_id} not found."
assert group.permissions and type(group.permissions) == dict, f"Group {group_id} has no permissions."
assert group.permissions and isinstance(group.permissions, dict), (
f"Group {group_id} has no permissions."
)
max_lifespan = group.permissions.get("max_archive_lifespan_months", -1)
if max_lifespan == -1: return None
if max_lifespan == -1:
return None
return datetime.datetime.now() + datetime.timedelta(days=30 * max_lifespan)
def get_store_archive_until_or_never(db: Session, group_id: str) -> datetime.datetime:
def get_store_archive_until_or_never(
db: Session, group_id: str
) -> Union[datetime.datetime, None]:
try:
return get_store_archive_until(db, group_id)
except AssertionError as e:
except AssertionError:
return None

7
app/shared/constants.py Normal file
View File

@@ -0,0 +1,7 @@
# Statuses
STATUS_FAILURE = "FAILURE"
STATUS_PENDING = "PENDING"
STATUS_SUCCESS = "SUCCESS"
# AA CLI CONFIGS
SHEET_ID = "--gsheet_feeder_db.sheet_id"

View File

@@ -1,8 +1,14 @@
from functools import lru_cache
from sqlalchemy import Engine, create_engine, event, text
from sqlalchemy.orm import sessionmaker
from contextlib import asynccontextmanager, contextmanager
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, AsyncEngine, async_sessionmaker
from functools import lru_cache
from sqlalchemy import Engine, create_engine, event, text
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.orm import sessionmaker
from app.shared.settings import get_settings
@@ -12,9 +18,9 @@ def make_engine(database_url: str):
engine = create_engine(
database_url,
connect_args={"check_same_thread": False},
pool_size=15, # Increase pool size
max_overflow=20, # Allow more temporary connections
pool_recycle=1800 # Recycle connections every 30 minutes
pool_size=15, # Increase pool size
max_overflow=20, # Allow more temporary connections
pool_recycle=1800, # Recycle connections every 30 minutes
)
@event.listens_for(engine, "connect")
@@ -34,8 +40,10 @@ def make_session_local(engine: Engine):
@contextmanager
def get_db():
session = make_session_local(make_engine(get_settings().DATABASE_PATH))()
try: yield session
finally: session.close()
try:
yield session
finally:
session.close()
def get_db_dependency():
@@ -53,22 +61,32 @@ def wal_checkpoint():
# ASYNC connections
async def make_async_engine(database_url: str) -> AsyncEngine:
engine = create_async_engine(database_url, connect_args={"check_same_thread": False})
engine = create_async_engine(
database_url, connect_args={"check_same_thread": False}
)
async with engine.begin() as conn:
await conn.run_sync(lambda sync_conn: sync_conn.execute(text("PRAGMA journal_mode=WAL;")))
await conn.run_sync(
lambda sync_conn: sync_conn.execute(
text("PRAGMA journal_mode=WAL;")
)
)
return engine
async def make_async_session_local(engine: AsyncEngine) -> AsyncSession:
return async_sessionmaker(engine, expire_on_commit=False, autoflush=False, autocommit=False)
return async_sessionmaker(
engine, expire_on_commit=False, autoflush=False, autocommit=False
)
@asynccontextmanager
async def get_db_async():
engine = await make_async_engine(get_settings().ASYNC_DATABASE_PATH)
engine = await make_async_engine(get_settings().async_database_path)
async_session = await make_async_session_local(engine)
async with async_session() as session:
try: yield session
finally: await engine.dispose()
try:
yield session
finally:
await engine.dispose()

View File

@@ -1,8 +1,17 @@
from sqlalchemy import Column, String, JSON, DateTime, Boolean, Table, ForeignKey
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship, declarative_base
import uuid
from sqlalchemy import (
JSON,
Boolean,
Column,
DateTime,
ForeignKey,
String,
Table,
)
from sqlalchemy.orm import declarative_base, relationship
from sqlalchemy.sql import func
Base = declarative_base()
@@ -11,7 +20,7 @@ def generate_uuid():
return str(uuid.uuid4())
# many to many association tables
# many-to-many association tables
association_table_archive_tags = Table(
"mtm_archives_tags",
Base.metadata,
@@ -33,7 +42,9 @@ class Archive(Base):
id = Column(String, primary_key=True, index=True)
url = Column(String, index=True)
result = Column(JSON, default=None)
public = Column(Boolean, default=True) # if public=false, access by group and author
public = Column(
Boolean, default=True
) # if public=false, access by group and author
deleted = Column(Boolean, default=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
@@ -43,7 +54,11 @@ class Archive(Base):
author_id = Column(String, ForeignKey("users.email"))
sheet_id = Column(String, ForeignKey("sheets.id"), default=None)
tags = relationship("Tag", back_populates="archives", secondary=association_table_archive_tags)
tags = relationship(
"Tag",
back_populates="archives",
secondary=association_table_archive_tags,
)
group = relationship("Group", back_populates="archives")
author = relationship("User", back_populates="archives")
urls = relationship("ArchiveUrl", back_populates="archive")
@@ -66,7 +81,11 @@ class Tag(Base):
id = Column(String, primary_key=True, index=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
archives = relationship("Archive", back_populates="tags", secondary=association_table_archive_tags)
archives = relationship(
"Archive",
back_populates="tags",
secondary=association_table_archive_tags,
)
class User(Base):
@@ -76,7 +95,9 @@ class User(Base):
archives = relationship("Archive", back_populates="author")
sheets = relationship("Sheet", back_populates="author")
groups = relationship("Group", back_populates="users", secondary=association_table_user_groups)
groups = relationship(
"Group", back_populates="users", secondary=association_table_user_groups
)
class Group(Base):
@@ -92,7 +113,9 @@ class Group(Base):
archives = relationship("Archive", back_populates="group")
sheets = relationship("Sheet", back_populates="group")
users = relationship("User", back_populates="groups", secondary=association_table_user_groups)
users = relationship(
"User", back_populates="groups", secondary=association_table_user_groups
)
class Sheet(Base):
@@ -101,11 +124,27 @@ class Sheet(Base):
id = Column(String, primary_key=True, index=True, doc="Google Sheet ID")
name = Column(String, default=None)
author_id = Column(String, ForeignKey("users.email"))
group_id = Column(String, ForeignKey("groups.id"), doc="Group ID, user must be in a group to create a sheet.")
frequency = Column(String, default="daily", doc="Frequency of archiving: hourly, daily, weekly.")
group_id = Column(
String,
ForeignKey("groups.id"),
doc="Group ID, user must be in a group to create a sheet.",
)
frequency = Column(
String,
default="daily",
doc="Frequency of archiving: hourly, daily, weekly.",
)
# TODO: stats is not being used, consider removing
stats = Column(JSON, default={}, doc="Sheet statistics like total links, total rows, ...")
last_url_archived_at = Column(DateTime(timezone=True), server_default=func.now(), doc="Last time a new link was archived.")
stats = Column(
JSON,
default={},
doc="Sheet statistics like total links, total rows, ...",
)
last_url_archived_at = Column(
DateTime(timezone=True),
server_default=func.now(),
doc="Last time a new link was archived.",
)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())

View File

@@ -1,13 +1,17 @@
from sqlalchemy.orm import Session
from datetime import datetime
from app.shared.db import models
from sqlalchemy.orm import Session
from app.shared import schemas
from app.shared.db import models
# TODO: isolate database operations away from worker and into WEB
# ONLY WORKER
def update_sheet_last_url_archived_at(db: Session, sheet_id: str):
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id).first()
db_sheet = (
db.query(models.Sheet).filter(models.Sheet.id == sheet_id).first()
)
if db_sheet:
db_sheet.last_url_archived_at = datetime.now()
db.commit()
@@ -17,12 +21,17 @@ def update_sheet_last_url_archived_at(db: Session, sheet_id: str):
# ONLY WORKER and INTEROP
def get_group(db: Session, group_name: str) -> models.Group:
return db.query(models.Group).filter(models.Group.id == group_name).first()
def create_or_get_user(db: Session, author_id: str) -> models.User:
if type(author_id) == str: author_id = author_id.lower()
db_user = db.query(models.User).filter(models.User.email == author_id).first()
if isinstance(author_id, str):
author_id = author_id.lower()
db_user = (
db.query(models.User).filter(models.User.email == author_id).first()
)
if not db_user:
db_user = models.User(email=author_id)
db.add(db_user)
@@ -41,8 +50,22 @@ def create_tag(db: Session, tag: str) -> models.Tag:
return db_tag
def create_archive(db: Session, archive: schemas.ArchiveCreate, tags: list[models.Tag], urls: list[models.ArchiveUrl]) -> models.Archive:
db_archive = models.Archive(id=archive.id, url=archive.url, result=archive.result, public=archive.public, author_id=archive.author_id, group_id=archive.group_id, sheet_id=archive.sheet_id, store_until=archive.store_until)
def create_archive(
db: Session,
archive: schemas.ArchiveCreate,
tags: list[models.Tag],
urls: list[models.ArchiveUrl],
) -> models.Archive:
db_archive = models.Archive(
id=archive.id,
url=archive.url,
result=archive.result,
public=archive.public,
author_id=archive.author_id,
group_id=archive.group_id,
sheet_id=archive.sheet_id,
store_until=archive.store_until,
)
db_archive.tags = tags
db_archive.urls = urls
db.add(db_archive)
@@ -51,10 +74,14 @@ def create_archive(db: Session, archive: schemas.ArchiveCreate, tags: list[model
return db_archive
def store_archived_url(db: Session, archive: schemas.ArchiveCreate) -> models.Archive:
def store_archived_url(
db: Session, archive: schemas.ArchiveCreate
) -> models.Archive:
# create and load user, tags, if needed
create_or_get_user(db, archive.author_id)
db_tags = [create_tag(db, tag) for tag in (archive.tags or [])]
# insert everything
db_archive = create_archive(db, archive=archive, tags=db_tags, urls=archive.urls)
db_archive = create_archive(
db, archive=archive, tags=db_tags, urls=archive.urls
)
return db_archive

View File

@@ -1,4 +1,5 @@
import traceback
from loguru import logger
@@ -6,8 +7,10 @@ from loguru import logger
logger.add("logs/api_logs.log", retention="30 days")
logger.add("logs/error_logs.log", retention="30 days", level="ERROR")
def log_error(e: Exception, traceback_str: str = None, extra:str = ""):
if not traceback_str: traceback_str = traceback.format_exc()
if extra: extra = f"{extra}\n"
def log_error(e: Exception, traceback_str: str = None, extra: str = ""):
if not traceback_str:
traceback_str = traceback.format_exc()
if extra:
extra = f"{extra}\n"
logger.error(f"{extra}{e.__class__.__name__}: {e}\n{traceback_str}")

View File

@@ -1,7 +1,8 @@
from datetime import datetime
from typing import Annotated
from annotated_types import Len
from pydantic import BaseModel
from datetime import datetime
class SubmitSheet(BaseModel):
@@ -10,6 +11,7 @@ class SubmitSheet(BaseModel):
group_id: str = "default"
tags: set[str] | None = set()
class ArchiveUrl(BaseModel):
url: str
public: bool = False
@@ -17,6 +19,7 @@ class ArchiveUrl(BaseModel):
group_id: str | None
tags: set[str] | None = set()
class ArchiveResult(BaseModel):
id: str
url: str

View File

@@ -1,31 +1,38 @@
from functools import lru_cache
import os
from functools import lru_cache
from typing import Annotated, Set
from annotated_types import Len
from fastapi_mail import ConnectionConfig
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import Annotated, Set
from annotated_types import Len
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=os.environ.get("ENVIRONMENT_FILE") , env_file_encoding='utf-8', extra='ignore', str_strip_whitespace=True)
model_config = SettingsConfigDict(
env_file=os.environ.get("ENVIRONMENT_FILE"),
env_file_encoding="utf-8",
extra="ignore",
str_strip_whitespace=True,
)
# general
# general
SERVE_LOCAL_ARCHIVE: str | None = None
USER_GROUPS_FILENAME: str = "app/user-groups.yaml"
# database
# database
DATABASE_PATH: str
DATABASE_QUERY_LIMIT: int = 100
@property
def ASYNC_DATABASE_PATH(self) -> str:
def async_database_path(self) -> str:
return self.DATABASE_PATH.replace("sqlite://", "sqlite+aiosqlite://")
# security
# security
API_BEARER_TOKEN: Annotated[str, Len(min_length=20)]
ALLOWED_ORIGINS: Annotated[Set[str], Len(min_length=1)]
CHROME_APP_IDS: Annotated[Set[Annotated[str, Len(min_length=10)]], Len(min_length=1)]
CHROME_APP_IDS: Annotated[
Set[Annotated[str, Len(min_length=10)]], Len(min_length=1)
]
BLOCKED_EMAILS: Annotated[Set[str], Len(min_length=0)] = set()
# if not provided only OAUTH access_tokens are allowed
FIREBASE_SERVICE_ACCOUNT_JSON: str = ""
@@ -34,20 +41,21 @@ class Settings(BaseSettings):
REDIS_PASSWORD: str = ""
REDIS_HOSTNAME: str = "localhost"
REDIS_EXCEPTIONS_CHANNEL: str = "exceptions-channel"
@property
def CELERY_BROKER_URL(self)-> str:
def celery_broker_url(self) -> str:
if self.REDIS_PASSWORD:
return f"redis://:{self.REDIS_PASSWORD}@{self.REDIS_HOSTNAME}:6379"
return f"redis://{self.REDIS_HOSTNAME}:6379"
# cronjobs
CRON_ARCHIVE_SHEETS: bool = False
CRON_DELETE_STALE_SHEETS: bool = False
DELETE_STALE_SHEETS_DAYS: int = 14
CRON_DELETE_SCHEDULED_ARCHIVES: bool = False
DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS: int = 7
# observability
# observability
REPEAT_COUNT_METRICS_SECONDS: int = 30
# email configuration, if needed
@@ -59,8 +67,9 @@ class Settings(BaseSettings):
MAIL_PORT: int = 587
MAIL_STARTTLS: bool = False
MAIL_SSL_TLS: bool = True
@property
def MAIL_CONFIG(self) -> str:
def mail_config(self) -> ConnectionConfig:
return ConnectionConfig(
MAIL_FROM=self.MAIL_FROM,
MAIL_FROM_NAME=self.MAIL_FROM_NAME,
@@ -75,4 +84,4 @@ class Settings(BaseSettings):
@lru_cache
def get_settings():
return Settings()
return Settings()

View File

@@ -1,8 +1,8 @@
from functools import lru_cache
from celery import Celery
import redis
from celery import Celery
import redis
from app.shared.settings import get_settings
@@ -10,14 +10,14 @@ from app.shared.settings import get_settings
def get_celery(name: str = "") -> Celery:
return Celery(
name,
broker_url=get_settings().CELERY_BROKER_URL,
result_backend=get_settings().CELERY_BROKER_URL,
broker_url=get_settings().celery_broker_url,
result_backend=get_settings().celery_broker_url,
broker_connection_retry_on_startup=False,
broker_transport_options={
'queue_order_strategy': 'priority',
}
"queue_order_strategy": "priority",
},
)
def get_redis() -> redis.Redis:
return redis.Redis.from_url(get_settings().CELERY_BROKER_URL)
return redis.Redis.from_url(get_settings().celery_broker_url)

View File

@@ -1,9 +1,16 @@
import json
import os
from typing import Dict, List, Set
import yaml
from loguru import logger
from pydantic import BaseModel, computed_field, field_validator, Field, model_validator
from typing import Dict, List, Set
from pydantic import (
BaseModel,
Field,
computed_field,
field_validator,
model_validator,
)
from typing_extensions import Self
@@ -12,13 +19,16 @@ class UserGroups:
user_groups = self.read_yaml(filename)
self.validate_and_load(user_groups)
def read_yaml(self, user_groups_filename):
@staticmethod
def read_yaml(user_groups_filename):
# read yaml safely
with open(user_groups_filename) as inf:
try:
return yaml.safe_load(inf)
except yaml.YAMLError as e:
logger.error(f"could not open user groups filename {user_groups_filename}: {e}")
logger.error(
f"could not open user groups filename {user_groups_filename}: {e}"
)
raise e
def validate_and_load(self, user_groups):
@@ -45,22 +55,36 @@ class GroupPermissions(BaseModel):
max_monthly_mbs: int = 0
priority: str = "low"
@field_validator('max_sheets', 'max_archive_lifespan_months', 'max_monthly_urls', 'max_monthly_mbs', mode='before')
@classmethod
@field_validator(
"max_sheets",
"max_archive_lifespan_months",
"max_monthly_urls",
"max_monthly_mbs",
mode="before",
)
def validate_max_values(cls, v):
if v < -1:
raise ValueError("max_* values should be positive integers or -1 (for no limit).")
raise ValueError(
"max_* values should be positive integers or -1 (for no limit)."
)
return v
@field_validator('sheet_frequency', mode='before')
@classmethod
@field_validator("sheet_frequency", mode="before")
def validate_sheet_frequency(cls, v):
if not v: return []
if not v:
return []
allowed = ["daily", "hourly"]
for k in v:
if k not in allowed:
raise ValueError(f"Invalid sheet frequency: '{k}', expected one of {allowed}")
raise ValueError(
f"Invalid sheet frequency: '{k}', expected one of {allowed}"
)
return v
@field_validator('priority', mode='before')
@classmethod
@field_validator("priority", mode="before")
def validate_priority(cls, v):
v = v.lower()
if v not in ["low", "high"]:
@@ -70,19 +94,31 @@ class GroupPermissions(BaseModel):
class GroupModel(BaseModel):
description: str
orchestrator: str
orchestrator_sheet: str
orchestrator: str | None = None
orchestrator_sheet: str | None = None
permissions: GroupPermissions
@field_validator('orchestrator', 'orchestrator_sheet', mode='before')
@classmethod
@field_validator("orchestrator", mode="before")
def validate_orchestrator(cls, v):
if not os.path.exists(v):
# orchestrator is only needed if the group has archive_url permission
if cls.permissions.archive_url and not os.path.exists(v):
raise ValueError(f"Orchestrator file not found with this path: {v}")
return v
@classmethod
@field_validator("orchestrator_sheet", mode="before")
def validate_orchestrator_sheet(cls, v):
# orchestrator_sheet is only needed if the group has archive_sheet permission
if cls.permissions.archive_sheet and not os.path.exists(v):
raise ValueError(f"Orchestrator file not found with this path: {v}")
return v
@computed_field
@property
def service_account_email(self) -> str:
if self.orchestrator_sheet is None:
return ""
if hasattr(self, "_service_account_email"):
return self._service_account_email
orch = yaml.safe_load(open(self.orchestrator_sheet))
@@ -98,13 +134,17 @@ class GroupModel(BaseModel):
service_account_json = find_service_account_email(orch)
if not service_account_json:
raise ValueError(f"service_account key not found in orchestrator sheet file: {self.orchestrator_sheet}.")
raise ValueError(
f"service_account key not found in orchestrator sheet file: {self.orchestrator_sheet}."
)
with open(service_account_json) as f:
self._service_account_email = json.load(f).get("client_email")
if not self._service_account_email:
raise ValueError(f"Service account email not found in {service_account_json}.")
raise ValueError(
f"Service account email not found in {service_account_json}."
)
return self._service_account_email
@@ -114,29 +154,45 @@ class UserGroupModel(BaseModel):
domains: Dict[str, List[str]] = Field(default_factory=dict)
groups: Dict[str, GroupModel] = Field(default_factory=dict)
@field_validator('users', mode='before')
@classmethod
@field_validator("users", mode="before")
def validate_emails(cls, v):
for email in v.keys():
if '@' not in email:
raise ValueError(f"Invalid user, it should be an address: {email}")
if "@" not in email:
raise ValueError(
f"Invalid user, it should be an address: {email}"
)
if not v[email]:
raise ValueError(f"User {email} has no explicitly listed groups, only include them here if they should be in a group.")
raise ValueError(
f"User {email} has no explicitly listed groups, only include them here if they should be in a group."
)
# all users belong to the default group
return {k.lower().strip(): list(set(["default"] + [g.lower().strip() for g in v])) for k, v in v.items()}
return {
k.lower().strip(): list(
set(["default"] + [g.lower().strip() for g in v])
)
for k, v in v.items()
}
@field_validator('domains', mode='before')
@classmethod
@field_validator("domains", mode="before")
def validate_domains(cls, v):
for domain, members in v.items():
if '.' not in domain:
raise ValueError(f"Invalid domain, it should contain a dot: {domain}")
if "." not in domain:
raise ValueError(
f"Invalid domain, it should contain a dot: {domain}"
)
if not members:
raise ValueError(f"Domain {domain} should have at least one member.")
return {k.lower().strip(): list(set([g.lower().strip() for g in v])) for k, v in v.items()}
raise ValueError(
f"Domain {domain} should have at least one member."
)
return {
k.lower().strip(): list({[g.lower().strip() for g in v]})
for k, v in v.items()
}
@field_validator('groups', mode='before')
@classmethod
@field_validator("groups", mode="before")
def validate_groups(cls, v):
if "default" not in v.keys():
raise ValueError("Please include a 'default' group.")
@@ -147,20 +203,28 @@ class UserGroupModel(BaseModel):
raise ValueError(f"Group names should be lowercase: {group}")
return v
@model_validator(mode='after')
@model_validator(mode="after")
def check_groups_consistency(self) -> Self:
groups_in_domains = set([g for domain in self.domains for g in self.domains[domain]])
groups_in_users = set([g for user in self.users for g in self.users[user]])
groups_in_domains = {
g for domain in self.domains for g in self.domains[domain]
}
groups_in_users = {g for user in self.users for g in self.users[user]}
configured_groups = set(self.groups.keys())
# groups mentioned in domains and users should be defined, but this is not a ValueError since historical DB data may require it
# groups mentioned in domains and users should be defined, but this is
# not a ValueError since historical DB data may require it
if groups_in_domains - configured_groups:
logger.warning(f"These groups are associated to DOMAINS but not defined in the GROUPS section, the domains settings may not work as expected: {groups_in_domains - configured_groups}")
logger.warning(
f"These groups are associated to DOMAINS but not defined in the GROUPS section, the domains settings may not work as expected: {groups_in_domains - configured_groups}"
)
if groups_in_users - configured_groups:
logger.warning(f"These groups are associated to USERS but not defined in the GROUPS section, the users settings may not work as expected: {groups_in_users - configured_groups}")
logger.warning(
f"These groups are associated to USERS but not defined in the GROUPS section, the users settings may not work as expected: {groups_in_users - configured_groups}"
)
return self
# for the API return values

View File

@@ -1,10 +1,14 @@
def fnv1a_hash_mod(s: str, modulo:int) -> int:
# receives a string and returns a number in [0:modulo-1], ensures an even distribution over the modulo range
hash = 0x811c9dc5 # FNV offset basis
fnv_prime = 0x01000193 # FNV prime
def fnv1a_hash_mod(s: str, modulo: int) -> int:
# receives a string and returns a number in [0:modulo-1], ensures an even
# distribution over the modulo range
offset_basis_hash = 0x811C9DC5 # FNV offset basis
fnv_prime = 0x01000193 # FNV prime
for char in s:
hash ^= ord(char)
hash *= fnv_prime
hash &= 0xFFFFFFFF # Keep it 32-bit
return (hash if hash < 0x80000000 else hash - 0x100000000) % modulo
offset_basis_hash ^= ord(char)
offset_basis_hash *= fnv_prime
offset_basis_hash &= 0xFFFFFFFF # Keep it 32-bit
return (
offset_basis_hash
if offset_basis_hash < 0x80000000
else offset_basis_hash - 0x100000000
) % modulo

View File

@@ -1,19 +1,39 @@
import os
from datetime import datetime
from http import HTTPStatus
from typing import AsyncGenerator
from fastapi.testclient import TestClient
import pytest
from unittest.mock import patch
import pytest
import pytest_asyncio
from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine
from app.web.config import ALLOW_ANY_EMAIL
from fastapi.testclient import TestClient
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from app.shared.db import models
from app.shared.db.database import (
make_async_engine,
make_async_session_local,
make_engine,
make_session_local,
)
from app.shared.settings import Settings
from app.web.config import ALLOW_ANY_EMAIL
from app.web.db import crud
from app.web.db.crud import get_user_group_names
from app.web.db.user_state import UserState
from app.web.main import app_factory
from app.web.security import (
get_token_or_user_auth,
get_user_auth,
get_user_state,
token_api_key_auth,
)
@pytest.fixture(autouse=True)
def mock_logger_add():
"""Fixture to mock loguru.logger.add for all tests."""
with patch('loguru.logger.add') as mock_add:
with patch("loguru.logger.add") as mock_add:
yield mock_add # This makes the mock available to tests
@@ -24,23 +44,22 @@ def get_settings():
@pytest.fixture(autouse=True)
def mock_settings():
with patch('app.shared.settings.Settings', return_value=Settings(_env_file=".env.test")) as mock_settings:
with patch(
"app.shared.settings.Settings",
return_value=Settings(_env_file=".env.test"),
) as mock_settings:
yield mock_settings
@pytest.fixture()
def test_db(get_settings: Settings):
from app.shared.db import models
from app.shared.db.database import make_engine
from app.web.db.crud import get_user_group_names
get_user_group_names.cache_clear()
make_engine.cache_clear()
engine = make_engine(get_settings.DATABASE_PATH)
fs = get_settings.DATABASE_PATH.replace("sqlite:///", "")
if not os.path.exists(fs):
open(fs, 'w').close()
open(fs, "w").close()
models.Base.metadata.create_all(engine)
@@ -57,7 +76,6 @@ def test_db(get_settings: Settings):
@pytest.fixture()
def db_session(test_db):
from app.shared.db.database import make_session_local
session_local = make_session_local(test_db)
with session_local() as session:
yield session
@@ -65,17 +83,12 @@ def db_session(test_db):
@pytest_asyncio.fixture()
async def async_test_db(get_settings: Settings):
from app.shared.db import models
from app.shared.db.database import make_async_engine
from app.web.db.crud import get_user_group_names
import asyncio
get_user_group_names.cache_clear()
engine = await make_async_engine(get_settings.ASYNC_DATABASE_PATH)
engine = await make_async_engine(get_settings.async_database_path)
fs = get_settings.ASYNC_DATABASE_PATH.replace("sqlite+aiosqlite:///", "")
fs = get_settings.async_database_path.replace("sqlite+aiosqlite:///", "")
if not os.path.exists(fs):
open(fs, 'w').close()
open(fs, "w").close()
async def create_all():
async with engine.begin() as conn:
@@ -99,8 +112,9 @@ async def async_test_db(get_settings: Settings):
@pytest_asyncio.fixture()
async def async_db_session(async_test_db: AsyncEngine) -> AsyncGenerator[AsyncSession, None]:
from app.shared.db.database import make_async_session_local
async def async_db_session(
async_test_db: AsyncEngine,
) -> AsyncGenerator[AsyncSession, None]:
session_local = await make_async_session_local(async_test_db)
async with session_local() as session:
yield session
@@ -108,8 +122,6 @@ async def async_db_session(async_test_db: AsyncEngine) -> AsyncGenerator[AsyncSe
@pytest.fixture()
def app(db_session):
from app.web.main import app_factory
from app.web.db import crud
app = app_factory()
crud.upsert_user_groups(db_session)
return app
@@ -123,10 +135,13 @@ def client(app):
@pytest.fixture()
def app_with_auth(app, db_session):
from app.web.security import get_token_or_user_auth, get_user_auth, get_user_state
app.dependency_overrides[get_token_or_user_auth] = lambda: "rick@example.com"
app.dependency_overrides[get_token_or_user_auth] = (
lambda: "rick@example.com"
)
app.dependency_overrides[get_user_auth] = lambda: "morty@example.com"
app.dependency_overrides[get_user_state] = lambda: UserState(db_session, "MORTY@example.com")
app.dependency_overrides[get_user_state] = lambda: UserState(
db_session, "MORTY@example.com"
)
return app
@@ -138,7 +153,6 @@ def client_with_auth(app_with_auth):
@pytest.fixture()
def app_with_token(app):
from app.web.security import token_api_key_auth, get_token_or_user_auth
app.dependency_overrides[token_api_key_auth] = lambda: ALLOW_ANY_EMAIL
app.dependency_overrides[get_token_or_user_auth] = lambda: ALLOW_ANY_EMAIL
return app
@@ -155,6 +169,93 @@ def test_no_auth():
# reusable code to ensure a method/endpoint combination is unauthorized
def no_auth(http_method, endpoint):
response = http_method(endpoint)
assert response.status_code == 403
assert response.status_code == HTTPStatus.FORBIDDEN
assert response.json() == {"detail": "Not authenticated"}
return no_auth
@pytest.fixture()
def test_data(db_session):
author_emails = [
"rick@example.com",
"morty@example.com",
"jerry@example.com",
]
# creates 3 users
for email in author_emails:
db_session.add(models.User(email=email))
db_session.commit()
assert db_session.query(models.User).count() == 3
# creates 100 archives for 3 users over 2 months with repeating URLs
for i in range(100):
author = author_emails[i % 3]
archive = models.Archive(
id=f"archive-id-456-{i}",
url=f"https://example-{i % 3}.com",
result={},
public=author == "jerry@example.com",
author_id=author,
group_id="spaceship"
if author == "morty@example.com" and i % 2 == 0
else None,
created_at=datetime(2021, (i % 2) + 1, (i % 25) + 1),
)
if i % 5 == 0:
archive.tags.append(models.Tag(id=f"tag-{i}"))
if i % 10 == 0:
archive.tags.append(models.Tag(id=f"tag-second-{i}"))
if i % 4 == 0:
archive.tags.append(models.Tag(id=f"tag-third-{i}"))
for j in range(10):
archive.urls.append(
models.ArchiveUrl(
url=f"https://example-{i}.com/{j}", key=f"media_{j}"
)
)
db_session.add(archive)
# creates a sheet for each user
for i, email in enumerate(
["rick@example.com", "morty@example.com", "jerry@example.com"]
):
db_session.add(
models.Sheet(
id=f"sheet-{i}",
name=f"sheet-{i}",
author_id=email,
group_id=None,
frequency="daily",
)
)
if email == "rick@example.com":
db_session.add(
models.Sheet(
id=f"sheet-{i}-2",
name=f"sheet-{i}-2",
author_id=email,
group_id="spaceship",
frequency="hourly",
)
)
db_session.commit()
assert db_session.query(models.Archive).count() == 100
assert db_session.query(models.Tag).count() == 20 + 10 + 25
assert db_session.query(models.ArchiveUrl).count() == 1000
assert (
db_session.query(models.ArchiveUrl)
.filter(models.ArchiveUrl.archive_id == "archive-id-456-0")
.count()
== 10
)
# setup groups
assert db_session.query(models.Group).count() == 0
crud.upsert_user_groups(db_session)
assert db_session.query(models.Group).count() == 4
assert db_session.query(models.User).count() == 3

View File

@@ -1,3 +1,3 @@
{
"client_email": "fake_service_account@fake_service_account.iam.gserviceaccount.com"
}
"client_email": "fake_service_account@fake_service_account.iam.gserviceaccount.com"
}

View File

@@ -1,7 +1,8 @@
steps:
feeder: cli_feeder
feeders:
- cli_feeder
archivers: # order matters
- youtubedl_archiver
- generic_extractor
enrichers:
- hash_enricher
@@ -12,10 +13,10 @@ steps:
- console_db
configurations:
gsheet_feeder:
gsheet_feeder_db:
service_account: "app/tests/fake_service_account.json"
cli_feeder:
urls:
urls:
- "url1"
hash_enricher:
algorithm: "SHA-256"

View File

@@ -1,6 +1,7 @@
def test_generate_uuid():
from app.shared.db.models import generate_uuid
from app.shared.db.models import generate_uuid
assert generate_uuid() != generate_uuid()
assert len(generate_uuid()) == 36
assert generate_uuid().count("-") == 4
def test_generate_uuid():
assert generate_uuid() != generate_uuid()
assert len(generate_uuid()) == 36
assert generate_uuid().count("-") == 4

View File

@@ -1,12 +1,10 @@
from app.shared.db import models
from app.shared.db import worker_crud, models
from datetime import datetime
from app.shared import schemas
from app.shared.db import models, worker_crud
from app.tests.web.db.test_crud import test_data
def test_update_sheet_last_url_archived_at(db_session):
# Create test sheet
test_sheet = models.Sheet(id="sheet-123")
db_session.add(test_sheet)
@@ -15,17 +13,24 @@ def test_update_sheet_last_url_archived_at(db_session):
# Test updating existing sheet
assert isinstance(test_sheet.last_url_archived_at, datetime)
before = test_sheet.last_url_archived_at
assert worker_crud.update_sheet_last_url_archived_at(db_session, "sheet-123") is True
assert (
worker_crud.update_sheet_last_url_archived_at(db_session, "sheet-123")
is True
)
db_session.refresh(test_sheet)
assert isinstance(test_sheet.last_url_archived_at, datetime)
assert test_sheet.last_url_archived_at > before
# Test non-existent sheet
assert worker_crud.update_sheet_last_url_archived_at(db_session, "non-existent-sheet") is False
assert (
worker_crud.update_sheet_last_url_archived_at(
db_session, "non-existent-sheet"
)
is False
)
def test_get_group(test_data, db_session):
from app.shared.db import worker_crud
assert worker_crud.get_group(db_session, "spaceship") is not None
assert worker_crud.get_group(db_session, "interdimensional") is not None
assert worker_crud.get_group(db_session, "animated-characters") is not None
@@ -33,24 +38,24 @@ def test_get_group(test_data, db_session):
def test_create_or_get_user(test_data, db_session):
from app.shared.db import worker_crud
assert db_session.query(models.User).count() == 3
# already exists
assert (u1 := worker_crud.create_or_get_user(db_session, "rick@example.com")) is not None
assert (
u1 := worker_crud.create_or_get_user(db_session, "rick@example.com")
) is not None
assert u1.email == "rick@example.com"
# new user
assert (u2 := worker_crud.create_or_get_user(db_session, "beth@example.com")) is not None
assert (
u2 := worker_crud.create_or_get_user(db_session, "beth@example.com")
) is not None
assert u2.email == "beth@example.com"
assert db_session.query(models.User).count() == 4
def test_create_tag(db_session):
from app.shared.db import worker_crud
assert db_session.query(models.Tag).count() == 0
# create first
@@ -58,7 +63,10 @@ def test_create_tag(db_session):
assert create_tag is not None
assert create_tag.id == "tag-101"
assert db_session.query(models.Tag).count() == 1
assert db_session.query(models.Tag).filter(models.Tag.id == "tag-101").first() == create_tag
assert (
db_session.query(models.Tag).filter(models.Tag.id == "tag-101").first()
== create_tag
)
# same id does not add new db entry
existing_tag = worker_crud.create_tag(db_session, "tag-101")
@@ -73,9 +81,6 @@ def test_create_tag(db_session):
def test_create_task(db_session):
from app.shared.db import worker_crud
from app.shared import schemas
task = schemas.ArchiveCreate(
id="archive-id-456-101",
url="https://example-0.com",
@@ -84,17 +89,22 @@ def test_create_task(db_session):
author_id="rick@example.com",
group_id="spaceship",
tags=[],
urls=[]
urls=[],
)
# with tags and urls
nt = worker_crud.create_archive(db_session, task, [models.Tag(id="tag-101")], [models.ArchiveUrl(url="https://example-0.com/0", key="media_0")])
nt = worker_crud.create_archive(
db_session,
task,
[models.Tag(id="tag-101")],
[models.ArchiveUrl(url="https://example-0.com/0", key="media_0")],
)
assert nt is not None
assert nt.id == "archive-id-456-101"
assert nt.url == "https://example-0.com"
assert nt.author_id == "rick@example.com"
assert nt.public == False
assert nt.public is False
assert nt.group_id == "spaceship"
assert len(nt.tags) == 1
assert nt.tags[0].id == "tag-101"
@@ -110,8 +120,8 @@ def test_create_task(db_session):
assert nt.id == "archive-id-456-102"
assert nt.url == "https://example-0.com"
assert nt.author_id == "rick@example.com"
assert nt.public == False
assert nt.public is False
assert nt.group_id == "spaceship"
assert len(nt.tags) == 0
assert len(nt.urls) == 0
assert nt.created_at is not None
assert nt.created_at is not None

View File

@@ -1,10 +1,15 @@
from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch
import pytest
from app.shared.business_logic import get_store_archive_until, get_store_archive_until_or_never
from app.shared.business_logic import (
get_store_archive_until,
get_store_archive_until_or_never,
)
class Test_get_store_archive_until:
class TestGetStoreArchiveUntil:
GROUP_ID = "test-group"
def test_group_not_found(self, db_session):
@@ -12,7 +17,10 @@ class Test_get_store_archive_until:
get_store_archive_until(db_session, self.GROUP_ID)
assert str(exc.value) == f"Group {self.GROUP_ID} not found."
@patch("app.shared.db.worker_crud.get_group", return_value=MagicMock(permissions=None))
@patch(
"app.shared.db.worker_crud.get_group",
return_value=MagicMock(permissions=None),
)
def test_group_no_permissions(self, db_session):
with pytest.raises(AssertionError) as exc:
get_store_archive_until(db_session, self.GROUP_ID)
@@ -43,14 +51,17 @@ class Test_get_store_archive_until:
mock_get_group.assert_called_once_with(db_session, self.GROUP_ID)
class Test_get_store_archive_until_or_never:
class TestGetStoreArchiveUntilOrNever:
GROUP_ID = "test-group"
def test_group_not_found(self, db_session):
result = get_store_archive_until_or_never(db_session, self.GROUP_ID)
assert result is None
@patch("app.shared.db.worker_crud.get_group", return_value=MagicMock(permissions=None))
@patch(
"app.shared.db.worker_crud.get_group",
return_value=MagicMock(permissions=None),
)
def test_group_no_permissions(self, db_session):
result = get_store_archive_until_or_never(db_session, self.GROUP_ID)
assert result is None

View File

@@ -11,7 +11,7 @@ def test_fnv1a_hash_mod():
# Test different modulos
hash1 = fnv1a_hash_mod("test", 5)
hash2 = fnv1a_hash_mod("test", 10)
hash2 = fnv1a_hash_mod("test", 10)
assert 0 <= hash1 < 5
assert 0 <= hash2 < 10
@@ -28,4 +28,4 @@ def test_fnv1a_hash_mod():
assert 0 <= fnv1a_hash_mod("测试", 10) < 10
# Test modulo = 1 edge case
assert fnv1a_hash_mod("test", 1) == 0
assert fnv1a_hash_mod("test", 1) == 0

View File

@@ -3,4 +3,4 @@ This is just an invalid yaml for testing
still broken: True
- one
- two
- two

View File

@@ -84,4 +84,4 @@ groups:
# max_archive_lifespan_months: 12
max_monthly_urls: 1
# max_monthly_mbs: 50
priority: "low"
priority: "low"

View File

@@ -2,123 +2,379 @@ from datetime import datetime, timedelta
from unittest.mock import patch
import pytest
import sqlalchemy
import yaml
from sqlalchemy import false, true
from sqlalchemy.sql import select
from app.shared.db import models
from app.shared.settings import Settings
from app.web.config import ALLOW_ANY_EMAIL
from app.web.db import crud
authors = ["rick@example.com", "morty@example.com", "jerry@example.com"]
@pytest.fixture()
def test_data(db_session):
# creates 3 users
for email in authors:
db_session.add(models.User(email=email))
db_session.commit()
assert db_session.query(models.User).count() == 3
# creates 100 archives for 3 users over 2 months with repeating URLs
for i in range(100):
author = authors[i % 3]
archive = models.Archive(
id=f"archive-id-456-{i}",
url=f"https://example-{i%3}.com",
result={},
public=author == "jerry@example.com",
author_id=author,
group_id="spaceship" if author == "morty@example.com" and i % 2 == 0 else None,
created_at=datetime(2021, (i % 2) + 1, (i % 25) + 1)
)
if i % 5 == 0:
archive.tags.append(models.Tag(id=f"tag-{i}"))
if i % 10 == 0:
archive.tags.append(models.Tag(id=f"tag-second-{i}"))
if i % 4 == 0:
archive.tags.append(models.Tag(id=f"tag-third-{i}"))
for j in range(10):
archive.urls.append(models.ArchiveUrl(url=f"https://example-{i}.com/{j}", key=f"media_{j}"))
db_session.add(archive)
# creates a sheet for each user
for i, email in enumerate(authors):
db_session.add(models.Sheet(id=f"sheet-{i}", name=f"sheet-{i}", author_id=email, group_id=None, frequency="daily"))
if email == "rick@example.com":
db_session.add(models.Sheet(id=f"sheet-{i}-2", name=f"sheet-{i}-2", author_id=email, group_id="spaceship", frequency="hourly"))
db_session.commit()
assert db_session.query(models.Archive).count() == 100
assert db_session.query(models.Tag).count() == 20 + 10 + 25
assert db_session.query(models.ArchiveUrl).count() == 1000
assert db_session.query(models.ArchiveUrl).filter(models.ArchiveUrl.archive_id == "archive-id-456-0").count() == 10
# setup groups
assert db_session.query(models.Group).count() == 0
from app.web.db import crud
crud.upsert_user_groups(db_session)
assert db_session.query(models.Group).count() == 4
assert db_session.query(models.User).count() == 3
def test_search_archives_by_url(test_data, db_session):
from app.web.config import ALLOW_ANY_EMAIL
# rick's archives are private
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "rick@example.com", True, False)) == 34
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "rick@example.com", [], False)) == 34
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "rick@example.com", [], True)) == 34
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", ALLOW_ANY_EMAIL, [], False)) == 34
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", ALLOW_ANY_EMAIL, True, False)) == 34
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "morty@example.com", [], False)) == 0
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "morty@example.com", [], True)) == 0
# Rick's archives are private
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example-0.com",
"rick@example.com",
True,
False,
)
)
== 34
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example-0.com",
"rick@example.com",
[],
False,
)
)
== 34
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example-0.com",
"rick@example.com",
[],
True,
)
)
== 34
)
assert (
len(
crud.search_archives_by_url(
db_session, "https://example-0.com", ALLOW_ANY_EMAIL, [], False
)
)
== 34
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example-0.com",
ALLOW_ANY_EMAIL,
True,
False,
)
)
== 34
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example-0.com",
"morty@example.com",
[],
False,
)
)
== 0
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example-0.com",
"morty@example.com",
[],
True,
)
)
== 0
)
# morty's archives are public but half are in spaceship group
assert len(crud.search_archives_by_url(db_session, "https://example-1.com", "rick@example.com", ["spaceship"], False)) == 16
assert len(crud.search_archives_by_url(db_session, "https://example-1.com", "rick@example.com", True, False)) == 16
assert len(crud.search_archives_by_url(db_session, "https://example-1.com", "jerry@example.com", True, True)) == 16
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example-1.com",
"rick@example.com",
["spaceship"],
False,
)
)
== 16
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example-1.com",
"rick@example.com",
True,
False,
)
)
== 16
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example-1.com",
"jerry@example.com",
True,
True,
)
)
== 16
)
# jerry's archives are public
assert len(crud.search_archives_by_url(db_session, "https://example-2.com", "jerry@example.com", [], True)) == 33
assert len(crud.search_archives_by_url(db_session, "https://example-2.com", "rick@example.com", [], True)) == 33
# Jerry's archives are public
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example-2.com",
"jerry@example.com",
[],
True,
)
)
== 33
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example-2.com",
"rick@example.com",
[],
True,
)
)
== 33
)
# fuzzy search
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False)) == 100
assert len(crud.search_archives_by_url(db_session, "https://EXAMPLE", ALLOW_ANY_EMAIL, False, False)) == 100
assert len(crud.search_archives_by_url(db_session, "2.com", ALLOW_ANY_EMAIL, False, False)) == 33
assert (
len(
crud.search_archives_by_url(
db_session, "https://example", ALLOW_ANY_EMAIL, False, False
)
)
== 100
)
assert (
len(
crud.search_archives_by_url(
db_session, "https://EXAMPLE", ALLOW_ANY_EMAIL, False, False
)
)
== 100
)
assert (
len(
crud.search_archives_by_url(
db_session, "2.com", ALLOW_ANY_EMAIL, False, False
)
)
== 33
)
# absolute search
assert len(crud.search_archives_by_url(db_session, "example-2.com", ALLOW_ANY_EMAIL, [], False, absolute_search=True)) == 0
assert len(crud.search_archives_by_url(db_session, "https://example-2.com", ALLOW_ANY_EMAIL, [], False, absolute_search=True)) == 33
assert (
len(
crud.search_archives_by_url(
db_session,
"example-2.com",
ALLOW_ANY_EMAIL,
[],
False,
absolute_search=True,
)
)
== 0
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example-2.com",
ALLOW_ANY_EMAIL,
[],
False,
absolute_search=True,
)
)
== 33
)
# archived_after
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, True, True, archived_after=datetime(2010, 1, 1))) == 100
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_after=datetime(2021, 1, 15))) == 70
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_after=datetime(2031, 1, 1))) == 0
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example",
ALLOW_ANY_EMAIL,
True,
True,
archived_after=datetime(2010, 1, 1),
)
)
== 100
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example",
ALLOW_ANY_EMAIL,
False,
False,
archived_after=datetime(2021, 1, 15),
)
)
== 70
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example",
ALLOW_ANY_EMAIL,
False,
False,
archived_after=datetime(2031, 1, 1),
)
)
== 0
)
# archived before
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_before=datetime(2010, 1, 1))) == 0
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_before=datetime(2021, 1, 15))) == 28
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_before=datetime(2031, 1, 1))) == 100
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example",
ALLOW_ANY_EMAIL,
False,
False,
archived_before=datetime(2010, 1, 1),
)
)
== 0
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example",
ALLOW_ANY_EMAIL,
False,
False,
archived_before=datetime(2021, 1, 15),
)
)
== 28
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example",
ALLOW_ANY_EMAIL,
False,
False,
archived_before=datetime(2031, 1, 1),
)
)
== 100
)
# archived before and after
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_after=datetime(2001, 1, 1), archived_before=datetime(2031, 1, 11))) == 100
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_after=datetime(2021, 1, 14), archived_before=datetime(2021, 1, 16))) == 2
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example",
ALLOW_ANY_EMAIL,
False,
False,
archived_after=datetime(2001, 1, 1),
archived_before=datetime(2031, 1, 11),
)
)
== 100
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example",
ALLOW_ANY_EMAIL,
False,
False,
archived_after=datetime(2021, 1, 14),
archived_before=datetime(2021, 1, 16),
)
)
== 2
)
# limit
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, limit=10)) == 10
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, limit=-1)) == 1
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example",
ALLOW_ANY_EMAIL,
False,
False,
limit=10,
)
)
== 10
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example",
ALLOW_ANY_EMAIL,
False,
False,
limit=-1,
)
)
== 1
)
# skip
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, skip=10)) == 90
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example",
ALLOW_ANY_EMAIL,
False,
False,
skip=10,
)
)
== 90
)
def test_search_archives_by_email(test_data, db_session):
from app.web.config import ALLOW_ANY_EMAIL
# lower/upper case
assert len(crud.search_archives_by_email(db_session, "rick@example.com")) == 34
assert (
len(crud.search_archives_by_email(db_session, "rick@example.com")) == 34
)
# ALLOW_ANY_EMAIL is not a user
assert len(crud.search_archives_by_email(db_session, ALLOW_ANY_EMAIL)) == 0
@@ -136,45 +392,108 @@ def test_search_archives_by_email(test_data, db_session):
@patch("app.web.db.crud.DATABASE_QUERY_LIMIT", new=25)
def test_max_query_limit(test_data, db_session):
from app.web.config import ALLOW_ANY_EMAIL
assert (
len(
crud.search_archives_by_url(
db_session, "https://example", ALLOW_ANY_EMAIL, [], False
)
)
== 25
)
assert (
len(
crud.search_archives_by_url(
db_session,
"https://example",
ALLOW_ANY_EMAIL,
True,
True,
limit=1000,
)
)
== 25
)
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, [], False)) == 25
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, True, True, limit=1000)) == 25
assert len(crud.search_archives_by_email(db_session, "rick@example.com")) == 25
assert len(crud.search_archives_by_email(db_session, "rick@example.com", limit=1000)) == 25
assert (
len(crud.search_archives_by_email(db_session, "rick@example.com")) == 25
)
assert (
len(
crud.search_archives_by_email(
db_session, "rick@example.com", limit=1000
)
)
== 25
)
def test_soft_delete(test_data, db_session):
# none deleted yet
db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").first() is not None
assert db_session.query(models.Archive).filter(models.Archive.deleted == True).count() == 0
assert (
db_session.query(models.Archive)
.filter(models.Archive.id == "archive-id-456-0")
.first()
is not None
)
assert (
db_session.query(models.Archive)
.filter(models.Archive.deleted.is_(true()))
.count()
== 0
)
# delete
assert crud.soft_delete_archive(db_session, "archive-id-456-0", "rick@example.com") == True
assert (
crud.soft_delete_archive(
db_session, "archive-id-456-0", "rick@example.com"
)
is True
)
# ensure soft delete
assert db_session.query(models.Archive).filter(models.Archive.deleted == True).count() == 1
db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").first() is None
assert (
db_session.query(models.Archive)
.filter(models.Archive.deleted.is_(true()))
.count()
== 1
)
assert (
db_session.query(models.Archive)
.filter(models.Archive.id == "archive-id-456-0")
.filter(models.Archive.deleted.is_(false()))
.first()
is None
)
# already deleted
assert crud.soft_delete_archive(db_session, "archive-id-456-0", "rick@example.com") == False
assert (
crud.soft_delete_archive(
db_session, "archive-id-456-0", "rick@example.com"
)
is False
)
def test_count_archives(test_data, db_session):
assert crud.count_archives(db_session) == 100
db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").delete()
db_session.query(models.Archive).filter(
models.Archive.id == "archive-id-456-0"
).delete()
db_session.commit()
assert crud.count_archives(db_session) == 99
def test_count_archive_urls(test_data, db_session):
assert crud.count_archive_urls(db_session) == 1000
db_session.query(models.ArchiveUrl).filter(models.ArchiveUrl.url == "https://example-0.com/0").delete()
db_session.query(models.ArchiveUrl).filter(
models.ArchiveUrl.url == "https://example-0.com/0"
).delete()
db_session.commit()
assert crud.count_archive_urls(db_session) == 999
db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").delete()
db_session.query(models.Archive).filter(
models.Archive.id == "archive-id-456-0"
).delete()
db_session.commit()
# no Cascade is enabled
assert crud.count_archives(db_session) == 99
@@ -183,16 +502,23 @@ def test_count_archive_urls(test_data, db_session):
def test_count_users(test_data, db_session):
assert crud.count_users(db_session) == 3
db_session.query(models.User).filter(models.User.email == "rick@example.com").delete()
db_session.query(models.User).filter(
models.User.email == "rick@example.com"
).delete()
db_session.commit()
assert crud.count_users(db_session) == 2
def test_count_by_users_since(test_data, db_session):
from app.web.db import crud
# 100y window
assert len(cu := crud.count_by_user_since(db_session, 60 * 60 * 24 * 31 * 12 * 100)) == 3
assert (
len(
cu := crud.count_by_user_since(
db_session, 60 * 60 * 24 * 31 * 12 * 100
)
)
== 3
)
assert cu[0].total == 34
assert cu[1].total == 33
assert cu[2].total == 33
@@ -201,9 +527,18 @@ def test_count_by_users_since(test_data, db_session):
def test_upsert_group(test_data, db_session):
assert db_session.query(models.Group).count() == 4
repeatable_params = ["desc 1", "orch.yaml", "sheet.yaml", "service_account_email@example.com", {"read": ["all"]}, ["example.com"]]
repeatable_params = [
"desc 1",
"orch.yaml",
"sheet.yaml",
"service_account_email@example.com",
{"read": ["all"]},
["example.com"],
]
assert (g1 := crud.upsert_group(db_session, "spaceship", *repeatable_params)) is not None
assert (
g1 := crud.upsert_group(db_session, "spaceship", *repeatable_params)
) is not None
assert g1.id == "spaceship"
assert g1.description == "desc 1"
assert g1.orchestrator == "orch.yaml"
@@ -212,14 +547,25 @@ def test_upsert_group(test_data, db_session):
assert g1.permissions == {"read": ["all"]}
assert g1.domains == ["example.com"]
assert len(g1.users) == 2
assert [u.email for u in g1.users] == ["rick@example.com", "morty@example.com"]
assert [u.email for u in g1.users] == [
"rick@example.com",
"morty@example.com",
]
assert (g2 := crud.upsert_group(db_session, "interdimensional", *repeatable_params)) is not None
assert (
g2 := crud.upsert_group(
db_session, "interdimensional", *repeatable_params
)
) is not None
assert g2.id == "interdimensional"
assert len(g2.users) == 1
assert [u.email for u in g2.users] == ["rick@example.com"]
assert (g3 := crud.upsert_group(db_session, "this-is-a-new-group", *repeatable_params)) is not None
assert (
g3 := crud.upsert_group(
db_session, "this-is-a-new-group", *repeatable_params
)
) is not None
assert g3.id == "this-is-a-new-group"
assert len(g3.users) == 0
@@ -227,29 +573,38 @@ def test_upsert_group(test_data, db_session):
def test_upsert_user_groups(db_session):
@patch('app.web.db.crud.get_settings', new=lambda: bad_setings)
@patch("app.web.db.crud.get_settings", new=lambda: bad_settings)
def test_missing_yaml(db_session):
with pytest.raises(FileNotFoundError):
crud.upsert_user_groups(db_session)
@patch('app.web.db.crud.get_settings', new=lambda: bad_setings)
@patch("app.web.db.crud.get_settings", new=lambda: bad_settings)
def test_broken_yaml(db_session):
with pytest.raises(yaml.YAMLError):
crud.upsert_user_groups(db_session)
bad_setings = Settings(_env_file=".env.test")
bad_settings = Settings(_env_file=".env.test")
bad_setings.USER_GROUPS_FILENAME = "app/tests/user-groups.test.missing.yaml"
bad_settings.USER_GROUPS_FILENAME = (
"app/tests/user-groups.test.missing.yaml"
)
test_missing_yaml(db_session)
bad_setings.USER_GROUPS_FILENAME = "app/tests/user-groups.test.broken.yaml"
bad_settings.USER_GROUPS_FILENAME = "app/tests/user-groups.test.broken.yaml"
test_broken_yaml(db_session)
def test_create_sheet(db_session):
assert db_session.query(models.Sheet).count() == 0
s = crud.create_sheet(db_session, "sheet-id-123", "sheet name", "email@example.com", "group-id", "hourly")
s = crud.create_sheet(
db_session,
"sheet-id-123",
"sheet name",
"email@example.com",
"group-id",
"hourly",
)
assert s is not None
assert s.id == "sheet-id-123"
assert s.name == "sheet name"
@@ -259,19 +614,35 @@ def test_create_sheet(db_session):
assert db_session.query(models.Sheet).count() == 1
# duplicate id
import sqlalchemy
with pytest.raises(sqlalchemy.exc.IntegrityError):
crud.create_sheet(db_session, "sheet-id-123", "I thought this was another sheet", "email", "group-id", "hourly")
crud.create_sheet(
db_session,
"sheet-id-123",
"I thought this was another sheet",
"email",
"group-id",
"hourly",
)
def test_get_user_sheet(test_data, db_session):
assert crud.get_user_sheet(db_session, "", "sheet-0") is None
assert crud.get_user_sheet(db_session, "morty@example.com", "sheet-0") is None
assert (
crud.get_user_sheet(db_session, "morty@example.com", "sheet-0") is None
)
assert crud.get_user_sheet(db_session, "rick@example.com", "sheet-0") is not None
assert crud.get_user_sheet(db_session, "rick@example.com", "sheet-0-2") is not None
assert crud.get_user_sheet(db_session, "morty@example.com", "sheet-1") is not None
assert (
crud.get_user_sheet(db_session, "rick@example.com", "sheet-0")
is not None
)
assert (
crud.get_user_sheet(db_session, "rick@example.com", "sheet-0-2")
is not None
)
assert (
crud.get_user_sheet(db_session, "morty@example.com", "sheet-1")
is not None
)
def test_get_user_sheets(test_data, db_session):
@@ -283,9 +654,9 @@ def test_get_user_sheets(test_data, db_session):
def test_delete_sheet(test_data, db_session):
assert crud.delete_sheet(db_session, "sheet-0", "") == False
assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") == True
assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") == False
assert crud.delete_sheet(db_session, "sheet-0", "") is False
assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") is True
assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") is False
@pytest.mark.asyncio
@@ -297,21 +668,21 @@ async def test_find_by_store_until(async_db_session):
url="https://example-expired-1.com",
result={},
author_id="rick@example.com",
store_until=now - timedelta(days=1)
store_until=now - timedelta(days=1),
)
archive2 = models.Archive(
id="archive-expired-2",
url="https://example-expired-2.com",
result={},
author_id="rick@example.com",
store_until=now - timedelta(hours=1)
store_until=now - timedelta(hours=1),
)
archive3 = models.Archive(
id="archive-active",
url="https://example-active.com",
result={},
author_id="rick@example.com",
store_until=now + timedelta(days=1)
store_until=now + timedelta(days=1),
)
async_db_session.add_all([archive1, archive2, archive3])
await async_db_session.commit()
@@ -321,11 +692,15 @@ async def test_find_by_store_until(async_db_session):
assert len(list(expired)) == 2
# Should find 1 archive expired before 2 hours ago
expired = await crud.find_by_store_until(async_db_session, now - timedelta(hours=2))
expired = await crud.find_by_store_until(
async_db_session, now - timedelta(hours=2)
)
assert len(list(expired)) == 1
# Should find no archives expired before 2 days ago
expired = await crud.find_by_store_until(async_db_session, now - timedelta(days=2))
expired = await crud.find_by_store_until(
async_db_session, now - timedelta(days=2)
)
assert len(list(expired)) == 0
# Should not find deleted archives
@@ -337,44 +712,82 @@ async def test_find_by_store_until(async_db_session):
@pytest.mark.asyncio
async def test_get_sheets_by_id_hash(async_db_session):
author_emails = [
"rick@example.com",
"morty@example.com",
"jerry@example.com",
]
# Add test data
authors = ["rick@example.com", "morty@example.com", "jerry@example.com"]
sheets = [
models.Sheet(id="sheet-0", name="sheet-0", author_id=authors[0], group_id=None, frequency="daily"),
models.Sheet(id="sheet-0-2", name="sheet-0-2", author_id=authors[0], group_id="spaceship", frequency="hourly"),
models.Sheet(id="sheet-1", name="sheet-1", author_id=authors[1], group_id=None, frequency="daily"),
models.Sheet(id="sheet-2", name="sheet-2", author_id=authors[2], group_id=None, frequency="daily")
models.Sheet(
id="sheet-0",
name="sheet-0",
author_id=author_emails[0],
group_id=None,
frequency="daily",
),
models.Sheet(
id="sheet-0-2",
name="sheet-0-2",
author_id=author_emails[0],
group_id="spaceship",
frequency="hourly",
),
models.Sheet(
id="sheet-1",
name="sheet-1",
author_id=author_emails[1],
group_id=None,
frequency="daily",
),
models.Sheet(
id="sheet-2",
name="sheet-2",
author_id=author_emails[2],
group_id=None,
frequency="daily",
),
]
async_db_session.add_all(sheets)
await async_db_session.commit()
with patch("app.web.db.crud.fnv1a_hash_mod", return_value=1):
# Test retrieving hourly sheets
hourly_sheets = await crud.get_sheets_by_id_hash(async_db_session, "hourly", 4, 1)
hourly_sheets = await crud.get_sheets_by_id_hash(
async_db_session, "hourly", 4, 1
)
assert len(hourly_sheets) == 1
assert hourly_sheets[0].id == "sheet-0-2"
assert hourly_sheets[0].frequency == "hourly"
# Test retrieving daily sheets
daily_sheets = await crud.get_sheets_by_id_hash(async_db_session, "daily", 4, 1)
daily_sheets = await crud.get_sheets_by_id_hash(
async_db_session, "daily", 4, 1
)
assert len(daily_sheets) == 3
assert all(sheet.frequency == "daily" for sheet in daily_sheets)
assert {sheet.id for sheet in daily_sheets} == {"sheet-0", "sheet-1", "sheet-2"}
assert {sheet.id for sheet in daily_sheets} == {
"sheet-0",
"sheet-1",
"sheet-2",
}
# Test with non-matching hash
no_sheets = await crud.get_sheets_by_id_hash(async_db_session, "daily", 4, 3)
no_sheets = await crud.get_sheets_by_id_hash(
async_db_session, "daily", 4, 3
)
assert len(no_sheets) == 0
# Test with non-existent frequency
weekly_sheets = await crud.get_sheets_by_id_hash(async_db_session, "weekly", 4, 1)
weekly_sheets = await crud.get_sheets_by_id_hash(
async_db_session, "weekly", 4, 1
)
assert len(weekly_sheets) == 0
@pytest.mark.asyncio
async def test_delete_stale_sheets(async_db_session):
from datetime import datetime, timedelta
from sqlalchemy.sql import select
now = datetime.now()
active_date = now - timedelta(days=5)
stale_date = now - timedelta(days=15)
@@ -386,29 +799,29 @@ async def test_delete_stale_sheets(async_db_session):
name="Active Sheet 1",
author_id="rick@example.com",
frequency="daily",
last_url_archived_at=active_date
last_url_archived_at=active_date,
),
models.Sheet(
id="sheet-active-2",
name="Active Sheet 2",
author_id="morty@example.com",
frequency="hourly",
last_url_archived_at=active_date
last_url_archived_at=active_date,
),
models.Sheet(
id="sheet-stale-1",
name="Stale Sheet 1",
author_id="rick@example.com",
frequency="daily",
last_url_archived_at=stale_date
last_url_archived_at=stale_date,
),
models.Sheet(
id="sheet-stale-2",
name="Stale Sheet 2",
author_id="morty@example.com",
frequency="daily",
last_url_archived_at=stale_date
)
last_url_archived_at=stale_date,
),
]
async_db_session.add_all(sheets)
await async_db_session.commit()
@@ -435,4 +848,4 @@ async def test_delete_stale_sheets(async_db_session):
# Running again should not delete anything
deleted = await crud.delete_stale_sheets(async_db_session, 7)
assert len(deleted) == 0
assert len(deleted) == 0

View File

@@ -1,10 +1,11 @@
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from app.shared.db import models
from app.shared.user_groups import GroupInfo, GroupPermissions
from app.web.db.user_state import UserState
from app.web.utils.misc import convert_priority_to_queue_dict
def fresh_user_state():
@@ -20,39 +21,73 @@ def user_state():
def user_state_with_groups(user_state):
user_groups = [
models.Group(id="no-permissions", permissions={}),
models.Group(id="group1", description="this is g1", service_account_email="sa1@example.com", permissions={"read": ["group1", "no-permissions"], "read_public": True, "archive_url": True, "archive_sheet": True, "max_archive_lifespan_months": 24, "max_monthly_urls": 100, "max_monthly_mbs": 1000, "priority": "high"}),
models.Group(id="group2", description="this is g2", service_account_email="sa2@example.com", permissions={"read": ["all"], "read_public": True, "archive_url": False, "archive_sheet": False, "max_archive_lifespan_months": -1, "max_monthly_urls": -1, "max_monthly_mbs": -1, "priority": "low", "sheet_frequency": {"daily"}}),
models.Group(
id="group1",
description="this is g1",
service_account_email="sa1@example.com",
permissions={
"read": ["group1", "no-permissions"],
"read_public": True,
"archive_url": True,
"archive_sheet": True,
"max_archive_lifespan_months": 24,
"max_monthly_urls": 100,
"max_monthly_mbs": 1000,
"priority": "high",
},
),
models.Group(
id="group2",
description="this is g2",
service_account_email="sa2@example.com",
permissions={
"read": ["all"],
"read_public": True,
"archive_url": False,
"archive_sheet": False,
"max_archive_lifespan_months": -1,
"max_monthly_urls": -1,
"max_monthly_mbs": -1,
"priority": "low",
"sheet_frequency": {"daily"},
},
),
]
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=user_groups):
with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=user_groups,
):
yield user_state
def test_permissions(user_state_with_groups):
permissions = user_state_with_groups.permissions
assert permissions["all"].read == True
assert permissions["all"].read_public == True
assert permissions["all"].archive_url == True
assert permissions["all"].archive_sheet == True
assert permissions["all"].read is True
assert permissions["all"].read_public is True
assert permissions["all"].archive_url is True
assert permissions["all"].archive_sheet is True
assert permissions["all"].max_archive_lifespan_months == -1
assert permissions["all"].max_monthly_urls == -1
assert permissions["all"].max_monthly_mbs == -1
assert permissions["all"].priority == "high"
assert permissions["group1"].read == set(["group1", "no-permissions"])
assert permissions["group1"].read_public == True
assert permissions["group1"].archive_url == True
assert permissions["group1"].archive_sheet == True
assert permissions["group1"].read == {"group1", "no-permissions"}
assert permissions["group1"].read_public is True
assert permissions["group1"].archive_url is True
assert permissions["group1"].archive_sheet is True
assert permissions["group1"].max_archive_lifespan_months == 24
assert permissions["group1"].max_monthly_urls == 100
assert permissions["group1"].max_monthly_mbs == 1000
assert permissions["group1"].priority == "high"
assert permissions["group2"].read == set(["all"])
assert permissions["group2"].read_public == True
assert permissions["group2"].archive_url == False
assert permissions["group2"].archive_sheet == False
assert permissions["group2"].read == {"all"}
assert permissions["group2"].read_public is True
assert permissions["group2"].archive_url is False
assert permissions["group2"].archive_sheet is False
assert permissions["group2"].max_archive_lifespan_months == -1
assert permissions["group2"].max_monthly_urls == -1
assert permissions["group2"].max_monthly_mbs == -1
@@ -62,13 +97,19 @@ def test_permissions(user_state_with_groups):
def test_user_groups_names(user_state):
with patch('app.web.db.crud.get_user_group_names', return_value=["group1", "group2"]) as mock:
with patch(
"app.web.db.crud.get_user_group_names",
return_value=["group1", "group2"],
) as mock:
assert user_state.user_groups_names == ["group1", "group2", "default"]
mock.assert_called_once_with(None, "test@example.com")
def test_user_groups(user_state):
with patch('app.web.db.crud.get_user_groups_by_name', return_value=[MagicMock(), MagicMock()]) as mock:
with patch(
"app.web.db.crud.get_user_groups_by_name",
return_value=[MagicMock(), MagicMock()],
) as mock:
user_state._user_groups_names = ["group1", "group2"]
assert len(user_state.user_groups) == 2
mock.assert_called_once_with(None, ["group1", "group2"])
@@ -77,85 +118,166 @@ def test_user_groups(user_state):
def test_read():
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[models.Group(id="no-permissions", permissions={})],
) as mock:
assert not hasattr(us, "_read")
assert us.read == set()
assert us._read == set()
mock.assert_called_once()
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"read": ["group1", "no-permissions"]})]):
assert us.read == set(["group1", "no-permissions"])
with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(
id="group1", permissions={"read": ["group1", "no-permissions"]}
)
],
):
assert us.read == {"group1", "no-permissions"}
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"read": ["all"]})]):
assert us.read == True
with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[models.Group(id="group1", permissions={"read": ["all"]})],
):
assert us.read is True
def test_read_public():
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[models.Group(id="no-permissions", permissions={})],
) as mock:
assert not hasattr(us, "_read_public")
assert us.read_public == False
assert us._read_public == False
assert us.read_public is False
assert us._read_public is False
mock.assert_called_once()
# no new calls
assert us.read_public == False
assert us.read_public is False
mock.assert_called_once()
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"read_public": True})]):
assert us.read_public == True
with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(id="group1", permissions={"read_public": True})
],
):
assert us.read_public is True
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"read_public": False})]):
assert us.read_public == False
with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(id="group1", permissions={"read_public": False})
],
):
assert us.read_public is False
def test_archive_url():
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[models.Group(id="no-permissions", permissions={})],
) as mock:
assert not hasattr(us, "_archive_url")
assert us.archive_url == False
assert us._archive_url == False
assert us.archive_url is False
assert us._archive_url is False
mock.assert_called_once()
# no new calls
assert us.archive_url == False
assert us.archive_url is False
mock.assert_called_once()
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"archive_url": False})]):
assert us.archive_url == False
with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(id="group1", permissions={"archive_url": False})
],
):
assert us.archive_url is False
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"archive_url": True})]):
assert us.archive_url == True
with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(id="group1", permissions={"archive_url": True})
],
):
assert us.archive_url is True
def test_archive_sheet():
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[models.Group(id="no-permissions", permissions={})],
) as mock:
assert not hasattr(us, "_archive_sheet")
assert us.archive_sheet == False
assert us._archive_sheet == False
assert us.archive_sheet is False
assert us._archive_sheet is False
mock.assert_called_once()
# no new calls
assert us.archive_sheet == False
assert us.archive_sheet is False
mock.assert_called_once()
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"archive_sheet": False})]):
assert us.archive_sheet == False
with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(id="group1", permissions={"archive_sheet": False})
],
):
assert us.archive_sheet is False
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"archive_sheet": True})]):
assert us.archive_sheet == True
with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(id="group1", permissions={"archive_sheet": True})
],
):
assert us.archive_sheet is True
def test_sheet_frequency():
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[models.Group(id="no-permissions", permissions={})],
) as mock:
assert not hasattr(us, "_sheet_frequency")
assert us.sheet_frequency == set()
assert us._sheet_frequency == set()
@@ -165,18 +287,42 @@ def test_sheet_frequency():
mock.assert_called_once()
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"sheet_frequency": ["daily", "hourly"]})]):
with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(
id="group1",
permissions={"sheet_frequency": ["daily", "hourly"]},
)
],
):
assert us.sheet_frequency == {"daily", "hourly"}
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"sheet_frequency": []})]):
with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(id="group1", permissions={"sheet_frequency": []})
],
):
assert us.sheet_frequency == set()
def test_max_archive_lifespan_months():
us = fresh_user_state()
default = GroupPermissions.model_fields["max_archive_lifespan_months"].default
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
default = GroupPermissions.model_fields[
"max_archive_lifespan_months"
].default
with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[models.Group(id="no-permissions", permissions={})],
) as mock:
assert not hasattr(us, "_max_archive_lifespan_months")
assert us.max_archive_lifespan_months == default
assert us._max_archive_lifespan_months == default
@@ -186,18 +332,44 @@ def test_max_archive_lifespan_months():
mock.assert_called_once()
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_archive_lifespan_months": 24})]):
with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(
id="group1", permissions={"max_archive_lifespan_months": 24}
)
],
):
assert us.max_archive_lifespan_months == 24
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_archive_lifespan_months": 150}), models.Group(id="group2", permissions={"max_archive_lifespan_months": -1})]):
with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(
id="group1", permissions={"max_archive_lifespan_months": 150}
),
models.Group(
id="group2", permissions={"max_archive_lifespan_months": -1}
),
],
):
assert us.max_archive_lifespan_months == -1
def test_max_monthly_urls():
us = fresh_user_state()
default = GroupPermissions.model_fields["max_monthly_urls"].default
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[models.Group(id="no-permissions", permissions={})],
) as mock:
assert not hasattr(us, "_max_monthly_urls")
assert us.max_monthly_urls == default
assert us._max_monthly_urls == default
@@ -207,18 +379,38 @@ def test_max_monthly_urls():
mock.assert_called_once()
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_monthly_urls": 100})]):
with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(id="group1", permissions={"max_monthly_urls": 100})
],
):
assert us.max_monthly_urls == 100
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_monthly_urls": 150}), models.Group(id="group2", permissions={"max_monthly_urls": -1})]):
with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(id="group1", permissions={"max_monthly_urls": 150}),
models.Group(id="group2", permissions={"max_monthly_urls": -1}),
],
):
assert us.max_monthly_urls == -1
def test_max_monthly_mbs():
us = fresh_user_state()
default = GroupPermissions.model_fields["max_monthly_mbs"].default
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[models.Group(id="no-permissions", permissions={})],
) as mock:
assert not hasattr(us, "_max_monthly_mbs")
assert us.max_monthly_mbs == default
assert us._max_monthly_mbs == default
@@ -228,17 +420,37 @@ def test_max_monthly_mbs():
mock.assert_called_once()
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_monthly_mbs": 1000})]):
with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(id="group1", permissions={"max_monthly_mbs": 1000})
],
):
assert us.max_monthly_mbs == 1000
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_monthly_mbs": 1500}), models.Group(id="group2", permissions={"max_monthly_mbs": -1})]):
with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(id="group1", permissions={"max_monthly_mbs": 1500}),
models.Group(id="group2", permissions={"max_monthly_mbs": -1}),
],
):
assert us.max_monthly_mbs == -1
def test_priority(user_state):
default = GroupPermissions.model_fields["priority"].default
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[models.Group(id="no-permissions", permissions={})],
) as mock:
assert not hasattr(user_state, "_priority")
assert user_state.priority == default
assert user_state._priority == default
@@ -248,11 +460,26 @@ def test_priority(user_state):
mock.assert_called_once()
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"priority": "high"})]):
with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(id="group1", permissions={"priority": "high"})
],
):
assert us.priority == "high"
us = fresh_user_state()
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"priority": "low"}), models.Group(id="group2", permissions={"priority": "medium"})]):
with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(id="group1", permissions={"priority": "low"}),
models.Group(id="group2", permissions={"priority": "medium"}),
],
):
assert us.priority == "low"
@@ -262,21 +489,45 @@ def test_active():
(True, False, False, False, True),
(False, True, False, False, True),
(False, False, True, False, True),
(False, False, False, True, True)
(False, False, False, True, True),
]:
us = fresh_user_state()
with patch.object(UserState, 'read', new_callable=PropertyMock, return_value=read), \
patch.object(UserState, 'read_public', new_callable=PropertyMock, return_value=read_public), \
patch.object(UserState, 'archive_url', new_callable=PropertyMock, return_value=archive_url), \
patch.object(UserState, 'archive_sheet', new_callable=PropertyMock, return_value=archive_sheet):
with (
patch.object(
UserState, "read", new_callable=PropertyMock, return_value=read
),
patch.object(
UserState,
"read_public",
new_callable=PropertyMock,
return_value=read_public,
),
patch.object(
UserState,
"archive_url",
new_callable=PropertyMock,
return_value=archive_url,
),
patch.object(
UserState,
"archive_sheet",
new_callable=PropertyMock,
return_value=archive_sheet,
),
):
assert us.active == is_active
def test_in_group(user_state):
with patch.object(UserState, 'user_groups_names', new_callable=PropertyMock, return_value=["group1", "group2"]):
assert user_state.in_group("group1") == True
assert user_state.in_group("group2") == True
assert user_state.in_group("group3") == False
with patch.object(
UserState,
"user_groups_names",
new_callable=PropertyMock,
return_value=["group1", "group2"],
):
assert user_state.in_group("group1") is True
assert user_state.in_group("group2") is True
assert user_state.in_group("group3") is False
def test_usage(db_session):
@@ -294,10 +545,34 @@ def test_usage(db_session):
]
megabytes = int(sum(bytes) / 1024 / 1024)
with patch.object(db_session, 'query', side_effect=[
MagicMock(filter=MagicMock(return_value=MagicMock(group_by=MagicMock(return_value=MagicMock(all=MagicMock(return_value=user_sheets)))))),
MagicMock(filter=MagicMock(return_value=MagicMock(group_by=MagicMock(return_value=MagicMock(all=MagicMock(return_value=urls_by_group))))))
]):
with patch.object(
db_session,
"query",
side_effect=[
MagicMock(
filter=MagicMock(
return_value=MagicMock(
group_by=MagicMock(
return_value=MagicMock(
all=MagicMock(return_value=user_sheets)
)
)
)
)
),
MagicMock(
filter=MagicMock(
return_value=MagicMock(
group_by=MagicMock(
return_value=MagicMock(
all=MagicMock(return_value=urls_by_group)
)
)
)
)
),
],
):
usage_response = user_state.usage()
assert usage_response.monthly_urls == 155
@@ -305,11 +580,15 @@ def test_usage(db_session):
assert usage_response.total_sheets == 115
assert usage_response.groups["group1"].monthly_urls == 50
assert usage_response.groups["group1"].monthly_mbs == int(bytes[0] / 1024 / 1024)
assert usage_response.groups["group1"].monthly_mbs == int(
bytes[0] / 1024 / 1024
)
assert usage_response.groups["group1"].total_sheets == 5
assert usage_response.groups["group2"].monthly_urls == 100
assert usage_response.groups["group2"].monthly_mbs == int(bytes[1] / 1024 / 1024)
assert usage_response.groups["group2"].monthly_mbs == int(
bytes[1] / 1024 / 1024
)
assert usage_response.groups["group2"].total_sheets == 10
assert usage_response.groups["group3"].monthly_urls == 0
@@ -317,7 +596,9 @@ def test_usage(db_session):
assert usage_response.groups["group3"].total_sheets == 100
assert usage_response.groups["group4"].monthly_urls == 5
assert usage_response.groups["group4"].monthly_mbs == int(bytes[2] / 1024 / 1024)
assert usage_response.groups["group4"].monthly_mbs == int(
bytes[2] / 1024 / 1024
)
assert usage_response.groups["group4"].total_sheets == 0
@@ -333,8 +614,23 @@ def test_has_quota_monthly_sheets(db_session):
]
for permissions, count, expected in test_cases:
with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions):
with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(count=MagicMock(return_value=count))))):
with patch.object(
UserState,
"permissions",
new_callable=PropertyMock,
return_value=permissions,
):
with patch.object(
us.db,
"query",
return_value=MagicMock(
filter=MagicMock(
return_value=MagicMock(
count=MagicMock(return_value=count)
)
)
),
):
assert us.has_quota_monthly_sheets("group1") == expected
@@ -349,8 +645,23 @@ def test_has_quota_max_monthly_urls(db_session):
]
for permissions, count, expected in test_cases:
with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions):
with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(count=MagicMock(return_value=count))))):
with patch.object(
UserState,
"permissions",
new_callable=PropertyMock,
return_value=permissions,
):
with patch.object(
us.db,
"query",
return_value=MagicMock(
filter=MagicMock(
return_value=MagicMock(
count=MagicMock(return_value=count)
)
)
),
):
assert us.has_quota_max_monthly_urls("group1") == expected
test_cases = [
(-1, 1000, True),
@@ -360,8 +671,23 @@ def test_has_quota_max_monthly_urls(db_session):
]
for max_urls, count, expected in test_cases:
with patch.object(UserState, 'max_monthly_urls', new_callable=PropertyMock, return_value=max_urls):
with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(count=MagicMock(return_value=count))))):
with patch.object(
UserState,
"max_monthly_urls",
new_callable=PropertyMock,
return_value=max_urls,
):
with patch.object(
us.db,
"query",
return_value=MagicMock(
filter=MagicMock(
return_value=MagicMock(
count=MagicMock(return_value=count)
)
)
),
):
assert us.has_quota_max_monthly_urls("") == expected
@@ -376,8 +702,29 @@ def test_has_quota_max_monthly_mbs(db_session):
]
for permissions, mbs, expected in test_cases:
with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions):
with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(with_entities=MagicMock(return_value=MagicMock(scalar=MagicMock(return_value=mbs * 1024 * 1024))))))):
with patch.object(
UserState,
"permissions",
new_callable=PropertyMock,
return_value=permissions,
):
with patch.object(
us.db,
"query",
return_value=MagicMock(
filter=MagicMock(
return_value=MagicMock(
with_entities=MagicMock(
return_value=MagicMock(
scalar=MagicMock(
return_value=mbs * 1024 * 1024
)
)
)
)
)
),
):
assert us.has_quota_max_monthly_mbs("group1") == expected
test_cases = [
@@ -388,8 +735,29 @@ def test_has_quota_max_monthly_mbs(db_session):
]
for max_mbs, mbs, expected in test_cases:
with patch.object(UserState, 'max_monthly_mbs', new_callable=PropertyMock, return_value=max_mbs):
with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(with_entities=MagicMock(return_value=MagicMock(scalar=MagicMock(return_value=mbs * 1024 * 1024))))))):
with patch.object(
UserState,
"max_monthly_mbs",
new_callable=PropertyMock,
return_value=max_mbs,
):
with patch.object(
us.db,
"query",
return_value=MagicMock(
filter=MagicMock(
return_value=MagicMock(
with_entities=MagicMock(
return_value=MagicMock(
scalar=MagicMock(
return_value=mbs * 1024 * 1024
)
)
)
)
)
),
):
assert us.has_quota_max_monthly_mbs("") == expected
@@ -399,10 +767,15 @@ def test_can_manually_trigger(user_state):
"group2": GroupInfo(manually_trigger_sheet=False),
}
with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions):
assert user_state.can_manually_trigger("group1") == True
assert user_state.can_manually_trigger("group2") == False
assert user_state.can_manually_trigger("group3") == False
with patch.object(
UserState,
"permissions",
new_callable=PropertyMock,
return_value=permissions,
):
assert user_state.can_manually_trigger("group1") is True
assert user_state.can_manually_trigger("group2") is False
assert user_state.can_manually_trigger("group3") is False
def test_is_sheet_frequency_allowed(user_state):
@@ -411,23 +784,44 @@ def test_is_sheet_frequency_allowed(user_state):
"group2": GroupInfo(sheet_frequency={"daily"}),
}
with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions):
assert user_state.is_sheet_frequency_allowed("group1", "daily") == True
assert user_state.is_sheet_frequency_allowed("group1", "hourly") == True
assert user_state.is_sheet_frequency_allowed("group1", "weekly") == False
assert user_state.is_sheet_frequency_allowed("group2", "hourly") == False
assert user_state.is_sheet_frequency_allowed("group2", "daily") == True
assert user_state.is_sheet_frequency_allowed("group3", "daily") == False
with patch.object(
UserState,
"permissions",
new_callable=PropertyMock,
return_value=permissions,
):
assert user_state.is_sheet_frequency_allowed("group1", "daily") is True
assert user_state.is_sheet_frequency_allowed("group1", "hourly") is True
assert (
user_state.is_sheet_frequency_allowed("group1", "weekly") is False
)
assert (
user_state.is_sheet_frequency_allowed("group2", "hourly") is False
)
assert user_state.is_sheet_frequency_allowed("group2", "daily") is True
assert user_state.is_sheet_frequency_allowed("group3", "daily") is False
def test_priority_group(user_state):
from app.web.utils.misc import convert_priority_to_queue_dict
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[
models.Group(id="group1", permissions={"priority": "high"}),
models.Group(id="group2", permissions={"priority": "medium"}),
models.Group(id="group3", permissions={"priority": "low"}),
]):
assert user_state.priority_group("group1") == convert_priority_to_queue_dict("high")
assert user_state.priority_group("group2") == convert_priority_to_queue_dict("medium")
assert user_state.priority_group("group3") == convert_priority_to_queue_dict("low")
assert user_state.priority_group("group4") == convert_priority_to_queue_dict("low")
with patch.object(
UserState,
"user_groups",
new_callable=PropertyMock,
return_value=[
models.Group(id="group1", permissions={"priority": "high"}),
models.Group(id="group2", permissions={"priority": "medium"}),
models.Group(id="group3", permissions={"priority": "low"}),
],
):
assert user_state.priority_group(
"group1"
) == convert_priority_to_queue_dict("high")
assert user_state.priority_group(
"group2"
) == convert_priority_to_queue_dict("medium")
assert user_state.priority_group(
"group3"
) == convert_priority_to_queue_dict("low")
assert user_state.priority_group(
"group4"
) == convert_priority_to_queue_dict("low")

View File

@@ -1,60 +0,0 @@
from datetime import datetime
import json
from unittest.mock import MagicMock, patch
from app.shared.db import models
from app.web.config import ALLOW_ANY_EMAIL
from app.web.db import crud
def test_submit_manual_archive_unauthenticated(client, test_no_auth):
test_no_auth(client.post, "/interop/submit-archive")
def test_submit_manual_archive_not_user_auth(client_with_auth, test_no_auth):
test_no_auth(client_with_auth.post, "/interop/submit-archive")
@patch("app.web.endpoints.interoperability.business_logic", return_value=MagicMock(get_store_archive_until=MagicMock(return_value=datetime)))
def test_submit_manual_archive(m1, client_with_token, db_session):
# normal workflow
aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.s3.com"]}]})
r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": True, "author_id": "jerry@gmail.com", "group_id": "spaceship", "tags": ["test"], "url": "http://example.com"})
assert r.status_code == 201
assert "id" in r.json()
inserted = db_session.query(models.Archive).filter(models.Archive.id == r.json()["id"]).first()
assert inserted.url == "http://example.com"
assert inserted.group_id == "spaceship"
assert inserted.author_id == "jerry@gmail.com"
assert sorted([t.id for t in inserted.tags]) == sorted(["test", "manual"])
assert inserted.public
assert type(inserted.result) == dict
assert [u.url for u in inserted.urls] == ["http://example.s3.com"]
assert type(inserted.store_until) == datetime
# cannot have the same URL twice
aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.com", "http://example.com"]}]})
r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": False, "author_id": "jerry@gmail.com", "tags": ["test"], "url": "http://example.com"})
assert r.status_code == 422
assert r.json() == {"detail": "Cannot insert into DB due to integrity error, likely duplicate urls."}
# test with invalid JSON
def test_submit_manual_archive_invalid_json(client_with_token):
r = client_with_token.post("/interop/submit-archive", json={"result": "invalid json", "public": False, "author_id": "jer", "tags": ["test"], "url": "http://example.com"})
assert r.status_code == 422
assert r.json() == {"detail": "Invalid JSON in result field."}
@patch("app.web.endpoints.interoperability.business_logic.get_store_archive_until", side_effect=AssertionError("AssertionError"))
def test_submit_manual_archive_no_store_until(m_sau, client_with_token, db_session):
aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.s3.com"]}]})
r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": True, "author_id": "jerry@gmail.com", "group_id": "spaceship", "tags": ["test"], "url": "http://example.com"})
assert r.status_code == 201
assert len(r.json()["id"]) == 36
res = db_session.query(models.Archive).filter(models.Archive.id == r.json()["id"]).first()
assert res.store_until is None
# testing that store_until = None is not comparable with datetime, and will always return False
res = db_session.query(models.Archive).filter(models.Archive.id == r.json()["id"], models.Archive.store_until < datetime.now()).first()
assert res is None

View File

@@ -1,193 +0,0 @@
from datetime import datetime
import json
from unittest.mock import MagicMock, patch
from fastapi.testclient import TestClient
from app.shared.schemas import TaskResult
def test_endpoints_no_auth(client, test_no_auth):
test_no_auth(client.post, "/sheet/create")
test_no_auth(client.get, "/sheet/mine")
test_no_auth(client.delete, "/sheet/123-sheet-id")
test_no_auth(client.post, "/sheet/123-sheet-id/archive")
def test_create_sheet_endpoint(app_with_auth, db_session):
client_with_auth = TestClient(app_with_auth)
good_data = {
"id": "123-sheet-id",
"name": "Test Sheet",
"group_id": "spaceship",
"frequency": "daily"
}
# with good data
response = client_with_auth.post("/sheet/create", json=good_data)
assert response.status_code == 201
j = response.json()
assert datetime.fromisoformat(j.pop("created_at"))
assert datetime.fromisoformat(j.pop("last_url_archived_at"))
assert j.pop("author_id") == 'morty@example.com'
assert j == good_data
# already exists
response = client_with_auth.post("/sheet/create", json=good_data)
assert response.status_code == 400
assert response.json() == {"detail": "Sheet with this ID is already being archived."}
# bad group
bad_data = good_data.copy()
bad_data["group_id"] = "not a group"
response = client_with_auth.post("/sheet/create", json=bad_data)
assert response.status_code == 403
assert response.json() == {"detail": "User does not have access to this group."}
# switch to jerry who's got less quota/permissions
from app.web.security import get_user_state
from app.web.db.user_state import UserState
app_with_auth.dependency_overrides[get_user_state] = lambda: UserState(db_session, "jerry@example.com")
client_jerry = TestClient(app_with_auth)
# frequency not allowed
jerry_data = good_data.copy()
jerry_data["group_id"] = "animated-characters"
jerry_data["frequency"] = "hourly"
jerry_data["id"] = "jerry-sheet-id"
response = client_jerry.post("/sheet/create", json=jerry_data)
assert response.status_code == 422
assert response.json() == {"detail": "Invalid frequency selected for this group."}
jerry_data["frequency"] = "daily"
# success for the first sheet, bad quota on second
response = client_jerry.post("/sheet/create", json=jerry_data)
assert response.status_code == 201
response = client_jerry.post("/sheet/create", json=jerry_data)
assert response.status_code == 429
assert response.json() == {"detail": "User has reached their sheet quota for this group."}
def test_get_user_sheets_endpoint(client_with_auth, db_session):
# no data
response = client_with_auth.get("/sheet/mine")
assert response.status_code == 200
assert response.json() == []
# with data
from app.shared.db import models
db_session.add(
models.Sheet(id="123", name="Test Sheet 1", author_id="morty@example.com", group_id="spaceship", frequency="hourly")
)
db_session.commit()
db_session.add_all([
models.Sheet(id="456", name="Test Sheet 2", author_id="morty@example.com", group_id="interdimensional", frequency="daily"),
models.Sheet(id="789", name="Test Sheet 3", author_id="rick@example.com", group_id="interdimensional", frequency="hourly"),
])
db_session.commit()
response = client_with_auth.get("/sheet/mine")
assert response.status_code == 200
r = response.json()
assert isinstance(r, list)
assert len(r) == 2
assert datetime.fromisoformat(r[0].pop("created_at"))
assert datetime.fromisoformat(r[0].pop("last_url_archived_at"))
assert datetime.fromisoformat(r[1].pop("created_at"))
assert datetime.fromisoformat(r[1].pop("last_url_archived_at"))
assert r[0] == {
'id': '123',
'author_id': 'morty@example.com',
'frequency': 'hourly',
'group_id': 'spaceship',
'name': 'Test Sheet 1',
}
assert r[1] == {
'id': '456',
'author_id': 'morty@example.com',
'frequency': 'daily',
'group_id': 'interdimensional',
'name': 'Test Sheet 2',
}
def test_delete_sheet_endpoint(client_with_auth, db_session):
# missing sheet
response = client_with_auth.delete("/sheet/123-sheet-id")
assert response.status_code == 200
assert response.json() == {
"id": "123-sheet-id",
"deleted": False
}
# add sheets for deletion
from app.shared.db import models
db_session.add_all([
models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="interdimensional", frequency="daily"),
models.Sheet(id="456-sheet-id", name="Test Sheet 2", author_id="rick@example.com", group_id="spaceship", frequency="hourly"),
])
db_session.commit()
# morty can delete his
response = client_with_auth.delete("/sheet/123-sheet-id")
assert response.status_code == 200
assert response.json() == {"id": "123-sheet-id", "deleted": True}
# but only once
response = client_with_auth.delete("/sheet/123-sheet-id")
assert response.status_code == 200
assert response.json() == {"id": "123-sheet-id", "deleted": False}
# and not rick's
response = client_with_auth.delete("/sheet/456-sheet-id")
assert response.status_code == 200
assert response.json() == {"id": "456-sheet-id", "deleted": False}
class TestArchiveUserSheetEndpoint:
@patch("app.web.endpoints.sheet.celery", return_value=MagicMock())
def test_normal_flow(self, m_celery, client_with_auth, db_session):
from app.shared.db import models
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="spaceship", frequency="hourly"))
db_session.commit()
m_signature = MagicMock()
m_signature.apply_async.return_value = TaskResult(id="123-taskid", status="PENDING", result="")
m_celery.signature.return_value = m_signature
r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == 201
assert r.json() == {"id": "123-taskid"}
m_celery.signature.assert_called_once()
m_signature.apply_async.assert_called_once()
def test_token_auth(self, client_with_token, test_no_auth):
test_no_auth(client_with_token.post, "/sheet/123-sheet-id/archive")
def test_missing_data(self, client_with_auth):
r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == 403
assert r.json() == {"detail": "No access to this sheet."}
def test_no_access(self, client_with_auth, db_session):
from app.shared.db import models
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="rick@example.com", group_id="spaceship", frequency="hourly"))
db_session.commit()
r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == 403
assert r.json() == {"detail": "No access to this sheet."}
def test_user_not_in_group(self, client_with_auth, db_session):
from app.shared.db import models
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="interdimensional", frequency="hourly"))
db_session.commit()
r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == 403
assert r.json() == {"detail": "User does not have access to this group."}
def test_user_cannot_manually_trigger(self, client_with_auth, db_session):
from app.shared.db import models
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="default", frequency="hourly"))
db_session.commit()
r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == 429
assert r.json() == {"detail": "User cannot manually trigger sheet archiving in this group."}

View File

@@ -1,202 +0,0 @@
import json
from unittest.mock import MagicMock, patch
from app.shared.schemas import ArchiveCreate, TaskResult
from app.web.config import ALLOW_ANY_EMAIL
def test_archive_url_unauthenticated(client, test_no_auth):
test_no_auth(client.post, "/url/archive")
@patch("app.web.endpoints.url.UserState")
@patch("app.web.endpoints.url.celery", return_value=MagicMock())
def test_archive_url(m_celery, m2, client_with_auth):
m_signature = MagicMock()
m_signature.apply_async.return_value = TaskResult(id="123-456-789", status="PENDING", result="")
m_celery.signature.return_value = m_signature
m_user_state = MagicMock()
m2.return_value = m_user_state
# url is too short
response = client_with_auth.post("/url/archive", json={"url": "bad"})
assert response.status_code == 422
assert response.json()["detail"][0]["msg"] == 'String should have at least 5 characters'
m_celery.signature.assert_not_called()
# url is invalid
response = client_with_auth.post("/url/archive", json={"url": "example.com"})
assert response.status_code == 400
assert response.json()["detail"] == "Invalid URL received."
# valid request
m_user_state.has_quota_max_monthly_urls.return_value = True
m_user_state.has_quota_max_monthly_mbs.return_value = True
response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
assert response.status_code == 201
assert response.json() == {'id': '123-456-789'}
m_celery.signature.assert_called_once()
m_signature.apply_async.assert_called_once()
called_val = m_celery.signature.call_args
assert called_val[0][0] == "create_archive_task"
assert json.loads(called_val[1]['args'][0]) == {"id": None, "url": "https://example.com", "result": None, "public": False, "author_id": "rick@example.com", "group_id": "default", "tags": None, "sheet_id": None, "store_until": None, "urls": None}
m_user_state.has_quota_max_monthly_urls.assert_called_once()
m_user_state.has_quota_max_monthly_mbs.assert_called_once()
m_user_state.in_group.assert_called_once_with("default")
# user is not in group
m_user_state.in_group.return_value = False
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "new-group"})
assert response.status_code == 403
assert response.json()["detail"] == "User does not have access to this group."
m_user_state.in_group.assert_called_with("new-group")
# user is in group
m_user_state.in_group.return_value = True
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spaceship"})
assert response.status_code == 201
assert response.json() == {'id': '123-456-789'}
assert m_celery.signature.call_count == 2
assert m_signature.apply_async.call_count == 2
called_val = m_celery.signature.call_args
assert json.loads(called_val[1]['args'][0])["group_id"] == "spaceship"
m_user_state.in_group.assert_called_with("spaceship")
# user is over monthly URL quota
m_user_state.has_quota_max_monthly_urls.return_value = False
m_user_state.has_quota_max_monthly_mbs.return_value = True
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spaceship"})
assert response.status_code == 429
assert response.json()["detail"] == "User has reached their monthly URL quota."
m_user_state.has_quota_max_monthly_urls.assert_called_with("spaceship")
# user is over monthly MB quota
m_user_state.has_quota_max_monthly_urls.return_value = True
m_user_state.has_quota_max_monthly_mbs.return_value = False
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spacesuit"})
assert response.status_code == 429
assert response.json()["detail"] == "User has reached their monthly MB quota."
m_user_state.has_quota_max_monthly_mbs.assert_called_with("spacesuit")
assert m_celery.signature.call_count == 2
assert m_signature.apply_async.call_count == 2
@patch("app.web.endpoints.url.UserState")
def test_archive_url_quotas(m1, client_with_auth):
m_user_state = MagicMock()
m1.return_value = m_user_state
# misses on monthly URLs quota
m_user_state.has_quota_max_monthly_urls.return_value = False
response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
assert response.status_code == 429
assert response.json()["detail"] == "User has reached their monthly URL quota."
m_user_state.has_quota_max_monthly_urls.assert_called_once()
# misses on monthly MBs quota
m_user_state.has_quota_max_monthly_urls.return_value = True
m_user_state.has_quota_max_monthly_mbs.return_value = False
response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
assert response.status_code == 429
assert response.json()["detail"] == "User has reached their monthly MB quota."
m_user_state.has_quota_max_monthly_mbs.assert_called_once()
@patch("app.web.endpoints.url.celery", return_value=MagicMock())
def test_archive_url_with_api_token(m_celery, client_with_token):
m_signature = MagicMock()
m_signature.apply_async.return_value = TaskResult(id="123-456-789", status="PENDING", result="")
m_celery.signature.return_value = m_signature
response = client_with_token.post("/url/archive", json={"url": "https://example.com", "author_id": "someone@example.com"})
assert response.status_code == 201
assert response.json() == {'id': '123-456-789'}
m_celery.signature.assert_called_once()
m_signature.apply_async.assert_called_once()
called_val = m_celery.signature.call_args
assert called_val[0][0] == "create_archive_task"
assert json.loads(called_val[1]['args'][0]) == {"id": None, "url": "https://example.com", "result": None, "public": False, "author_id": "someone@example.com", "group_id": "default", "tags": None, "sheet_id": None, "store_until": None, "urls": None}
# missing id should use ALLOW_ANY_EMAIL
response = client_with_token.post("/url/archive", json={"url": "https://example.com", "author_id": None})
assert response.status_code == 201
called_val = m_celery.signature.call_args
assert called_val[0][0] == "create_archive_task"
assert json.loads(called_val[1]['args'][0]) == {"id": None, "url": "https://example.com", "result": None, "public": False, "author_id": ALLOW_ANY_EMAIL, "group_id": "default", "tags": None, "sheet_id": None, "store_until": None, "urls": None}
def test_search_by_url_unauthenticated(client, test_no_auth):
test_no_auth(client.get, "/url/search")
def test_search_by_url(client_with_auth, client_with_token, db_session):
# tests the search endpoint, including through some db data for the endpoint params
response = client_with_auth.get("/url/search")
assert response.status_code == 422
assert response.json()["detail"][0]["msg"] == "Field required"
response = client_with_auth.get("/url/search?url=https://example.com")
assert response.status_code == 200
assert response.json() == []
from app.shared import schemas
from app.shared.db import worker_crud
for i in range(11):
worker_crud.create_archive(db_session, ArchiveCreate(id=f"url-456-{i}", url="https://example.com" if i < 10 else "https://something-else.com", result={}, public=True, author_id="rick@example.com"), [], [])
# NB: this insertion is too fast for the ordering to be correct as they are within the same second
response = client_with_auth.get("/url/search?url=https://example.com")
assert response.status_code == 200
assert len(j := response.json()) == 10
assert "url-456-0" in [i["id"] for i in j]
assert "url-456-9" in [i["id"] for i in j]
assert "url-456-10" not in [i["id"] for i in j]
assert j[0].keys() == schemas.ArchiveResult.model_fields.keys()
response = client_with_auth.get("/url/search?url=https://example.com&limit=5")
assert response.status_code == 200
assert len(response.json()) == 5
response = client_with_auth.get("/url/search?url=https://example.com&skip=5&limit=2")
assert response.status_code == 200
assert len(response.json()) == 2
response = client_with_auth.get("/url/search?url=https://example.com&archived_before=2010-01-01")
assert response.status_code == 200
assert len(response.json()) == 0
response = client_with_auth.get("/url/search?url=https://example.com&archived_after=2010-01-01")
assert response.status_code == 200
assert len(response.json()) == 10
# API token will also work
response = client_with_token.get("/url/search?url=https://example.com&archived_after=2010-01-01")
assert response.status_code == 200
assert len(response.json()) == 10
@patch("app.web.endpoints.url.UserState")
def test_search_no_read_access(mock_user_state, client_with_auth):
mock_user_state.return_value.read = False
mock_user_state.return_value.read_public = False
response = client_with_auth.get("/url/search?url=https://example.com")
assert response.status_code == 403
assert response.json() == {"detail": "User does not have read access."}
def test_delete_task_unauthenticated(client, test_no_auth):
test_no_auth(client.delete, "/url/123-456-789")
def test_delete_task(client_with_auth, db_session):
response = client_with_auth.delete("/url/delete-123-456-789")
assert response.status_code == 200
assert response.json() == {"id": "delete-123-456-789", "deleted": False}
from app.shared.db import worker_crud
worker_crud.create_archive(db_session, ArchiveCreate(id="delete-123-456-789", url="https://example.com", result={}, public=True, author_id="morty@example.com"), [], [])
response = client_with_auth.delete("/url/delete-123-456-789")
assert response.status_code == 200
assert response.json() == {"id": "delete-123-456-789", "deleted": True}

View File

@@ -1,15 +1,20 @@
from http import HTTPStatus
from unittest.mock import MagicMock
from fastapi.testclient import TestClient
import pytest
from fastapi.testclient import TestClient
from loguru import logger
from app.shared.schemas import Usage, UsageResponse
from app.shared.user_groups import GroupInfo
from app.web.config import VERSION
from app.tests.web.db.test_crud import test_data
from app.web.security import get_user_state
from app.web.utils.metrics import measure_regular_metrics
def test_endpoint_home(client_with_auth):
r = client_with_auth.get("/")
assert r.status_code == 200
assert r.status_code == HTTPStatus.OK
j = r.json()
assert "version" in j and j["version"] == VERSION
assert "breakingChanges" in j
@@ -18,7 +23,7 @@ def test_endpoint_home(client_with_auth):
def test_endpoint_health(client_with_auth):
r = client_with_auth.get("/health")
assert r.status_code == 200
assert r.status_code == HTTPStatus.OK
assert r.json() == {"status": "ok"}
@@ -29,32 +34,31 @@ def test_endpoint_active_no_auth(client, test_no_auth):
def test_endpoint_active(app):
m_user_state = MagicMock()
from app.web.security import get_user_state
app.dependency_overrides[get_user_state] = lambda: m_user_state
# inactive user
m_user_state.active = False
client = TestClient(app)
r = client.get("/user/active")
assert r.status_code == 200
assert r.status_code == HTTPStatus.OK
assert r.json() == {"active": False}
# active user
m_user_state.active = True
client = TestClient(app)
r = client.get("/user/active")
assert r.status_code == 200
assert r.status_code == HTTPStatus.OK
assert r.json() == {"active": True}
def test_no_serve_local_archive_by_default(client_with_auth):
r = client_with_auth.get("/app/local_archive_test/temp.txt")
assert r.status_code == 404
assert r.status_code == HTTPStatus.NOT_FOUND
def test_favicon(client_with_auth):
r = client_with_auth.get("/favicon.ico")
assert r.status_code == 200
assert r.status_code == HTTPStatus.OK
assert r.headers["content-type"] == "image/vnd.microsoft.icon"
@@ -70,8 +74,10 @@ def test_endpoint_test_prometheus_no_user_auth(client_with_auth, test_no_auth):
async def test_prometheus_metrics(test_data, client_with_token, get_settings):
# before metrics calculation
r = client_with_token.get("/metrics")
assert r.status_code == 200
assert r.headers["content-type"] == "text/plain; version=0.0.4; charset=utf-8"
assert r.status_code == HTTPStatus.OK
assert (
r.headers["content-type"] == "text/plain; version=0.0.4; charset=utf-8"
)
assert "disk_utilization" in r.text
assert "database_metrics" in r.text
assert "exceptions" in r.text
@@ -79,8 +85,9 @@ async def test_prometheus_metrics(test_data, client_with_token, get_settings):
assert 'disk_utilization{type="used"}' not in r.text
# after metrics calculation
from app.web.utils.metrics import measure_regular_metrics
await measure_regular_metrics(get_settings.DATABASE_PATH, 60 * 60 * 24 * 31 * 12 * 100)
await measure_regular_metrics(
get_settings.DATABASE_PATH, 60 * 60 * 24 * 31 * 12 * 100
)
r2 = client_with_token.get("/metrics")
assert 'disk_utilization{type="used"}' in r2.text
assert 'disk_utilization{type="free"}' in r2.text
@@ -88,20 +95,37 @@ async def test_prometheus_metrics(test_data, client_with_token, get_settings):
assert 'database_metrics{query="count_archives"} 100.0' in r2.text
assert 'database_metrics{query="count_archive_urls"} 1000.0' in r2.text
assert 'database_metrics{query="count_users"} 3.0' in r2.text
assert 'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0' in r2.text
assert 'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0' in r2.text
assert 'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0' in r2.text
assert (
'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0'
in r2.text
)
assert (
'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0'
in r2.text
)
assert (
'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0'
in r2.text
)
# 30s window, should not change the gauges nor the total in the counters
from app.web.utils.metrics import measure_regular_metrics
await measure_regular_metrics(get_settings.DATABASE_PATH, 30)
r3 = client_with_token.get("/metrics")
assert 'database_metrics{query="count_archives"} 100.0' in r3.text
assert 'database_metrics{query="count_archive_urls"} 1000.0' in r3.text
assert 'database_metrics{query="count_users"} 3.0' in r3.text
assert 'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0' in r3.text
assert 'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0' in r3.text
assert 'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0' in r3.text
assert (
'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0'
in r3.text
)
assert (
'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0'
in r3.text
)
assert (
'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0'
in r3.text
)
def test_endpoint_get_user_permissions_no_user_auth(client, test_no_auth):
@@ -109,14 +133,12 @@ def test_endpoint_get_user_permissions_no_user_auth(client, test_no_auth):
def test_endpoint_get_user_permissions(app):
from app.web.security import get_user_state
m_user_state = MagicMock()
rv = {
"all": GroupInfo(read=True),
"group1": GroupInfo(archive_url=True),
}
from loguru import logger
logger.info(rv)
m_user_state.permissions = rv
@@ -124,13 +146,13 @@ def test_endpoint_get_user_permissions(app):
client = TestClient(app)
r = client.get("/user/permissions")
assert r.status_code == 200
assert r.status_code == HTTPStatus.OK
response = r.json()
assert response.keys() == {"all", "group1"}
assert response["all"]["read"]
assert response["group1"]["read"] == []
assert response["group1"]["archive_url"]
assert response["all"]["archive_url"] == False
assert response["all"]["archive_url"] is False
def test_endpoint_get_user_usage_no_user_auth(client, test_no_auth):
@@ -138,8 +160,6 @@ def test_endpoint_get_user_usage_no_user_auth(client, test_no_auth):
def test_endpoint_get_user_usage_inactive(app):
from app.web.security import get_user_state
m_user_state = MagicMock()
m_user_state.active = False
@@ -147,13 +167,11 @@ def test_endpoint_get_user_usage_inactive(app):
client = TestClient(app)
r = client.get("/user/usage")
assert r.status_code == 403
assert r.status_code == HTTPStatus.FORBIDDEN
assert r.json() == {"detail": "User is not active."}
def test_endpoint_get_user_usage_active(app):
from app.web.security import get_user_state
m_user_state = MagicMock()
m_user_state.active = True
mock_usage = UsageResponse(
@@ -162,8 +180,8 @@ def test_endpoint_get_user_usage_active(app):
total_sheets=3,
groups={
"group1": Usage(monthly_urls=4, monthly_mbs=5, total_sheets=6),
"group2": Usage(monthly_urls=7, monthly_mbs=8, total_sheets=9)
}
"group2": Usage(monthly_urls=7, monthly_mbs=8, total_sheets=9),
},
)
m_user_state.usage.return_value = mock_usage
@@ -171,5 +189,5 @@ def test_endpoint_get_user_usage_active(app):
client = TestClient(app)
r = client.get("/user/usage")
assert r.status_code == 200
assert r.status_code == HTTPStatus.OK
assert UsageResponse(**r.json()) == mock_usage

View File

@@ -0,0 +1,147 @@
import json
from datetime import datetime
from http import HTTPStatus
from unittest.mock import MagicMock, patch
from app.shared.db import models
def test_submit_manual_archive_unauthenticated(client, test_no_auth):
test_no_auth(client.post, "/interop/submit-archive")
def test_submit_manual_archive_not_user_auth(client_with_auth, test_no_auth):
test_no_auth(client_with_auth.post, "/interop/submit-archive")
@patch(
"app.web.routers.interoperability.business_logic",
return_value=MagicMock(
get_store_archive_until=MagicMock(return_value=datetime)
),
)
def test_submit_manual_archive(m1, client_with_token, db_session):
# normal workflow
aa_metadata = json.dumps(
{
"status": "test: success",
"metadata": {"url": "http://example.com"},
"media": [{"filename": "fn1", "urls": ["http://example.s3.com"]}],
}
)
r = client_with_token.post(
"/interop/submit-archive",
json={
"result": aa_metadata,
"public": True,
"author_id": "jerry@gmail.com",
"group_id": "spaceship",
"tags": ["test"],
"url": "http://example.com",
},
)
assert r.status_code == HTTPStatus.CREATED
assert "id" in r.json()
inserted = (
db_session.query(models.Archive)
.filter(models.Archive.id == r.json()["id"])
.first()
)
assert inserted.url == "http://example.com"
assert inserted.group_id == "spaceship"
assert inserted.author_id == "jerry@gmail.com"
assert sorted([t.id for t in inserted.tags]) == sorted(["test", "manual"])
assert inserted.public
assert isinstance(inserted.result, dict)
assert [u.url for u in inserted.urls] == ["http://example.s3.com"]
assert isinstance(inserted.store_until, datetime)
# cannot have the same URL twice
aa_metadata = json.dumps(
{
"status": "test: success",
"metadata": {"url": "http://example.com"},
"media": [
{
"filename": "fn1",
"urls": ["http://example.com", "http://example.com"],
}
],
}
)
r = client_with_token.post(
"/interop/submit-archive",
json={
"result": aa_metadata,
"public": False,
"author_id": "jerry@gmail.com",
"tags": ["test"],
"url": "http://example.com",
},
)
assert r.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
assert r.json() == {
"detail": "Cannot insert into DB due to integrity error, likely duplicate urls."
}
# test with invalid JSON
def test_submit_manual_archive_invalid_json(client_with_token):
r = client_with_token.post(
"/interop/submit-archive",
json={
"result": "invalid json",
"public": False,
"author_id": "jer",
"tags": ["test"],
"url": "http://example.com",
},
)
assert r.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
assert r.json() == {"detail": "Invalid JSON in result field."}
@patch(
"app.web.routers.interoperability.business_logic.get_store_archive_until",
side_effect=AssertionError("AssertionError"),
)
def test_submit_manual_archive_no_store_until(
m_sau, client_with_token, db_session
):
aa_metadata = json.dumps(
{
"status": "test: success",
"metadata": {"url": "http://example.com"},
"media": [{"filename": "fn1", "urls": ["http://example.s3.com"]}],
}
)
r = client_with_token.post(
"/interop/submit-archive",
json={
"result": aa_metadata,
"public": True,
"author_id": "jerry@gmail.com",
"group_id": "spaceship",
"tags": ["test"],
"url": "http://example.com",
},
)
assert r.status_code == HTTPStatus.CREATED
assert len(r.json()["id"]) == 36
res = (
db_session.query(models.Archive)
.filter(models.Archive.id == r.json()["id"])
.first()
)
assert res.store_until is None
# testing that store_until = None is not comparable with datetime, and will always return False
res = (
db_session.query(models.Archive)
.filter(
models.Archive.id == r.json()["id"],
models.Archive.store_until < datetime.now(),
)
.first()
)
assert res is None

View File

@@ -0,0 +1,268 @@
from datetime import datetime
from http import HTTPStatus
from unittest.mock import MagicMock, patch
from fastapi.testclient import TestClient
from app.shared.constants import STATUS_PENDING
from app.shared.db import models
from app.shared.schemas import TaskResult
from app.web.db.user_state import UserState
from app.web.security import get_user_state
def test_endpoints_no_auth(client, test_no_auth):
test_no_auth(client.post, "/sheet/create")
test_no_auth(client.get, "/sheet/mine")
test_no_auth(client.delete, "/sheet/123-sheet-id")
test_no_auth(client.post, "/sheet/123-sheet-id/archive")
def test_create_sheet_endpoint(app_with_auth, db_session):
client_with_auth = TestClient(app_with_auth)
good_data = {
"id": "123-sheet-id",
"name": "Test Sheet",
"group_id": "spaceship",
"frequency": "daily",
}
# with good data
response = client_with_auth.post("/sheet/create", json=good_data)
assert response.status_code == HTTPStatus.CREATED
j = response.json()
assert datetime.fromisoformat(j.pop("created_at"))
assert datetime.fromisoformat(j.pop("last_url_archived_at"))
assert j.pop("author_id") == "morty@example.com"
assert j == good_data
# already exists
response = client_with_auth.post("/sheet/create", json=good_data)
assert response.status_code == HTTPStatus.BAD_REQUEST
assert response.json() == {
"detail": "Sheet with this ID is already being archived."
}
# bad group
bad_data = good_data.copy()
bad_data["group_id"] = "not a group"
response = client_with_auth.post("/sheet/create", json=bad_data)
assert response.status_code == HTTPStatus.FORBIDDEN
assert response.json() == {
"detail": "User does not have access to this group."
}
# switch to jerry who's got less quota/permissions
app_with_auth.dependency_overrides[get_user_state] = lambda: UserState(
db_session, "jerry@example.com"
)
client_jerry = TestClient(app_with_auth)
# frequency not allowed
jerry_data = good_data.copy()
jerry_data["group_id"] = "animated-characters"
jerry_data["frequency"] = "hourly"
jerry_data["id"] = "jerry-sheet-id"
response = client_jerry.post("/sheet/create", json=jerry_data)
assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
assert response.json() == {
"detail": "Invalid frequency selected for this group."
}
jerry_data["frequency"] = "daily"
# success for the first sheet, bad quota on second
response = client_jerry.post("/sheet/create", json=jerry_data)
assert response.status_code == HTTPStatus.CREATED
response = client_jerry.post("/sheet/create", json=jerry_data)
assert response.status_code == HTTPStatus.TOO_MANY_REQUESTS
assert response.json() == {
"detail": "User has reached their sheet quota for this group."
}
def test_get_user_sheets_endpoint(client_with_auth, db_session):
# no data
response = client_with_auth.get("/sheet/mine")
assert response.status_code == HTTPStatus.OK
assert response.json() == []
# with data
db_session.add(
models.Sheet(
id="123",
name="Test Sheet 1",
author_id="morty@example.com",
group_id="spaceship",
frequency="hourly",
)
)
db_session.commit()
db_session.add_all(
[
models.Sheet(
id="456",
name="Test Sheet 2",
author_id="morty@example.com",
group_id="interdimensional",
frequency="daily",
),
models.Sheet(
id="789",
name="Test Sheet 3",
author_id="rick@example.com",
group_id="interdimensional",
frequency="hourly",
),
]
)
db_session.commit()
response = client_with_auth.get("/sheet/mine")
assert response.status_code == HTTPStatus.OK
r = response.json()
assert isinstance(r, list)
assert len(r) == 2
assert datetime.fromisoformat(r[0].pop("created_at"))
assert datetime.fromisoformat(r[0].pop("last_url_archived_at"))
assert datetime.fromisoformat(r[1].pop("created_at"))
assert datetime.fromisoformat(r[1].pop("last_url_archived_at"))
assert r[0] == {
"id": "123",
"author_id": "morty@example.com",
"frequency": "hourly",
"group_id": "spaceship",
"name": "Test Sheet 1",
}
assert r[1] == {
"id": "456",
"author_id": "morty@example.com",
"frequency": "daily",
"group_id": "interdimensional",
"name": "Test Sheet 2",
}
def test_delete_sheet_endpoint(client_with_auth, db_session):
# missing sheet
response = client_with_auth.delete("/sheet/123-sheet-id")
assert response.status_code == HTTPStatus.OK
assert response.json() == {"id": "123-sheet-id", "deleted": False}
# add sheets for deletion
db_session.add_all(
[
models.Sheet(
id="123-sheet-id",
name="Test Sheet 1",
author_id="morty@example.com",
group_id="interdimensional",
frequency="daily",
),
models.Sheet(
id="456-sheet-id",
name="Test Sheet 2",
author_id="rick@example.com",
group_id="spaceship",
frequency="hourly",
),
]
)
db_session.commit()
# morty can delete his
response = client_with_auth.delete("/sheet/123-sheet-id")
assert response.status_code == HTTPStatus.OK
assert response.json() == {"id": "123-sheet-id", "deleted": True}
# but only once
response = client_with_auth.delete("/sheet/123-sheet-id")
assert response.status_code == HTTPStatus.OK
assert response.json() == {"id": "123-sheet-id", "deleted": False}
# and not Rick's
response = client_with_auth.delete("/sheet/456-sheet-id")
assert response.status_code == HTTPStatus.OK
assert response.json() == {"id": "456-sheet-id", "deleted": False}
class TestArchiveUserSheetEndpoint:
@patch("app.web.routers.sheet.celery", return_value=MagicMock())
def test_normal_flow(self, m_celery, client_with_auth, db_session):
db_session.add(
models.Sheet(
id="123-sheet-id",
name="Test Sheet 1",
author_id="morty@example.com",
group_id="spaceship",
frequency="hourly",
)
)
db_session.commit()
m_signature = MagicMock()
m_signature.apply_async.return_value = TaskResult(
id="123-taskid", status=STATUS_PENDING, result=""
)
m_celery.signature.return_value = m_signature
r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == HTTPStatus.CREATED
assert r.json() == {"id": "123-taskid"}
m_celery.signature.assert_called_once()
m_signature.apply_async.assert_called_once()
def test_token_auth(self, client_with_token, test_no_auth):
test_no_auth(client_with_token.post, "/sheet/123-sheet-id/archive")
def test_missing_data(self, client_with_auth):
r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == HTTPStatus.FORBIDDEN
assert r.json() == {"detail": "No access to this sheet."}
def test_no_access(self, client_with_auth, db_session):
db_session.add(
models.Sheet(
id="123-sheet-id",
name="Test Sheet 1",
author_id="rick@example.com",
group_id="spaceship",
frequency="hourly",
)
)
db_session.commit()
r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == HTTPStatus.FORBIDDEN
assert r.json() == {"detail": "No access to this sheet."}
def test_user_not_in_group(self, client_with_auth, db_session):
db_session.add(
models.Sheet(
id="123-sheet-id",
name="Test Sheet 1",
author_id="morty@example.com",
group_id="interdimensional",
frequency="hourly",
)
)
db_session.commit()
r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == HTTPStatus.FORBIDDEN
assert r.json() == {
"detail": "User does not have access to this group."
}
def test_user_cannot_manually_trigger(self, client_with_auth, db_session):
db_session.add(
models.Sheet(
id="123-sheet-id",
name="Test Sheet 1",
author_id="morty@example.com",
group_id="default",
frequency="hourly",
)
)
db_session.commit()
r = client_with_auth.post("/sheet/123-sheet-id/archive")
assert r.status_code == HTTPStatus.TOO_MANY_REQUESTS
assert r.json() == {
"detail": "User cannot manually trigger sheet archiving in this group."
}

View File

@@ -1,51 +1,53 @@
from http import HTTPStatus
from unittest.mock import patch
from app.shared.constants import STATUS_FAILURE, STATUS_PENDING, STATUS_SUCCESS
def test_endpoint_task_status_no_auth(client, test_no_auth):
test_no_auth(client.get, "/task/test-task-id")
@patch("app.web.endpoints.task.AsyncResult")
@patch("app.web.routers.task.AsyncResult")
def test_get_status_success(mock_async_result, client_with_auth):
mock_async_result.return_value.status = "SUCCESS"
mock_async_result.return_value.status = STATUS_SUCCESS
mock_async_result.return_value.result = {"data": "some result"}
response = client_with_auth.get("/task/test-task-id")
assert response.status_code == 200
assert response.status_code == HTTPStatus.OK
assert response.json() == {
"id": "test-task-id",
"status": "SUCCESS",
"result": {"data": "some result"}
"status": STATUS_SUCCESS,
"result": {"data": "some result"},
}
@patch("app.web.endpoints.task.AsyncResult")
@patch("app.web.routers.task.AsyncResult")
def test_get_status_failure(mock_async_result, client_with_auth):
mock_async_result.return_value.status = "FAILURE"
mock_async_result.return_value.status = STATUS_FAILURE
mock_async_result.return_value.result = Exception("Some error")
response = client_with_auth.get("/task/test-task-id")
assert response.status_code == 200
assert response.status_code == HTTPStatus.OK
assert response.json() == {
"id": "test-task-id",
"status": "FAILURE",
"result": {"error": "Some error"}
"status": STATUS_FAILURE,
"result": {"error": "Some error"},
}
@patch("app.web.endpoints.task.AsyncResult")
@patch("app.web.routers.task.AsyncResult")
def test_get_status_pending(mock_async_result, client_with_auth):
mock_async_result.return_value.status = "PENDING"
mock_async_result.return_value.status = STATUS_PENDING
mock_async_result.return_value.result = None
response = client_with_auth.get("/task/test-task-id")
assert response.status_code == 200
assert response.status_code == HTTPStatus.OK
assert response.json() == {
"id": "test-task-id",
"status": "PENDING",
"result": None
"status": STATUS_PENDING,
"result": None,
}

View File

@@ -0,0 +1,312 @@
import json
from http import HTTPStatus
from unittest.mock import MagicMock, patch
from app.shared import schemas
from app.shared.constants import STATUS_PENDING
from app.shared.db import worker_crud
from app.shared.schemas import ArchiveCreate, TaskResult
from app.web.config import ALLOW_ANY_EMAIL
def test_archive_url_unauthenticated(client, test_no_auth):
test_no_auth(client.post, "/url/archive")
@patch("app.web.routers.url.UserState")
@patch("app.web.routers.url.celery", return_value=MagicMock())
def test_archive_url(m_celery, m2, client_with_auth):
m_signature = MagicMock()
m_signature.apply_async.return_value = TaskResult(
id="123-456-789", status=STATUS_PENDING, result=""
)
m_celery.signature.return_value = m_signature
m_user_state = MagicMock()
m2.return_value = m_user_state
# url is too short
response = client_with_auth.post("/url/archive", json={"url": "bad"})
assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
assert (
response.json()["detail"][0]["msg"]
== "String should have at least 5 characters"
)
m_celery.signature.assert_not_called()
# url is invalid
response = client_with_auth.post(
"/url/archive", json={"url": "example.com"}
)
assert response.status_code == HTTPStatus.BAD_REQUEST
assert response.json()["detail"] == "Invalid URL received."
# valid request
m_user_state.has_quota_max_monthly_urls.return_value = True
m_user_state.has_quota_max_monthly_mbs.return_value = True
response = client_with_auth.post(
"/url/archive", json={"url": "https://example.com"}
)
assert response.status_code == HTTPStatus.CREATED
assert response.json() == {"id": "123-456-789"}
m_celery.signature.assert_called_once()
m_signature.apply_async.assert_called_once()
called_val = m_celery.signature.call_args
assert called_val[0][0] == "create_archive_task"
assert json.loads(called_val[1]["args"][0]) == {
"id": None,
"url": "https://example.com",
"result": None,
"public": False,
"author_id": "rick@example.com",
"group_id": "default",
"tags": None,
"sheet_id": None,
"store_until": None,
"urls": None,
}
m_user_state.has_quota_max_monthly_urls.assert_called_once()
m_user_state.has_quota_max_monthly_mbs.assert_called_once()
m_user_state.in_group.assert_called_once_with("default")
# user is not in group
m_user_state.in_group.return_value = False
response = client_with_auth.post(
"/url/archive",
json={"url": "https://example.com", "group_id": "new-group"},
)
assert response.status_code == HTTPStatus.FORBIDDEN
assert (
response.json()["detail"] == "User does not have access to this group."
)
m_user_state.in_group.assert_called_with("new-group")
# user is in group
m_user_state.in_group.return_value = True
response = client_with_auth.post(
"/url/archive",
json={"url": "https://example.com", "group_id": "spaceship"},
)
assert response.status_code == HTTPStatus.CREATED
assert response.json() == {"id": "123-456-789"}
assert m_celery.signature.call_count == 2
assert m_signature.apply_async.call_count == 2
called_val = m_celery.signature.call_args
assert json.loads(called_val[1]["args"][0])["group_id"] == "spaceship"
m_user_state.in_group.assert_called_with("spaceship")
# user is over monthly URL quota
m_user_state.has_quota_max_monthly_urls.return_value = False
m_user_state.has_quota_max_monthly_mbs.return_value = True
response = client_with_auth.post(
"/url/archive",
json={"url": "https://example.com", "group_id": "spaceship"},
)
assert response.status_code == HTTPStatus.TOO_MANY_REQUESTS
assert (
response.json()["detail"] == "User has reached their monthly URL quota."
)
m_user_state.has_quota_max_monthly_urls.assert_called_with("spaceship")
# user is over monthly MB quota
m_user_state.has_quota_max_monthly_urls.return_value = True
m_user_state.has_quota_max_monthly_mbs.return_value = False
response = client_with_auth.post(
"/url/archive",
json={"url": "https://example.com", "group_id": "spacesuit"},
)
assert response.status_code == HTTPStatus.TOO_MANY_REQUESTS
assert (
response.json()["detail"] == "User has reached their monthly MB quota."
)
m_user_state.has_quota_max_monthly_mbs.assert_called_with("spacesuit")
assert m_celery.signature.call_count == 2
assert m_signature.apply_async.call_count == 2
@patch("app.web.routers.url.UserState")
def test_archive_url_quotas(m1, client_with_auth):
m_user_state = MagicMock()
m1.return_value = m_user_state
# misses on monthly URLs quota
m_user_state.has_quota_max_monthly_urls.return_value = False
response = client_with_auth.post(
"/url/archive", json={"url": "https://example.com"}
)
assert response.status_code == HTTPStatus.TOO_MANY_REQUESTS
assert (
response.json()["detail"] == "User has reached their monthly URL quota."
)
m_user_state.has_quota_max_monthly_urls.assert_called_once()
# misses on monthly MBs quota
m_user_state.has_quota_max_monthly_urls.return_value = True
m_user_state.has_quota_max_monthly_mbs.return_value = False
response = client_with_auth.post(
"/url/archive", json={"url": "https://example.com"}
)
assert response.status_code == HTTPStatus.TOO_MANY_REQUESTS
assert (
response.json()["detail"] == "User has reached their monthly MB quota."
)
m_user_state.has_quota_max_monthly_mbs.assert_called_once()
@patch("app.web.routers.url.celery", return_value=MagicMock())
def test_archive_url_with_api_token(m_celery, client_with_token):
m_signature = MagicMock()
m_signature.apply_async.return_value = TaskResult(
id="123-456-789", status=STATUS_PENDING, result=""
)
m_celery.signature.return_value = m_signature
response = client_with_token.post(
"/url/archive",
json={"url": "https://example.com", "author_id": "someone@example.com"},
)
assert response.status_code == HTTPStatus.CREATED
assert response.json() == {"id": "123-456-789"}
m_celery.signature.assert_called_once()
m_signature.apply_async.assert_called_once()
called_val = m_celery.signature.call_args
assert called_val[0][0] == "create_archive_task"
assert json.loads(called_val[1]["args"][0]) == {
"id": None,
"url": "https://example.com",
"result": None,
"public": False,
"author_id": "someone@example.com",
"group_id": "default",
"tags": None,
"sheet_id": None,
"store_until": None,
"urls": None,
}
# missing id should use ALLOW_ANY_EMAIL
response = client_with_token.post(
"/url/archive", json={"url": "https://example.com", "author_id": None}
)
assert response.status_code == HTTPStatus.CREATED
called_val = m_celery.signature.call_args
assert called_val[0][0] == "create_archive_task"
assert json.loads(called_val[1]["args"][0]) == {
"id": None,
"url": "https://example.com",
"result": None,
"public": False,
"author_id": ALLOW_ANY_EMAIL,
"group_id": "default",
"tags": None,
"sheet_id": None,
"store_until": None,
"urls": None,
}
def test_search_by_url_unauthenticated(client, test_no_auth):
test_no_auth(client.get, "/url/search")
def test_search_by_url(client_with_auth, client_with_token, db_session):
# tests the search endpoint, including through some db data for the endpoint params
response = client_with_auth.get("/url/search")
assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
assert response.json()["detail"][0]["msg"] == "Field required"
response = client_with_auth.get("/url/search?url=https://example.com")
assert response.status_code == HTTPStatus.OK
assert response.json() == []
for i in range(11):
worker_crud.create_archive(
db_session,
ArchiveCreate(
id=f"url-456-{i}",
url="https://example.com"
if i < 10
else "https://something-else.com",
result={},
public=True,
author_id="rick@example.com",
),
[],
[],
)
# NB: this insertion is too fast for the ordering to be correct as they are within the same second
response = client_with_auth.get("/url/search?url=https://example.com")
assert response.status_code == HTTPStatus.OK
assert len(j := response.json()) == 10
assert "url-456-0" in [i["id"] for i in j]
assert "url-456-9" in [i["id"] for i in j]
assert "url-456-10" not in [i["id"] for i in j]
assert j[0].keys() == schemas.ArchiveResult.model_fields.keys()
response = client_with_auth.get(
"/url/search?url=https://example.com&limit=5"
)
assert response.status_code == HTTPStatus.OK
assert len(response.json()) == 5
response = client_with_auth.get(
"/url/search?url=https://example.com&skip=5&limit=2"
)
assert response.status_code == HTTPStatus.OK
assert len(response.json()) == 2
response = client_with_auth.get(
"/url/search?url=https://example.com&archived_before=2010-01-01"
)
assert response.status_code == HTTPStatus.OK
assert len(response.json()) == 0
response = client_with_auth.get(
"/url/search?url=https://example.com&archived_after=2010-01-01"
)
assert response.status_code == HTTPStatus.OK
assert len(response.json()) == 10
# API token will also work
response = client_with_token.get(
"/url/search?url=https://example.com&archived_after=2010-01-01"
)
assert response.status_code == HTTPStatus.OK
assert len(response.json()) == 10
@patch("app.web.routers.url.UserState")
def test_search_no_read_access(mock_user_state, client_with_auth):
mock_user_state.return_value.read = False
mock_user_state.return_value.read_public = False
response = client_with_auth.get("/url/search?url=https://example.com")
assert response.status_code == HTTPStatus.FORBIDDEN
assert response.json() == {"detail": "User does not have read access."}
def test_delete_task_unauthenticated(client, test_no_auth):
test_no_auth(client.delete, "/url/123-456-789")
def test_delete_task(client_with_auth, db_session):
response = client_with_auth.delete("/url/delete-123-456-789")
assert response.status_code == HTTPStatus.OK
assert response.json() == {"id": "delete-123-456-789", "deleted": False}
worker_crud.create_archive(
db_session,
ArchiveCreate(
id="delete-123-456-789",
url="https://example.com",
result={},
public=True,
author_id="morty@example.com",
),
[],
[],
)
response = client_with_auth.delete("/url/delete-123-456-789")
assert response.status_code == HTTPStatus.OK
assert response.json() == {"id": "delete-123-456-789", "deleted": True}

View File

@@ -1,31 +1,55 @@
import os
import shutil
from http import HTTPStatus
from unittest.mock import patch
import alembic.config
import pytest
from fastapi.testclient import TestClient
import shutil
from app.web.main import app_factory
from app.web.utils.metrics import EXCEPTION_COUNTER
import pytest
def test_lifespan(app):
with TestClient(app) as client:
r = client.get("/health")
assert r.status_code == 200
assert r.status_code == HTTPStatus.OK
assert r.json() == {"status": "ok"}
def test_alembic(db_session):
import alembic.config
alembic.config.main(argv=['--raiseerr', 'upgrade', 'head'])
alembic.config.main(argv=['--raiseerr', 'downgrade', 'base'])
@patch("app.web.endpoints.url.crud.soft_delete_archive", side_effect=Exception('mocked error'))
def test_alembic(db_session):
alembic.config.main(
argv=[
"-c",
"./app/migrations/alembic.ini",
"--raiseerr",
"upgrade",
"head",
]
)
alembic.config.main(
argv=[
"-c",
"./app/migrations/alembic.ini",
"--raiseerr",
"downgrade",
"base",
]
)
@patch(
"app.web.routers.url.crud.soft_delete_archive",
side_effect=Exception("mocked error"),
)
def test_logging_middleware(m1, client_with_auth):
from app.web.utils.metrics import EXCEPTION_COUNTER
assert len(EXCEPTION_COUNTER.collect()[0].samples) == 0
with pytest.raises(Exception, match="mocked error"):
client_with_auth.delete("/url/123")
# creates one empty and one from above
assert len(EXCEPTION_COUNTER.collect()[0].samples) == 2
def test_serve_local_archive_logic(get_settings):
# create a test file first
@@ -36,13 +60,13 @@ def test_serve_local_archive_logic(get_settings):
try:
# modify the settings
get_settings.SERVE_LOCAL_ARCHIVE = "/app/local_archive_test"
from app.web.main import app_factory
app = app_factory(get_settings)
# test
client = TestClient(app)
r = client.get("/app/local_archive_test/temp.txt")
assert r.status_code == 200
assert r.status_code == HTTPStatus.OK
assert r.text == "test"
finally:
# cleanup

View File

@@ -1,101 +1,168 @@
from http import HTTPStatus
from unittest import mock
from unittest.mock import Mock, patch
import pytest
from fastapi import HTTPException
from fastapi.security import HTTPAuthorizationCredentials
import pytest
from app.web.config import ALLOW_ANY_EMAIL
from app.web.db.user_state import UserState
from app.web.security import (
authenticate_user,
get_token_or_user_auth,
get_user_auth,
get_user_state,
secure_compare,
token_api_key_auth,
)
def test_secure_compare():
from app.web.security import secure_compare
assert secure_compare("test", "test")
assert not secure_compare("test", "test2")
@pytest.mark.asyncio
async def test_get_token_or_user_auth_with_api():
from app.web.security import get_token_or_user_auth
mock_api = HTTPAuthorizationCredentials(scheme="lorem", credentials="this_is_the_test_api_token")
mock_api = HTTPAuthorizationCredentials(
scheme="lorem", credentials="this_is_the_test_api_token"
)
assert await get_token_or_user_auth(mock_api) == ALLOW_ANY_EMAIL
@pytest.mark.asyncio
async def test_get_token_or_user_auth_with_user():
from app.web.security import get_token_or_user_auth
bad_user = HTTPAuthorizationCredentials(scheme="ipsum", credentials="invalid")
e: pytest.ExceptionInfo = None
bad_user = HTTPAuthorizationCredentials(
scheme="ipsum", credentials="invalid"
)
with pytest.raises(HTTPException) as e:
await get_token_or_user_auth(bad_user)
assert e.value.status_code == 401
assert e.value.status_code == HTTPStatus.UNAUTHORIZED
assert e.value.detail == "invalid access_token"
@patch("app.web.security.authenticate_user", return_value=(True, "summer@example.com"))
@patch(
"app.web.security.authenticate_user",
return_value=(True, "summer@example.com"),
)
@pytest.mark.asyncio
async def test_get_user_auth(m1):
from app.web.security import get_user_auth
good_user = HTTPAuthorizationCredentials(scheme="ipsum", credentials="valid-and-good")
good_user = HTTPAuthorizationCredentials(
scheme="ipsum", credentials="valid-and-good"
)
assert await get_user_auth(good_user) == "summer@example.com"
@patch("app.web.security.secure_compare", return_value=False)
@pytest.mark.asyncio
async def test_token_api_key_auth_exception(m1):
from app.web.security import token_api_key_auth
e: pytest.ExceptionInfo = None
with pytest.raises(HTTPException) as e:
await token_api_key_auth(HTTPAuthorizationCredentials(scheme="ipsum", credentials="does-not-matter"), auto_error=True)
assert e.value.status_code == 401
await token_api_key_auth(
HTTPAuthorizationCredentials(
scheme="ipsum", credentials="does-not-matter"
),
auto_error=True,
)
assert e.value.status_code == HTTPStatus.UNAUTHORIZED
assert e.value.detail == "Wrong auth credentials"
@pytest.mark.asyncio
async def test_authenticate_user():
from app.web.security import authenticate_user
assert authenticate_user("test") == (False, "invalid access_token")
assert authenticate_user(123) == (False, "invalid access_token")
with patch("app.web.security.requests.get") as mock_get:
# bad response from oauth2
mock_get.return_value.status_code = 403
assert authenticate_user("this-will-call-requests") == (False, "invalid token")
mock_get.return_value.status_code = HTTPStatus.FORBIDDEN
assert authenticate_user("this-will-call-requests") == (
False,
"invalid token",
)
assert mock_get.call_count == 1
# 200 but invalid json
mock_get.return_value.status_code = 200
assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID")
mock_get.return_value.status_code = HTTPStatus.OK
assert authenticate_user("this-will-call-requests") == (
False,
"token does not belong to valid APP_ID",
)
assert mock_get.call_count == 2
# 200 but invalid azp and aud
mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "not_an_app"}
assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID")
mock_get.return_value.json.return_value = {
"email": "summer@example.com",
"azp": "not_an_app",
}
assert authenticate_user("this-will-call-requests") == (
False,
"token does not belong to valid APP_ID",
)
mock_get.return_value.json.return_value = {"email": "summer@example.com", "aud": "not_an_app"}
assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID")
mock_get.return_value.json.return_value = {
"email": "summer@example.com",
"aud": "not_an_app",
}
assert authenticate_user("this-will-call-requests") == (
False,
"token does not belong to valid APP_ID",
)
mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "not_an_app", "aud": "not_an_app"}
assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID")
mock_get.return_value.json.return_value = {
"email": "summer@example.com",
"azp": "not_an_app",
"aud": "not_an_app",
}
assert authenticate_user("this-will-call-requests") == (
False,
"token does not belong to valid APP_ID",
)
# blocked email
mock_get.return_value.json.return_value = {"email": "blocked@example.com", "azp": "test_app_id_1", "aud": "not_an_app"}
assert authenticate_user("this-will-call-requests") == (False, "email 'blocked@example.com' not allowed")
mock_get.return_value.json.return_value = {
"email": "blocked@example.com",
"azp": "test_app_id_1",
"aud": "not_an_app",
}
assert authenticate_user("this-will-call-requests") == (
False,
"email 'blocked@example.com' not allowed",
)
# not verified
mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "not_an_app", "aud": "test_app_id_1"}
assert authenticate_user("this-will-call-requests") == (False, "email 'summer@example.com' not verified")
mock_get.return_value.json.return_value = {
"email": "summer@example.com",
"azp": "not_an_app",
"aud": "test_app_id_1",
}
assert authenticate_user("this-will-call-requests") == (
False,
"email 'summer@example.com' not verified",
)
# token expired
mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "test_app_id_2", "email_verified": "true"}
assert authenticate_user("this-will-call-requests") == (False, "Token expired")
mock_get.return_value.json.return_value = {
"email": "summer@example.com",
"azp": "test_app_id_2",
"email_verified": "true",
}
assert authenticate_user("this-will-call-requests") == (
False,
"Token expired",
)
# 200 and valid azp and aup and verified
mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "test_app_id_2", "email_verified": "true", "expires_in": 100}
assert authenticate_user("this-will-call-requests") == (True, "summer@example.com")
mock_get.return_value.json.return_value = {
"email": "summer@example.com",
"azp": "test_app_id_2",
"email_verified": "true",
"expires_in": 100,
}
assert authenticate_user("this-will-call-requests") == (
True,
"summer@example.com",
)
assert mock_get.call_count == 9
@@ -104,6 +171,7 @@ async def test_authenticate_user():
@pytest.mark.asyncio
async def test_authenticate_user_with_id_token(m_init):
from firebase_admin import exceptions
from app.web.security import authenticate_user
with pytest.raises(ValueError) as e:
@@ -113,12 +181,20 @@ async def test_authenticate_user_with_id_token(m_init):
with patch("app.web.security.auth.verify_id_token") as mock_verify:
# missing email
mock_verify.return_value = {"email": None}
assert authenticate_user("fake_token") == (False, "email not found in token")
assert authenticate_user("fake_token") == (
False,
"email not found in token",
)
assert mock_verify.call_count == 1
# blocked email
mock_verify.return_value = {"email": "blocked@example.com", }
assert authenticate_user("fake_token") == (False, "email 'blocked@example.com' not allowed")
mock_verify.return_value = {
"email": "blocked@example.com",
}
assert authenticate_user("fake_token") == (
False,
"email 'blocked@example.com' not allowed",
)
assert mock_verify.call_count == 2
# valid email
@@ -132,17 +208,16 @@ async def test_authenticate_user_with_id_token(m_init):
@pytest.mark.asyncio
async def test_authenticate_user_exception():
from app.web.security import authenticate_user
with patch("app.web.security.requests.get") as mock_get:
mock_get.return_value.status_code = 200
mock_get.return_value.status_code = HTTPStatus.OK
mock_get.return_value.json.side_effect = Exception("mocked error")
assert authenticate_user("this-will-call-requests") == (False, "exception occurred")
assert authenticate_user("this-will-call-requests") == (
False,
"exception occurred",
)
def test_get_user_state():
from app.web.security import get_user_state
from app.web.db.user_state import UserState
mock_session = Mock()
test_email = "test@example.com"

View File

@@ -1,29 +1,47 @@
from datetime import datetime
from unittest.mock import patch
import pytest
from app.shared.db import models
from app.shared import schemas
from auto_archiver.core import Media, Metadata
from app.shared import constants, schemas
from app.shared.db import models
from app.web.utils.misc import get_all_urls
from app.worker.main import create_archive_task, create_sheet_task
class Test_create_archive_task():
class TestCreateArchiveTask:
URL = "https://example-live.com"
archive = schemas.ArchiveCreate(url=URL, tags=["tag-celery"], public=True, author_id="rick@example.com", group_id="interstellar")
archive = schemas.ArchiveCreate(
url=URL,
tags=["tag-celery"],
public=True,
author_id="rick@example.com",
group_id="interstellar",
)
@patch("app.worker.main.ArchivingOrchestrator")
@patch("app.worker.main.get_all_urls", return_value=[])
@patch("app.worker.main.insert_result_into_db")
@patch("app.worker.main.get_store_until", return_value=datetime.now())
@patch("app.worker.main.get_orchestrator_args", return_value=["arg1", "arg2"])
@patch(
"app.worker.main.get_orchestrator_args", return_value=["arg1", "arg2"]
)
@patch("celery.app.task.Task.request")
def test_success(self, m_req, m_args, m_store, m_insert, m_urls, m_orchestrator, db_session):
from app.worker.main import create_archive_task
def test_success(
self,
m_req,
m_args,
m_store,
m_insert,
m_urls,
m_orchestrator,
db_session,
):
m_req.id = "this-just-in"
m_orchestrator.return_value.feed.return_value = iter([Metadata().set_url(self.URL).success()])
m_orchestrator.return_value.feed.return_value = iter(
[Metadata().set_url(self.URL).success()]
)
task = create_archive_task(self.archive.model_dump_json())
@@ -39,15 +57,15 @@ class Test_create_archive_task():
assert len(task["media"]) == 0
def test_raise_invalid(self):
from app.worker.main import create_archive_task
with pytest.raises(Exception):
with pytest.raises(Exception) as _:
create_archive_task(self.archive.model_dump_json())
@patch("app.worker.main.ArchivingOrchestrator")
@patch("app.worker.main.get_orchestrator_args")
def test_raise_db_error(self, m_args, m_orchestrator):
from app.worker.main import create_archive_task
m_orchestrator.return_value.feed.side_effect = Exception("Orchestrator failed")
m_orchestrator.return_value.feed.side_effect = Exception(
"Orchestrator failed"
)
with pytest.raises(Exception) as e:
create_archive_task(self.archive.model_dump_json())
@@ -59,7 +77,6 @@ class Test_create_archive_task():
@patch("app.worker.main.insert_result_into_db", return_value=None)
@patch("app.worker.main.get_orchestrator_args")
def test_raise_empty_result(self, m_args, m_insert, m_orchestrator):
from app.worker.main import create_archive_task
m_orchestrator.return_value.feed.return_value = iter([None])
with pytest.raises(Exception) as e:
@@ -68,61 +85,83 @@ class Test_create_archive_task():
m_orchestrator.return_value.feed.assert_called_once()
class Test_create_sheet_task():
class TestCreateSheetTask:
URL = "https://example-live.com"
sheet = schemas.SubmitSheet(sheet_id="123", author_id="rick@example.com", group_id="interstellar", tags=["spaceship"])
sheet = schemas.SubmitSheet(
sheet_id="123",
author_id="rick@example.com",
group_id="interstellar",
tags=["spaceship"],
)
@patch("app.worker.main.get_all_urls", return_value=[])
@patch("app.worker.main.ArchivingOrchestrator")
@patch("app.worker.main.models.generate_uuid", return_value="constant-uuid")
@patch("app.worker.main.get_store_until", return_value=datetime.now())
@patch("app.worker.main.get_orchestrator_args")
def test_success(self, m_args, m_store, m_uuid, m_orchestrator, m_urls, db_session):
from app.worker.main import create_sheet_task
assert db_session.query(models.Archive).filter(models.Archive.url == self.URL).count() == 0
def test_success(
self, m_args, m_store, m_uuid, m_orchestrator, m_urls, db_session
):
assert (
db_session.query(models.Archive)
.filter(models.Archive.url == self.URL)
.count()
== 0
)
mock_metadata = Metadata().set_url(self.URL).success()
mock_metadata.add_media(Media("fn1.txt", urls=["outcome1.com"]))
m_orchestrator.return_value.feed.return_value = iter([False, mock_metadata, mock_metadata])
m_orchestrator.return_value.feed.return_value = iter(
[False, mock_metadata, mock_metadata]
)
res = create_sheet_task(self.sheet.model_dump_json())
m_args.assert_called_once_with("interstellar", True, ["--gsheet_feeder.sheet_id", "123"])
m_args.assert_called_once_with(
"interstellar", True, [constants.SHEET_ID, "123"]
)
m_orchestrator.return_value.setup.assert_called_once()
m_orchestrator.return_value.feed.assert_called_once()
m_store.assert_called_with("interstellar")
m_store.call_count == 2
m_uuid.call_count == 2
assert type(res) == dict
assert m_store.call_count == 2
assert m_uuid.call_count == 2
assert isinstance(res, dict)
assert res["stats"]["archived"] == 1
assert res["stats"]["failed"] == 1
assert len(res["stats"]["errors"]) == 1
assert res["sheet_id"] == "123"
assert res["success"]
assert type(res["time"]) == datetime
assert isinstance(res["time"], datetime)
# query created archive entry
inserted = db_session.query(models.Archive).filter(models.Archive.url == self.URL).one()
inserted = (
db_session.query(models.Archive)
.filter(models.Archive.url == self.URL)
.one()
)
assert inserted is not None
assert inserted.url == self.URL
assert len(inserted.tags) == 1
assert inserted.tags[0].id == "spaceship"
assert inserted.group_id == "interstellar"
assert inserted.author_id == "rick@example.com"
assert inserted.public == False
assert inserted.public is False
def test_get_all_urls(db_session):
from app.worker.main import get_all_urls
meta = Metadata().set_url("https://example.com")
m1 = meta.add_media(Media("fn1.txt", urls=["outcome1.com"]))
m2 = meta.add_media(Media("fn2.txt", urls=["outcome2.com"]))
m3 = meta.add_media(Media("fn3.txt", urls=["outcome3.com"]))
m1.set("screenshot", Media("screenshot.png", urls=["screenshot.com"]))
m2.set("thumbnails", [Media("thumb1.png", urls=["thumb1.com"]), Media("thumb2.png", urls=["thumb2.com"])])
m2.set(
"thumbnails",
[
Media("thumb1.png", urls=["thumb1.com"]),
Media("thumb2.png", urls=["thumb2.com"]),
],
)
m3.set("ssl_data", Media("ssl_data.txt", urls=["ssl_data.com"]).to_dict())
m3.set("bad_data", {"bad": "dict is ignored"})

View File

@@ -1,3 +1,4 @@
from app.web.main import app_factory
app = app_factory
app = app_factory

View File

@@ -1,14 +1,17 @@
VERSION = "0.9.4"
VERSION = "0.10.0"
API_DESCRIPTION = """
#### API for the Auto-Archiver project, a tool to archive web pages and Google Sheets.
**Usage notes:**
- The API requires a Bearer token for most operations, which you can obtain by logging in with your Google account.
- You can use this API to archive single URLs or entire Google Sheets.
- You can use this API to archive single URLs or entire Google Sheets.
- Once you submit a URL or Sheet for archiving, the API will return a task_id that you can use to check the status of the archiving process. It works asynchronously.
"""
BREAKING_CHANGES = {"minVersion": "0.4.0", "message": "The latest update has breaking changes, please update the extension to the most recent version."}
BREAKING_CHANGES = {
"minVersion": "0.4.0",
"message": "The latest update has breaking changes, please update the extension to the most recent version.",
}
# changing this will corrupt the database logic
ALLOW_ANY_EMAIL = "*"

View File

@@ -1,18 +1,30 @@
from collections import defaultdict
from functools import lru_cache
from sqlalchemy.orm import Session, load_only
from sqlalchemy import Column, or_, func, select
from loguru import logger
from datetime import datetime, timedelta
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Any, Type
from cachetools import LRUCache, cached
from cachetools.keys import hashkey
from loguru import logger
from sqlalchemy import (
Column,
ColumnElement,
ScalarResult,
false,
func,
not_,
or_,
select,
true,
)
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session, load_only
from app.web.config import ALLOW_ANY_EMAIL
from app.shared.db import models
from app.shared.db.models import Archive, Group
from app.shared.settings import get_settings
from app.shared.user_groups import UserGroups
from app.shared.utils.misc import fnv1a_hash_mod
from app.web.config import ALLOW_ANY_EMAIL
from app.web.utils.misc import convert_priority_to_queue_dict
@@ -22,24 +34,48 @@ DATABASE_QUERY_LIMIT = get_settings().DATABASE_QUERY_LIMIT
def get_limit(user_limit: int):
return max(1, min(user_limit, DATABASE_QUERY_LIMIT))
# --------------- TASK = Archive
def base_query(db: Session):
# NOTE: load_only is for optimization and not obfuscation, use .with_entities() if needed
return db.query(models.Archive)\
.filter(models.Archive.deleted == False)\
.options(load_only(models.Archive.id, models.Archive.created_at, models.Archive.url, models.Archive.result, models.Archive.store_until))
# NOTE: load_only is for optimization and not obfuscation, use
# .with_entities() if needed
return (
db.query(models.Archive)
.filter(not_(models.Archive.deleted))
.options(
load_only(
models.Archive.id,
models.Archive.created_at,
models.Archive.url,
models.Archive.result,
models.Archive.store_until,
)
)
)
def search_archives_by_url(db: Session, url: str, email: str, read_groups: bool | set[str], read_public: bool, skip: int = 0, limit: int = 100, archived_after: datetime = None, archived_before: datetime = None, absolute_search: bool = False) -> list[models.Archive]:
# searches for partial URLs, if email is * no ownership (or read/read_public) filtering happens
def search_archives_by_url(
db: Session,
url: str,
email: str,
read_groups: bool | set[str],
read_public: bool,
skip: int = 0,
limit: int = 100,
archived_after: datetime = None,
archived_before: datetime = None,
absolute_search: bool = False,
) -> list[Type[Archive]]:
# searches for partial URLs, if email is * no ownership
# (or read/read_public) filtering happens
query = base_query(db)
if email != ALLOW_ANY_EMAIL:
or_filters = [models.Archive.author_id == email]
if read_public:
or_filters.append(models.Archive.public == True)
if read_groups == True:
or_filters.append(models.Archive.public.is_(true()))
if read_groups is True:
or_filters.append(models.Archive.group_id.isnot(None))
else:
or_filters.append(models.Archive.group_id.in_(read_groups))
@@ -47,21 +83,43 @@ def search_archives_by_url(db: Session, url: str, email: str, read_groups: bool
if absolute_search:
query = query.filter(models.Archive.url == url)
else:
query = query.filter(models.Archive.url.like(f'%{url}%'))
query = query.filter(models.Archive.url.like(f"%{url}%"))
if archived_after:
query = query.filter(models.Archive.created_at > archived_after)
if archived_before:
query = query.filter(models.Archive.created_at < archived_before)
return query.order_by(models.Archive.created_at.desc()).offset(skip).limit(get_limit(limit)).all()
return (
query.order_by(models.Archive.created_at.desc())
.offset(skip)
.limit(get_limit(limit))
.all()
)
def search_archives_by_email(db: Session, email: str, skip: int = 0, limit: int = 100):
return base_query(db).filter(models.Archive.author_id == email).order_by(models.Archive.created_at.desc()).offset(skip).limit(get_limit(limit)).all()
def search_archives_by_email(
db: Session, email: str, skip: int = 0, limit: int = 100
):
return (
base_query(db)
.filter(models.Archive.author_id == email)
.order_by(models.Archive.created_at.desc())
.offset(skip)
.limit(get_limit(limit))
.all()
)
def soft_delete_archive(db: Session, id: str, email: str) -> bool:
# TODO: implement hard-delete with cronjob that deletes from S3
db_archive = db.query(models.Archive).filter(models.Archive.id == id, models.Archive.author_id == email, models.Archive.deleted == False).first()
db_archive = (
db.query(models.Archive)
.filter(
models.Archive.id == id,
models.Archive.author_id == email,
models.Archive.deleted.is_(false()),
)
.first()
)
if db_archive:
db_archive.deleted = True
db.commit()
@@ -82,22 +140,29 @@ def count_users(db: Session):
def count_by_user_since(db: Session, seconds_delta: int = 15):
time_threshold = datetime.now() - timedelta(seconds=seconds_delta)
return db.query(models.Archive.author_id, func.count().label('total'))\
.filter(models.Archive.created_at >= time_threshold)\
.group_by(models.Archive.author_id)\
.order_by(func.count().desc())\
.limit(500).all()
return (
db.query(models.Archive.author_id, func.count().label("total"))
.filter(models.Archive.created_at >= time_threshold)
.group_by(models.Archive.author_id)
.order_by(func.count().desc())
.limit(500)
.all()
)
async def find_by_store_until(db: AsyncSession, store_until_is_before: datetime) -> list[models.Archive]:
async def find_by_store_until(
db: AsyncSession, store_until_is_before: datetime
) -> ScalarResult[Archive]:
res = await db.execute(
select(models.Archive)
.filter(models.Archive.deleted == False, models.Archive.store_until < store_until_is_before)
select(models.Archive).filter(
models.Archive.deleted.is_(false()),
models.Archive.store_until < store_until_is_before,
)
)
return res.scalars()
async def soft_delete_expired_archives(db: AsyncSession) -> dict:
async def soft_delete_expired_archives(db: AsyncSession) -> int:
to_delete = await find_by_store_until(db, datetime.now())
counter = 0
for archive in to_delete:
@@ -105,47 +170,86 @@ async def soft_delete_expired_archives(db: AsyncSession) -> dict:
counter += 1
await db.commit()
return counter
# --------------- TAG
async def get_group_priority_async(db: AsyncSession, group_id: str) -> dict:
db_group = await db.get(models.Group, group_id)
priority = db_group.permissions.get("priority", "low") if db_group else "low"
priority = (
db_group.permissions.get("priority", "low") if db_group else "low"
)
return convert_priority_to_queue_dict(priority)
@cached(cache=LRUCache(maxsize=128), key=lambda db, email: hashkey(email))
def get_user_group_names(db: Session, email: str) -> list[str]:
def get_user_group_names(
db: Session, email: str
) -> list[Any] | list[ColumnElement[Any]]:
"""
given an email retrieves the user groups from the DB and then the email-domain groups from a global variable, the email does not need to belong to an existing user.
given an email retrieves the user groups from the DB and then the
email-domain groups from a global variable, the email does not need to
belong to an existing user.
"""
# TODO: the read: [group1, group2] permissions don't currently work
if not email or not len(email) or "@" not in email: return []
if not email or not len(email) or "@" not in email:
return []
# get user groups
user_groups = db.query(models.association_table_user_groups).filter_by(user_id=email).with_entities(Column("group_id")).all()
user_groups = (
db.query(models.association_table_user_groups)
.filter_by(user_id=email)
.with_entities(Column("group_id"))
.all()
)
user_level_groups_names = [g[0] for g in user_groups]
# get domain groups
domain = email.split('@')[1]
domain_level_groups = db.query(models.Group.id).filter(models.Group.domains.contains(domain)).with_entities(Column("id")).all()
domain = email.split("@")[1]
domain_level_groups = (
db.query(models.Group.id)
.filter(models.Group.domains.contains(domain))
.with_entities(Column("id"))
.all()
)
domain_level_groups_names = [g[0] for g in domain_level_groups]
return list(set(user_level_groups_names + domain_level_groups_names))
def get_user_groups_by_name(db: Session, groups: list[str]) -> list[models.Group]:
return db.query(models.Group).filter(
models.Group.id.in_(groups)
).all()
def get_user_groups_by_name(
db: Session, groups: list[str]
) -> list[Type[Group]]:
return db.query(models.Group).filter(models.Group.id.in_(groups)).all()
# --------------- INIT User-Groups
def upsert_group(db: Session, group_name: str, description: str, orchestrator: str, orchestrator_sheet: str, service_account_email: str, permissions: dict, domains: list) -> models.Group:
db_group = db.query(models.Group).filter(models.Group.id == group_name).first()
def upsert_group(
db: Session,
group_name: str,
description: str,
orchestrator: str,
orchestrator_sheet: str,
service_account_email: str,
permissions: dict,
domains: list,
) -> models.Group:
db_group = (
db.query(models.Group).filter(models.Group.id == group_name).first()
)
if db_group is None:
db_group = models.Group(id=group_name, description=description, orchestrator=orchestrator, orchestrator_sheet=orchestrator_sheet, service_account_email=service_account_email, permissions=permissions, domains=domains)
db_group = models.Group(
id=group_name,
description=description,
orchestrator=orchestrator,
orchestrator_sheet=orchestrator_sheet,
service_account_email=service_account_email,
permissions=permissions,
domains=domains,
)
db.add(db_group)
else:
db_group.description = description
@@ -172,8 +276,9 @@ def upsert_user(db: Session, email: str):
def upsert_user_groups(db: Session):
def display_email_pii(email: str):
return f"'{email[0:3]}...@{email.split('@')[1]}'"
"""
reads the user_groups yaml file and inserts any new users, groups,
reads the user_groups yaml file and inserts any new users, groups,
along with new participation of users in groups
"""
filename = get_settings().USER_GROUPS_FILENAME
@@ -192,18 +297,33 @@ def upsert_user_groups(db: Session):
for group in explicit_groups:
group_domains[group].add(domain)
import json
# upsert groups and save a map of groupid -> dbobject
for group_id, g in ug.groups.items():
upsert_group(db, group_id, g.description, g.orchestrator, g.orchestrator_sheet, g.service_account_email, json.loads(g.permissions.model_dump_json()), list(group_domains.get(group_id, [])))
db_groups: dict[str, models.Group] = {g.id: g for g in db.query(models.Group).all()}
upsert_group(
db,
group_id,
g.description,
g.orchestrator,
g.orchestrator_sheet,
g.service_account_email,
json.loads(g.permissions.model_dump_json()),
list(group_domains.get(group_id, [])),
)
db_groups: dict[str, models.Group] = {
g.id: g for g in db.query(models.Group).all()
}
# integrity checks
for group_in_domains in group_domains:
if group_in_domains not in db_groups:
logger.warning(f"[CONFIG] Group '{group_in_domains}' does not exist in the database: domains setting will not work.")
logger.warning(
f"[CONFIG] Group '{group_in_domains}' does not exist in the database: domains setting will not work."
)
# reinsert users in their EXPLICITLY DEFINED groups
# domain groups are check live, as there may be new users that are not explicitly registered but belong to a domain
# domain groups are check live, as there may be new users that are not
# explicitly registered but belong to a domain
for email, explicit_groups in ug.users.items():
explicit_groups = explicit_groups or []
logger.info(f"EXPLICIT {display_email_pii(email)} => {explicit_groups}")
@@ -213,7 +333,9 @@ def upsert_user_groups(db: Session):
# connect users to groups
for group_id in explicit_groups:
if group_id not in db_groups:
logger.warning(f"[CONFIG] Group {group_id} does not exist in config file, skipping for email={display_email_pii(email)}.")
logger.warning(
f"[CONFIG] Group {group_id} does not exist in config file, skipping for email={display_email_pii(email)}."
)
continue
db_groups[group_id].users.append(db_user)
@@ -221,12 +343,27 @@ def upsert_user_groups(db: Session):
count_user_groups = db.query(models.association_table_user_groups).count()
count_groups = db.query(func.count(models.Group.id)).scalar()
logger.success(f"[CONFIG] DONE: [users={count_users(db)}, groups={count_groups}, explicit user groups={count_user_groups}].")
logger.success(
f"[CONFIG] DONE: [users={count_users(db)}, groups={count_groups}, explicit user groups={count_user_groups}]."
)
# --------------- SHEET
def create_sheet(db: Session, sheet_id: str, name: str, email: str, group_id: str, frequency: str):
db_sheet = models.Sheet(id=sheet_id, name=name, author_id=email, group_id=group_id, frequency=frequency)
def create_sheet(
db: Session,
sheet_id: str,
name: str,
email: str,
group_id: str,
frequency: str,
):
db_sheet = models.Sheet(
id=sheet_id,
name=name,
author_id=email,
group_id=group_id,
frequency=frequency,
)
db.add(db_sheet)
db.commit()
db.refresh(db_sheet)
@@ -234,20 +371,31 @@ def create_sheet(db: Session, sheet_id: str, name: str, email: str, group_id: st
def get_user_sheet(db: Session, email: str, sheet_id: str) -> models.Sheet:
return db.query(models.Sheet).filter(models.Sheet.author_id == email, models.Sheet.id == sheet_id).first()
return (
db.query(models.Sheet)
.filter(models.Sheet.author_id == email, models.Sheet.id == sheet_id)
.first()
)
def get_user_sheets(db: Session, email: str) -> list[models.Sheet]:
return db.query(models.Sheet).filter(models.Sheet.author_id == email).order_by(models.Sheet.last_url_archived_at.desc()).all()
return (
db.query(models.Sheet)
.filter(models.Sheet.author_id == email)
.order_by(models.Sheet.last_url_archived_at.desc())
.all()
)
async def get_sheets_by_id_hash(db: AsyncSession, frequency: str, modulo: str, id_hash: int) -> list[models.Sheet]:
async def get_sheets_by_id_hash(
db: AsyncSession, frequency: str, modulo: str, id_hash: int
) -> list[models.Sheet]:
result = await db.execute(
select(models.Sheet).filter(models.Sheet.frequency == frequency)
)
filtered = []
for sheet in result.scalars():
if fnv1a_hash_mod(sheet.id, modulo) == id_hash:
if fnv1a_hash_mod(sheet.id, int(modulo)) == id_hash:
filtered.append(sheet)
return filtered
@@ -255,7 +403,9 @@ async def get_sheets_by_id_hash(db: AsyncSession, frequency: str, modulo: str, i
async def delete_stale_sheets(db: AsyncSession, inactivity_days: int) -> dict:
time_threshold = datetime.now() - timedelta(days=inactivity_days)
result = await db.execute(
select(models.Sheet).filter(models.Sheet.last_url_archived_at < time_threshold)
select(models.Sheet).filter(
models.Sheet.last_url_archived_at < time_threshold
)
)
deleted = defaultdict(list)
for sheet in result.scalars():
@@ -266,7 +416,11 @@ async def delete_stale_sheets(db: AsyncSession, inactivity_days: int) -> dict:
def delete_sheet(db: Session, sheet_id: str, email: str) -> bool:
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id, models.Sheet.author_id == email).first()
db_sheet = (
db.query(models.Sheet)
.filter(models.Sheet.id == sheet_id, models.Sheet.author_id == email)
.first()
)
if db_sheet:
db.delete(db_sheet)
db.commit()

View File

@@ -1,13 +1,13 @@
from typing import Dict, Set
import sqlalchemy
from sqlalchemy.orm import Session
from sqlalchemy import func
from datetime import datetime
from typing import Dict, Set
import sqlalchemy
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.shared.db import models
from app.shared.user_groups import GroupInfo, GroupPermissions
from app.shared.schemas import Usage, UsageResponse
from app.shared.user_groups import GroupInfo, GroupPermissions
from app.web.db import crud
from app.web.utils.misc import convert_priority_to_queue_dict
@@ -20,14 +20,15 @@ class UserState:
def __init__(self, db: Session, email: str):
self.db = db
self.email = email.lower()
self._permissions = {}
@property
def permissions(self) -> Dict[str, GroupInfo]:
"""
Returns a dict of all group permissions and a special {"all": read/archive_url/archive_sheet} key
Returns a dict of all group permissions and a special
{"all": read/archive_url/archive_sheet} key
"""
if not hasattr(self, '_permissions'):
self._permissions = {}
if not self._permissions:
self._permissions["all"] = GroupInfo(
read=self.read,
read_public=self.read_public,
@@ -37,23 +38,33 @@ class UserState:
max_archive_lifespan_months=self.max_archive_lifespan_months,
max_monthly_urls=self.max_monthly_urls,
max_monthly_mbs=self.max_monthly_mbs,
priority=self.priority
priority=self.priority,
)
for group in self.user_groups:
if not group.permissions: continue
self._permissions[group.id] = GroupInfo(**group.permissions, description=group.description, service_account_email=group.service_account_email)
if not group.permissions:
continue
self._permissions[group.id] = GroupInfo(
**group.permissions,
description=group.description,
service_account_email=group.service_account_email,
)
return self._permissions
@property
def user_groups_names(self):
if not hasattr(self, '_user_groups_names'):
self._user_groups_names = crud.get_user_group_names(self.db, self.email) + ["default"]
if not hasattr(self, "_user_groups_names"):
# TODO: Define hidden properties in __init__ method
self._user_groups_names = crud.get_user_group_names(
self.db, self.email
) + ["default"]
return self._user_groups_names
@property
def user_groups(self):
if not hasattr(self, '_user_groups'):
self._user_groups = crud.get_user_groups_by_name(self.db, self.user_groups_names)
if not hasattr(self, "_user_groups"):
self._user_groups = crud.get_user_groups_by_name(
self.db, self.user_groups_names
)
return self._user_groups
@property
@@ -61,10 +72,11 @@ class UserState:
"""
Read can be a list of group names or True, if all can be read.
"""
if not hasattr(self, '_read'):
if not hasattr(self, "_read"):
self._read = set()
for group in self.user_groups:
if not group.permissions: continue
if not group.permissions:
continue
group_read_permissions = group.permissions.get("read", [])
if "all" in group_read_permissions:
self._read = True
@@ -78,10 +90,11 @@ class UserState:
"""
Read public permission
"""
if not hasattr(self, '_read_public'):
if not hasattr(self, "_read_public"):
self._read_public = False
for group in self.user_groups:
if not group.permissions: continue
if not group.permissions:
continue
if group.permissions.get("read_public", False):
self._read_public = True
return self._read_public
@@ -92,10 +105,11 @@ class UserState:
"""
Archive URL permission
"""
if not hasattr(self, '_archive_url'):
if not hasattr(self, "_archive_url"):
self._archive_url = False
for group in self.user_groups:
if not group.permissions: continue
if not group.permissions:
continue
if group.permissions.get("archive_url", False):
self._archive_url = True
return self._archive_url
@@ -106,10 +120,11 @@ class UserState:
"""
Archive sheet permission
"""
if not hasattr(self, '_archive_sheet'):
if not hasattr(self, "_archive_sheet"):
self._archive_sheet = False
for group in self.user_groups:
if not group.permissions: continue
if not group.permissions:
continue
if group.permissions.get("archive_sheet", False):
self._archive_sheet = True
return self._archive_sheet
@@ -117,37 +132,53 @@ class UserState:
@property
def sheet_frequency(self):
if not hasattr(self, '_sheet_frequency'):
if not hasattr(self, "_sheet_frequency"):
self._sheet_frequency = set()
for group in self.user_groups:
if not group.permissions: continue
self._sheet_frequency.update(group.permissions.get("sheet_frequency", None))
if not group.permissions:
continue
self._sheet_frequency.update(
group.permissions.get("sheet_frequency", None)
)
return self._sheet_frequency
@property
def max_archive_lifespan_months(self) -> int:
if not hasattr(self, '_max_archive_lifespan_months'):
self._max_archive_lifespan_months = self._helper_for_grouping_max_numerical_permissions("max_archive_lifespan_months")
if not hasattr(self, "_max_archive_lifespan_months"):
self._max_archive_lifespan_months = (
self._helper_for_grouping_max_numerical_permissions(
"max_archive_lifespan_months"
)
)
return self._max_archive_lifespan_months
@property
def max_monthly_urls(self) -> int:
if not hasattr(self, '_max_monthly_urls'):
self._max_monthly_urls = self._helper_for_grouping_max_numerical_permissions("max_monthly_urls")
if not hasattr(self, "_max_monthly_urls"):
self._max_monthly_urls = (
self._helper_for_grouping_max_numerical_permissions(
"max_monthly_urls"
)
)
return self._max_monthly_urls
@property
def max_monthly_mbs(self) -> int:
if not hasattr(self, '_max_monthly_mbs'):
self._max_monthly_mbs = self._helper_for_grouping_max_numerical_permissions("max_monthly_mbs")
if not hasattr(self, "_max_monthly_mbs"):
self._max_monthly_mbs = (
self._helper_for_grouping_max_numerical_permissions(
"max_monthly_mbs"
)
)
return self._max_monthly_mbs
@property
def priority(self) -> str:
if not hasattr(self, '_priority'):
if not hasattr(self, "_priority"):
self._priority = "low"
for group in self.user_groups:
if not group.permissions: continue
if not group.permissions:
continue
if group.permissions.get("priority", self._priority) == "high":
self._priority = "high"
break
@@ -158,18 +189,28 @@ class UserState:
"""
A user is active if they can read/archive anything
"""
if not hasattr(self, '_active'):
self._active = bool(self.read or self.read_public or self.archive_url or self.archive_sheet)
if not hasattr(self, "_active"):
self._active = bool(
self.read
or self.read_public
or self.archive_url
or self.archive_sheet
)
return self._active
def _helper_for_grouping_max_numerical_permissions(self, permission_name: str) -> int:
def _helper_for_grouping_max_numerical_permissions(
self, permission_name: str
) -> int:
"""
Iterates one of the numerical permissions where -1 means no restrictions and returns either -1 or the maximum value, defaults according to GroupPermissions
Iterates one of the numerical permissions where -1 means no restrictions
and returns either -1 or the maximum value, defaults according to
GroupPermissions
"""
default = GroupPermissions.model_fields[permission_name].default
max_value = default
for group in self.user_groups:
if not group.permissions: continue
if not group.permissions:
continue
group_value = group.permissions.get(permission_name, default)
if group_value == -1:
max_value = -1
@@ -180,43 +221,65 @@ class UserState:
def in_group(self, group_id: str) -> bool:
return group_id in self.user_groups_names
def usage(self) -> Dict:
def usage(self) -> UsageResponse:
"""
returns the monthly quotas for the URLs/MBs and the totals for Sheets
Returns the monthly quotas for the URLs/MBs and the totals for Sheets
"""
current_month = datetime.now().month
current_year = datetime.now().year
# find and sum all user sheets over this month
user_sheets = self.db.query(
models.Sheet.group_id,
func.count(models.Sheet.id).label('sheet_count')
).filter(models.Sheet.author_id == self.email).group_by(models.Sheet.group_id).all()
user_sheets = (
self.db.query(
models.Sheet.group_id,
func.count(models.Sheet.id).label("sheet_count"),
)
.filter(models.Sheet.author_id == self.email)
.group_by(models.Sheet.group_id)
.all()
)
sheets_by_group = {sheet.group_id: sheet.sheet_count for sheet in user_sheets}
sheets_by_group = {
sheet.group_id: sheet.sheet_count for sheet in user_sheets
}
# find and sum all user urls over this month
urls_by_group = self.db.query(
models.Archive.group_id,
func.count(models.Archive.id).label('url_count'),
func.coalesce(func.sum(
urls_by_group = (
self.db.query(
models.Archive.group_id,
func.count(models.Archive.id).label("url_count"),
func.coalesce(
func.cast(
func.json_extract(models.Archive.result, '$.metadata.total_bytes'),
sqlalchemy.Integer
), 0
)
), 0).label('total_bytes')
).filter(
models.Archive.author_id == self.email,
func.extract('month', models.Archive.created_at) == current_month,
func.extract('year', models.Archive.created_at) == current_year
).group_by(models.Archive.group_id).all()
func.sum(
func.coalesce(
func.cast(
func.json_extract(
models.Archive.result,
"$.metadata.total_bytes",
),
sqlalchemy.Integer,
),
0,
)
),
0,
).label("total_bytes"),
)
.filter(
models.Archive.author_id == self.email,
func.extract("month", models.Archive.created_at)
== current_month,
func.extract("year", models.Archive.created_at) == current_year,
)
.group_by(models.Archive.group_id)
.all()
)
# merge the two queries
usage_by_group: Dict[str, Usage] = {
(url.group_id or ""):
Usage(monthly_urls=url.url_count, monthly_mbs=int(url.total_bytes / 1024 / 1024))
(url.group_id or ""): Usage(
monthly_urls=url.url_count,
monthly_mbs=int(url.total_bytes / 1024 / 1024),
)
for url in urls_by_group
}
for group_id, sheet_count in sheets_by_group.items():
@@ -235,7 +298,7 @@ class UserState:
monthly_urls=total_urls,
monthly_mbs=int(total_bytes / 1024 / 1024),
total_sheets=total_sheets,
groups=usage_by_group
groups=usage_by_group,
)
def has_quota_monthly_sheets(self, group_id: str) -> bool:
@@ -245,7 +308,14 @@ class UserState:
if group_id not in self.permissions:
return False
user_sheets = self.db.query(models.Sheet).filter(models.Sheet.author_id == self.email, models.Sheet.group_id == group_id).count()
user_sheets = (
self.db.query(models.Sheet)
.filter(
models.Sheet.author_id == self.email,
models.Sheet.group_id == group_id,
)
.count()
)
sheet_quota = self.permissions[group_id].max_sheets
if sheet_quota == -1:
@@ -254,13 +324,15 @@ class UserState:
def has_quota_max_monthly_urls(self, group_id: str) -> bool:
"""
checks if a user has reached their monthly url quota for a group, if global then group should be empty string
Checks if a user has reached their monthly url quota for a group, if
global then group should be empty string
"""
quota = 0
if not group_id:
quota = self.max_monthly_urls
else:
if group_id not in self.permissions: return False
if group_id not in self.permissions:
return False
quota = self.permissions[group_id].max_monthly_urls
if quota == -1:
@@ -268,24 +340,31 @@ class UserState:
current_month = datetime.now().month
current_year = datetime.now().year
user_urls = self.db.query(models.Archive).filter(
models.Archive.author_id == self.email,
models.Archive.group_id == group_id,
func.extract('month', models.Archive.created_at) == current_month,
func.extract('year', models.Archive.created_at) == current_year
).count()
user_urls = (
self.db.query(models.Archive)
.filter(
models.Archive.author_id == self.email,
models.Archive.group_id == group_id,
func.extract("month", models.Archive.created_at)
== current_month,
func.extract("year", models.Archive.created_at) == current_year,
)
.count()
)
return user_urls < quota
def has_quota_max_monthly_mbs(self, group_id: str) -> bool:
"""
checks if a user has reached their monthly MBs quota for a group, if global then group should be empty string
Checks if a user has reached their monthly MBs quota for a group, if
global then group should be empty string
"""
quota = 0
if not group_id:
quota = self.max_monthly_mbs
else:
if group_id not in self.permissions: return False
if group_id not in self.permissions:
return False
quota = self.permissions[group_id].max_monthly_mbs
if quota == -1:
@@ -295,19 +374,34 @@ class UserState:
current_year = datetime.now().year
# find and sum all user bytes over this month
user_bytes = self.db.query(models.Archive).filter(
models.Archive.author_id == self.email,
models.Archive.group_id == group_id,
func.extract('month', models.Archive.created_at) == current_month,
func.extract('year', models.Archive.created_at) == current_year
).with_entities(func.coalesce(func.sum(
func.coalesce(
func.cast(
func.json_extract(models.Archive.result, '$.metadata.total_bytes'),
sqlalchemy.Integer
), 0
user_bytes = (
self.db.query(models.Archive)
.filter(
models.Archive.author_id == self.email,
models.Archive.group_id == group_id,
func.extract("month", models.Archive.created_at)
== current_month,
func.extract("year", models.Archive.created_at) == current_year,
)
), 0).label('total')).scalar()
.with_entities(
func.coalesce(
func.sum(
func.coalesce(
func.cast(
func.json_extract(
models.Archive.result,
"$.metadata.total_bytes",
),
sqlalchemy.Integer,
),
0,
)
),
0,
).label("total")
)
.scalar()
)
# convert bytes to mb
user_mbs = int(user_bytes / 1024 / 1024)
@@ -315,7 +409,7 @@ class UserState:
def can_manually_trigger(self, group_id: str) -> bool:
"""
checks if a user is allowed to manually trigger a sheet
Checks if a user is allowed to manually trigger a sheet
"""
if group_id not in self.permissions:
return False
@@ -324,18 +418,21 @@ class UserState:
def is_sheet_frequency_allowed(self, group_id: str, frequency: str) -> bool:
"""
checks if a user is allowed to create a sheet with this frequency for this group
Checks if a user is allowed to create a sheet with this frequency for
this group
"""
if group_id not in self.permissions:
return False
return frequency in self.permissions[group_id].sheet_frequency
def priority_group(self, group_id: str) -> str:
def priority_group(self, group_id: str) -> dict:
priority = "low"
for group in self.user_groups:
if group.id != group_id: continue
if not group.permissions: continue
if group.id != group_id:
continue
if not group.permissions:
continue
priority = group.permissions.get("priority", priority)
break
return convert_priority_to_queue_dict(priority)

View File

@@ -1,50 +0,0 @@
from typing import Dict
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import FileResponse, JSONResponse
from app.web.config import VERSION, BREAKING_CHANGES
from app.shared.schemas import ActiveUser, UsageResponse
from app.web.db.user_state import UserState
from app.web.security import get_user_state
from app.shared.user_groups import GroupInfo
default_router = APIRouter()
@default_router.get("/")
async def home():
return JSONResponse({"version": VERSION, "breakingChanges": BREAKING_CHANGES})
@default_router.get("/health")
async def health():
return JSONResponse({"status": "ok"})
@default_router.get("/user/active", summary="Check if the user is active and can use the tool.")
async def active(
user: UserState = Depends(get_user_state),
) -> ActiveUser:
return {"active": user.active}
@default_router.get("/user/permissions", summary="Get the user's global 'all' permissions and the permissions for each group they belong to.")
def get_user_permissions(
user: UserState = Depends(get_user_state),
) -> Dict[str, GroupInfo]:
return user.permissions
@default_router.get("/user/usage", summary="Get the user's monthly URLs/MBs usage along with the total active sheets, breakdown by group.")
def get_user_usage(
user: UserState = Depends(get_user_state),
) -> UsageResponse:
if not user.active:
raise HTTPException(status_code=403, detail="User is not active.")
return user.usage()
@default_router.get('/favicon.ico', include_in_schema=False)
async def favicon() -> FileResponse:
return FileResponse("app/web/static/favicon.ico")

View File

@@ -1,81 +0,0 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse
from sqlalchemy import exc
from sqlalchemy.orm import Session
from app.web.db.user_state import UserState
from app.shared import schemas
from app.shared.task_messaging import get_celery
from app.web.security import get_user_state
from app.web.db import crud
from app.shared.db.database import get_db_dependency
sheet_router = APIRouter(prefix="/sheet", tags=["Google Spreadsheet operations"])
celery = get_celery()
@sheet_router.post("/create", status_code=201, summary="Store a new Google Sheet for regular archiving.")
def create_sheet(
sheet: schemas.SheetAdd,
user: UserState = Depends(get_user_state),
db: Session = Depends(get_db_dependency),
) -> schemas.SheetResponse:
if not user.in_group(sheet.group_id):
raise HTTPException(status_code=403, detail="User does not have access to this group.")
if not user.has_quota_monthly_sheets(sheet.group_id):
raise HTTPException(status_code=429, detail="User has reached their sheet quota for this group.")
if not user.is_sheet_frequency_allowed(sheet.group_id, sheet.frequency):
raise HTTPException(status_code=422, detail="Invalid frequency selected for this group.")
try:
return crud.create_sheet(db, sheet.id, sheet.name, user.email, sheet.group_id, sheet.frequency)
except exc.IntegrityError as e:
raise HTTPException(status_code=400, detail="Sheet with this ID is already being archived.") from e
@sheet_router.get("/mine", status_code=200, summary="Get the authenticated user's Google Sheets.")
def get_user_sheets(
user: UserState = Depends(get_user_state),
db: Session = Depends(get_db_dependency)
) -> list[schemas.SheetResponse]:
return crud.get_user_sheets(db, user.email)
@sheet_router.delete("/{id}", summary="Delete a Google Sheet by ID.")
def delete_sheet(
id: str,
user: UserState = Depends(get_user_state),
db: Session = Depends(get_db_dependency),
) -> schemas.DeleteResponse:
return JSONResponse({
"id": id,
"deleted": crud.delete_sheet(db, id, user.email)
})
@sheet_router.post("/{id}/archive", status_code=201, summary="Trigger an archiving task for a GSheet you own.", response_description="task_id for the archiving task.")
def archive_user_sheet(
id: str,
user: UserState = Depends(get_user_state),
db: Session = Depends(get_db_dependency),
) -> schemas.Task:
sheet = crud.get_user_sheet(db, user.email, sheet_id=id)
if not sheet:
raise HTTPException(status_code=403, detail="No access to this sheet.")
if not user.in_group(sheet.group_id):
raise HTTPException(status_code=403, detail="User does not have access to this group.")
if not user.can_manually_trigger(sheet.group_id):
raise HTTPException(status_code=429, detail="User cannot manually trigger sheet archiving in this group.")
group_queue = user.priority_group(sheet.group_id)
task = celery.signature("create_sheet_task", args=[schemas.SubmitSheet(sheet_id=id, author_id=user.email, group_id=sheet.group_id).model_dump_json()]).apply_async(**group_queue)
return JSONResponse({"id": task.id}, status_code=201)

View File

@@ -1,40 +0,0 @@
from celery.result import AsyncResult
from fastapi import APIRouter, Depends
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from app.shared.task_messaging import get_celery
from app.web.security import get_token_or_user_auth
from app.shared import schemas
from app.shared.log import log_error
from app.web.utils.misc import custom_jsonable_encoder
task_router = APIRouter(prefix="/task", tags=["Async task operations"])
celery = get_celery()
@task_router.get("/{task_id}", summary="Check the status of an async task by its id, works for URLs and Sheet tasks.")
def get_status(task_id, email=Depends(get_token_or_user_auth)) -> schemas.TaskResult:
task = AsyncResult(task_id, app=celery)
try:
if task.status == "FAILURE":
# *FAILURE* The task raised an exception, or has exceeded the retry limit.
# The :attr:`result` attribute then contains the exception raised by the task.
# https://docs.celeryq.dev/en/stable/_modules/celery/result.html#AsyncResult
raise task.result
response = {
"id": task_id,
"status": task.status,
"result": task.result
}
return JSONResponse(jsonable_encoder(response, exclude_unset=True, custom_encoder={bytes: custom_jsonable_encoder}))
except Exception as e:
log_error(e)
return JSONResponse({
"id": task_id,
"status": "FAILURE",
"result": {"error": str(e)}
})

View File

@@ -1,85 +0,0 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse
from datetime import datetime
from loguru import logger
from sqlalchemy.orm import Session
from app.web.config import ALLOW_ANY_EMAIL
from app.shared import schemas
from app.shared.task_messaging import get_celery
from app.web.security import get_token_or_user_auth, get_user_state
from app.web.db import crud
from app.web.db.user_state import UserState
from app.shared.db.database import get_db_dependency
from urllib.parse import urlparse
from app.web.utils.misc import convert_priority_to_queue_dict
url_router = APIRouter(prefix="/url", tags=["Single URL operations"])
celery = get_celery()
@url_router.post("/archive", status_code=201, summary="Submit a single URL archive request, starts an archiving task.", response_description="task_id for the archiving task, will match the archive id.")
def archive_url(
archive: schemas.ArchiveTrigger,
email=Depends(get_token_or_user_auth),
db: Session = Depends(get_db_dependency)
) -> schemas.Task:
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {archive.url}")
parsed_url = urlparse(archive.url)
if not all([parsed_url.scheme, parsed_url.netloc]):
raise HTTPException(status_code=400, detail="Invalid URL received.")
archive_create = schemas.ArchiveCreate(**archive.model_dump())
if email != ALLOW_ANY_EMAIL:
archive_create.author_id = email
user = UserState(db, email)
if archive.group_id and not user.in_group(archive.group_id):
raise HTTPException(status_code=403, detail="User does not have access to this group.")
if not user.has_quota_max_monthly_urls(archive.group_id):
raise HTTPException(status_code=429, detail="User has reached their monthly URL quota.")
if not user.has_quota_max_monthly_mbs(archive.group_id):
raise HTTPException(status_code=429, detail="User has reached their monthly MB quota.")
group_queue = user.priority_group(archive_create.group_id)
else:
archive_create.author_id = archive.author_id or email
group_queue = convert_priority_to_queue_dict("high")
task = celery.signature("create_archive_task", args=[archive_create.model_dump_json()]).apply_async(**group_queue)
task_response = schemas.Task(id=task.id)
return JSONResponse(task_response.model_dump(), status_code=201)
@url_router.get("/search", summary="Search for archive entries by URL.")
def search_by_url(
url: str, skip: int = 0, limit: int = 25,
archived_after: datetime = None, archived_before: datetime = None,
db: Session = Depends(get_db_dependency),
email: str = Depends(get_token_or_user_auth)
) -> list[schemas.ArchiveResult]:
read_groups, read_public = False, False
if email != ALLOW_ANY_EMAIL:
user = UserState(db, email)
if not user.read and not user.read_public:
raise HTTPException(status_code=403, detail="User does not have read access.")
read_groups = user.read
read_public = user.read_public
return crud.search_archives_by_url(db, url.strip(), email, read_groups, read_public, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before)
@url_router.delete("/{id}", summary="Delete a single URL archive by id.")
def delete_archive(
id:str,
user: UserState = Depends(get_user_state),
db: Session = Depends(get_db_dependency)
) -> schemas.DeleteResponse:
logger.info(f"deleting url archive task {id} request by {user.email}")
return JSONResponse({
"id": id,
"deleted": crud.soft_delete_archive(db, id, user.email)
})

View File

@@ -1,22 +1,32 @@
import asyncio
from collections import defaultdict
import datetime
import logging
from collections import defaultdict
from contextlib import asynccontextmanager
import alembic.config
from fastapi import FastAPI
from contextlib import asynccontextmanager
from fastapi_mail import FastMail, MessageSchema, MessageType
from fastapi_utils.tasks import repeat_every
from loguru import logger
from fastapi_mail import FastMail, MessageSchema, MessageType
from app.shared.db import models
from app.shared.db.database import get_db, get_db_async, make_engine, wal_checkpoint
from app.shared import schemas
from app.shared.db import models
from app.shared.db.database import (
get_db,
get_db_async,
make_engine,
wal_checkpoint,
)
from app.shared.settings import get_settings
from app.shared.task_messaging import get_celery
from app.web.db import crud
from app.web.middleware import increase_exceptions_counter
from app.web.utils.metrics import measure_regular_metrics, redis_subscribe_worker_exceptions
from app.web.utils.metrics import (
measure_regular_metrics,
redis_subscribe_worker_exceptions,
)
celery = get_celery()
@@ -28,9 +38,22 @@ async def lifespan(app: FastAPI):
# STARTUP
engine = make_engine(get_settings().DATABASE_PATH)
models.Base.metadata.create_all(bind=engine)
alembic.config.main(prog="alembic", argv=['--raiseerr', 'upgrade', 'head'])
alembic.config.main(
prog="alembic",
argv=[
"-c",
"./app/migrations/alembic.ini",
"--raiseerr",
"upgrade",
"head",
],
)
logging.getLogger("uvicorn.access").disabled = True # loguru
asyncio.create_task(redis_subscribe_worker_exceptions(get_settings().REDIS_EXCEPTIONS_CHANNEL))
asyncio.create_task(
redis_subscribe_worker_exceptions(
get_settings().REDIS_EXCEPTIONS_CHANNEL
)
)
asyncio.create_task(repeat_measure_regular_metrics())
with get_db() as db:
crud.upsert_user_groups(db)
@@ -61,41 +84,74 @@ async def lifespan(app: FastAPI):
# CRON JOBS
@repeat_every(seconds=get_settings().REPEAT_COUNT_METRICS_SECONDS, on_exception=increase_exceptions_counter)
@repeat_every(
seconds=get_settings().REPEAT_COUNT_METRICS_SECONDS,
on_exception=increase_exceptions_counter,
)
async def repeat_measure_regular_metrics():
await measure_regular_metrics(get_settings().DATABASE_PATH, get_settings().REPEAT_COUNT_METRICS_SECONDS)
await measure_regular_metrics(
get_settings().DATABASE_PATH,
get_settings().REPEAT_COUNT_METRICS_SECONDS,
)
@repeat_every(seconds=60, wait_first=120, on_exception=increase_exceptions_counter)
@repeat_every(
seconds=60, wait_first=120, on_exception=increase_exceptions_counter
)
async def archive_hourly_sheets_cronjob():
await archive_sheets_cronjob("hourly", 60, datetime.datetime.now().minute)
@repeat_every(seconds=3600, wait_first=120, on_exception=increase_exceptions_counter)
@repeat_every(
seconds=3600, wait_first=120, on_exception=increase_exceptions_counter
)
async def archive_daily_sheets_cronjob():
await archive_sheets_cronjob("daily", 24, datetime.datetime.now().hour)
async def archive_sheets_cronjob(frequency: str, interval: int, current_time_unit: int):
async def archive_sheets_cronjob(
frequency: str, interval: int, current_time_unit: int
):
triggered_jobs = []
async with get_db_async() as db:
sheets = await crud.get_sheets_by_id_hash(db, frequency, interval, current_time_unit)
sheets = await crud.get_sheets_by_id_hash(
db, frequency, str(interval), current_time_unit
)
for s in sheets:
group_queue = await crud.get_group_priority_async(db, s.group_id)
task = celery.signature("create_sheet_task", args=[schemas.SubmitSheet(sheet_id=s.id, author_id=s.author_id, group_id=s.group_id).model_dump_json()]).apply_async(**group_queue)
task = celery.signature(
"create_sheet_task",
args=[
schemas.SubmitSheet(
sheet_id=s.id,
author_id=s.author_id,
group_id=s.group_id,
).model_dump_json()
],
).apply_async(**group_queue)
triggered_jobs.append({"sheet_id": s.id, "task_id": task.id})
logger.debug(f"[CRON {frequency.upper()}:{current_time_unit}] Triggered {len(triggered_jobs)} sheet tasks: {triggered_jobs}")
logger.debug(
f"[CRON {frequency.upper()}:{current_time_unit}] Triggered {len(triggered_jobs)} sheet tasks: {triggered_jobs}"
)
# TODO: on exception should logerror but also prometheus counter
DELETE_WINDOW = get_settings().DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS * 24 * 60 * 60
DELETE_WINDOW = (
get_settings().DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS * 24 * 60 * 60
)
@repeat_every(seconds=DELETE_WINDOW, wait_first=180, on_exception=increase_exceptions_counter)
@repeat_every(
seconds=DELETE_WINDOW,
wait_first=180,
on_exception=increase_exceptions_counter,
)
async def notify_about_expired_archives():
notify_from = datetime.datetime.now() + datetime.timedelta(days=get_settings().DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS)
notify_from = datetime.datetime.now() + datetime.timedelta(
days=get_settings().DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS
)
async with get_db_async() as db:
scheduled_deletions = await crud.find_by_store_until(db, notify_from)
@@ -104,10 +160,15 @@ async def notify_about_expired_archives():
user_archives[archive.author_id].append(archive)
if user_archives:
fastmail = FastMail(get_settings().MAIL_CONFIG)
fastmail = FastMail(get_settings().mail_config)
# notify users
for email in user_archives:
list_of_archives = "\n".join([f'{a.url}, {a.id}, {a.store_until.isoformat()}<br/>' for a in user_archives[email]])
list_of_archives = "\n".join(
[
f"{a.url}, {a.id}, {a.store_until.isoformat()}<br/>"
for a in user_archives[email]
]
)
# TODO: how can users download them in bulk?
message = MessageSchema(
subject="Auto Archiver: Archives Scheduled for Deletion",
@@ -127,16 +188,23 @@ async def notify_about_expired_archives():
</body>
</html>
""",
subtype=MessageType.html
subtype=MessageType.html,
)
await fastmail.send_message(message)
logger.debug(f"[CRON] Email sent to {email} about {len(user_archives[email])} scheduled archives deletion.")
logger.debug(
f"[CRON] Email sent to {email} about {len(user_archives[email])} scheduled archives deletion."
)
# now schedule the deletion event
asyncio.create_task(delete_expired_archives())
@repeat_every(max_repetitions=1, wait_first=10, seconds=0, on_exception=increase_exceptions_counter)
@repeat_every(
max_repetitions=1,
wait_first=10,
seconds=0,
on_exception=increase_exceptions_counter,
)
async def delete_expired_archives():
async with get_db_async() as db:
count_deleted = await crud.soft_delete_expired_archives(db)
@@ -144,19 +212,27 @@ async def delete_expired_archives():
logger.debug(f"[CRON] Deleted {count_deleted} archives.")
@repeat_every(seconds=86400, wait_first=150, on_exception=increase_exceptions_counter)
@repeat_every(
seconds=86400, wait_first=150, on_exception=increase_exceptions_counter
)
async def delete_stale_sheets():
STALE_DAYS = get_settings().DELETE_STALE_SHEETS_DAYS
logger.debug(f"[CRON] Deleting stale sheets older than {STALE_DAYS} days.")
async with get_db_async() as db:
user_sheets = await crud.delete_stale_sheets(db, STALE_DAYS)
if not user_sheets: return
if not user_sheets:
return
fastmail = FastMail(get_settings().MAIL_CONFIG)
fastmail = FastMail(get_settings().mail_config)
# notify users
for email in user_sheets:
list_of_sheets = "\n".join([f'<li><a href="https://docs.google.com/spreadsheets/d/{s.id}">{s.name}</a></li>' for s in user_sheets[email]])
list_of_sheets = "\n".join(
[
f'<li><a href="https://docs.google.com/spreadsheets/d/{s.id}">{s.name}</a></li>'
for s in user_sheets[email]
]
)
message = MessageSchema(
subject="Auto Archiver: Stale Sheets Removed",
recipients=[email],
@@ -173,14 +249,16 @@ async def delete_stale_sheets():
</body>
</html>
""",
subtype=MessageType.html
subtype=MessageType.html,
)
await fastmail.send_message(message)
logger.debug(f"[CRON] Email sent to {email} about stale sheets deletion.")
logger.debug(
f"[CRON] Email sent to {email} about stale sheets deletion."
)
# @repeat_at
async def generate_users_export_csv():
#TODO: implement a cronjob that regularly requested user data to a CSV file
# TODO: implement a cronjob that regularly requested user data to a CSV file
# see https://colab.research.google.com/drive/1QDbo3QXHPBdiTuANlA1AWVvN-rqxuCPa?authuser=0#scrollTo=4nPXeSdK8RBT
pass
pass

View File

@@ -1,34 +1,42 @@
import os
from fastapi import FastAPI, Depends
from fastapi.staticfiles import StaticFiles
from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from prometheus_fastapi_instrumentator import Instrumentator
from fastapi.staticfiles import StaticFiles
from loguru import logger
from prometheus_fastapi_instrumentator import Instrumentator
from app.web.middleware import logging_middleware
from app.shared.settings import Settings, get_settings
from app.shared.task_messaging import get_celery
from app.web.security import token_api_key_auth
from app.web.config import VERSION, API_DESCRIPTION
from app.web.config import API_DESCRIPTION, VERSION
from app.web.events import lifespan
from app.shared.settings import get_settings
from app.web.middleware import logging_middleware
from app.web.routers.default import router as default_router
from app.web.routers.interoperability import router as interoperability_router
from app.web.routers.sheet import router as sheet_router
from app.web.routers.task import router as task_router
from app.web.routers.url import router as url_router
from app.web.security import token_api_key_auth
from app.web.endpoints.default import default_router
from app.web.endpoints.url import url_router
from app.web.endpoints.sheet import sheet_router
from app.web.endpoints.task import task_router
from app.web.endpoints.interoperability import interoperability_router
celery = get_celery()
def app_factory(settings = get_settings()):
def app_factory(settings: Settings = None):
# TODO: Create dev, test, and prod versions of settings that do not have
# TODO: to be passed in as a parameter
if settings is None:
settings = get_settings()
app = FastAPI(
title="Auto-Archiver API",
description=API_DESCRIPTION,
version=VERSION,
contact={"name": "GitHub", "url": "https://github.com/bellingcat/auto-archiver-api"},
lifespan=lifespan
contact={
"name": "GitHub",
"url": "https://github.com/bellingcat/auto-archiver-api",
},
lifespan=lifespan,
)
app.add_middleware(
@@ -47,14 +55,30 @@ def app_factory(settings = get_settings()):
app.include_router(interoperability_router)
# prometheus exposed in /metrics with authentication
Instrumentator(should_group_status_codes=False, excluded_handlers=["/metrics", "/health", "/openapi.json", "/favicon.ico"]).instrument(app).expose(app, dependencies=[Depends(token_api_key_auth)])
Instrumentator(
should_group_status_codes=False,
excluded_handlers=[
"/metrics",
"/health",
"/openapi.json",
"/favicon.ico",
],
).instrument(app).expose(app, dependencies=[Depends(token_api_key_auth)])
if settings.SERVE_LOCAL_ARCHIVE:
local_dir = settings.SERVE_LOCAL_ARCHIVE
if not os.path.isdir(local_dir) and os.path.isdir(local_dir.replace("/app", ".")):
if not os.path.isdir(local_dir) and os.path.isdir(
local_dir.replace("/app", ".")
):
local_dir = local_dir.replace("/app", ".")
if len(settings.SERVE_LOCAL_ARCHIVE) > 1 and os.path.isdir(local_dir):
logger.warning(f"MOUNTing local archive, use this in development only {settings.SERVE_LOCAL_ARCHIVE}")
app.mount(settings.SERVE_LOCAL_ARCHIVE, StaticFiles(directory=local_dir), name=settings.SERVE_LOCAL_ARCHIVE)
logger.warning(
f"MOUNTing local archive, use this in development only {settings.SERVE_LOCAL_ARCHIVE}"
)
app.mount(
settings.SERVE_LOCAL_ARCHIVE,
StaticFiles(directory=local_dir),
name=settings.SERVE_LOCAL_ARCHIVE,
)
return app
return app

View File

@@ -1,7 +1,8 @@
import traceback
from loguru import logger
from fastapi import Request
from loguru import logger
from app.shared.log import log_error
from app.web.utils.metrics import EXCEPTION_COUNTER
@@ -9,23 +10,33 @@ from app.web.utils.metrics import EXCEPTION_COUNTER
async def logging_middleware(request: Request, call_next):
try:
response = await call_next(request)
#TODO: use Origin to have summary prometheus metrics on where requests come from
# TODO: use Origin to have summary prometheus metrics on where
# requests come from
# origin = request.headers.get("origin")
logger.info(f"{request.client.host}:{request.client.port} {request.method} {request.url._url} - HTTP {response.status_code}")
logger.info(
f"{request.client.host}:{request.client.port} {request.method} {request.url._url} - HTTP {response.status_code}"
)
return response
except Exception as e:
location = f"{request.method} {request.url._url}"
await increase_exceptions_counter(e, location)
logger.info(f"{request.client.host}:{request.client.port} {location} - {e.__class__.__name__} {e}")
logger.info(
f"{request.client.host}:{request.client.port} {location} - {e.__class__.__name__} {e}"
)
raise e
async def increase_exceptions_counter(e: Exception, location:str="cronjob"):
async def increase_exceptions_counter(
e: Exception, location: str = "cronjob"
) -> None:
if location == "cronjob":
try:
last_trace = traceback.extract_tb(e.__traceback__)[-1]
_file, _line, func_name, _text = last_trace
location = func_name
except Exception as e:
logger.error(f"Unable to get function name from cronjob exception traceback: {e}")
except Exception as e:
logger.error(
f"Unable to get function name from cronjob exception traceback: {e}"
)
EXCEPTION_COUNTER.labels(type=e.__class__.__name__, location=location).inc()
log_error(e)
log_error(e)

View File

View File

@@ -0,0 +1,64 @@
from http import HTTPStatus
from typing import Dict
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import FileResponse, JSONResponse
from app.shared.schemas import ActiveUser, UsageResponse
from app.shared.user_groups import GroupInfo
from app.web.config import BREAKING_CHANGES, VERSION
from app.web.db.user_state import UserState
from app.web.security import get_user_state
router = APIRouter()
@router.get("/")
async def home() -> JSONResponse:
return JSONResponse(
{"version": VERSION, "breakingChanges": BREAKING_CHANGES}
)
@router.get("/health")
async def health() -> JSONResponse:
return JSONResponse({"status": "ok"})
@router.get(
"/user/active", summary="Check if the user is active and can use the tool."
)
async def active(
user: UserState = Depends(get_user_state),
) -> ActiveUser:
return ActiveUser(active=user.active)
@router.get(
"/user/permissions",
summary="Get the user's global 'all' permissions and the permissions for each group they belong to.",
)
def get_user_permissions(
user: UserState = Depends(get_user_state),
) -> Dict[str, GroupInfo]:
return user.permissions
@router.get(
"/user/usage",
summary="Get the user's monthly URLs/MBs usage along with the total active sheets, breakdown by group.",
)
def get_user_usage(
user: UserState = Depends(get_user_state),
) -> UsageResponse:
if not user.active:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN, detail="User is not active."
)
return user.usage()
@router.get("/favicon.ico", include_in_schema=False)
async def favicon() -> FileResponse:
return FileResponse("app/web/static/favicon.ico")

View File

@@ -1,41 +1,53 @@
import json
from http import HTTPStatus
import sqlalchemy
from auto_archiver.core import Metadata
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse
from loguru import logger
import sqlalchemy
from auto_archiver.core import Metadata
from sqlalchemy.orm import Session
from app.shared.aa_utils import get_all_urls
from app.web.config import ALLOW_ANY_EMAIL
from app.shared import business_logic, schemas
from app.shared.db import worker_crud
from app.shared.db import models, worker_crud
from app.shared.db.database import get_db_dependency
from app.web.security import token_api_key_auth
from app.shared.db import models
from app.shared.log import log_error
from app.web.config import ALLOW_ANY_EMAIL
from app.web.security import token_api_key_auth
from app.web.utils.misc import get_all_urls
interoperability_router = APIRouter(prefix="/interop", tags=["Interoperability endpoints."])
router = APIRouter(prefix="/interop", tags=["Interoperability endpoints."])
# ----- endpoint to submit data archived elsewhere
@interoperability_router.post("/submit-archive", status_code=201, summary="Submit a manual archive entry, for data that was archived elsewhere.")
@router.post(
"/submit-archive",
status_code=HTTPStatus.CREATED,
summary="Submit a manual archive entry, for data that was archived elsewhere.",
)
def submit_manual_archive(
manual: schemas.SubmitManualArchive,
auth=Depends(token_api_key_auth),
db: Session = Depends(get_db_dependency)
db: Session = Depends(get_db_dependency),
):
try:
result: Metadata = Metadata.from_json(manual.result)
except json.JSONDecodeError as e:
log_error(e)
raise HTTPException(status_code=422, detail="Invalid JSON in result field.")
raise HTTPException(
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
detail="Invalid JSON in result field.",
) from e
manual.author_id = manual.author_id or ALLOW_ANY_EMAIL
manual.tags.add("manual")
store_until = business_logic.get_store_archive_until_or_never(db, manual.group_id)
logger.debug(f"[MANUAL ARCHIVE] {manual.author_id} {manual.url} {store_until}")
store_until = business_logic.get_store_archive_until_or_never(
db, manual.group_id
)
logger.debug(
f"[MANUAL ARCHIVE] {manual.author_id} {manual.url} {store_until}"
)
try:
archive = schemas.ArchiveCreate(
@@ -51,8 +63,15 @@ def submit_manual_archive(
)
db_archive = worker_crud.store_archived_url(db, archive)
logger.debug(f"[MANUAL ARCHIVE STORED] {db_archive.author_id} {db_archive.url}")
return JSONResponse({"id": db_archive.id}, status_code=201)
logger.debug(
f"[MANUAL ARCHIVE STORED] {db_archive.author_id} {db_archive.url}"
)
return JSONResponse(
{"id": db_archive.id}, status_code=HTTPStatus.CREATED
)
except sqlalchemy.exc.IntegrityError as e:
log_error(e)
raise HTTPException(status_code=422, detail=f"Cannot insert into DB due to integrity error, likely duplicate urls.")
raise HTTPException(
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
detail="Cannot insert into DB due to integrity error, likely duplicate urls.",
) from e

132
app/web/routers/sheet.py Normal file
View File

@@ -0,0 +1,132 @@
from http import HTTPStatus
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse
from sqlalchemy import exc
from sqlalchemy.orm import Session
from app.shared.db.database import get_db_dependency
from app.shared.schemas import (
DeleteResponse,
SheetAdd,
SheetResponse,
SubmitSheet,
)
from app.shared.task_messaging import get_celery
from app.web.db import crud
from app.web.db.user_state import UserState
from app.web.security import get_user_state
router = APIRouter(prefix="/sheet", tags=["Google Spreadsheet operations"])
celery = get_celery()
@router.post(
"/create",
status_code=HTTPStatus.CREATED,
summary="Store a new Google Sheet for regular archiving.",
)
def create_sheet(
sheet: SheetAdd,
user: UserState = Depends(get_user_state),
db: Session = Depends(get_db_dependency),
) -> SheetResponse:
if not user.in_group(sheet.group_id):
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="User does not have access to this group.",
)
if not user.has_quota_monthly_sheets(sheet.group_id):
raise HTTPException(
status_code=HTTPStatus.TOO_MANY_REQUESTS,
detail="User has reached their sheet quota for this group.",
)
if not user.is_sheet_frequency_allowed(sheet.group_id, sheet.frequency):
raise HTTPException(
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
detail="Invalid frequency selected for this group.",
)
try:
return crud.create_sheet(
db,
sheet.id,
sheet.name,
user.email,
sheet.group_id,
sheet.frequency,
)
except exc.IntegrityError as e:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="Sheet with this ID is already being archived.",
) from e
@router.get(
"/mine",
status_code=HTTPStatus.OK,
summary="Get the authenticated user's Google Sheets.",
)
def get_user_sheets(
user: UserState = Depends(get_user_state),
db: Session = Depends(get_db_dependency),
) -> list[SheetResponse]:
return crud.get_user_sheets(db, user.email)
@router.delete("/{sheet_id}", summary="Delete a Google Sheet by ID.")
def delete_sheet(
sheet_id: str,
user: UserState = Depends(get_user_state),
db: Session = Depends(get_db_dependency),
) -> DeleteResponse:
return DeleteResponse(
id=sheet_id, deleted=crud.delete_sheet(db, sheet_id, user.email)
)
@router.post(
"/{sheet_id}/archive",
status_code=HTTPStatus.CREATED,
summary="Trigger an archiving task for a GSheet you own.",
response_description="task_id for the archiving task.",
)
def archive_user_sheet(
sheet_id: str,
user: UserState = Depends(get_user_state),
db: Session = Depends(get_db_dependency),
) -> JSONResponse:
sheet = crud.get_user_sheet(db, user.email, sheet_id=sheet_id)
if not sheet:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN, detail="No access to this sheet."
)
if not user.in_group(sheet.group_id):
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="User does not have access to this group.",
)
if not user.can_manually_trigger(sheet.group_id):
raise HTTPException(
status_code=HTTPStatus.TOO_MANY_REQUESTS,
detail="User cannot manually trigger sheet archiving in this group.",
)
group_queue = user.priority_group(sheet.group_id)
task = celery.signature(
"create_sheet_task",
args=[
SubmitSheet(
sheet_id=sheet_id, author_id=user.email, group_id=sheet.group_id
).model_dump_json()
],
).apply_async(**group_queue)
return JSONResponse({"id": task.id}, status_code=HTTPStatus.CREATED)

52
app/web/routers/task.py Normal file
View File

@@ -0,0 +1,52 @@
from celery.result import AsyncResult
from fastapi import APIRouter, Depends
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from app.shared import schemas
from app.shared.constants import STATUS_FAILURE
from app.shared.log import log_error
from app.shared.task_messaging import get_celery
from app.web.security import get_token_or_user_auth
from app.web.utils.misc import custom_jsonable_encoder
router = APIRouter(prefix="/task", tags=["Async task operations"])
celery = get_celery()
@router.get(
"/{task_id}",
summary="Check the status of an async task by its id, works for URLs and Sheet tasks.",
)
def get_status(
task_id, email=Depends(get_token_or_user_auth)
) -> schemas.TaskResult:
task = AsyncResult(task_id, app=celery)
try:
if task.status == STATUS_FAILURE:
# *FAILURE* The task raised an exception, or has exceeded the retry limit.
# The :attr:`result` attribute then contains the exception raised by
# the task.
# https://docs.celeryq.dev/en/stable/_modules/celery/result.html#AsyncResult
raise task.result
response = {"id": task_id, "status": task.status, "result": task.result}
return JSONResponse(
jsonable_encoder(
response,
exclude_unset=True,
custom_encoder={bytes: custom_jsonable_encoder},
)
)
except Exception as e:
log_error(e)
return JSONResponse(
{
"id": task_id,
"status": STATUS_FAILURE,
"result": {"error": str(e)},
}
)

125
app/web/routers/url.py Normal file
View File

@@ -0,0 +1,125 @@
from datetime import datetime
from http import HTTPStatus
from urllib.parse import urlparse
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse
from loguru import logger
from sqlalchemy.orm import Session
from app.shared import schemas
from app.shared.db.database import get_db_dependency
from app.shared.schemas import DeleteResponse
from app.shared.task_messaging import get_celery
from app.web.config import ALLOW_ANY_EMAIL
from app.web.db import crud
from app.web.db.user_state import UserState
from app.web.security import get_token_or_user_auth, get_user_state
from app.web.utils.misc import convert_priority_to_queue_dict
router = APIRouter(prefix="/url", tags=["Single URL operations"])
celery = get_celery()
@router.post(
"/archive",
status_code=HTTPStatus.CREATED,
summary="Submit a single URL archive request, starts an archiving task.",
response_description="task_id for the archiving task, will match the archive id.",
)
def archive_url(
archive: schemas.ArchiveTrigger,
email=Depends(get_token_or_user_auth),
db: Session = Depends(get_db_dependency),
) -> JSONResponse:
logger.info(
f"new {archive.public=} task for {email=} and {archive.group_id=}: {archive.url}"
)
parsed_url = urlparse(archive.url)
if not all([parsed_url.scheme, parsed_url.netloc]):
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, detail="Invalid URL received."
)
archive_create = schemas.ArchiveCreate(**archive.model_dump())
if email != ALLOW_ANY_EMAIL:
archive_create.author_id = email
user = UserState(db, email)
if archive.group_id and not user.in_group(archive.group_id):
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="User does not have access to this group.",
)
if not user.has_quota_max_monthly_urls(archive.group_id):
raise HTTPException(
status_code=HTTPStatus.TOO_MANY_REQUESTS,
detail="User has reached their monthly URL quota.",
)
if not user.has_quota_max_monthly_mbs(archive.group_id):
raise HTTPException(
status_code=HTTPStatus.TOO_MANY_REQUESTS,
detail="User has reached their monthly MB quota.",
)
group_queue = user.priority_group(archive_create.group_id)
else:
archive_create.author_id = archive.author_id or email
group_queue = convert_priority_to_queue_dict("high")
task = celery.signature(
"create_archive_task", args=[archive_create.model_dump_json()]
).apply_async(**group_queue)
task_response = schemas.Task(id=task.id)
return JSONResponse(
task_response.model_dump(), status_code=HTTPStatus.CREATED
)
@router.get("/search", summary="Search for archive entries by URL.")
def search_by_url(
url: str,
skip: int = 0,
limit: int = 25,
archived_after: datetime = None,
archived_before: datetime = None,
db: Session = Depends(get_db_dependency),
email: str = Depends(get_token_or_user_auth),
) -> list[schemas.ArchiveResult]:
read_groups, read_public = False, False
if email != ALLOW_ANY_EMAIL:
user = UserState(db, email)
if not user.read and not user.read_public:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="User does not have read access.",
)
read_groups = user.read
read_public = user.read_public
return crud.search_archives_by_url(
db,
url.strip(),
email,
read_groups,
read_public,
skip=skip,
limit=limit,
archived_after=archived_after,
archived_before=archived_before,
)
@router.delete("/{archive_id}", summary="Delete a single URL archive by id.")
def delete_archive(
archive_id: str,
user: UserState = Depends(get_user_state),
db: Session = Depends(get_db_dependency),
) -> DeleteResponse:
logger.info(
f"deleting url archive task {archive_id} request by {user.email}"
)
return DeleteResponse(
id=archive_id,
deleted=crud.soft_delete_archive(db, archive_id, user.email),
)

View File

@@ -1,27 +1,32 @@
from loguru import logger
import requests, secrets
from fastapi import HTTPException, status, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.orm import Session
import secrets
from http import HTTPStatus
import firebase_admin
from firebase_admin import credentials, auth, exceptions
import requests
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from firebase_admin import auth, credentials, exceptions
from loguru import logger
from sqlalchemy.orm import Session
from app.web.config import ALLOW_ANY_EMAIL
from app.shared.settings import get_settings
from app.shared.db.database import get_db_dependency
from app.shared.settings import get_settings
from app.web.config import ALLOW_ANY_EMAIL
from app.web.db.user_state import UserState
settings = get_settings()
bearer_security = HTTPBearer()
FIREBASE_OAUTH_ENABLED = settings.FIREBASE_SERVICE_ACCOUNT_JSON != ""
if FIREBASE_OAUTH_ENABLED:
logger.debug("Firebase OAUTH enabled, initializing...")
firebase_admin.initialize_app(credentials.Certificate(settings.FIREBASE_SERVICE_ACCOUNT_JSON))
firebase_admin.initialize_app(
credentials.Certificate(settings.FIREBASE_SERVICE_ACCOUNT_JSON)
)
def secure_compare(token, api_key):
def secure_compare(token, api_key) -> bool:
return secrets.compare_digest(token.encode("utf8"), api_key.encode("utf8"))
@@ -29,9 +34,13 @@ def secure_compare(token, api_key):
def api_key_auth(api_key):
assert len(api_key) >= 20, "Invalid API key, must be at least 20 chars"
async def auth(bearer: HTTPAuthorizationCredentials = Depends(bearer_security), auto_error=True):
async def auth(
bearer: HTTPAuthorizationCredentials = Depends(bearer_security),
auto_error=True,
):
is_correct = secure_compare(bearer.credentials, api_key)
if is_correct: return True
if is_correct:
return True
if auto_error:
raise HTTPException(
@@ -43,18 +52,23 @@ def api_key_auth(api_key):
return auth
# --------------------- Token Auth for AA itself to query the API, AA setup tool and Prometheus
# --- Token Auth for AA itself to query the API, AA setup tool and Prometheus
token_api_key_auth = api_key_auth(settings.API_BEARER_TOKEN)
async def get_token_or_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
async def get_token_or_user_auth(
credentials: HTTPAuthorizationCredentials = Depends(bearer_security),
):
# tries to use the static API_KEY and defaults to google JWT auth
if await token_api_key_auth(credentials, auto_error=False): return ALLOW_ANY_EMAIL
if await token_api_key_auth(credentials, auto_error=False):
return ALLOW_ANY_EMAIL
return await get_user_auth(credentials)
async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
# validates the Bearer token in the case that it requires it
async def get_user_auth(
credentials: HTTPAuthorizationCredentials = Depends(bearer_security),
):
# Validates the Bearer token in the case that it requires it
valid_user, info = authenticate_user(credentials.credentials)
if valid_user:
return info.lower()
@@ -66,39 +80,55 @@ async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bear
)
def authenticate_user(access_token):
def authenticate_user(access_token) -> (bool, str):
if FIREBASE_OAUTH_ENABLED:
try:
j = auth.verify_id_token(access_token)
email = j.get('email', None)
logger.debug(f"Successfully verified the ID token for {email}")
if email is None:
return False, "email not found in token"
if email in settings.BLOCKED_EMAILS:
return False, f"email '{email}' not allowed"
return True, email
except exceptions.FirebaseError as e:
logger.warning(f"Error verifying ID token: {str(e)[:80]}...")
return firebase_login_attempt(access_token)
except exceptions.FirebaseError:
# used a non-Firebase token, fallback to Google OAuth
pass
# https://cloud.google.com/docs/authentication/token-types#access
if type(access_token) != str or len(access_token) < 10: return False, "invalid access_token"
r = requests.get("https://oauth2.googleapis.com/tokeninfo", {"access_token": access_token})
if r.status_code != 200: return False, "invalid token"
if not isinstance(access_token, str) or len(access_token) < 10:
return False, "invalid access_token"
r = requests.get(
"https://oauth2.googleapis.com/tokeninfo",
{"access_token": access_token},
)
if r.status_code != HTTPStatus.OK:
return False, "invalid token"
try:
j = r.json()
if j.get("azp") not in settings.CHROME_APP_IDS and j.get("aud") not in settings.CHROME_APP_IDS:
return False, f"token does not belong to valid APP_ID"
if (
j.get("azp") not in settings.CHROME_APP_IDS
and j.get("aud") not in settings.CHROME_APP_IDS
):
return False, "token does not belong to valid APP_ID"
if j.get("email") in settings.BLOCKED_EMAILS:
return False, f"email '{j.get('email')}' not allowed"
if j.get("email_verified") != "true":
return False, f"email '{j.get('email')}' not verified"
if int(j.get("expires_in", -1)) <= 0:
return False, "Token expired"
return True, j.get('email').lower()
return True, j.get("email").lower()
except Exception as e:
logger.warning(f"AUTH EXCEPTION occurred: {e}")
return False, "exception occurred"
def get_user_state(email: str = Depends(get_user_auth), db: Session = Depends(get_db_dependency)):
def firebase_login_attempt(access_token) -> (bool, str):
j = auth.verify_id_token(access_token)
email = j.get("email", None)
logger.debug(f"Successfully verified the ID token for {email}")
if email is None:
return False, "email not found in token"
if email in settings.BLOCKED_EMAILS:
return False, f"email '{email}' not allowed"
return True, email
def get_user_state(
email: str = Depends(get_user_auth),
db: Session = Depends(get_db_dependency),
) -> UserState:
return UserState(db, email)

View File

@@ -2,52 +2,57 @@ import asyncio
import json
import os
import shutil
from prometheus_client import Counter, Gauge
from app.web.db import crud
from app.shared.db.database import get_db
from app.shared.log import log_error
from app.shared.task_messaging import get_redis
from app.web.db import crud
# Custom metrics
EXCEPTION_COUNTER = Counter(
"exceptions",
"Number of times a certain exception has occurred.",
labelnames=["type", "location"]
labelnames=["type", "location"],
)
WORKER_EXCEPTION = Counter(
"worker_exceptions_total",
"Number of times a certain exception has occurred on the worker.",
labelnames=["type", "exception", "task", "traceback"]
labelnames=["type", "exception", "task", "traceback"],
)
DISK_UTILIZATION = Gauge(
"disk_utilization",
"Disk utilization in GB",
labelnames=["type"]
"disk_utilization", "Disk utilization in GB", labelnames=["type"]
)
DATABASE_METRICS = Gauge(
"database_metrics",
"Database metric readings at a certain point in time",
labelnames=["query"]
labelnames=["query"],
)
DATABASE_METRICS_COUNTER = Counter(
"database_metrics_counter",
"Database metrics that increase over time",
labelnames=["query", "user"]
labelnames=["query", "user"],
)
async def redis_subscribe_worker_exceptions(REDIS_EXCEPTIONS_CHANNEL: str):
# Subscribe to Redis channel and increment the counter for each exception with info on the exception and task
async def redis_subscribe_worker_exceptions(redis_exceptions_channel: str):
# Subscribe to Redis channel and increment the counter for each exception
# with info on the exception and task
Redis = get_redis()
PubSubExceptions = Redis.pubsub()
PubSubExceptions.subscribe(REDIS_EXCEPTIONS_CHANNEL)
PubSubExceptions.subscribe(redis_exceptions_channel)
while True:
message = PubSubExceptions.get_message()
if message and message["type"] == "message":
data = json.loads(message["data"].decode("utf-8"))
WORKER_EXCEPTION.labels(type=data["type"], exception=data["exception"], task=data["task"], traceback=data["traceback"]).inc()
WORKER_EXCEPTION.labels(
type=data["type"],
exception=data["exception"],
task=data["task"],
traceback=data["traceback"],
).inc()
await asyncio.sleep(1)
@@ -58,12 +63,19 @@ async def measure_regular_metrics(sqlite_db_url: str, repeat_in_seconds: int):
try:
fs = os.stat(sqlite_db_url.replace("sqlite:///", ""))
DISK_UTILIZATION.labels(type="database").set(fs.st_size / (2**30))
except Exception as e: log_error(e)
except Exception as e:
log_error(e)
with get_db() as db:
DATABASE_METRICS.labels(query="count_archives").set(crud.count_archives(db))
DATABASE_METRICS.labels(query="count_archive_urls").set(crud.count_archive_urls(db))
DATABASE_METRICS.labels(query="count_archives").set(
crud.count_archives(db)
)
DATABASE_METRICS.labels(query="count_archive_urls").set(
crud.count_archive_urls(db)
)
DATABASE_METRICS.labels(query="count_users").set(crud.count_users(db))
for user in crud.count_by_user_since(db, repeat_in_seconds):
DATABASE_METRICS_COUNTER.labels(query="count_by_user", user=user.author_id).inc(user.total)
DATABASE_METRICS_COUNTER.labels(
query="count_by_user", user=user.author_id
).inc(user.total)

View File

@@ -1,15 +1,62 @@
import base64
from typing import List
from auto_archiver.core import Media, Metadata
from fastapi.encoders import jsonable_encoder
from loguru import logger
from app.shared.db import models
def custom_jsonable_encoder(obj):
if isinstance(obj, bytes):
return base64.b64encode(obj).decode('utf-8')
return base64.b64encode(obj).decode("utf-8")
return jsonable_encoder(obj)
def convert_priority_to_queue_dict(priority: str) -> dict:
return {
"priority": 0 if priority == "high" else 10,
"queue": f"{priority}_priority"
"queue": f"{priority}_priority",
}
def convert_if_media(media):
if isinstance(media, Media):
return media
elif isinstance(media, dict):
try:
return Media.from_dict(media)
except Exception as e:
logger.debug(f"error parsing {media} : {e}")
return False
def get_all_urls(result: Metadata) -> List[models.ArchiveUrl]:
db_urls = []
for m in result.media:
for i, url in enumerate(m.urls):
db_urls.append(
models.ArchiveUrl(url=url, key=m.get("id", f"media_{i}"))
)
for k, prop in m.properties.items():
if prop_converted := convert_if_media(prop):
for i, url in enumerate(prop_converted.urls):
db_urls.append(
models.ArchiveUrl(
url=url, key=prop_converted.get("id", f"{k}_{i}")
)
)
if isinstance(prop, list):
for i, prop_media in enumerate(prop):
if prop_media := convert_if_media(prop_media):
for j, url in enumerate(prop_media.urls):
db_urls.append(
models.ArchiveUrl(
url=url,
key=prop_media.get(
"id", f"{k}{prop_media.key}_{i}.{j}"
),
)
)
return db_urls

View File

@@ -1,21 +1,22 @@
import datetime
import json
import traceback
import traceback, datetime
from auto_archiver.core.orchestrator import ArchivingOrchestrator
from celery.signals import task_failure
from loguru import logger
from sqlalchemy import exc
from auto_archiver.core.orchestrator import ArchivingOrchestrator
from app.shared.db import models
from app.shared import business_logic, constants, schemas
from app.shared.db import models, worker_crud
from app.shared.db.database import get_db
from app.shared import business_logic, schemas
from app.shared.task_messaging import get_celery, get_redis
from app.shared.settings import get_settings
from app.shared.log import log_error
from app.shared.aa_utils import get_all_urls
from app.shared.db import worker_crud
from app.shared.settings import get_settings
from app.shared.task_messaging import get_celery, get_redis
from app.web.utils.misc import get_all_urls
from app.worker.worker_log import setup_celery_logger
settings = get_settings()
celery = get_celery("worker")
@@ -24,26 +25,36 @@ Redis = get_redis()
USER_GROUPS_FILENAME = settings.USER_GROUPS_FILENAME
setup_celery_logger(celery)
AA_LOGGER_ID = None
# TODO: these are temporary PATCHES for new aa's functionality
# logger.add("app/worker/worker_log.log", level="DEBUG")
logger.remove = lambda x: print(f"logger.remove({x})")
# TODO: after release, as it requires updating past entries with sheet_id where tag is used, drop tags
@celery.task(name="create_archive_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 1})
# TODO: after release, as it requires updating past entries with sheet_id where tag
# is used, drop tags
@celery.task(
name="create_archive_task",
bind=True,
autoretry_for=(Exception,),
retry_backoff=True,
retry_kwargs={"max_retries": 1},
)
def create_archive_task(self, archive_json: str):
global AA_LOGGER_ID
archive = schemas.ArchiveCreate.model_validate_json(archive_json)
# call auto-archiver
args = get_orchestrator_args(archive.group_id, False, [archive.url])
result = None
try:
orchestrator = ArchivingOrchestrator()
orchestrator.logger_id = AA_LOGGER_ID # ensure single logger
orchestrator.setup(args)
result = next(orchestrator.feed())
AA_LOGGER_ID = orchestrator.logger_id
for orch_res in orchestrator.feed():
result = orch_res
except SystemExit as e:
log_error(e, f"create_archive_task: SystemExit from AA")
log_error(e, "create_archive_task: SystemExit from AA")
except Exception as e:
log_error(e, f"create_archive_task")
log_error(e, "create_archive_task")
raise e
assert result, f"UNABLE TO archive: {archive.url}"
@@ -59,13 +70,20 @@ def create_archive_task(self, archive_json: str):
@celery.task(name="create_sheet_task", bind=True)
def create_sheet_task(self, sheet_json: str):
global AA_LOGGER_ID
sheet = schemas.SubmitSheet.model_validate_json(sheet_json)
queue_name = (create_sheet_task.request.delivery_info or {}).get('routing_key', 'unknown')
queue_name = (create_sheet_task.request.delivery_info or {}).get(
"routing_key", "unknown"
)
logger.info(f"[queue={queue_name}] SHEET START {sheet=}")
args = get_orchestrator_args(sheet.group_id, True, ["--gsheet_feeder.sheet_id", sheet.sheet_id])
args = get_orchestrator_args(
sheet.group_id, True, [constants.SHEET_ID, sheet.sheet_id]
)
orchestrator = ArchivingOrchestrator()
orchestrator.logger_id = AA_LOGGER_ID # ensure single logger
orchestrator.setup(args)
AA_LOGGER_ID = orchestrator.logger_id
stats = {"archived": 0, "failed": 0, "errors": []}
try:
@@ -81,7 +99,7 @@ def create_sheet_task(self, sheet_json: str):
result=json.loads(result.to_json()),
sheet_id=sheet.sheet_id,
urls=get_all_urls(result),
store_until=get_store_until(sheet.group_id)
store_until=get_store_until(sheet.group_id),
)
insert_result_into_db(archive)
stats["archived"] += 1
@@ -94,25 +112,38 @@ def create_sheet_task(self, sheet_json: str):
stats["errors"].append(str(e))
except SystemExit as e:
log_error(e, f"create_sheet_task: SystemExit from AA")
log_error(e, "create_sheet_task: SystemExit from AA")
if stats["archived"] > 0:
with get_db() as session:
worker_crud.update_sheet_last_url_archived_at(session, sheet.sheet_id)
worker_crud.update_sheet_last_url_archived_at(
session, sheet.sheet_id
)
logger.info(f"SHEET DONE {sheet=}")
# TODO: is this used anywhere? maybe drop it
return schemas.CelerySheetTask(success=True, sheet_id=sheet.sheet_id, time=datetime.datetime.now().isoformat(), stats=stats).model_dump()
return schemas.CelerySheetTask(
success=True,
sheet_id=sheet.sheet_id,
time=datetime.datetime.now().isoformat(),
stats=stats,
).model_dump()
def get_orchestrator_args(group_id: str, orchestrator_for_sheet: bool, cli_args: list = []) -> list:
def get_orchestrator_args(
group_id: str, orchestrator_for_sheet: bool, cli_args: list = None
) -> list:
cli_args.append("--logging.enabled=false")
aa_configs = []
with get_db() as session:
group = worker_crud.get_group(session, group_id)
if orchestrator_for_sheet:
orchestrator_fn = group.orchestrator_sheet
else:
orchestrator_fn = worker_crud.get_group(session, group_id).orchestrator
orchestrator_fn = worker_crud.get_group(
session, group_id
).orchestrator
assert orchestrator_fn, f"no orchestrator found for {group_id}"
aa_configs.extend(["--config", orchestrator_fn])
aa_configs.extend(cli_args)
@@ -122,7 +153,9 @@ def get_orchestrator_args(group_id: str, orchestrator_for_sheet: bool, cli_args:
def insert_result_into_db(archive: schemas.ArchiveCreate) -> str:
with get_db() as session:
db_archive = worker_crud.store_archived_url(session, archive)
logger.debug(f"[ARCHIVE STORED] {db_archive.author_id} {db_archive.url}")
logger.debug(
f"[ARCHIVE STORED] {db_archive.author_id} {db_archive.url}"
)
return db_archive.id
@@ -131,13 +164,22 @@ def get_store_until(group_id: str) -> datetime.datetime:
return business_logic.get_store_archive_until(session, group_id)
def redis_publish_exception(exception, task_name, traceback: str = ""):
def redis_publish_exception(exception, task_name, trace_back: str = ""):
REDIS_EXCEPTIONS_CHANNEL = settings.REDIS_EXCEPTIONS_CHANNEL
try:
exception_data = {"task": task_name, "type": exception.__class__.__name__, "exception": exception, "traceback": traceback}
Redis.publish(REDIS_EXCEPTIONS_CHANNEL, json.dumps(exception_data, default=str))
exception_data = {
"task": task_name,
"type": exception.__class__.__name__,
"exception": exception,
"traceback": trace_back,
}
Redis.publish(
REDIS_EXCEPTIONS_CHANNEL, json.dumps(exception_data, default=str)
)
except Exception as e:
log_error(e, f"[CRITICAL] Could not publish to {REDIS_EXCEPTIONS_CHANNEL}")
log_error(
e, f"[CRITICAL] Could not publish to {REDIS_EXCEPTIONS_CHANNEL}"
)
@task_failure.connect(sender=create_sheet_task)
@@ -145,6 +187,10 @@ def redis_publish_exception(exception, task_name, traceback: str = ""):
def task_failure_notifier(sender, **kwargs):
# automatically capture exceptions in the worker tasks
logger.warning(f"⚠️ worker task failed: {sender.name}")
traceback_msg = "\n".join(traceback.format_list(traceback.extract_tb(kwargs['traceback'])))
log_error(kwargs['exception'], traceback_msg, f"task_failure: {sender.name}")
redis_publish_exception(kwargs['exception'], sender.name, traceback_msg)
traceback_msg = "\n".join(
traceback.format_list(traceback.extract_tb(kwargs["traceback"]))
)
log_error(
kwargs["exception"], traceback_msg, f"task_failure: {sender.name}"
)
redis_publish_exception(kwargs["exception"], sender.name, traceback_msg)

View File

@@ -1,29 +1,38 @@
from loguru import logger
from celery import Celery
import sys
from loguru import logger
from app.shared.task_messaging import get_celery
celery = get_celery("worker")
def setup_celery_logger(celery):
# Remove Celery's default handlers to prevent duplicate logs
celery_logger = celery.log.get_default_logger()
for handler in celery_logger.handlers[:]:
celery_logger.removeHandler(handler)
# Set up Loguru logging
logger.add("logs/celery_logs.log", retention="30 days", level="DEBUG")
logger.add("logs/celery_error_logs.log", retention="30 days", level="ERROR")
def setup_celery_logger(c):
# Remove Celery's default handlers to prevent duplicate logs
celery_logger = c.log.get_default_logger()
for handler in celery_logger.handlers[:]:
celery_logger.removeHandler(handler)
# Redirect Celery logs to Loguru
class InterceptHandler:
def write(self, message):
if message.strip():
logger.info(message.strip())
# Required to prevent issues with buffered output
def flush(self): pass
def isatty(self): return False
# Set up Loguru logging
logger.add("logs/celery_logs.log", retention="30 days", level="DEBUG")
logger.add("logs/celery_error_logs.log", retention="30 days", level="ERROR")
sys.stdout = InterceptHandler()
sys.stderr = InterceptHandler()
# Redirect Celery logs to Loguru
class InterceptHandler:
@staticmethod
def write(message):
if message.strip():
logger.info(message.strip())
# Required to prevent issues with buffered output
@staticmethod
def flush():
pass
@staticmethod
def isatty():
return False
sys.stdout = InterceptHandler()
sys.stderr = InterceptHandler()

View File

@@ -12,7 +12,7 @@ services:
- ALLOWED_ORIGINS=["http://localhost:8000","http://localhost:8004","http://localhost:8081","chrome-extension://ojcimmjndnlmmlgnjaeojoebaceokpdp"]
- USER_GROUPS_FILENAME=/aa-api/app/user-groups.dev.yaml
- DATABASE_PATH=sqlite:////aa-api/database/auto-archiver.db
worker:
# command: watchmedo auto-restart --patterns="*.py" --recursive --ignore-directories -- celery -- --app=app.worker.main.celery worker -Q high_priority,low_priority --concurrency=${CONCURRENCY} --max-tasks-per-child=100

View File

@@ -5,9 +5,9 @@ volumes:
name: "auto-archiver-api"
services:
web:
build:
build:
context: .
dockerfile: web.Dockerfile
dockerfile: docker/web/Dockerfile
restart: always
env_file: .env.prod
environment:
@@ -29,9 +29,9 @@ services:
retries: 3
worker:
build:
build:
context: .
dockerfile: worker.Dockerfile
dockerfile: docker/worker/Dockerfile
restart: always
env_file: .env.prod
command: celery --app=app.worker.main.celery worker -Q high_priority,low_priority --concurrency=${CONCURRENCY} --max-tasks-per-child=100 -O fair
@@ -68,4 +68,4 @@ services:
test: ["CMD", "redis-cli", "-a", "${REDIS_PASSWORD}", "ping"]
interval: 30s
timeout: 10s
retries: 3
retries: 3

View File

@@ -10,13 +10,12 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
&& rm -rf /var/lib/apt/lists/*
RUN pip install --no-cache-dir poetry
COPY pyproject.toml poetry.lock README.md .
COPY ../../pyproject.toml ../../poetry.lock ../../README.md ./
RUN poetry install --with web --no-interaction --no-ansi --no-cache
# Copy the application code and configurations
COPY alembic.ini ./
COPY ./app/ ./app/
COPY user-groups.* ./app/
COPY ../../app ./app/
COPY ../../user-groups.* ./app/
# Run the FastAPI app with Uvicorn
ENTRYPOINT ["poetry", "run"]

View File

@@ -20,14 +20,13 @@ RUN apt update -y && \
python3 -m venv ./poetry-venv && \
./poetry-venv/bin/python -m pip install --upgrade pip && \
./poetry-venv/bin/python -m pip install "poetry>=2.0.0,<3.0.0"
COPY pyproject.toml poetry.lock ./
COPY ../../pyproject.toml ../../poetry.lock ./
RUN ./poetry-venv/bin/poetry install --without dev --no-root --no-cache
# install dependencies
# copy source code and .env files over
COPY alembic.ini ./
COPY ./app/ ./app/
COPY user-groups.* ./app/
COPY ../../app ./app/
COPY ../../user-groups.* ./app/
ENTRYPOINT ["./poetry-venv/bin/poetry", "run"]
ENTRYPOINT ["./poetry-venv/bin/poetry", "run"]

1669
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -22,7 +22,6 @@ requires-python = ">=3.10,<3.13"
dependencies = [
"auto-archiver (>=0.13.1)",
"oscrypto @ git+https://github.com/wbond/oscrypto.git@d5f3437ed24257895ae1edd9e503cfb352e635a8",
"celery (>=5.0)",
"redis (==3.5.3)",
"loguru (>=0.7.3,<0.8.0)",
@@ -31,6 +30,16 @@ dependencies = [
"requests (>=2.25.1)",
"pyopenssl (>=23.3.0)",
]
[tool.pytest.ini_options]
pythonpath = "."
[tool.coverage.run]
omit = ["app/migrations/*"]
[tool.ruff.lint.flake8-bugbear]
extend-immutable-calls = ["fastapi.Depends", "fastapi.Query"]
[tool.poetry.group.worker.dependencies]
watchdog = ">=6.0.0,<7.0.0"
setuptools = "^75.8.0"
@@ -53,4 +62,4 @@ pytest = ">=8.3.4,<9.0.0"
httpx = ">=0.28.1,<0.29.0"
coverage = ">=7.6.11,<8.0.0"
pytest-asyncio = ">=0.25.3,<0.26.0"
pre-commit = "^4.1.0"

View File

@@ -59,4 +59,3 @@ groups:
permissions:
read: ["default"]
read_public: true