diff --git a/.coveragerc b/.coveragerc
deleted file mode 100644
index f6cc4e8..0000000
--- a/.coveragerc
+++ /dev/null
@@ -1,3 +0,0 @@
-[run]
-omit =
- app/migrations/*
diff --git a/.env.alembic b/.env.alembic
index 8691557..11bf2aa 100644
--- a/.env.alembic
+++ b/.env.alembic
@@ -2,4 +2,4 @@ CHROME_APP_IDS='["1234567890"]'
ALLOWED_ORIGINS='["allowed"]'
BLOCKED_EMAILS='[]'
DATABASE_PATH="sqlite:///./database/auto-archiver.db"
-API_BEARER_TOKEN=THIS_API_TOKEN_SHOULD_NEVER_BE_USED
\ No newline at end of file
+API_BEARER_TOKEN=THIS_API_TOKEN_SHOULD_NEVER_BE_USED
diff --git a/.env.example b/.env.example
index ef3935a..ea544ef 100644
--- a/.env.example
+++ b/.env.example
@@ -35,4 +35,4 @@ MAIL_SSL_TLS=True
# celery workers config
-CONCURRENCY=2
\ No newline at end of file
+CONCURRENCY=2
diff --git a/.env.test b/.env.test
index 32318f0..360f40e 100644
--- a/.env.test
+++ b/.env.test
@@ -5,4 +5,4 @@ BLOCKED_EMAILS='["blocked@example.com"]'
DATABASE_PATH="sqlite:///auto-archiver.test.db"
API_BEARER_TOKEN=this_is_the_test_api_token
-USER_GROUPS_FILENAME=app/tests/user-groups.test.yaml
\ No newline at end of file
+USER_GROUPS_FILENAME=app/tests/user-groups.test.yaml
diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md
new file mode 100644
index 0000000..3ec80e9
--- /dev/null
+++ b/.github/pull_request_template.md
@@ -0,0 +1,23 @@
+
+## Describe your changes
+
+
+## Non-obvious technical information
+
+
+## Checklist before requesting a review
+
+- [ ] The code runs successfully.
+
+```commandline
+HERE IS SOME COMMAND LINE OUTPUT
+```
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 9b63544..6b8c556 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -1,16 +1,12 @@
name: CI
on:
push:
- branches:
- - main
- - dev
+ branches: [ main, dev ]
pull_request:
- branches:
- - main
- - dev
+ branches: [ main, dev ]
jobs:
- test:
+ test-with-coverage:
runs-on: ubuntu-latest
services:
@@ -41,4 +37,4 @@ jobs:
run: poetry run coverage run -m pytest -v -ra --color=yes app/tests/
- name: Report coverage
- run: poetry run coverage report
\ No newline at end of file
+ run: poetry run coverage report
diff --git a/.github/workflows/format-and-fail.yml b/.github/workflows/format-and-fail.yml
new file mode 100644
index 0000000..c203aa9
--- /dev/null
+++ b/.github/workflows/format-and-fail.yml
@@ -0,0 +1,16 @@
+name: Format and Fail
+on:
+ push:
+ branches: [ main, dev ]
+ pull_request:
+ branches: [ main, dev ]
+
+jobs:
+ pre-commit:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ - uses: actions/setup-python@v5
+ with:
+ python-version: "3.11"
+ - uses: pre-commit/action@v3.0.0
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
new file mode 100644
index 0000000..1ad37f3
--- /dev/null
+++ b/.github/workflows/test.yml
@@ -0,0 +1,40 @@
+name: Run Tests
+on:
+ push:
+ branches: [ main, dev ]
+ pull_request:
+ branches: [ main, dev ]
+
+jobs:
+ test:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version: ['3.10', '3.11', '3.12']
+
+ services:
+ redis:
+ image: redis:6-alpine
+ ports:
+ - 6379:6379
+
+ steps:
+ - name: Set up Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - name: Checkout code
+ uses: actions/checkout@v4
+
+ - name: Install Poetry
+ run: pipx install poetry
+
+ - name: Install dependencies
+ run: poetry install --no-interaction --with dev
+
+ - name: Set dev environment variable
+ run: echo "ENVIRONMENT_FILE=.env.test" >> $GITHUB_ENV
+
+ - name: Run tests
+ run: poetry run pytest app/tests
diff --git a/.gitignore b/.gitignore
index 562885d..ebe19c0 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,16 +1,10 @@
+# Misc.
user-groups.dev.yaml
user-groups.yaml
orchestration.yaml
my-archives
*.pyc
-.DS_Store
secrets/*
-*.log
-__pycache__
-.pytest_cache
-.env
-.env.dev
-.env.prod
*.db
redis/data/*
.ipynb_checkpoints*
@@ -18,8 +12,6 @@ app/user-groups.yaml
app/user-groups.dev.yaml
wit*
app/crawls
-.coverage
-.pytest_cache/
htmlcov
local_archive
local_archive_test
@@ -27,6 +19,133 @@ local_archive_test
*db-shm
copy-files.sh
temp/
-.python-version
orchestration2.yaml
-database
\ No newline at end of file
+database
+
+# IDE files
+.idea
+.vscode
+**/.DS_Store
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env*
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..6707b21
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,78 @@
+repos:
+ - repo: https://github.com/nbQA-dev/nbQA
+ rev: 1.9.1
+ hooks:
+ - id: nbqa-ruff
+ args:
+ - --fix
+ - --target-version=py310
+ - --ignore=E721,E722
+ - --line-length=80
+ - id: nbqa-black
+ args:
+ - --line-length=80
+ - id: nbqa-isort
+ args:
+ - --float-to-top
+ - --profile=black
+
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v5.0.0
+ hooks:
+ - id: trailing-whitespace
+ - id: check-docstring-first
+ - id: check-executables-have-shebangs
+ - id: check-json
+ - id: check-case-conflict
+ - id: check-toml
+ - id: check-merge-conflict
+ - id: check-xml
+ - id: check-yaml
+ exclude: app/tests/user-groups.test.broken.yaml
+ - id: end-of-file-fixer
+ - id: check-symlinks
+ - id: mixed-line-ending
+ - id: sort-simple-yaml
+ - id: fix-encoding-pragma
+ args:
+ - --remove
+ - id: pretty-format-json
+ args:
+ - --autofix
+
+ - repo: https://github.com/pre-commit/pygrep-hooks
+ rev: v1.10.0
+ hooks:
+ - id: python-check-blanket-noqa
+ - id: python-check-mock-methods
+ - id: python-no-eval
+ - id: python-no-log-warn
+
+ - repo: https://github.com/PyCQA/isort
+ rev: 6.0.1
+ hooks:
+ - id: isort
+ name: Run isort to sort imports
+ files: \.py$
+ # To keep consistent with the global isort skip config defined in setup.cfg
+ exclude: ^build/.*$|^.tox/.*$|^venv/.*$
+ args:
+ - --lines-after-imports=2
+ - --profile=black
+ - --line-length=80
+
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.9.7
+ hooks:
+ - id: ruff
+ types_or: [python,pyi]
+ args:
+ - --fix
+ - --select=B,C,E,F,W,B9
+ - --line-length=80
+ - --ignore=E203,E402,E501,E261
+ - id: ruff-format
+ types_or: [ python,pyi]
+ args:
+ - --target-version=py310
+ - --line-length=80
diff --git a/CODEOWNERS b/CODEOWNERS
new file mode 100644
index 0000000..b5ffe06
--- /dev/null
+++ b/CODEOWNERS
@@ -0,0 +1 @@
+* @msramalho
diff --git a/LICENSE b/LICENSE
index e10dcd9..c5bae4c 100644
--- a/LICENSE
+++ b/LICENSE
@@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
\ No newline at end of file
+SOFTWARE.
diff --git a/Makefile b/Makefile
index 2fd462c..05a4bff 100644
--- a/Makefile
+++ b/Makefile
@@ -1,19 +1,33 @@
+.PHONY: lint
+lint:
+ poetry run pre-commit run --all-files
+
+.PHONY: test
+test:
+ export ENVIRONMENT_FILE=.env.test
+ poetry run coverage run -m pytest -v --disable-warnings --color=yes app/tests/
+ poetry run coverage report
+
+.PHONY: clean-dev
clean-dev:
@echo -n "Are you sure? [yes/N] (this will delete volumes) " && read ans && [ $${ans:-N} = yes ]
docker compose -f docker-compose.yml -f docker-compose.dev.yml down --volumes --remove-orphans
+.PHONY: dev
dev:
docker compose --env-file .env.dev -f docker-compose.yml -f docker-compose.dev.yml build
docker compose --env-file .env.dev -f docker-compose.yml -f docker-compose.dev.yml up --remove-orphans
-
+.PHONY: dev-redis-only
dev-redis-only:
docker compose --env-file .env.dev -f docker-compose.yml -f docker-compose.dev.yml build redis
docker compose --env-file .env.dev -f docker-compose.yml -f docker-compose.dev.yml up --remove-orphans redis
+.PHONY: stop-dev
stop-dev:
docker compose -f docker-compose.yml -f docker-compose.dev.yml down --volumes
+.PHONY: prod
prod:
docker compose --env-file .env.prod build
docker compose --env-file .env.prod up -d --remove-orphans
@@ -21,5 +35,6 @@ prod:
docker image prune -f
docker system df
+.PHONY: stop-prod
stop-prod:
- docker compose down
\ No newline at end of file
+ docker compose down
diff --git a/README.md b/README.md
index 2c87342..7ebe7a1 100644
--- a/README.md
+++ b/README.md
@@ -12,9 +12,9 @@ To properly set up the API you need to install `docker` and to have these files,
2. a `user-groups.yaml` to manage user permissions
1. note that all local files referenced in `user-groups.yaml` and any orchestration.yaml files should be relative to the home directory so if your service account is in `secrets/orchestration.yaml` use that path and not just `orchestration.yaml`.
2. go through the example file and configure it according to your needs.
-3. you will need to create and reference at least one `secrets/orchestration.yaml` file, you can do so by following the instructions in the [auto-archiver](https://github.com/bellingcat/auto-archiver#installation) that automatically generates one for you. If you use the archive sheets feature you will need to create a `orchestrationsheets-sheets.yaml` file as well that should have the `gsheet_feeder` and `gsheet_db` enabled and configured, the auto-archiver has [extensive documentation](https://auto-archiver.readthedocs.io/en/latest/) on how to set this up.
+3. you will need to create and reference at least one `secrets/orchestration.yaml` file, you can do so by following the instructions in the [auto-archiver](https://github.com/bellingcat/auto-archiver#installation) that automatically generates one for you. If you use the archive sheets feature you will need to create a `orchestrationsheets-sheets.yaml` file as well that should have the `gsheet_feeder_db` feeder and database enabled and configured, the auto-archiver has [extensive documentation](https://auto-archiver.readthedocs.io/en/latest/) on how to set this up.
-Do not commit those files, they are .gitignored by default.
+Do not commit those files, they are .gitignored by default.
We also advise you to keep any sensitive files in the `secrets/` folder which is pinned and gitignored.
We have examples for both of those files (`.env.example` and `user-groups.example.yaml`), and here's how to set them up whether you're in development or production:
@@ -108,6 +108,27 @@ Make sure environment and user-groups files are up to date.
Then `make prod`.
+## Development
+```bash
+# make sure all development dependencies are installed
+poetry install --with dev
+
+# this project uses pre-commit to enforce code style and formatting, set that up locally
+poetry run pre-commit install
+
+# you can test pre-commit with
+poetry run pre-commit run --all-files
+
+# this means pre-commit will always run with git commit, to skip it use
+git commit --no-verify
+
+# see the Makefile for more commands, but linting and formatting can be done with
+make lint
+
+# run all tests
+make test
+```
+
### Testing
```bash
# set the testing environment variables
diff --git a/app/web/endpoints/__init__.py b/app/__init__.py
similarity index 100%
rename from app/web/endpoints/__init__.py
rename to app/__init__.py
diff --git a/alembic.ini b/app/migrations/alembic.ini
similarity index 95%
rename from alembic.ini
rename to app/migrations/alembic.ini
index 30d7030..fa1578b 100644
--- a/alembic.ini
+++ b/app/migrations/alembic.ini
@@ -2,7 +2,7 @@
[alembic]
# path to migration scripts
-script_location = app/migrations
+script_location = ./app/migrations
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
@@ -10,10 +10,6 @@ script_location = app/migrations
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
-# sys.path path, will be prepended to sys.path if present.
-# defaults to the current working directory.
-prepend_sys_path = .
-
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python-dateutil library that can be
diff --git a/app/migrations/env.py b/app/migrations/env.py
index 870ef18..fc63a92 100644
--- a/app/migrations/env.py
+++ b/app/migrations/env.py
@@ -1,19 +1,20 @@
from logging.config import fileConfig
-from sqlalchemy import engine_from_config
-from sqlalchemy import pool
from alembic import context
+from sqlalchemy import engine_from_config, pool
from app.shared.settings import get_settings
+
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
-config.set_main_option('sqlalchemy.url', get_settings().DATABASE_PATH)
+config.set_main_option("sqlalchemy.url", get_settings().DATABASE_PATH)
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
- fileConfig(config.config_file_name, disable_existing_loggers=False) # disable_existing_loggers prevents loguru disabling
+ # disable_existing_loggers prevents loguru disabling
+ fileConfig(config.config_file_name, disable_existing_loggers=False)
# add your model's MetaData object here
# for 'autogenerate' support
diff --git a/app/migrations/versions/02b2f6d17ed0_create_archives_store_until_column.py b/app/migrations/versions/02b2f6d17ed0_create_archives_store_until_column.py
index d00fa2c..1bb5695 100644
--- a/app/migrations/versions/02b2f6d17ed0_create_archives_store_until_column.py
+++ b/app/migrations/versions/02b2f6d17ed0_create_archives_store_until_column.py
@@ -5,13 +5,14 @@ Revises: 1636724ec4b1
Create Date: 2025-02-08 15:22:20.392522
"""
-from alembic import op
+
import sqlalchemy as sa
+from alembic import op
# revision identifiers, used by Alembic.
-revision = '02b2f6d17ed0'
-down_revision = '1636724ec4b1'
+revision = "02b2f6d17ed0"
+down_revision = "1636724ec4b1"
branch_labels = None
depends_on = None
STORE_UNTIL_COL = "store_until"
@@ -20,15 +21,20 @@ STORE_UNTIL_COL = "store_until"
def upgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
- columns = [col['name'] for col in inspector.get_columns('archives')]
+ columns = [col["name"] for col in inspector.get_columns("archives")]
if STORE_UNTIL_COL not in columns:
- op.add_column('archives', sa.Column(STORE_UNTIL_COL, sa.DateTime(), nullable=True, default=None))
+ op.add_column(
+ "archives",
+ sa.Column(
+ STORE_UNTIL_COL, sa.DateTime(), nullable=True, default=None
+ ),
+ )
def downgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
- columns = [col['name'] for col in inspector.get_columns('archives')]
+ columns = [col["name"] for col in inspector.get_columns("archives")]
if STORE_UNTIL_COL in columns:
- op.drop_column('archives', STORE_UNTIL_COL)
+ op.drop_column("archives", STORE_UNTIL_COL)
diff --git a/app/migrations/versions/1636724ec4b1_rename_sheets_last_archived_col.py b/app/migrations/versions/1636724ec4b1_rename_sheets_last_archived_col.py
index 6c109f3..ad972fc 100644
--- a/app/migrations/versions/1636724ec4b1_rename_sheets_last_archived_col.py
+++ b/app/migrations/versions/1636724ec4b1_rename_sheets_last_archived_col.py
@@ -5,13 +5,14 @@ Revises: a23aaf3ae930
Create Date: 2025-02-05 19:19:01.984396
"""
-from alembic import op
+
import sqlalchemy as sa
+from alembic import op
# revision identifiers, used by Alembic.
-revision = '1636724ec4b1'
-down_revision = 'a23aaf3ae930'
+revision = "1636724ec4b1"
+down_revision = "a23aaf3ae930"
branch_labels = None
depends_on = None
@@ -19,14 +20,18 @@ depends_on = None
def upgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
- columns = [col['name'] for col in inspector.get_columns('sheets')]
- if 'last_archived_at' in columns:
- op.alter_column('sheets', 'last_archived_at', new_column_name='last_url_archived_at')
+ columns = [col["name"] for col in inspector.get_columns("sheets")]
+ if "last_archived_at" in columns:
+ op.alter_column(
+ "sheets", "last_archived_at", new_column_name="last_url_archived_at"
+ )
def downgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
- columns = [col['name'] for col in inspector.get_columns('sheets')]
- if 'last_url_archived_at' in columns:
- op.alter_column('sheets', 'last_url_archived_at', new_column_name='last_archived_at')
+ columns = [col["name"] for col in inspector.get_columns("sheets")]
+ if "last_url_archived_at" in columns:
+ op.alter_column(
+ "sheets", "last_url_archived_at", new_column_name="last_archived_at"
+ )
diff --git a/app/migrations/versions/63ac79df4ad0_add_new_service_account_email_column_to_.py b/app/migrations/versions/63ac79df4ad0_add_new_service_account_email_column_to_.py
index 7067746..0ea0b11 100644
--- a/app/migrations/versions/63ac79df4ad0_add_new_service_account_email_column_to_.py
+++ b/app/migrations/versions/63ac79df4ad0_add_new_service_account_email_column_to_.py
@@ -5,13 +5,14 @@ Revises: 02b2f6d17ed0
Create Date: 2025-02-11 21:53:23.293274
"""
-from alembic import op
+
import sqlalchemy as sa
+from alembic import op
# revision identifiers, used by Alembic.
-revision = '63ac79df4ad0'
-down_revision = '02b2f6d17ed0'
+revision = "63ac79df4ad0"
+down_revision = "02b2f6d17ed0"
branch_labels = None
depends_on = None
@@ -22,15 +23,17 @@ TABLE = "groups"
def upgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
- columns = [col['name'] for col in inspector.get_columns(TABLE)]
+ columns = [col["name"] for col in inspector.get_columns(TABLE)]
if NEW_COL not in columns:
- op.add_column(TABLE, sa.Column(NEW_COL, sa.String, nullable=True, default=None))
+ op.add_column(
+ TABLE, sa.Column(NEW_COL, sa.String, nullable=True, default=None)
+ )
def downgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
- columns = [col['name'] for col in inspector.get_columns(TABLE)]
+ columns = [col["name"] for col in inspector.get_columns(TABLE)]
if NEW_COL in columns:
op.drop_column(TABLE, NEW_COL)
diff --git a/app/migrations/versions/89121d2c96d8_add_sheet_id_to_archive_table.py b/app/migrations/versions/89121d2c96d8_add_sheet_id_to_archive_table.py
index 3011cf6..e34e06d 100644
--- a/app/migrations/versions/89121d2c96d8_add_sheet_id_to_archive_table.py
+++ b/app/migrations/versions/89121d2c96d8_add_sheet_id_to_archive_table.py
@@ -5,14 +5,14 @@ Revises: fa012ec405b8
Create Date: 2024-11-04 11:12:30.237299
"""
-from alembic import op
+
import sqlalchemy as sa
-from sqlalchemy.engine.reflection import Inspector
+from alembic import op
# revision identifiers, used by Alembic.
-revision = '89121d2c96d8'
-down_revision = 'fa012ec405b8'
+revision = "89121d2c96d8"
+down_revision = "fa012ec405b8"
branch_labels = None
depends_on = None
@@ -20,23 +20,27 @@ depends_on = None
def upgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
- columns = [col['name'] for col in inspector.get_columns('archives')]
+ columns = [col["name"] for col in inspector.get_columns("archives")]
- if 'sheet_id' not in columns:
- with op.batch_alter_table('archives') as batch_op:
- batch_op.add_column(sa.Column('sheet_id', sa.String(), nullable=True, default=None))
- batch_op.create_foreign_key('fk_sheet_id', 'sheets', ['sheet_id'], ['id'])
+ if "sheet_id" not in columns:
+ with op.batch_alter_table("archives") as batch_op:
+ batch_op.add_column(
+ sa.Column("sheet_id", sa.String(), nullable=True, default=None)
+ )
+ batch_op.create_foreign_key(
+ "fk_sheet_id", "sheets", ["sheet_id"], ["id"]
+ )
def downgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
- foreign_keys = [fk['name'] for fk in inspector.get_foreign_keys('archives')]
- columns = [col['name'] for col in inspector.get_columns('archives')]
+ foreign_keys = [fk["name"] for fk in inspector.get_foreign_keys("archives")]
+ columns = [col["name"] for col in inspector.get_columns("archives")]
- with op.batch_alter_table('archives') as batch_op:
- if 'fk_sheet_id' in foreign_keys:
- batch_op.drop_constraint('fk_sheet_id', type_='foreignkey')
+ with op.batch_alter_table("archives") as batch_op:
+ if "fk_sheet_id" in foreign_keys:
+ batch_op.drop_constraint("fk_sheet_id", type_="foreignkey")
- if 'sheet_id' in columns:
- batch_op.drop_column('sheet_id')
+ if "sheet_id" in columns:
+ batch_op.drop_column("sheet_id")
diff --git a/app/migrations/versions/9369a264945b_modify_archive_url_to_have_uuid_id_.py b/app/migrations/versions/9369a264945b_modify_archive_url_to_have_uuid_id_.py
index a2b708a..1f7f348 100644
--- a/app/migrations/versions/9369a264945b_modify_archive_url_to_have_uuid_id_.py
+++ b/app/migrations/versions/9369a264945b_modify_archive_url_to_have_uuid_id_.py
@@ -1,14 +1,16 @@
"""modify archive url to have uuid id instead of url unique constraint
Revision ID: 9369a264945b
-Revises:
+Revises:
Create Date: 2023-12-20 17:24:59.320691
"""
+
from alembic import op
+
# revision identifiers, used by Alembic.
-revision = '9369a264945b'
+revision = "9369a264945b"
down_revision = None
branch_labels = None
depends_on = None
@@ -18,10 +20,11 @@ def upgrade() -> None:
# since the primary key constraint is not named, we have to recreate it first
with op.batch_alter_table("archive_urls") as batch_op:
batch_op.create_primary_key("pk_url", ["url"])
- batch_op.drop_constraint("pk_url", type_='primary')
+ batch_op.drop_constraint("pk_url", type_="primary")
batch_op.create_primary_key("pk_url_archive_id", ["url", "archive_id"])
+
def downgrade() -> None:
with op.batch_alter_table("archive_urls") as batch_op:
- batch_op.drop_constraint("pk_url_archive_id", type_='primary')
+ batch_op.drop_constraint("pk_url_archive_id", type_="primary")
batch_op.create_primary_key("url", ["url"])
diff --git a/app/migrations/versions/93a611e4c066_vacuum_database_if_there_s_enough_space.py b/app/migrations/versions/93a611e4c066_vacuum_database_if_there_s_enough_space.py
index 6b8d8d2..b099f59 100644
--- a/app/migrations/versions/93a611e4c066_vacuum_database_if_there_s_enough_space.py
+++ b/app/migrations/versions/93a611e4c066_vacuum_database_if_there_s_enough_space.py
@@ -5,12 +5,13 @@ Revises: 9369a264945b
Create Date: 2023-12-20 18:33:27.132566
"""
+
from alembic import op
# revision identifiers, used by Alembic.
-revision = '93a611e4c066'
-down_revision = '9369a264945b'
+revision = "93a611e4c066"
+down_revision = "9369a264945b"
branch_labels = None
depends_on = None
@@ -20,7 +21,9 @@ def upgrade() -> None:
with op.get_context().autocommit_block():
op.execute("VACUUM")
except Exception as e:
- print("Unable to run vacuum, maybe there's not enough disk space. it should be 2x the size of the database")
+ print(
+ "Unable to run vacuum, maybe there's not enough disk space. it should be 2x the size of the database"
+ )
print(e)
diff --git a/app/migrations/versions/a23aaf3ae930_drop_active_column.py b/app/migrations/versions/a23aaf3ae930_drop_active_column.py
index 912f408..aa8f97b 100644
--- a/app/migrations/versions/a23aaf3ae930_drop_active_column.py
+++ b/app/migrations/versions/a23aaf3ae930_drop_active_column.py
@@ -5,13 +5,14 @@ Revises: 89121d2c96d8
Create Date: 2025-02-04 12:19:20.753570
"""
-from alembic import op
+
import sqlalchemy as sa
+from alembic import op
# revision identifiers, used by Alembic.
-revision = 'a23aaf3ae930'
-down_revision = '89121d2c96d8'
+revision = "a23aaf3ae930"
+down_revision = "89121d2c96d8"
branch_labels = None
depends_on = None
@@ -19,16 +20,24 @@ depends_on = None
def upgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
- columns = [col['name'] for col in inspector.get_columns('users')]
+ columns = [col["name"] for col in inspector.get_columns("users")]
- if 'is_active' in columns:
- op.drop_column('users', 'is_active')
+ if "is_active" in columns:
+ op.drop_column("users", "is_active")
def downgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
- columns = [col['name'] for col in inspector.get_columns('users')]
+ columns = [col["name"] for col in inspector.get_columns("users")]
- if 'is_active' not in columns:
- op.add_column('users', sa.Column('is_active', sa.Boolean(), nullable=False, server_default=sa.false()))
+ if "is_active" not in columns:
+ op.add_column(
+ "users",
+ sa.Column(
+ "is_active",
+ sa.Boolean(),
+ nullable=False,
+ server_default=sa.false(),
+ ),
+ )
diff --git a/app/migrations/versions/fa012ec405b8_add_columns_to_groups_table.py b/app/migrations/versions/fa012ec405b8_add_columns_to_groups_table.py
index f0577ea..f5cbcaa 100644
--- a/app/migrations/versions/fa012ec405b8_add_columns_to_groups_table.py
+++ b/app/migrations/versions/fa012ec405b8_add_columns_to_groups_table.py
@@ -5,14 +5,14 @@ Revises: 93a611e4c066
Create Date: 2024-10-31 09:36:50.360710
"""
-from alembic import op
+
import sqlalchemy as sa
-from sqlalchemy.engine.reflection import Inspector
+from alembic import op
# revision identifiers, used by Alembic.
-revision = 'fa012ec405b8'
-down_revision = '93a611e4c066'
+revision = "fa012ec405b8"
+down_revision = "93a611e4c066"
branch_labels = None
depends_on = None
@@ -20,26 +20,41 @@ depends_on = None
def upgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
- columns = [col['name'] for col in inspector.get_columns('groups')]
+ columns = [col["name"] for col in inspector.get_columns("groups")]
- if 'description' not in columns:
- op.add_column('groups', sa.Column('description', sa.String(), nullable=True))
- if 'orchestrator' not in columns:
- op.add_column('groups', sa.Column('orchestrator', sa.String(), nullable=True))
- if 'orchestrator_sheet' not in columns:
- op.add_column('groups', sa.Column('orchestrator_sheet', sa.String(), nullable=True))
- if 'permissions' not in columns:
- op.add_column('groups', sa.Column('permissions', sa.JSON(), nullable=True))
- if 'domains' not in columns:
- op.add_column('groups', sa.Column('domains', sa.JSON(), nullable=True))
+ if "description" not in columns:
+ op.add_column(
+ "groups", sa.Column("description", sa.String(), nullable=True)
+ )
+ if "orchestrator" not in columns:
+ op.add_column(
+ "groups", sa.Column("orchestrator", sa.String(), nullable=True)
+ )
+ if "orchestrator_sheet" not in columns:
+ op.add_column(
+ "groups",
+ sa.Column("orchestrator_sheet", sa.String(), nullable=True),
+ )
+ if "permissions" not in columns:
+ op.add_column(
+ "groups", sa.Column("permissions", sa.JSON(), nullable=True)
+ )
+ if "domains" not in columns:
+ op.add_column("groups", sa.Column("domains", sa.JSON(), nullable=True))
def downgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
- columns = [col['name'] for col in inspector.get_columns('groups')]
+ columns = [col["name"] for col in inspector.get_columns("groups")]
- column_names = ['description', 'orchestrator', 'orchestrator_sheet', 'permissions', 'domains']
+ column_names = [
+ "description",
+ "orchestrator",
+ "orchestrator_sheet",
+ "permissions",
+ "domains",
+ ]
for column_name in column_names:
if column_name in columns:
- op.drop_column('groups', column_name)
+ op.drop_column("groups", column_name)
diff --git a/app/shared/aa_utils.py b/app/shared/aa_utils.py
deleted file mode 100644
index 393a975..0000000
--- a/app/shared/aa_utils.py
+++ /dev/null
@@ -1,32 +0,0 @@
-# TODO: code in this file should eventually be moved to the auto-archiver code base
-
-from typing import List
-from loguru import logger
-from auto_archiver.core import Media, Metadata
-
-from app.shared.db import models
-
-def get_all_urls(result: Metadata) -> List[models.ArchiveUrl]:
- db_urls = []
- for m in result.media:
- for i, url in enumerate(m.urls): db_urls.append(models.ArchiveUrl(url=url, key=m.get("id", f"media_{i}")))
- for k, prop in m.properties.items():
- if prop_converted := convert_if_media(prop):
- for i, url in enumerate(prop_converted.urls): db_urls.append(models.ArchiveUrl(url=url, key=prop_converted.get("id", f"{k}_{i}")))
- if isinstance(prop, list):
- for i, prop_media in enumerate(prop):
- if prop_media := convert_if_media(prop_media):
- for j, url in enumerate(prop_media.urls):
- db_urls.append(models.ArchiveUrl(url=url, key=prop_media.get("id", f"{k}{prop_media.key}_{i}.{j}")))
- return db_urls
-
-
-
-def convert_if_media(media):
- if isinstance(media, Media): return media
- elif isinstance(media, dict):
- try: return Media.from_dict(media)
- except Exception as e:
- logger.debug(f"error parsing {media} : {e}")
- return False
-
diff --git a/app/shared/business_logic.py b/app/shared/business_logic.py
index d179fda..fb8f2d7 100644
--- a/app/shared/business_logic.py
+++ b/app/shared/business_logic.py
@@ -1,25 +1,35 @@
-# TODO: temporary file for this code, maybe other code belongs here, maybe not. do decide
-
-
import datetime
+from typing import Union
+
from sqlalchemy.orm import Session
from app.shared.db import worker_crud
-def get_store_archive_until(db: Session, group_id: str) -> datetime.datetime:
+# TODO: temporary file for this code, maybe other code belongs here, maybe not. do
+# decide
+
+
+def get_store_archive_until(
+ db: Session, group_id: str
+) -> Union[datetime.datetime, None]:
group = worker_crud.get_group(db, group_id)
assert group, f"Group {group_id} not found."
- assert group.permissions and type(group.permissions) == dict, f"Group {group_id} has no permissions."
+ assert group.permissions and isinstance(group.permissions, dict), (
+ f"Group {group_id} has no permissions."
+ )
max_lifespan = group.permissions.get("max_archive_lifespan_months", -1)
- if max_lifespan == -1: return None
+ if max_lifespan == -1:
+ return None
return datetime.datetime.now() + datetime.timedelta(days=30 * max_lifespan)
-def get_store_archive_until_or_never(db: Session, group_id: str) -> datetime.datetime:
+def get_store_archive_until_or_never(
+ db: Session, group_id: str
+) -> Union[datetime.datetime, None]:
try:
return get_store_archive_until(db, group_id)
- except AssertionError as e:
+ except AssertionError:
return None
diff --git a/app/shared/constants.py b/app/shared/constants.py
new file mode 100644
index 0000000..dd334fa
--- /dev/null
+++ b/app/shared/constants.py
@@ -0,0 +1,7 @@
+# Statuses
+STATUS_FAILURE = "FAILURE"
+STATUS_PENDING = "PENDING"
+STATUS_SUCCESS = "SUCCESS"
+
+# AA CLI CONFIGS
+SHEET_ID = "--gsheet_feeder_db.sheet_id"
diff --git a/app/shared/db/database.py b/app/shared/db/database.py
index 171b97b..08297a1 100644
--- a/app/shared/db/database.py
+++ b/app/shared/db/database.py
@@ -1,8 +1,14 @@
-from functools import lru_cache
-from sqlalchemy import Engine, create_engine, event, text
-from sqlalchemy.orm import sessionmaker
from contextlib import asynccontextmanager, contextmanager
-from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, AsyncEngine, async_sessionmaker
+from functools import lru_cache
+
+from sqlalchemy import Engine, create_engine, event, text
+from sqlalchemy.ext.asyncio import (
+ AsyncEngine,
+ AsyncSession,
+ async_sessionmaker,
+ create_async_engine,
+)
+from sqlalchemy.orm import sessionmaker
from app.shared.settings import get_settings
@@ -12,9 +18,9 @@ def make_engine(database_url: str):
engine = create_engine(
database_url,
connect_args={"check_same_thread": False},
- pool_size=15, # Increase pool size
- max_overflow=20, # Allow more temporary connections
- pool_recycle=1800 # Recycle connections every 30 minutes
+ pool_size=15, # Increase pool size
+ max_overflow=20, # Allow more temporary connections
+ pool_recycle=1800, # Recycle connections every 30 minutes
)
@event.listens_for(engine, "connect")
@@ -34,8 +40,10 @@ def make_session_local(engine: Engine):
@contextmanager
def get_db():
session = make_session_local(make_engine(get_settings().DATABASE_PATH))()
- try: yield session
- finally: session.close()
+ try:
+ yield session
+ finally:
+ session.close()
def get_db_dependency():
@@ -53,22 +61,32 @@ def wal_checkpoint():
# ASYNC connections
async def make_async_engine(database_url: str) -> AsyncEngine:
- engine = create_async_engine(database_url, connect_args={"check_same_thread": False})
+ engine = create_async_engine(
+ database_url, connect_args={"check_same_thread": False}
+ )
async with engine.begin() as conn:
- await conn.run_sync(lambda sync_conn: sync_conn.execute(text("PRAGMA journal_mode=WAL;")))
+ await conn.run_sync(
+ lambda sync_conn: sync_conn.execute(
+ text("PRAGMA journal_mode=WAL;")
+ )
+ )
return engine
async def make_async_session_local(engine: AsyncEngine) -> AsyncSession:
- return async_sessionmaker(engine, expire_on_commit=False, autoflush=False, autocommit=False)
+ return async_sessionmaker(
+ engine, expire_on_commit=False, autoflush=False, autocommit=False
+ )
@asynccontextmanager
async def get_db_async():
- engine = await make_async_engine(get_settings().ASYNC_DATABASE_PATH)
+ engine = await make_async_engine(get_settings().async_database_path)
async_session = await make_async_session_local(engine)
async with async_session() as session:
- try: yield session
- finally: await engine.dispose()
+ try:
+ yield session
+ finally:
+ await engine.dispose()
diff --git a/app/shared/db/models.py b/app/shared/db/models.py
index 1736224..8acedc1 100644
--- a/app/shared/db/models.py
+++ b/app/shared/db/models.py
@@ -1,8 +1,17 @@
-from sqlalchemy import Column, String, JSON, DateTime, Boolean, Table, ForeignKey
-from sqlalchemy.sql import func
-from sqlalchemy.orm import relationship, declarative_base
import uuid
+from sqlalchemy import (
+ JSON,
+ Boolean,
+ Column,
+ DateTime,
+ ForeignKey,
+ String,
+ Table,
+)
+from sqlalchemy.orm import declarative_base, relationship
+from sqlalchemy.sql import func
+
Base = declarative_base()
@@ -11,7 +20,7 @@ def generate_uuid():
return str(uuid.uuid4())
-# many to many association tables
+# many-to-many association tables
association_table_archive_tags = Table(
"mtm_archives_tags",
Base.metadata,
@@ -33,7 +42,9 @@ class Archive(Base):
id = Column(String, primary_key=True, index=True)
url = Column(String, index=True)
result = Column(JSON, default=None)
- public = Column(Boolean, default=True) # if public=false, access by group and author
+ public = Column(
+ Boolean, default=True
+ ) # if public=false, access by group and author
deleted = Column(Boolean, default=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
@@ -43,7 +54,11 @@ class Archive(Base):
author_id = Column(String, ForeignKey("users.email"))
sheet_id = Column(String, ForeignKey("sheets.id"), default=None)
- tags = relationship("Tag", back_populates="archives", secondary=association_table_archive_tags)
+ tags = relationship(
+ "Tag",
+ back_populates="archives",
+ secondary=association_table_archive_tags,
+ )
group = relationship("Group", back_populates="archives")
author = relationship("User", back_populates="archives")
urls = relationship("ArchiveUrl", back_populates="archive")
@@ -66,7 +81,11 @@ class Tag(Base):
id = Column(String, primary_key=True, index=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
- archives = relationship("Archive", back_populates="tags", secondary=association_table_archive_tags)
+ archives = relationship(
+ "Archive",
+ back_populates="tags",
+ secondary=association_table_archive_tags,
+ )
class User(Base):
@@ -76,7 +95,9 @@ class User(Base):
archives = relationship("Archive", back_populates="author")
sheets = relationship("Sheet", back_populates="author")
- groups = relationship("Group", back_populates="users", secondary=association_table_user_groups)
+ groups = relationship(
+ "Group", back_populates="users", secondary=association_table_user_groups
+ )
class Group(Base):
@@ -92,7 +113,9 @@ class Group(Base):
archives = relationship("Archive", back_populates="group")
sheets = relationship("Sheet", back_populates="group")
- users = relationship("User", back_populates="groups", secondary=association_table_user_groups)
+ users = relationship(
+ "User", back_populates="groups", secondary=association_table_user_groups
+ )
class Sheet(Base):
@@ -101,11 +124,27 @@ class Sheet(Base):
id = Column(String, primary_key=True, index=True, doc="Google Sheet ID")
name = Column(String, default=None)
author_id = Column(String, ForeignKey("users.email"))
- group_id = Column(String, ForeignKey("groups.id"), doc="Group ID, user must be in a group to create a sheet.")
- frequency = Column(String, default="daily", doc="Frequency of archiving: hourly, daily, weekly.")
+ group_id = Column(
+ String,
+ ForeignKey("groups.id"),
+ doc="Group ID, user must be in a group to create a sheet.",
+ )
+ frequency = Column(
+ String,
+ default="daily",
+ doc="Frequency of archiving: hourly, daily, weekly.",
+ )
# TODO: stats is not being used, consider removing
- stats = Column(JSON, default={}, doc="Sheet statistics like total links, total rows, ...")
- last_url_archived_at = Column(DateTime(timezone=True), server_default=func.now(), doc="Last time a new link was archived.")
+ stats = Column(
+ JSON,
+ default={},
+ doc="Sheet statistics like total links, total rows, ...",
+ )
+ last_url_archived_at = Column(
+ DateTime(timezone=True),
+ server_default=func.now(),
+ doc="Last time a new link was archived.",
+ )
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
diff --git a/app/shared/db/worker_crud.py b/app/shared/db/worker_crud.py
index 814689a..28d7af8 100644
--- a/app/shared/db/worker_crud.py
+++ b/app/shared/db/worker_crud.py
@@ -1,13 +1,17 @@
-from sqlalchemy.orm import Session
from datetime import datetime
-from app.shared.db import models
+from sqlalchemy.orm import Session
+
from app.shared import schemas
+from app.shared.db import models
+
# TODO: isolate database operations away from worker and into WEB
# ONLY WORKER
def update_sheet_last_url_archived_at(db: Session, sheet_id: str):
- db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id).first()
+ db_sheet = (
+ db.query(models.Sheet).filter(models.Sheet.id == sheet_id).first()
+ )
if db_sheet:
db_sheet.last_url_archived_at = datetime.now()
db.commit()
@@ -17,12 +21,17 @@ def update_sheet_last_url_archived_at(db: Session, sheet_id: str):
# ONLY WORKER and INTEROP
+
def get_group(db: Session, group_name: str) -> models.Group:
return db.query(models.Group).filter(models.Group.id == group_name).first()
+
def create_or_get_user(db: Session, author_id: str) -> models.User:
- if type(author_id) == str: author_id = author_id.lower()
- db_user = db.query(models.User).filter(models.User.email == author_id).first()
+ if isinstance(author_id, str):
+ author_id = author_id.lower()
+ db_user = (
+ db.query(models.User).filter(models.User.email == author_id).first()
+ )
if not db_user:
db_user = models.User(email=author_id)
db.add(db_user)
@@ -41,8 +50,22 @@ def create_tag(db: Session, tag: str) -> models.Tag:
return db_tag
-def create_archive(db: Session, archive: schemas.ArchiveCreate, tags: list[models.Tag], urls: list[models.ArchiveUrl]) -> models.Archive:
- db_archive = models.Archive(id=archive.id, url=archive.url, result=archive.result, public=archive.public, author_id=archive.author_id, group_id=archive.group_id, sheet_id=archive.sheet_id, store_until=archive.store_until)
+def create_archive(
+ db: Session,
+ archive: schemas.ArchiveCreate,
+ tags: list[models.Tag],
+ urls: list[models.ArchiveUrl],
+) -> models.Archive:
+ db_archive = models.Archive(
+ id=archive.id,
+ url=archive.url,
+ result=archive.result,
+ public=archive.public,
+ author_id=archive.author_id,
+ group_id=archive.group_id,
+ sheet_id=archive.sheet_id,
+ store_until=archive.store_until,
+ )
db_archive.tags = tags
db_archive.urls = urls
db.add(db_archive)
@@ -51,10 +74,14 @@ def create_archive(db: Session, archive: schemas.ArchiveCreate, tags: list[model
return db_archive
-def store_archived_url(db: Session, archive: schemas.ArchiveCreate) -> models.Archive:
+def store_archived_url(
+ db: Session, archive: schemas.ArchiveCreate
+) -> models.Archive:
# create and load user, tags, if needed
create_or_get_user(db, archive.author_id)
db_tags = [create_tag(db, tag) for tag in (archive.tags or [])]
# insert everything
- db_archive = create_archive(db, archive=archive, tags=db_tags, urls=archive.urls)
+ db_archive = create_archive(
+ db, archive=archive, tags=db_tags, urls=archive.urls
+ )
return db_archive
diff --git a/app/shared/log.py b/app/shared/log.py
index 68587e2..152bcc0 100644
--- a/app/shared/log.py
+++ b/app/shared/log.py
@@ -1,4 +1,5 @@
import traceback
+
from loguru import logger
@@ -6,8 +7,10 @@ from loguru import logger
logger.add("logs/api_logs.log", retention="30 days")
logger.add("logs/error_logs.log", retention="30 days", level="ERROR")
-
-def log_error(e: Exception, traceback_str: str = None, extra:str = ""):
- if not traceback_str: traceback_str = traceback.format_exc()
- if extra: extra = f"{extra}\n"
+
+def log_error(e: Exception, traceback_str: str = None, extra: str = ""):
+ if not traceback_str:
+ traceback_str = traceback.format_exc()
+ if extra:
+ extra = f"{extra}\n"
logger.error(f"{extra}{e.__class__.__name__}: {e}\n{traceback_str}")
diff --git a/app/shared/schemas.py b/app/shared/schemas.py
index 66119f7..16dec0f 100644
--- a/app/shared/schemas.py
+++ b/app/shared/schemas.py
@@ -1,7 +1,8 @@
+from datetime import datetime
from typing import Annotated
+
from annotated_types import Len
from pydantic import BaseModel
-from datetime import datetime
class SubmitSheet(BaseModel):
@@ -10,6 +11,7 @@ class SubmitSheet(BaseModel):
group_id: str = "default"
tags: set[str] | None = set()
+
class ArchiveUrl(BaseModel):
url: str
public: bool = False
@@ -17,6 +19,7 @@ class ArchiveUrl(BaseModel):
group_id: str | None
tags: set[str] | None = set()
+
class ArchiveResult(BaseModel):
id: str
url: str
diff --git a/app/shared/settings.py b/app/shared/settings.py
index a9eae6a..d7c5c39 100644
--- a/app/shared/settings.py
+++ b/app/shared/settings.py
@@ -1,31 +1,38 @@
-
-from functools import lru_cache
import os
+from functools import lru_cache
+from typing import Annotated, Set
+
+from annotated_types import Len
from fastapi_mail import ConnectionConfig
from pydantic_settings import BaseSettings, SettingsConfigDict
-from typing import Annotated, Set
-from annotated_types import Len
class Settings(BaseSettings):
-
- model_config = SettingsConfigDict(env_file=os.environ.get("ENVIRONMENT_FILE") , env_file_encoding='utf-8', extra='ignore', str_strip_whitespace=True)
+ model_config = SettingsConfigDict(
+ env_file=os.environ.get("ENVIRONMENT_FILE"),
+ env_file_encoding="utf-8",
+ extra="ignore",
+ str_strip_whitespace=True,
+ )
- # general
+ # general
SERVE_LOCAL_ARCHIVE: str | None = None
USER_GROUPS_FILENAME: str = "app/user-groups.yaml"
- # database
+ # database
DATABASE_PATH: str
DATABASE_QUERY_LIMIT: int = 100
+
@property
- def ASYNC_DATABASE_PATH(self) -> str:
+ def async_database_path(self) -> str:
return self.DATABASE_PATH.replace("sqlite://", "sqlite+aiosqlite://")
- # security
+ # security
API_BEARER_TOKEN: Annotated[str, Len(min_length=20)]
ALLOWED_ORIGINS: Annotated[Set[str], Len(min_length=1)]
- CHROME_APP_IDS: Annotated[Set[Annotated[str, Len(min_length=10)]], Len(min_length=1)]
+ CHROME_APP_IDS: Annotated[
+ Set[Annotated[str, Len(min_length=10)]], Len(min_length=1)
+ ]
BLOCKED_EMAILS: Annotated[Set[str], Len(min_length=0)] = set()
# if not provided only OAUTH access_tokens are allowed
FIREBASE_SERVICE_ACCOUNT_JSON: str = ""
@@ -34,20 +41,21 @@ class Settings(BaseSettings):
REDIS_PASSWORD: str = ""
REDIS_HOSTNAME: str = "localhost"
REDIS_EXCEPTIONS_CHANNEL: str = "exceptions-channel"
+
@property
- def CELERY_BROKER_URL(self)-> str:
+ def celery_broker_url(self) -> str:
if self.REDIS_PASSWORD:
return f"redis://:{self.REDIS_PASSWORD}@{self.REDIS_HOSTNAME}:6379"
return f"redis://{self.REDIS_HOSTNAME}:6379"
-
+
# cronjobs
CRON_ARCHIVE_SHEETS: bool = False
CRON_DELETE_STALE_SHEETS: bool = False
DELETE_STALE_SHEETS_DAYS: int = 14
CRON_DELETE_SCHEDULED_ARCHIVES: bool = False
DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS: int = 7
-
- # observability
+
+ # observability
REPEAT_COUNT_METRICS_SECONDS: int = 30
# email configuration, if needed
@@ -59,8 +67,9 @@ class Settings(BaseSettings):
MAIL_PORT: int = 587
MAIL_STARTTLS: bool = False
MAIL_SSL_TLS: bool = True
+
@property
- def MAIL_CONFIG(self) -> str:
+ def mail_config(self) -> ConnectionConfig:
return ConnectionConfig(
MAIL_FROM=self.MAIL_FROM,
MAIL_FROM_NAME=self.MAIL_FROM_NAME,
@@ -75,4 +84,4 @@ class Settings(BaseSettings):
@lru_cache
def get_settings():
- return Settings()
\ No newline at end of file
+ return Settings()
diff --git a/app/shared/task_messaging.py b/app/shared/task_messaging.py
index 21fb3d1..95438b5 100644
--- a/app/shared/task_messaging.py
+++ b/app/shared/task_messaging.py
@@ -1,8 +1,8 @@
-
from functools import lru_cache
-from celery import Celery
-import redis
+from celery import Celery
+
+import redis
from app.shared.settings import get_settings
@@ -10,14 +10,14 @@ from app.shared.settings import get_settings
def get_celery(name: str = "") -> Celery:
return Celery(
name,
- broker_url=get_settings().CELERY_BROKER_URL,
- result_backend=get_settings().CELERY_BROKER_URL,
+ broker_url=get_settings().celery_broker_url,
+ result_backend=get_settings().celery_broker_url,
broker_connection_retry_on_startup=False,
broker_transport_options={
- 'queue_order_strategy': 'priority',
- }
+ "queue_order_strategy": "priority",
+ },
)
def get_redis() -> redis.Redis:
- return redis.Redis.from_url(get_settings().CELERY_BROKER_URL)
+ return redis.Redis.from_url(get_settings().celery_broker_url)
diff --git a/app/shared/user_groups.py b/app/shared/user_groups.py
index 592e012..444fd6b 100644
--- a/app/shared/user_groups.py
+++ b/app/shared/user_groups.py
@@ -1,9 +1,16 @@
import json
import os
+from typing import Dict, List, Set
+
import yaml
from loguru import logger
-from pydantic import BaseModel, computed_field, field_validator, Field, model_validator
-from typing import Dict, List, Set
+from pydantic import (
+ BaseModel,
+ Field,
+ computed_field,
+ field_validator,
+ model_validator,
+)
from typing_extensions import Self
@@ -12,13 +19,16 @@ class UserGroups:
user_groups = self.read_yaml(filename)
self.validate_and_load(user_groups)
- def read_yaml(self, user_groups_filename):
+ @staticmethod
+ def read_yaml(user_groups_filename):
# read yaml safely
with open(user_groups_filename) as inf:
try:
return yaml.safe_load(inf)
except yaml.YAMLError as e:
- logger.error(f"could not open user groups filename {user_groups_filename}: {e}")
+ logger.error(
+ f"could not open user groups filename {user_groups_filename}: {e}"
+ )
raise e
def validate_and_load(self, user_groups):
@@ -45,22 +55,36 @@ class GroupPermissions(BaseModel):
max_monthly_mbs: int = 0
priority: str = "low"
- @field_validator('max_sheets', 'max_archive_lifespan_months', 'max_monthly_urls', 'max_monthly_mbs', mode='before')
+ @classmethod
+ @field_validator(
+ "max_sheets",
+ "max_archive_lifespan_months",
+ "max_monthly_urls",
+ "max_monthly_mbs",
+ mode="before",
+ )
def validate_max_values(cls, v):
if v < -1:
- raise ValueError("max_* values should be positive integers or -1 (for no limit).")
+ raise ValueError(
+ "max_* values should be positive integers or -1 (for no limit)."
+ )
return v
- @field_validator('sheet_frequency', mode='before')
+ @classmethod
+ @field_validator("sheet_frequency", mode="before")
def validate_sheet_frequency(cls, v):
- if not v: return []
+ if not v:
+ return []
allowed = ["daily", "hourly"]
for k in v:
if k not in allowed:
- raise ValueError(f"Invalid sheet frequency: '{k}', expected one of {allowed}")
+ raise ValueError(
+ f"Invalid sheet frequency: '{k}', expected one of {allowed}"
+ )
return v
- @field_validator('priority', mode='before')
+ @classmethod
+ @field_validator("priority", mode="before")
def validate_priority(cls, v):
v = v.lower()
if v not in ["low", "high"]:
@@ -70,19 +94,31 @@ class GroupPermissions(BaseModel):
class GroupModel(BaseModel):
description: str
- orchestrator: str
- orchestrator_sheet: str
+ orchestrator: str | None = None
+ orchestrator_sheet: str | None = None
permissions: GroupPermissions
- @field_validator('orchestrator', 'orchestrator_sheet', mode='before')
+ @classmethod
+ @field_validator("orchestrator", mode="before")
def validate_orchestrator(cls, v):
- if not os.path.exists(v):
+ # orchestrator is only needed if the group has archive_url permission
+ if cls.permissions.archive_url and not os.path.exists(v):
+ raise ValueError(f"Orchestrator file not found with this path: {v}")
+ return v
+
+ @classmethod
+ @field_validator("orchestrator_sheet", mode="before")
+ def validate_orchestrator_sheet(cls, v):
+ # orchestrator_sheet is only needed if the group has archive_sheet permission
+ if cls.permissions.archive_sheet and not os.path.exists(v):
raise ValueError(f"Orchestrator file not found with this path: {v}")
return v
@computed_field
@property
def service_account_email(self) -> str:
+ if self.orchestrator_sheet is None:
+ return ""
if hasattr(self, "_service_account_email"):
return self._service_account_email
orch = yaml.safe_load(open(self.orchestrator_sheet))
@@ -98,13 +134,17 @@ class GroupModel(BaseModel):
service_account_json = find_service_account_email(orch)
if not service_account_json:
- raise ValueError(f"service_account key not found in orchestrator sheet file: {self.orchestrator_sheet}.")
+ raise ValueError(
+ f"service_account key not found in orchestrator sheet file: {self.orchestrator_sheet}."
+ )
with open(service_account_json) as f:
self._service_account_email = json.load(f).get("client_email")
if not self._service_account_email:
- raise ValueError(f"Service account email not found in {service_account_json}.")
+ raise ValueError(
+ f"Service account email not found in {service_account_json}."
+ )
return self._service_account_email
@@ -114,29 +154,45 @@ class UserGroupModel(BaseModel):
domains: Dict[str, List[str]] = Field(default_factory=dict)
groups: Dict[str, GroupModel] = Field(default_factory=dict)
- @field_validator('users', mode='before')
@classmethod
+ @field_validator("users", mode="before")
def validate_emails(cls, v):
for email in v.keys():
- if '@' not in email:
- raise ValueError(f"Invalid user, it should be an address: {email}")
+ if "@" not in email:
+ raise ValueError(
+ f"Invalid user, it should be an address: {email}"
+ )
if not v[email]:
- raise ValueError(f"User {email} has no explicitly listed groups, only include them here if they should be in a group.")
+ raise ValueError(
+ f"User {email} has no explicitly listed groups, only include them here if they should be in a group."
+ )
# all users belong to the default group
- return {k.lower().strip(): list(set(["default"] + [g.lower().strip() for g in v])) for k, v in v.items()}
+ return {
+ k.lower().strip(): list(
+ set(["default"] + [g.lower().strip() for g in v])
+ )
+ for k, v in v.items()
+ }
- @field_validator('domains', mode='before')
@classmethod
+ @field_validator("domains", mode="before")
def validate_domains(cls, v):
for domain, members in v.items():
- if '.' not in domain:
- raise ValueError(f"Invalid domain, it should contain a dot: {domain}")
+ if "." not in domain:
+ raise ValueError(
+ f"Invalid domain, it should contain a dot: {domain}"
+ )
if not members:
- raise ValueError(f"Domain {domain} should have at least one member.")
- return {k.lower().strip(): list(set([g.lower().strip() for g in v])) for k, v in v.items()}
+ raise ValueError(
+ f"Domain {domain} should have at least one member."
+ )
+ return {
+ k.lower().strip(): list({[g.lower().strip() for g in v]})
+ for k, v in v.items()
+ }
- @field_validator('groups', mode='before')
@classmethod
+ @field_validator("groups", mode="before")
def validate_groups(cls, v):
if "default" not in v.keys():
raise ValueError("Please include a 'default' group.")
@@ -147,20 +203,28 @@ class UserGroupModel(BaseModel):
raise ValueError(f"Group names should be lowercase: {group}")
return v
- @model_validator(mode='after')
+ @model_validator(mode="after")
def check_groups_consistency(self) -> Self:
- groups_in_domains = set([g for domain in self.domains for g in self.domains[domain]])
- groups_in_users = set([g for user in self.users for g in self.users[user]])
+ groups_in_domains = {
+ g for domain in self.domains for g in self.domains[domain]
+ }
+ groups_in_users = {g for user in self.users for g in self.users[user]}
configured_groups = set(self.groups.keys())
- # groups mentioned in domains and users should be defined, but this is not a ValueError since historical DB data may require it
+ # groups mentioned in domains and users should be defined, but this is
+ # not a ValueError since historical DB data may require it
if groups_in_domains - configured_groups:
- logger.warning(f"These groups are associated to DOMAINS but not defined in the GROUPS section, the domains settings may not work as expected: {groups_in_domains - configured_groups}")
+ logger.warning(
+ f"These groups are associated to DOMAINS but not defined in the GROUPS section, the domains settings may not work as expected: {groups_in_domains - configured_groups}"
+ )
if groups_in_users - configured_groups:
- logger.warning(f"These groups are associated to USERS but not defined in the GROUPS section, the users settings may not work as expected: {groups_in_users - configured_groups}")
+ logger.warning(
+ f"These groups are associated to USERS but not defined in the GROUPS section, the users settings may not work as expected: {groups_in_users - configured_groups}"
+ )
return self
+
# for the API return values
diff --git a/app/shared/utils/misc.py b/app/shared/utils/misc.py
index 562b2c3..21e349a 100644
--- a/app/shared/utils/misc.py
+++ b/app/shared/utils/misc.py
@@ -1,10 +1,14 @@
-
-def fnv1a_hash_mod(s: str, modulo:int) -> int:
- # receives a string and returns a number in [0:modulo-1], ensures an even distribution over the modulo range
- hash = 0x811c9dc5 # FNV offset basis
- fnv_prime = 0x01000193 # FNV prime
+def fnv1a_hash_mod(s: str, modulo: int) -> int:
+ # receives a string and returns a number in [0:modulo-1], ensures an even
+ # distribution over the modulo range
+ offset_basis_hash = 0x811C9DC5 # FNV offset basis
+ fnv_prime = 0x01000193 # FNV prime
for char in s:
- hash ^= ord(char)
- hash *= fnv_prime
- hash &= 0xFFFFFFFF # Keep it 32-bit
- return (hash if hash < 0x80000000 else hash - 0x100000000) % modulo
\ No newline at end of file
+ offset_basis_hash ^= ord(char)
+ offset_basis_hash *= fnv_prime
+ offset_basis_hash &= 0xFFFFFFFF # Keep it 32-bit
+ return (
+ offset_basis_hash
+ if offset_basis_hash < 0x80000000
+ else offset_basis_hash - 0x100000000
+ ) % modulo
diff --git a/app/tests/conftest.py b/app/tests/conftest.py
index afa76f9..c41d249 100644
--- a/app/tests/conftest.py
+++ b/app/tests/conftest.py
@@ -1,19 +1,39 @@
import os
+from datetime import datetime
+from http import HTTPStatus
from typing import AsyncGenerator
-from fastapi.testclient import TestClient
-import pytest
from unittest.mock import patch
+
+import pytest
import pytest_asyncio
-from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine
-from app.web.config import ALLOW_ANY_EMAIL
+from fastapi.testclient import TestClient
+from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
+
+from app.shared.db import models
+from app.shared.db.database import (
+ make_async_engine,
+ make_async_session_local,
+ make_engine,
+ make_session_local,
+)
from app.shared.settings import Settings
+from app.web.config import ALLOW_ANY_EMAIL
+from app.web.db import crud
+from app.web.db.crud import get_user_group_names
from app.web.db.user_state import UserState
+from app.web.main import app_factory
+from app.web.security import (
+ get_token_or_user_auth,
+ get_user_auth,
+ get_user_state,
+ token_api_key_auth,
+)
@pytest.fixture(autouse=True)
def mock_logger_add():
"""Fixture to mock loguru.logger.add for all tests."""
- with patch('loguru.logger.add') as mock_add:
+ with patch("loguru.logger.add") as mock_add:
yield mock_add # This makes the mock available to tests
@@ -24,23 +44,22 @@ def get_settings():
@pytest.fixture(autouse=True)
def mock_settings():
- with patch('app.shared.settings.Settings', return_value=Settings(_env_file=".env.test")) as mock_settings:
+ with patch(
+ "app.shared.settings.Settings",
+ return_value=Settings(_env_file=".env.test"),
+ ) as mock_settings:
yield mock_settings
@pytest.fixture()
def test_db(get_settings: Settings):
- from app.shared.db import models
- from app.shared.db.database import make_engine
- from app.web.db.crud import get_user_group_names
-
get_user_group_names.cache_clear()
make_engine.cache_clear()
engine = make_engine(get_settings.DATABASE_PATH)
fs = get_settings.DATABASE_PATH.replace("sqlite:///", "")
if not os.path.exists(fs):
- open(fs, 'w').close()
+ open(fs, "w").close()
models.Base.metadata.create_all(engine)
@@ -57,7 +76,6 @@ def test_db(get_settings: Settings):
@pytest.fixture()
def db_session(test_db):
- from app.shared.db.database import make_session_local
session_local = make_session_local(test_db)
with session_local() as session:
yield session
@@ -65,17 +83,12 @@ def db_session(test_db):
@pytest_asyncio.fixture()
async def async_test_db(get_settings: Settings):
- from app.shared.db import models
- from app.shared.db.database import make_async_engine
- from app.web.db.crud import get_user_group_names
- import asyncio
-
get_user_group_names.cache_clear()
- engine = await make_async_engine(get_settings.ASYNC_DATABASE_PATH)
+ engine = await make_async_engine(get_settings.async_database_path)
- fs = get_settings.ASYNC_DATABASE_PATH.replace("sqlite+aiosqlite:///", "")
+ fs = get_settings.async_database_path.replace("sqlite+aiosqlite:///", "")
if not os.path.exists(fs):
- open(fs, 'w').close()
+ open(fs, "w").close()
async def create_all():
async with engine.begin() as conn:
@@ -99,8 +112,9 @@ async def async_test_db(get_settings: Settings):
@pytest_asyncio.fixture()
-async def async_db_session(async_test_db: AsyncEngine) -> AsyncGenerator[AsyncSession, None]:
- from app.shared.db.database import make_async_session_local
+async def async_db_session(
+ async_test_db: AsyncEngine,
+) -> AsyncGenerator[AsyncSession, None]:
session_local = await make_async_session_local(async_test_db)
async with session_local() as session:
yield session
@@ -108,8 +122,6 @@ async def async_db_session(async_test_db: AsyncEngine) -> AsyncGenerator[AsyncSe
@pytest.fixture()
def app(db_session):
- from app.web.main import app_factory
- from app.web.db import crud
app = app_factory()
crud.upsert_user_groups(db_session)
return app
@@ -123,10 +135,13 @@ def client(app):
@pytest.fixture()
def app_with_auth(app, db_session):
- from app.web.security import get_token_or_user_auth, get_user_auth, get_user_state
- app.dependency_overrides[get_token_or_user_auth] = lambda: "rick@example.com"
+ app.dependency_overrides[get_token_or_user_auth] = (
+ lambda: "rick@example.com"
+ )
app.dependency_overrides[get_user_auth] = lambda: "morty@example.com"
- app.dependency_overrides[get_user_state] = lambda: UserState(db_session, "MORTY@example.com")
+ app.dependency_overrides[get_user_state] = lambda: UserState(
+ db_session, "MORTY@example.com"
+ )
return app
@@ -138,7 +153,6 @@ def client_with_auth(app_with_auth):
@pytest.fixture()
def app_with_token(app):
- from app.web.security import token_api_key_auth, get_token_or_user_auth
app.dependency_overrides[token_api_key_auth] = lambda: ALLOW_ANY_EMAIL
app.dependency_overrides[get_token_or_user_auth] = lambda: ALLOW_ANY_EMAIL
return app
@@ -155,6 +169,93 @@ def test_no_auth():
# reusable code to ensure a method/endpoint combination is unauthorized
def no_auth(http_method, endpoint):
response = http_method(endpoint)
- assert response.status_code == 403
+ assert response.status_code == HTTPStatus.FORBIDDEN
assert response.json() == {"detail": "Not authenticated"}
+
return no_auth
+
+
+@pytest.fixture()
+def test_data(db_session):
+ author_emails = [
+ "rick@example.com",
+ "morty@example.com",
+ "jerry@example.com",
+ ]
+
+ # creates 3 users
+ for email in author_emails:
+ db_session.add(models.User(email=email))
+ db_session.commit()
+ assert db_session.query(models.User).count() == 3
+
+ # creates 100 archives for 3 users over 2 months with repeating URLs
+ for i in range(100):
+ author = author_emails[i % 3]
+ archive = models.Archive(
+ id=f"archive-id-456-{i}",
+ url=f"https://example-{i % 3}.com",
+ result={},
+ public=author == "jerry@example.com",
+ author_id=author,
+ group_id="spaceship"
+ if author == "morty@example.com" and i % 2 == 0
+ else None,
+ created_at=datetime(2021, (i % 2) + 1, (i % 25) + 1),
+ )
+ if i % 5 == 0:
+ archive.tags.append(models.Tag(id=f"tag-{i}"))
+ if i % 10 == 0:
+ archive.tags.append(models.Tag(id=f"tag-second-{i}"))
+ if i % 4 == 0:
+ archive.tags.append(models.Tag(id=f"tag-third-{i}"))
+ for j in range(10):
+ archive.urls.append(
+ models.ArchiveUrl(
+ url=f"https://example-{i}.com/{j}", key=f"media_{j}"
+ )
+ )
+ db_session.add(archive)
+
+ # creates a sheet for each user
+ for i, email in enumerate(
+ ["rick@example.com", "morty@example.com", "jerry@example.com"]
+ ):
+ db_session.add(
+ models.Sheet(
+ id=f"sheet-{i}",
+ name=f"sheet-{i}",
+ author_id=email,
+ group_id=None,
+ frequency="daily",
+ )
+ )
+ if email == "rick@example.com":
+ db_session.add(
+ models.Sheet(
+ id=f"sheet-{i}-2",
+ name=f"sheet-{i}-2",
+ author_id=email,
+ group_id="spaceship",
+ frequency="hourly",
+ )
+ )
+
+ db_session.commit()
+
+ assert db_session.query(models.Archive).count() == 100
+ assert db_session.query(models.Tag).count() == 20 + 10 + 25
+ assert db_session.query(models.ArchiveUrl).count() == 1000
+ assert (
+ db_session.query(models.ArchiveUrl)
+ .filter(models.ArchiveUrl.archive_id == "archive-id-456-0")
+ .count()
+ == 10
+ )
+
+ # setup groups
+ assert db_session.query(models.Group).count() == 0
+
+ crud.upsert_user_groups(db_session)
+ assert db_session.query(models.Group).count() == 4
+ assert db_session.query(models.User).count() == 3
diff --git a/app/tests/fake_service_account.json b/app/tests/fake_service_account.json
index 3d41bd9..10c0585 100644
--- a/app/tests/fake_service_account.json
+++ b/app/tests/fake_service_account.json
@@ -1,3 +1,3 @@
{
- "client_email": "fake_service_account@fake_service_account.iam.gserviceaccount.com"
-}
\ No newline at end of file
+ "client_email": "fake_service_account@fake_service_account.iam.gserviceaccount.com"
+}
diff --git a/app/tests/orchestration.test.yaml b/app/tests/orchestration.test.yaml
index 4ee1880..9d2e44b 100644
--- a/app/tests/orchestration.test.yaml
+++ b/app/tests/orchestration.test.yaml
@@ -1,7 +1,8 @@
steps:
- feeder: cli_feeder
+ feeders:
+ - cli_feeder
archivers: # order matters
- - youtubedl_archiver
+ - generic_extractor
enrichers:
- hash_enricher
@@ -12,10 +13,10 @@ steps:
- console_db
configurations:
- gsheet_feeder:
+ gsheet_feeder_db:
service_account: "app/tests/fake_service_account.json"
cli_feeder:
- urls:
+ urls:
- "url1"
hash_enricher:
algorithm: "SHA-256"
diff --git a/app/tests/shared/db/test_models.py b/app/tests/shared/db/test_models.py
index 35ba368..537532b 100644
--- a/app/tests/shared/db/test_models.py
+++ b/app/tests/shared/db/test_models.py
@@ -1,6 +1,7 @@
-def test_generate_uuid():
- from app.shared.db.models import generate_uuid
+from app.shared.db.models import generate_uuid
- assert generate_uuid() != generate_uuid()
- assert len(generate_uuid()) == 36
- assert generate_uuid().count("-") == 4
\ No newline at end of file
+
+def test_generate_uuid():
+ assert generate_uuid() != generate_uuid()
+ assert len(generate_uuid()) == 36
+ assert generate_uuid().count("-") == 4
diff --git a/app/tests/shared/db/test_worker_crud.py b/app/tests/shared/db/test_worker_crud.py
index 1098cbe..4e5a434 100644
--- a/app/tests/shared/db/test_worker_crud.py
+++ b/app/tests/shared/db/test_worker_crud.py
@@ -1,12 +1,10 @@
-from app.shared.db import models
-from app.shared.db import worker_crud, models
from datetime import datetime
+from app.shared import schemas
+from app.shared.db import models, worker_crud
-from app.tests.web.db.test_crud import test_data
def test_update_sheet_last_url_archived_at(db_session):
-
# Create test sheet
test_sheet = models.Sheet(id="sheet-123")
db_session.add(test_sheet)
@@ -15,17 +13,24 @@ def test_update_sheet_last_url_archived_at(db_session):
# Test updating existing sheet
assert isinstance(test_sheet.last_url_archived_at, datetime)
before = test_sheet.last_url_archived_at
- assert worker_crud.update_sheet_last_url_archived_at(db_session, "sheet-123") is True
+ assert (
+ worker_crud.update_sheet_last_url_archived_at(db_session, "sheet-123")
+ is True
+ )
db_session.refresh(test_sheet)
assert isinstance(test_sheet.last_url_archived_at, datetime)
assert test_sheet.last_url_archived_at > before
-
+
# Test non-existent sheet
- assert worker_crud.update_sheet_last_url_archived_at(db_session, "non-existent-sheet") is False
+ assert (
+ worker_crud.update_sheet_last_url_archived_at(
+ db_session, "non-existent-sheet"
+ )
+ is False
+ )
+
def test_get_group(test_data, db_session):
- from app.shared.db import worker_crud
-
assert worker_crud.get_group(db_session, "spaceship") is not None
assert worker_crud.get_group(db_session, "interdimensional") is not None
assert worker_crud.get_group(db_session, "animated-characters") is not None
@@ -33,24 +38,24 @@ def test_get_group(test_data, db_session):
def test_create_or_get_user(test_data, db_session):
- from app.shared.db import worker_crud
-
assert db_session.query(models.User).count() == 3
# already exists
- assert (u1 := worker_crud.create_or_get_user(db_session, "rick@example.com")) is not None
+ assert (
+ u1 := worker_crud.create_or_get_user(db_session, "rick@example.com")
+ ) is not None
assert u1.email == "rick@example.com"
# new user
- assert (u2 := worker_crud.create_or_get_user(db_session, "beth@example.com")) is not None
+ assert (
+ u2 := worker_crud.create_or_get_user(db_session, "beth@example.com")
+ ) is not None
assert u2.email == "beth@example.com"
assert db_session.query(models.User).count() == 4
def test_create_tag(db_session):
- from app.shared.db import worker_crud
-
assert db_session.query(models.Tag).count() == 0
# create first
@@ -58,7 +63,10 @@ def test_create_tag(db_session):
assert create_tag is not None
assert create_tag.id == "tag-101"
assert db_session.query(models.Tag).count() == 1
- assert db_session.query(models.Tag).filter(models.Tag.id == "tag-101").first() == create_tag
+ assert (
+ db_session.query(models.Tag).filter(models.Tag.id == "tag-101").first()
+ == create_tag
+ )
# same id does not add new db entry
existing_tag = worker_crud.create_tag(db_session, "tag-101")
@@ -73,9 +81,6 @@ def test_create_tag(db_session):
def test_create_task(db_session):
- from app.shared.db import worker_crud
- from app.shared import schemas
-
task = schemas.ArchiveCreate(
id="archive-id-456-101",
url="https://example-0.com",
@@ -84,17 +89,22 @@ def test_create_task(db_session):
author_id="rick@example.com",
group_id="spaceship",
tags=[],
- urls=[]
+ urls=[],
)
# with tags and urls
- nt = worker_crud.create_archive(db_session, task, [models.Tag(id="tag-101")], [models.ArchiveUrl(url="https://example-0.com/0", key="media_0")])
+ nt = worker_crud.create_archive(
+ db_session,
+ task,
+ [models.Tag(id="tag-101")],
+ [models.ArchiveUrl(url="https://example-0.com/0", key="media_0")],
+ )
assert nt is not None
assert nt.id == "archive-id-456-101"
assert nt.url == "https://example-0.com"
assert nt.author_id == "rick@example.com"
- assert nt.public == False
+ assert nt.public is False
assert nt.group_id == "spaceship"
assert len(nt.tags) == 1
assert nt.tags[0].id == "tag-101"
@@ -110,8 +120,8 @@ def test_create_task(db_session):
assert nt.id == "archive-id-456-102"
assert nt.url == "https://example-0.com"
assert nt.author_id == "rick@example.com"
- assert nt.public == False
+ assert nt.public is False
assert nt.group_id == "spaceship"
assert len(nt.tags) == 0
assert len(nt.urls) == 0
- assert nt.created_at is not None
\ No newline at end of file
+ assert nt.created_at is not None
diff --git a/app/tests/shared/test_business_logic.py b/app/tests/shared/test_business_logic.py
index 225fb11..e10d402 100644
--- a/app/tests/shared/test_business_logic.py
+++ b/app/tests/shared/test_business_logic.py
@@ -1,10 +1,15 @@
from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch
+
import pytest
-from app.shared.business_logic import get_store_archive_until, get_store_archive_until_or_never
+
+from app.shared.business_logic import (
+ get_store_archive_until,
+ get_store_archive_until_or_never,
+)
-class Test_get_store_archive_until:
+class TestGetStoreArchiveUntil:
GROUP_ID = "test-group"
def test_group_not_found(self, db_session):
@@ -12,7 +17,10 @@ class Test_get_store_archive_until:
get_store_archive_until(db_session, self.GROUP_ID)
assert str(exc.value) == f"Group {self.GROUP_ID} not found."
- @patch("app.shared.db.worker_crud.get_group", return_value=MagicMock(permissions=None))
+ @patch(
+ "app.shared.db.worker_crud.get_group",
+ return_value=MagicMock(permissions=None),
+ )
def test_group_no_permissions(self, db_session):
with pytest.raises(AssertionError) as exc:
get_store_archive_until(db_session, self.GROUP_ID)
@@ -43,14 +51,17 @@ class Test_get_store_archive_until:
mock_get_group.assert_called_once_with(db_session, self.GROUP_ID)
-class Test_get_store_archive_until_or_never:
+class TestGetStoreArchiveUntilOrNever:
GROUP_ID = "test-group"
def test_group_not_found(self, db_session):
result = get_store_archive_until_or_never(db_session, self.GROUP_ID)
assert result is None
- @patch("app.shared.db.worker_crud.get_group", return_value=MagicMock(permissions=None))
+ @patch(
+ "app.shared.db.worker_crud.get_group",
+ return_value=MagicMock(permissions=None),
+ )
def test_group_no_permissions(self, db_session):
result = get_store_archive_until_or_never(db_session, self.GROUP_ID)
assert result is None
diff --git a/app/tests/shared/utils/test_misc.py b/app/tests/shared/utils/test_misc.py
index d7595c8..18db28d 100644
--- a/app/tests/shared/utils/test_misc.py
+++ b/app/tests/shared/utils/test_misc.py
@@ -11,7 +11,7 @@ def test_fnv1a_hash_mod():
# Test different modulos
hash1 = fnv1a_hash_mod("test", 5)
- hash2 = fnv1a_hash_mod("test", 10)
+ hash2 = fnv1a_hash_mod("test", 10)
assert 0 <= hash1 < 5
assert 0 <= hash2 < 10
@@ -28,4 +28,4 @@ def test_fnv1a_hash_mod():
assert 0 <= fnv1a_hash_mod("测试", 10) < 10
# Test modulo = 1 edge case
- assert fnv1a_hash_mod("test", 1) == 0
\ No newline at end of file
+ assert fnv1a_hash_mod("test", 1) == 0
diff --git a/app/tests/user-groups.test.broken.yaml b/app/tests/user-groups.test.broken.yaml
index 8bc59c5..9b41741 100644
--- a/app/tests/user-groups.test.broken.yaml
+++ b/app/tests/user-groups.test.broken.yaml
@@ -3,4 +3,4 @@ This is just an invalid yaml for testing
still broken: True
- one
- - two
\ No newline at end of file
+ - two
diff --git a/app/tests/user-groups.test.yaml b/app/tests/user-groups.test.yaml
index 16a3ba7..e9a446f 100644
--- a/app/tests/user-groups.test.yaml
+++ b/app/tests/user-groups.test.yaml
@@ -84,4 +84,4 @@ groups:
# max_archive_lifespan_months: 12
max_monthly_urls: 1
# max_monthly_mbs: 50
- priority: "low"
\ No newline at end of file
+ priority: "low"
diff --git a/app/tests/web/db/test_crud.py b/app/tests/web/db/test_crud.py
index aad9d4c..7b569b6 100644
--- a/app/tests/web/db/test_crud.py
+++ b/app/tests/web/db/test_crud.py
@@ -2,123 +2,379 @@ from datetime import datetime, timedelta
from unittest.mock import patch
import pytest
+import sqlalchemy
import yaml
+from sqlalchemy import false, true
+from sqlalchemy.sql import select
+
from app.shared.db import models
from app.shared.settings import Settings
-
+from app.web.config import ALLOW_ANY_EMAIL
from app.web.db import crud
-authors = ["rick@example.com", "morty@example.com", "jerry@example.com"]
-
-
-@pytest.fixture()
-def test_data(db_session):
-
- # creates 3 users
- for email in authors:
- db_session.add(models.User(email=email))
- db_session.commit()
- assert db_session.query(models.User).count() == 3
-
- # creates 100 archives for 3 users over 2 months with repeating URLs
- for i in range(100):
- author = authors[i % 3]
- archive = models.Archive(
- id=f"archive-id-456-{i}",
- url=f"https://example-{i%3}.com",
- result={},
- public=author == "jerry@example.com",
- author_id=author,
- group_id="spaceship" if author == "morty@example.com" and i % 2 == 0 else None,
- created_at=datetime(2021, (i % 2) + 1, (i % 25) + 1)
- )
- if i % 5 == 0:
- archive.tags.append(models.Tag(id=f"tag-{i}"))
- if i % 10 == 0:
- archive.tags.append(models.Tag(id=f"tag-second-{i}"))
- if i % 4 == 0:
- archive.tags.append(models.Tag(id=f"tag-third-{i}"))
- for j in range(10):
- archive.urls.append(models.ArchiveUrl(url=f"https://example-{i}.com/{j}", key=f"media_{j}"))
- db_session.add(archive)
-
- # creates a sheet for each user
- for i, email in enumerate(authors):
- db_session.add(models.Sheet(id=f"sheet-{i}", name=f"sheet-{i}", author_id=email, group_id=None, frequency="daily"))
- if email == "rick@example.com":
- db_session.add(models.Sheet(id=f"sheet-{i}-2", name=f"sheet-{i}-2", author_id=email, group_id="spaceship", frequency="hourly"))
-
- db_session.commit()
-
- assert db_session.query(models.Archive).count() == 100
- assert db_session.query(models.Tag).count() == 20 + 10 + 25
- assert db_session.query(models.ArchiveUrl).count() == 1000
- assert db_session.query(models.ArchiveUrl).filter(models.ArchiveUrl.archive_id == "archive-id-456-0").count() == 10
-
- # setup groups
- assert db_session.query(models.Group).count() == 0
- from app.web.db import crud
- crud.upsert_user_groups(db_session)
- assert db_session.query(models.Group).count() == 4
- assert db_session.query(models.User).count() == 3
def test_search_archives_by_url(test_data, db_session):
- from app.web.config import ALLOW_ANY_EMAIL
-
- # rick's archives are private
- assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "rick@example.com", True, False)) == 34
- assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "rick@example.com", [], False)) == 34
- assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "rick@example.com", [], True)) == 34
- assert len(crud.search_archives_by_url(db_session, "https://example-0.com", ALLOW_ANY_EMAIL, [], False)) == 34
- assert len(crud.search_archives_by_url(db_session, "https://example-0.com", ALLOW_ANY_EMAIL, True, False)) == 34
- assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "morty@example.com", [], False)) == 0
- assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "morty@example.com", [], True)) == 0
+ # Rick's archives are private
+ assert (
+ len(
+ crud.search_archives_by_url(
+ db_session,
+ "https://example-0.com",
+ "rick@example.com",
+ True,
+ False,
+ )
+ )
+ == 34
+ )
+ assert (
+ len(
+ crud.search_archives_by_url(
+ db_session,
+ "https://example-0.com",
+ "rick@example.com",
+ [],
+ False,
+ )
+ )
+ == 34
+ )
+ assert (
+ len(
+ crud.search_archives_by_url(
+ db_session,
+ "https://example-0.com",
+ "rick@example.com",
+ [],
+ True,
+ )
+ )
+ == 34
+ )
+ assert (
+ len(
+ crud.search_archives_by_url(
+ db_session, "https://example-0.com", ALLOW_ANY_EMAIL, [], False
+ )
+ )
+ == 34
+ )
+ assert (
+ len(
+ crud.search_archives_by_url(
+ db_session,
+ "https://example-0.com",
+ ALLOW_ANY_EMAIL,
+ True,
+ False,
+ )
+ )
+ == 34
+ )
+ assert (
+ len(
+ crud.search_archives_by_url(
+ db_session,
+ "https://example-0.com",
+ "morty@example.com",
+ [],
+ False,
+ )
+ )
+ == 0
+ )
+ assert (
+ len(
+ crud.search_archives_by_url(
+ db_session,
+ "https://example-0.com",
+ "morty@example.com",
+ [],
+ True,
+ )
+ )
+ == 0
+ )
# morty's archives are public but half are in spaceship group
- assert len(crud.search_archives_by_url(db_session, "https://example-1.com", "rick@example.com", ["spaceship"], False)) == 16
- assert len(crud.search_archives_by_url(db_session, "https://example-1.com", "rick@example.com", True, False)) == 16
- assert len(crud.search_archives_by_url(db_session, "https://example-1.com", "jerry@example.com", True, True)) == 16
+ assert (
+ len(
+ crud.search_archives_by_url(
+ db_session,
+ "https://example-1.com",
+ "rick@example.com",
+ ["spaceship"],
+ False,
+ )
+ )
+ == 16
+ )
+ assert (
+ len(
+ crud.search_archives_by_url(
+ db_session,
+ "https://example-1.com",
+ "rick@example.com",
+ True,
+ False,
+ )
+ )
+ == 16
+ )
+ assert (
+ len(
+ crud.search_archives_by_url(
+ db_session,
+ "https://example-1.com",
+ "jerry@example.com",
+ True,
+ True,
+ )
+ )
+ == 16
+ )
- # jerry's archives are public
- assert len(crud.search_archives_by_url(db_session, "https://example-2.com", "jerry@example.com", [], True)) == 33
- assert len(crud.search_archives_by_url(db_session, "https://example-2.com", "rick@example.com", [], True)) == 33
+ # Jerry's archives are public
+ assert (
+ len(
+ crud.search_archives_by_url(
+ db_session,
+ "https://example-2.com",
+ "jerry@example.com",
+ [],
+ True,
+ )
+ )
+ == 33
+ )
+ assert (
+ len(
+ crud.search_archives_by_url(
+ db_session,
+ "https://example-2.com",
+ "rick@example.com",
+ [],
+ True,
+ )
+ )
+ == 33
+ )
# fuzzy search
- assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False)) == 100
- assert len(crud.search_archives_by_url(db_session, "https://EXAMPLE", ALLOW_ANY_EMAIL, False, False)) == 100
- assert len(crud.search_archives_by_url(db_session, "2.com", ALLOW_ANY_EMAIL, False, False)) == 33
+ assert (
+ len(
+ crud.search_archives_by_url(
+ db_session, "https://example", ALLOW_ANY_EMAIL, False, False
+ )
+ )
+ == 100
+ )
+ assert (
+ len(
+ crud.search_archives_by_url(
+ db_session, "https://EXAMPLE", ALLOW_ANY_EMAIL, False, False
+ )
+ )
+ == 100
+ )
+ assert (
+ len(
+ crud.search_archives_by_url(
+ db_session, "2.com", ALLOW_ANY_EMAIL, False, False
+ )
+ )
+ == 33
+ )
# absolute search
- assert len(crud.search_archives_by_url(db_session, "example-2.com", ALLOW_ANY_EMAIL, [], False, absolute_search=True)) == 0
- assert len(crud.search_archives_by_url(db_session, "https://example-2.com", ALLOW_ANY_EMAIL, [], False, absolute_search=True)) == 33
+ assert (
+ len(
+ crud.search_archives_by_url(
+ db_session,
+ "example-2.com",
+ ALLOW_ANY_EMAIL,
+ [],
+ False,
+ absolute_search=True,
+ )
+ )
+ == 0
+ )
+ assert (
+ len(
+ crud.search_archives_by_url(
+ db_session,
+ "https://example-2.com",
+ ALLOW_ANY_EMAIL,
+ [],
+ False,
+ absolute_search=True,
+ )
+ )
+ == 33
+ )
# archived_after
- assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, True, True, archived_after=datetime(2010, 1, 1))) == 100
- assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_after=datetime(2021, 1, 15))) == 70
- assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_after=datetime(2031, 1, 1))) == 0
+ assert (
+ len(
+ crud.search_archives_by_url(
+ db_session,
+ "https://example",
+ ALLOW_ANY_EMAIL,
+ True,
+ True,
+ archived_after=datetime(2010, 1, 1),
+ )
+ )
+ == 100
+ )
+ assert (
+ len(
+ crud.search_archives_by_url(
+ db_session,
+ "https://example",
+ ALLOW_ANY_EMAIL,
+ False,
+ False,
+ archived_after=datetime(2021, 1, 15),
+ )
+ )
+ == 70
+ )
+ assert (
+ len(
+ crud.search_archives_by_url(
+ db_session,
+ "https://example",
+ ALLOW_ANY_EMAIL,
+ False,
+ False,
+ archived_after=datetime(2031, 1, 1),
+ )
+ )
+ == 0
+ )
# archived before
- assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_before=datetime(2010, 1, 1))) == 0
- assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_before=datetime(2021, 1, 15))) == 28
- assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_before=datetime(2031, 1, 1))) == 100
+ assert (
+ len(
+ crud.search_archives_by_url(
+ db_session,
+ "https://example",
+ ALLOW_ANY_EMAIL,
+ False,
+ False,
+ archived_before=datetime(2010, 1, 1),
+ )
+ )
+ == 0
+ )
+ assert (
+ len(
+ crud.search_archives_by_url(
+ db_session,
+ "https://example",
+ ALLOW_ANY_EMAIL,
+ False,
+ False,
+ archived_before=datetime(2021, 1, 15),
+ )
+ )
+ == 28
+ )
+ assert (
+ len(
+ crud.search_archives_by_url(
+ db_session,
+ "https://example",
+ ALLOW_ANY_EMAIL,
+ False,
+ False,
+ archived_before=datetime(2031, 1, 1),
+ )
+ )
+ == 100
+ )
# archived before and after
- assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_after=datetime(2001, 1, 1), archived_before=datetime(2031, 1, 11))) == 100
- assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_after=datetime(2021, 1, 14), archived_before=datetime(2021, 1, 16))) == 2
+ assert (
+ len(
+ crud.search_archives_by_url(
+ db_session,
+ "https://example",
+ ALLOW_ANY_EMAIL,
+ False,
+ False,
+ archived_after=datetime(2001, 1, 1),
+ archived_before=datetime(2031, 1, 11),
+ )
+ )
+ == 100
+ )
+ assert (
+ len(
+ crud.search_archives_by_url(
+ db_session,
+ "https://example",
+ ALLOW_ANY_EMAIL,
+ False,
+ False,
+ archived_after=datetime(2021, 1, 14),
+ archived_before=datetime(2021, 1, 16),
+ )
+ )
+ == 2
+ )
# limit
- assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, limit=10)) == 10
- assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, limit=-1)) == 1
+ assert (
+ len(
+ crud.search_archives_by_url(
+ db_session,
+ "https://example",
+ ALLOW_ANY_EMAIL,
+ False,
+ False,
+ limit=10,
+ )
+ )
+ == 10
+ )
+ assert (
+ len(
+ crud.search_archives_by_url(
+ db_session,
+ "https://example",
+ ALLOW_ANY_EMAIL,
+ False,
+ False,
+ limit=-1,
+ )
+ )
+ == 1
+ )
# skip
- assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, skip=10)) == 90
+ assert (
+ len(
+ crud.search_archives_by_url(
+ db_session,
+ "https://example",
+ ALLOW_ANY_EMAIL,
+ False,
+ False,
+ skip=10,
+ )
+ )
+ == 90
+ )
def test_search_archives_by_email(test_data, db_session):
- from app.web.config import ALLOW_ANY_EMAIL
-
# lower/upper case
- assert len(crud.search_archives_by_email(db_session, "rick@example.com")) == 34
+ assert (
+ len(crud.search_archives_by_email(db_session, "rick@example.com")) == 34
+ )
# ALLOW_ANY_EMAIL is not a user
assert len(crud.search_archives_by_email(db_session, ALLOW_ANY_EMAIL)) == 0
@@ -136,45 +392,108 @@ def test_search_archives_by_email(test_data, db_session):
@patch("app.web.db.crud.DATABASE_QUERY_LIMIT", new=25)
def test_max_query_limit(test_data, db_session):
- from app.web.config import ALLOW_ANY_EMAIL
+ assert (
+ len(
+ crud.search_archives_by_url(
+ db_session, "https://example", ALLOW_ANY_EMAIL, [], False
+ )
+ )
+ == 25
+ )
+ assert (
+ len(
+ crud.search_archives_by_url(
+ db_session,
+ "https://example",
+ ALLOW_ANY_EMAIL,
+ True,
+ True,
+ limit=1000,
+ )
+ )
+ == 25
+ )
- assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, [], False)) == 25
- assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, True, True, limit=1000)) == 25
-
- assert len(crud.search_archives_by_email(db_session, "rick@example.com")) == 25
- assert len(crud.search_archives_by_email(db_session, "rick@example.com", limit=1000)) == 25
+ assert (
+ len(crud.search_archives_by_email(db_session, "rick@example.com")) == 25
+ )
+ assert (
+ len(
+ crud.search_archives_by_email(
+ db_session, "rick@example.com", limit=1000
+ )
+ )
+ == 25
+ )
def test_soft_delete(test_data, db_session):
# none deleted yet
- db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").first() is not None
- assert db_session.query(models.Archive).filter(models.Archive.deleted == True).count() == 0
+ assert (
+ db_session.query(models.Archive)
+ .filter(models.Archive.id == "archive-id-456-0")
+ .first()
+ is not None
+ )
+ assert (
+ db_session.query(models.Archive)
+ .filter(models.Archive.deleted.is_(true()))
+ .count()
+ == 0
+ )
# delete
- assert crud.soft_delete_archive(db_session, "archive-id-456-0", "rick@example.com") == True
+ assert (
+ crud.soft_delete_archive(
+ db_session, "archive-id-456-0", "rick@example.com"
+ )
+ is True
+ )
# ensure soft delete
- assert db_session.query(models.Archive).filter(models.Archive.deleted == True).count() == 1
- db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").first() is None
+ assert (
+ db_session.query(models.Archive)
+ .filter(models.Archive.deleted.is_(true()))
+ .count()
+ == 1
+ )
+ assert (
+ db_session.query(models.Archive)
+ .filter(models.Archive.id == "archive-id-456-0")
+ .filter(models.Archive.deleted.is_(false()))
+ .first()
+ is None
+ )
# already deleted
- assert crud.soft_delete_archive(db_session, "archive-id-456-0", "rick@example.com") == False
+ assert (
+ crud.soft_delete_archive(
+ db_session, "archive-id-456-0", "rick@example.com"
+ )
+ is False
+ )
def test_count_archives(test_data, db_session):
assert crud.count_archives(db_session) == 100
- db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").delete()
+ db_session.query(models.Archive).filter(
+ models.Archive.id == "archive-id-456-0"
+ ).delete()
db_session.commit()
assert crud.count_archives(db_session) == 99
def test_count_archive_urls(test_data, db_session):
assert crud.count_archive_urls(db_session) == 1000
- db_session.query(models.ArchiveUrl).filter(models.ArchiveUrl.url == "https://example-0.com/0").delete()
+ db_session.query(models.ArchiveUrl).filter(
+ models.ArchiveUrl.url == "https://example-0.com/0"
+ ).delete()
db_session.commit()
assert crud.count_archive_urls(db_session) == 999
- db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").delete()
+ db_session.query(models.Archive).filter(
+ models.Archive.id == "archive-id-456-0"
+ ).delete()
db_session.commit()
# no Cascade is enabled
assert crud.count_archives(db_session) == 99
@@ -183,16 +502,23 @@ def test_count_archive_urls(test_data, db_session):
def test_count_users(test_data, db_session):
assert crud.count_users(db_session) == 3
- db_session.query(models.User).filter(models.User.email == "rick@example.com").delete()
+ db_session.query(models.User).filter(
+ models.User.email == "rick@example.com"
+ ).delete()
db_session.commit()
assert crud.count_users(db_session) == 2
def test_count_by_users_since(test_data, db_session):
- from app.web.db import crud
-
# 100y window
- assert len(cu := crud.count_by_user_since(db_session, 60 * 60 * 24 * 31 * 12 * 100)) == 3
+ assert (
+ len(
+ cu := crud.count_by_user_since(
+ db_session, 60 * 60 * 24 * 31 * 12 * 100
+ )
+ )
+ == 3
+ )
assert cu[0].total == 34
assert cu[1].total == 33
assert cu[2].total == 33
@@ -201,9 +527,18 @@ def test_count_by_users_since(test_data, db_session):
def test_upsert_group(test_data, db_session):
assert db_session.query(models.Group).count() == 4
- repeatable_params = ["desc 1", "orch.yaml", "sheet.yaml", "service_account_email@example.com", {"read": ["all"]}, ["example.com"]]
+ repeatable_params = [
+ "desc 1",
+ "orch.yaml",
+ "sheet.yaml",
+ "service_account_email@example.com",
+ {"read": ["all"]},
+ ["example.com"],
+ ]
- assert (g1 := crud.upsert_group(db_session, "spaceship", *repeatable_params)) is not None
+ assert (
+ g1 := crud.upsert_group(db_session, "spaceship", *repeatable_params)
+ ) is not None
assert g1.id == "spaceship"
assert g1.description == "desc 1"
assert g1.orchestrator == "orch.yaml"
@@ -212,14 +547,25 @@ def test_upsert_group(test_data, db_session):
assert g1.permissions == {"read": ["all"]}
assert g1.domains == ["example.com"]
assert len(g1.users) == 2
- assert [u.email for u in g1.users] == ["rick@example.com", "morty@example.com"]
+ assert [u.email for u in g1.users] == [
+ "rick@example.com",
+ "morty@example.com",
+ ]
- assert (g2 := crud.upsert_group(db_session, "interdimensional", *repeatable_params)) is not None
+ assert (
+ g2 := crud.upsert_group(
+ db_session, "interdimensional", *repeatable_params
+ )
+ ) is not None
assert g2.id == "interdimensional"
assert len(g2.users) == 1
assert [u.email for u in g2.users] == ["rick@example.com"]
- assert (g3 := crud.upsert_group(db_session, "this-is-a-new-group", *repeatable_params)) is not None
+ assert (
+ g3 := crud.upsert_group(
+ db_session, "this-is-a-new-group", *repeatable_params
+ )
+ ) is not None
assert g3.id == "this-is-a-new-group"
assert len(g3.users) == 0
@@ -227,29 +573,38 @@ def test_upsert_group(test_data, db_session):
def test_upsert_user_groups(db_session):
- @patch('app.web.db.crud.get_settings', new=lambda: bad_setings)
+ @patch("app.web.db.crud.get_settings", new=lambda: bad_settings)
def test_missing_yaml(db_session):
with pytest.raises(FileNotFoundError):
crud.upsert_user_groups(db_session)
- @patch('app.web.db.crud.get_settings', new=lambda: bad_setings)
+ @patch("app.web.db.crud.get_settings", new=lambda: bad_settings)
def test_broken_yaml(db_session):
with pytest.raises(yaml.YAMLError):
crud.upsert_user_groups(db_session)
- bad_setings = Settings(_env_file=".env.test")
+ bad_settings = Settings(_env_file=".env.test")
- bad_setings.USER_GROUPS_FILENAME = "app/tests/user-groups.test.missing.yaml"
+ bad_settings.USER_GROUPS_FILENAME = (
+ "app/tests/user-groups.test.missing.yaml"
+ )
test_missing_yaml(db_session)
- bad_setings.USER_GROUPS_FILENAME = "app/tests/user-groups.test.broken.yaml"
+ bad_settings.USER_GROUPS_FILENAME = "app/tests/user-groups.test.broken.yaml"
test_broken_yaml(db_session)
def test_create_sheet(db_session):
assert db_session.query(models.Sheet).count() == 0
- s = crud.create_sheet(db_session, "sheet-id-123", "sheet name", "email@example.com", "group-id", "hourly")
+ s = crud.create_sheet(
+ db_session,
+ "sheet-id-123",
+ "sheet name",
+ "email@example.com",
+ "group-id",
+ "hourly",
+ )
assert s is not None
assert s.id == "sheet-id-123"
assert s.name == "sheet name"
@@ -259,19 +614,35 @@ def test_create_sheet(db_session):
assert db_session.query(models.Sheet).count() == 1
- # duplicate id
- import sqlalchemy
with pytest.raises(sqlalchemy.exc.IntegrityError):
- crud.create_sheet(db_session, "sheet-id-123", "I thought this was another sheet", "email", "group-id", "hourly")
+ crud.create_sheet(
+ db_session,
+ "sheet-id-123",
+ "I thought this was another sheet",
+ "email",
+ "group-id",
+ "hourly",
+ )
def test_get_user_sheet(test_data, db_session):
assert crud.get_user_sheet(db_session, "", "sheet-0") is None
- assert crud.get_user_sheet(db_session, "morty@example.com", "sheet-0") is None
+ assert (
+ crud.get_user_sheet(db_session, "morty@example.com", "sheet-0") is None
+ )
- assert crud.get_user_sheet(db_session, "rick@example.com", "sheet-0") is not None
- assert crud.get_user_sheet(db_session, "rick@example.com", "sheet-0-2") is not None
- assert crud.get_user_sheet(db_session, "morty@example.com", "sheet-1") is not None
+ assert (
+ crud.get_user_sheet(db_session, "rick@example.com", "sheet-0")
+ is not None
+ )
+ assert (
+ crud.get_user_sheet(db_session, "rick@example.com", "sheet-0-2")
+ is not None
+ )
+ assert (
+ crud.get_user_sheet(db_session, "morty@example.com", "sheet-1")
+ is not None
+ )
def test_get_user_sheets(test_data, db_session):
@@ -283,9 +654,9 @@ def test_get_user_sheets(test_data, db_session):
def test_delete_sheet(test_data, db_session):
- assert crud.delete_sheet(db_session, "sheet-0", "") == False
- assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") == True
- assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") == False
+ assert crud.delete_sheet(db_session, "sheet-0", "") is False
+ assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") is True
+ assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") is False
@pytest.mark.asyncio
@@ -297,21 +668,21 @@ async def test_find_by_store_until(async_db_session):
url="https://example-expired-1.com",
result={},
author_id="rick@example.com",
- store_until=now - timedelta(days=1)
+ store_until=now - timedelta(days=1),
)
archive2 = models.Archive(
id="archive-expired-2",
url="https://example-expired-2.com",
result={},
author_id="rick@example.com",
- store_until=now - timedelta(hours=1)
+ store_until=now - timedelta(hours=1),
)
archive3 = models.Archive(
id="archive-active",
url="https://example-active.com",
result={},
author_id="rick@example.com",
- store_until=now + timedelta(days=1)
+ store_until=now + timedelta(days=1),
)
async_db_session.add_all([archive1, archive2, archive3])
await async_db_session.commit()
@@ -321,11 +692,15 @@ async def test_find_by_store_until(async_db_session):
assert len(list(expired)) == 2
# Should find 1 archive expired before 2 hours ago
- expired = await crud.find_by_store_until(async_db_session, now - timedelta(hours=2))
+ expired = await crud.find_by_store_until(
+ async_db_session, now - timedelta(hours=2)
+ )
assert len(list(expired)) == 1
# Should find no archives expired before 2 days ago
- expired = await crud.find_by_store_until(async_db_session, now - timedelta(days=2))
+ expired = await crud.find_by_store_until(
+ async_db_session, now - timedelta(days=2)
+ )
assert len(list(expired)) == 0
# Should not find deleted archives
@@ -337,44 +712,82 @@ async def test_find_by_store_until(async_db_session):
@pytest.mark.asyncio
async def test_get_sheets_by_id_hash(async_db_session):
+ author_emails = [
+ "rick@example.com",
+ "morty@example.com",
+ "jerry@example.com",
+ ]
+
# Add test data
- authors = ["rick@example.com", "morty@example.com", "jerry@example.com"]
sheets = [
- models.Sheet(id="sheet-0", name="sheet-0", author_id=authors[0], group_id=None, frequency="daily"),
- models.Sheet(id="sheet-0-2", name="sheet-0-2", author_id=authors[0], group_id="spaceship", frequency="hourly"),
- models.Sheet(id="sheet-1", name="sheet-1", author_id=authors[1], group_id=None, frequency="daily"),
- models.Sheet(id="sheet-2", name="sheet-2", author_id=authors[2], group_id=None, frequency="daily")
+ models.Sheet(
+ id="sheet-0",
+ name="sheet-0",
+ author_id=author_emails[0],
+ group_id=None,
+ frequency="daily",
+ ),
+ models.Sheet(
+ id="sheet-0-2",
+ name="sheet-0-2",
+ author_id=author_emails[0],
+ group_id="spaceship",
+ frequency="hourly",
+ ),
+ models.Sheet(
+ id="sheet-1",
+ name="sheet-1",
+ author_id=author_emails[1],
+ group_id=None,
+ frequency="daily",
+ ),
+ models.Sheet(
+ id="sheet-2",
+ name="sheet-2",
+ author_id=author_emails[2],
+ group_id=None,
+ frequency="daily",
+ ),
]
async_db_session.add_all(sheets)
await async_db_session.commit()
with patch("app.web.db.crud.fnv1a_hash_mod", return_value=1):
# Test retrieving hourly sheets
- hourly_sheets = await crud.get_sheets_by_id_hash(async_db_session, "hourly", 4, 1)
+ hourly_sheets = await crud.get_sheets_by_id_hash(
+ async_db_session, "hourly", 4, 1
+ )
assert len(hourly_sheets) == 1
assert hourly_sheets[0].id == "sheet-0-2"
assert hourly_sheets[0].frequency == "hourly"
# Test retrieving daily sheets
- daily_sheets = await crud.get_sheets_by_id_hash(async_db_session, "daily", 4, 1)
+ daily_sheets = await crud.get_sheets_by_id_hash(
+ async_db_session, "daily", 4, 1
+ )
assert len(daily_sheets) == 3
assert all(sheet.frequency == "daily" for sheet in daily_sheets)
- assert {sheet.id for sheet in daily_sheets} == {"sheet-0", "sheet-1", "sheet-2"}
+ assert {sheet.id for sheet in daily_sheets} == {
+ "sheet-0",
+ "sheet-1",
+ "sheet-2",
+ }
# Test with non-matching hash
- no_sheets = await crud.get_sheets_by_id_hash(async_db_session, "daily", 4, 3)
+ no_sheets = await crud.get_sheets_by_id_hash(
+ async_db_session, "daily", 4, 3
+ )
assert len(no_sheets) == 0
# Test with non-existent frequency
- weekly_sheets = await crud.get_sheets_by_id_hash(async_db_session, "weekly", 4, 1)
+ weekly_sheets = await crud.get_sheets_by_id_hash(
+ async_db_session, "weekly", 4, 1
+ )
assert len(weekly_sheets) == 0
@pytest.mark.asyncio
async def test_delete_stale_sheets(async_db_session):
- from datetime import datetime, timedelta
- from sqlalchemy.sql import select
-
now = datetime.now()
active_date = now - timedelta(days=5)
stale_date = now - timedelta(days=15)
@@ -386,29 +799,29 @@ async def test_delete_stale_sheets(async_db_session):
name="Active Sheet 1",
author_id="rick@example.com",
frequency="daily",
- last_url_archived_at=active_date
+ last_url_archived_at=active_date,
),
models.Sheet(
id="sheet-active-2",
name="Active Sheet 2",
author_id="morty@example.com",
frequency="hourly",
- last_url_archived_at=active_date
+ last_url_archived_at=active_date,
),
models.Sheet(
id="sheet-stale-1",
name="Stale Sheet 1",
author_id="rick@example.com",
frequency="daily",
- last_url_archived_at=stale_date
+ last_url_archived_at=stale_date,
),
models.Sheet(
id="sheet-stale-2",
name="Stale Sheet 2",
author_id="morty@example.com",
frequency="daily",
- last_url_archived_at=stale_date
- )
+ last_url_archived_at=stale_date,
+ ),
]
async_db_session.add_all(sheets)
await async_db_session.commit()
@@ -435,4 +848,4 @@ async def test_delete_stale_sheets(async_db_session):
# Running again should not delete anything
deleted = await crud.delete_stale_sheets(async_db_session, 7)
- assert len(deleted) == 0
\ No newline at end of file
+ assert len(deleted) == 0
diff --git a/app/tests/web/db/test_user_state.py b/app/tests/web/db/test_user_state.py
index 42c61d1..5d18cea 100644
--- a/app/tests/web/db/test_user_state.py
+++ b/app/tests/web/db/test_user_state.py
@@ -1,10 +1,11 @@
-
from unittest.mock import MagicMock, PropertyMock, patch
+
import pytest
from app.shared.db import models
from app.shared.user_groups import GroupInfo, GroupPermissions
from app.web.db.user_state import UserState
+from app.web.utils.misc import convert_priority_to_queue_dict
def fresh_user_state():
@@ -20,39 +21,73 @@ def user_state():
def user_state_with_groups(user_state):
user_groups = [
models.Group(id="no-permissions", permissions={}),
- models.Group(id="group1", description="this is g1", service_account_email="sa1@example.com", permissions={"read": ["group1", "no-permissions"], "read_public": True, "archive_url": True, "archive_sheet": True, "max_archive_lifespan_months": 24, "max_monthly_urls": 100, "max_monthly_mbs": 1000, "priority": "high"}),
- models.Group(id="group2", description="this is g2", service_account_email="sa2@example.com", permissions={"read": ["all"], "read_public": True, "archive_url": False, "archive_sheet": False, "max_archive_lifespan_months": -1, "max_monthly_urls": -1, "max_monthly_mbs": -1, "priority": "low", "sheet_frequency": {"daily"}}),
+ models.Group(
+ id="group1",
+ description="this is g1",
+ service_account_email="sa1@example.com",
+ permissions={
+ "read": ["group1", "no-permissions"],
+ "read_public": True,
+ "archive_url": True,
+ "archive_sheet": True,
+ "max_archive_lifespan_months": 24,
+ "max_monthly_urls": 100,
+ "max_monthly_mbs": 1000,
+ "priority": "high",
+ },
+ ),
+ models.Group(
+ id="group2",
+ description="this is g2",
+ service_account_email="sa2@example.com",
+ permissions={
+ "read": ["all"],
+ "read_public": True,
+ "archive_url": False,
+ "archive_sheet": False,
+ "max_archive_lifespan_months": -1,
+ "max_monthly_urls": -1,
+ "max_monthly_mbs": -1,
+ "priority": "low",
+ "sheet_frequency": {"daily"},
+ },
+ ),
]
- with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=user_groups):
+ with patch.object(
+ UserState,
+ "user_groups",
+ new_callable=PropertyMock,
+ return_value=user_groups,
+ ):
yield user_state
def test_permissions(user_state_with_groups):
permissions = user_state_with_groups.permissions
- assert permissions["all"].read == True
- assert permissions["all"].read_public == True
- assert permissions["all"].archive_url == True
- assert permissions["all"].archive_sheet == True
+ assert permissions["all"].read is True
+ assert permissions["all"].read_public is True
+ assert permissions["all"].archive_url is True
+ assert permissions["all"].archive_sheet is True
assert permissions["all"].max_archive_lifespan_months == -1
assert permissions["all"].max_monthly_urls == -1
assert permissions["all"].max_monthly_mbs == -1
assert permissions["all"].priority == "high"
- assert permissions["group1"].read == set(["group1", "no-permissions"])
- assert permissions["group1"].read_public == True
- assert permissions["group1"].archive_url == True
- assert permissions["group1"].archive_sheet == True
+ assert permissions["group1"].read == {"group1", "no-permissions"}
+ assert permissions["group1"].read_public is True
+ assert permissions["group1"].archive_url is True
+ assert permissions["group1"].archive_sheet is True
assert permissions["group1"].max_archive_lifespan_months == 24
assert permissions["group1"].max_monthly_urls == 100
assert permissions["group1"].max_monthly_mbs == 1000
assert permissions["group1"].priority == "high"
- assert permissions["group2"].read == set(["all"])
- assert permissions["group2"].read_public == True
- assert permissions["group2"].archive_url == False
- assert permissions["group2"].archive_sheet == False
+ assert permissions["group2"].read == {"all"}
+ assert permissions["group2"].read_public is True
+ assert permissions["group2"].archive_url is False
+ assert permissions["group2"].archive_sheet is False
assert permissions["group2"].max_archive_lifespan_months == -1
assert permissions["group2"].max_monthly_urls == -1
assert permissions["group2"].max_monthly_mbs == -1
@@ -62,13 +97,19 @@ def test_permissions(user_state_with_groups):
def test_user_groups_names(user_state):
- with patch('app.web.db.crud.get_user_group_names', return_value=["group1", "group2"]) as mock:
+ with patch(
+ "app.web.db.crud.get_user_group_names",
+ return_value=["group1", "group2"],
+ ) as mock:
assert user_state.user_groups_names == ["group1", "group2", "default"]
mock.assert_called_once_with(None, "test@example.com")
def test_user_groups(user_state):
- with patch('app.web.db.crud.get_user_groups_by_name', return_value=[MagicMock(), MagicMock()]) as mock:
+ with patch(
+ "app.web.db.crud.get_user_groups_by_name",
+ return_value=[MagicMock(), MagicMock()],
+ ) as mock:
user_state._user_groups_names = ["group1", "group2"]
assert len(user_state.user_groups) == 2
mock.assert_called_once_with(None, ["group1", "group2"])
@@ -77,85 +118,166 @@ def test_user_groups(user_state):
def test_read():
us = fresh_user_state()
- with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
+ with patch.object(
+ UserState,
+ "user_groups",
+ new_callable=PropertyMock,
+ return_value=[models.Group(id="no-permissions", permissions={})],
+ ) as mock:
assert not hasattr(us, "_read")
assert us.read == set()
assert us._read == set()
mock.assert_called_once()
us = fresh_user_state()
- with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"read": ["group1", "no-permissions"]})]):
- assert us.read == set(["group1", "no-permissions"])
+ with patch.object(
+ UserState,
+ "user_groups",
+ new_callable=PropertyMock,
+ return_value=[
+ models.Group(
+ id="group1", permissions={"read": ["group1", "no-permissions"]}
+ )
+ ],
+ ):
+ assert us.read == {"group1", "no-permissions"}
us = fresh_user_state()
- with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"read": ["all"]})]):
- assert us.read == True
+ with patch.object(
+ UserState,
+ "user_groups",
+ new_callable=PropertyMock,
+ return_value=[models.Group(id="group1", permissions={"read": ["all"]})],
+ ):
+ assert us.read is True
def test_read_public():
us = fresh_user_state()
- with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
+ with patch.object(
+ UserState,
+ "user_groups",
+ new_callable=PropertyMock,
+ return_value=[models.Group(id="no-permissions", permissions={})],
+ ) as mock:
assert not hasattr(us, "_read_public")
- assert us.read_public == False
- assert us._read_public == False
+ assert us.read_public is False
+ assert us._read_public is False
mock.assert_called_once()
# no new calls
- assert us.read_public == False
+ assert us.read_public is False
mock.assert_called_once()
us = fresh_user_state()
- with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"read_public": True})]):
- assert us.read_public == True
+ with patch.object(
+ UserState,
+ "user_groups",
+ new_callable=PropertyMock,
+ return_value=[
+ models.Group(id="group1", permissions={"read_public": True})
+ ],
+ ):
+ assert us.read_public is True
us = fresh_user_state()
- with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"read_public": False})]):
- assert us.read_public == False
+ with patch.object(
+ UserState,
+ "user_groups",
+ new_callable=PropertyMock,
+ return_value=[
+ models.Group(id="group1", permissions={"read_public": False})
+ ],
+ ):
+ assert us.read_public is False
def test_archive_url():
us = fresh_user_state()
- with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
+ with patch.object(
+ UserState,
+ "user_groups",
+ new_callable=PropertyMock,
+ return_value=[models.Group(id="no-permissions", permissions={})],
+ ) as mock:
assert not hasattr(us, "_archive_url")
- assert us.archive_url == False
- assert us._archive_url == False
+ assert us.archive_url is False
+ assert us._archive_url is False
mock.assert_called_once()
# no new calls
- assert us.archive_url == False
+ assert us.archive_url is False
mock.assert_called_once()
us = fresh_user_state()
- with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"archive_url": False})]):
- assert us.archive_url == False
+ with patch.object(
+ UserState,
+ "user_groups",
+ new_callable=PropertyMock,
+ return_value=[
+ models.Group(id="group1", permissions={"archive_url": False})
+ ],
+ ):
+ assert us.archive_url is False
us = fresh_user_state()
- with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"archive_url": True})]):
- assert us.archive_url == True
+ with patch.object(
+ UserState,
+ "user_groups",
+ new_callable=PropertyMock,
+ return_value=[
+ models.Group(id="group1", permissions={"archive_url": True})
+ ],
+ ):
+ assert us.archive_url is True
def test_archive_sheet():
us = fresh_user_state()
- with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
+ with patch.object(
+ UserState,
+ "user_groups",
+ new_callable=PropertyMock,
+ return_value=[models.Group(id="no-permissions", permissions={})],
+ ) as mock:
assert not hasattr(us, "_archive_sheet")
- assert us.archive_sheet == False
- assert us._archive_sheet == False
+ assert us.archive_sheet is False
+ assert us._archive_sheet is False
mock.assert_called_once()
# no new calls
- assert us.archive_sheet == False
+ assert us.archive_sheet is False
mock.assert_called_once()
us = fresh_user_state()
- with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"archive_sheet": False})]):
- assert us.archive_sheet == False
+ with patch.object(
+ UserState,
+ "user_groups",
+ new_callable=PropertyMock,
+ return_value=[
+ models.Group(id="group1", permissions={"archive_sheet": False})
+ ],
+ ):
+ assert us.archive_sheet is False
us = fresh_user_state()
- with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"archive_sheet": True})]):
- assert us.archive_sheet == True
+ with patch.object(
+ UserState,
+ "user_groups",
+ new_callable=PropertyMock,
+ return_value=[
+ models.Group(id="group1", permissions={"archive_sheet": True})
+ ],
+ ):
+ assert us.archive_sheet is True
def test_sheet_frequency():
us = fresh_user_state()
- with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
+ with patch.object(
+ UserState,
+ "user_groups",
+ new_callable=PropertyMock,
+ return_value=[models.Group(id="no-permissions", permissions={})],
+ ) as mock:
assert not hasattr(us, "_sheet_frequency")
assert us.sheet_frequency == set()
assert us._sheet_frequency == set()
@@ -165,18 +287,42 @@ def test_sheet_frequency():
mock.assert_called_once()
us = fresh_user_state()
- with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"sheet_frequency": ["daily", "hourly"]})]):
+ with patch.object(
+ UserState,
+ "user_groups",
+ new_callable=PropertyMock,
+ return_value=[
+ models.Group(
+ id="group1",
+ permissions={"sheet_frequency": ["daily", "hourly"]},
+ )
+ ],
+ ):
assert us.sheet_frequency == {"daily", "hourly"}
us = fresh_user_state()
- with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"sheet_frequency": []})]):
+ with patch.object(
+ UserState,
+ "user_groups",
+ new_callable=PropertyMock,
+ return_value=[
+ models.Group(id="group1", permissions={"sheet_frequency": []})
+ ],
+ ):
assert us.sheet_frequency == set()
def test_max_archive_lifespan_months():
us = fresh_user_state()
- default = GroupPermissions.model_fields["max_archive_lifespan_months"].default
- with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
+ default = GroupPermissions.model_fields[
+ "max_archive_lifespan_months"
+ ].default
+ with patch.object(
+ UserState,
+ "user_groups",
+ new_callable=PropertyMock,
+ return_value=[models.Group(id="no-permissions", permissions={})],
+ ) as mock:
assert not hasattr(us, "_max_archive_lifespan_months")
assert us.max_archive_lifespan_months == default
assert us._max_archive_lifespan_months == default
@@ -186,18 +332,44 @@ def test_max_archive_lifespan_months():
mock.assert_called_once()
us = fresh_user_state()
- with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_archive_lifespan_months": 24})]):
+ with patch.object(
+ UserState,
+ "user_groups",
+ new_callable=PropertyMock,
+ return_value=[
+ models.Group(
+ id="group1", permissions={"max_archive_lifespan_months": 24}
+ )
+ ],
+ ):
assert us.max_archive_lifespan_months == 24
us = fresh_user_state()
- with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_archive_lifespan_months": 150}), models.Group(id="group2", permissions={"max_archive_lifespan_months": -1})]):
+ with patch.object(
+ UserState,
+ "user_groups",
+ new_callable=PropertyMock,
+ return_value=[
+ models.Group(
+ id="group1", permissions={"max_archive_lifespan_months": 150}
+ ),
+ models.Group(
+ id="group2", permissions={"max_archive_lifespan_months": -1}
+ ),
+ ],
+ ):
assert us.max_archive_lifespan_months == -1
def test_max_monthly_urls():
us = fresh_user_state()
default = GroupPermissions.model_fields["max_monthly_urls"].default
- with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
+ with patch.object(
+ UserState,
+ "user_groups",
+ new_callable=PropertyMock,
+ return_value=[models.Group(id="no-permissions", permissions={})],
+ ) as mock:
assert not hasattr(us, "_max_monthly_urls")
assert us.max_monthly_urls == default
assert us._max_monthly_urls == default
@@ -207,18 +379,38 @@ def test_max_monthly_urls():
mock.assert_called_once()
us = fresh_user_state()
- with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_monthly_urls": 100})]):
+ with patch.object(
+ UserState,
+ "user_groups",
+ new_callable=PropertyMock,
+ return_value=[
+ models.Group(id="group1", permissions={"max_monthly_urls": 100})
+ ],
+ ):
assert us.max_monthly_urls == 100
us = fresh_user_state()
- with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_monthly_urls": 150}), models.Group(id="group2", permissions={"max_monthly_urls": -1})]):
+ with patch.object(
+ UserState,
+ "user_groups",
+ new_callable=PropertyMock,
+ return_value=[
+ models.Group(id="group1", permissions={"max_monthly_urls": 150}),
+ models.Group(id="group2", permissions={"max_monthly_urls": -1}),
+ ],
+ ):
assert us.max_monthly_urls == -1
def test_max_monthly_mbs():
us = fresh_user_state()
default = GroupPermissions.model_fields["max_monthly_mbs"].default
- with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
+ with patch.object(
+ UserState,
+ "user_groups",
+ new_callable=PropertyMock,
+ return_value=[models.Group(id="no-permissions", permissions={})],
+ ) as mock:
assert not hasattr(us, "_max_monthly_mbs")
assert us.max_monthly_mbs == default
assert us._max_monthly_mbs == default
@@ -228,17 +420,37 @@ def test_max_monthly_mbs():
mock.assert_called_once()
us = fresh_user_state()
- with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_monthly_mbs": 1000})]):
+ with patch.object(
+ UserState,
+ "user_groups",
+ new_callable=PropertyMock,
+ return_value=[
+ models.Group(id="group1", permissions={"max_monthly_mbs": 1000})
+ ],
+ ):
assert us.max_monthly_mbs == 1000
us = fresh_user_state()
- with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_monthly_mbs": 1500}), models.Group(id="group2", permissions={"max_monthly_mbs": -1})]):
+ with patch.object(
+ UserState,
+ "user_groups",
+ new_callable=PropertyMock,
+ return_value=[
+ models.Group(id="group1", permissions={"max_monthly_mbs": 1500}),
+ models.Group(id="group2", permissions={"max_monthly_mbs": -1}),
+ ],
+ ):
assert us.max_monthly_mbs == -1
def test_priority(user_state):
default = GroupPermissions.model_fields["priority"].default
- with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
+ with patch.object(
+ UserState,
+ "user_groups",
+ new_callable=PropertyMock,
+ return_value=[models.Group(id="no-permissions", permissions={})],
+ ) as mock:
assert not hasattr(user_state, "_priority")
assert user_state.priority == default
assert user_state._priority == default
@@ -248,11 +460,26 @@ def test_priority(user_state):
mock.assert_called_once()
us = fresh_user_state()
- with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"priority": "high"})]):
+ with patch.object(
+ UserState,
+ "user_groups",
+ new_callable=PropertyMock,
+ return_value=[
+ models.Group(id="group1", permissions={"priority": "high"})
+ ],
+ ):
assert us.priority == "high"
us = fresh_user_state()
- with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"priority": "low"}), models.Group(id="group2", permissions={"priority": "medium"})]):
+ with patch.object(
+ UserState,
+ "user_groups",
+ new_callable=PropertyMock,
+ return_value=[
+ models.Group(id="group1", permissions={"priority": "low"}),
+ models.Group(id="group2", permissions={"priority": "medium"}),
+ ],
+ ):
assert us.priority == "low"
@@ -262,21 +489,45 @@ def test_active():
(True, False, False, False, True),
(False, True, False, False, True),
(False, False, True, False, True),
- (False, False, False, True, True)
+ (False, False, False, True, True),
]:
us = fresh_user_state()
- with patch.object(UserState, 'read', new_callable=PropertyMock, return_value=read), \
- patch.object(UserState, 'read_public', new_callable=PropertyMock, return_value=read_public), \
- patch.object(UserState, 'archive_url', new_callable=PropertyMock, return_value=archive_url), \
- patch.object(UserState, 'archive_sheet', new_callable=PropertyMock, return_value=archive_sheet):
+ with (
+ patch.object(
+ UserState, "read", new_callable=PropertyMock, return_value=read
+ ),
+ patch.object(
+ UserState,
+ "read_public",
+ new_callable=PropertyMock,
+ return_value=read_public,
+ ),
+ patch.object(
+ UserState,
+ "archive_url",
+ new_callable=PropertyMock,
+ return_value=archive_url,
+ ),
+ patch.object(
+ UserState,
+ "archive_sheet",
+ new_callable=PropertyMock,
+ return_value=archive_sheet,
+ ),
+ ):
assert us.active == is_active
def test_in_group(user_state):
- with patch.object(UserState, 'user_groups_names', new_callable=PropertyMock, return_value=["group1", "group2"]):
- assert user_state.in_group("group1") == True
- assert user_state.in_group("group2") == True
- assert user_state.in_group("group3") == False
+ with patch.object(
+ UserState,
+ "user_groups_names",
+ new_callable=PropertyMock,
+ return_value=["group1", "group2"],
+ ):
+ assert user_state.in_group("group1") is True
+ assert user_state.in_group("group2") is True
+ assert user_state.in_group("group3") is False
def test_usage(db_session):
@@ -294,10 +545,34 @@ def test_usage(db_session):
]
megabytes = int(sum(bytes) / 1024 / 1024)
- with patch.object(db_session, 'query', side_effect=[
- MagicMock(filter=MagicMock(return_value=MagicMock(group_by=MagicMock(return_value=MagicMock(all=MagicMock(return_value=user_sheets)))))),
- MagicMock(filter=MagicMock(return_value=MagicMock(group_by=MagicMock(return_value=MagicMock(all=MagicMock(return_value=urls_by_group))))))
- ]):
+ with patch.object(
+ db_session,
+ "query",
+ side_effect=[
+ MagicMock(
+ filter=MagicMock(
+ return_value=MagicMock(
+ group_by=MagicMock(
+ return_value=MagicMock(
+ all=MagicMock(return_value=user_sheets)
+ )
+ )
+ )
+ )
+ ),
+ MagicMock(
+ filter=MagicMock(
+ return_value=MagicMock(
+ group_by=MagicMock(
+ return_value=MagicMock(
+ all=MagicMock(return_value=urls_by_group)
+ )
+ )
+ )
+ )
+ ),
+ ],
+ ):
usage_response = user_state.usage()
assert usage_response.monthly_urls == 155
@@ -305,11 +580,15 @@ def test_usage(db_session):
assert usage_response.total_sheets == 115
assert usage_response.groups["group1"].monthly_urls == 50
- assert usage_response.groups["group1"].monthly_mbs == int(bytes[0] / 1024 / 1024)
+ assert usage_response.groups["group1"].monthly_mbs == int(
+ bytes[0] / 1024 / 1024
+ )
assert usage_response.groups["group1"].total_sheets == 5
assert usage_response.groups["group2"].monthly_urls == 100
- assert usage_response.groups["group2"].monthly_mbs == int(bytes[1] / 1024 / 1024)
+ assert usage_response.groups["group2"].monthly_mbs == int(
+ bytes[1] / 1024 / 1024
+ )
assert usage_response.groups["group2"].total_sheets == 10
assert usage_response.groups["group3"].monthly_urls == 0
@@ -317,7 +596,9 @@ def test_usage(db_session):
assert usage_response.groups["group3"].total_sheets == 100
assert usage_response.groups["group4"].monthly_urls == 5
- assert usage_response.groups["group4"].monthly_mbs == int(bytes[2] / 1024 / 1024)
+ assert usage_response.groups["group4"].monthly_mbs == int(
+ bytes[2] / 1024 / 1024
+ )
assert usage_response.groups["group4"].total_sheets == 0
@@ -333,8 +614,23 @@ def test_has_quota_monthly_sheets(db_session):
]
for permissions, count, expected in test_cases:
- with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions):
- with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(count=MagicMock(return_value=count))))):
+ with patch.object(
+ UserState,
+ "permissions",
+ new_callable=PropertyMock,
+ return_value=permissions,
+ ):
+ with patch.object(
+ us.db,
+ "query",
+ return_value=MagicMock(
+ filter=MagicMock(
+ return_value=MagicMock(
+ count=MagicMock(return_value=count)
+ )
+ )
+ ),
+ ):
assert us.has_quota_monthly_sheets("group1") == expected
@@ -349,8 +645,23 @@ def test_has_quota_max_monthly_urls(db_session):
]
for permissions, count, expected in test_cases:
- with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions):
- with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(count=MagicMock(return_value=count))))):
+ with patch.object(
+ UserState,
+ "permissions",
+ new_callable=PropertyMock,
+ return_value=permissions,
+ ):
+ with patch.object(
+ us.db,
+ "query",
+ return_value=MagicMock(
+ filter=MagicMock(
+ return_value=MagicMock(
+ count=MagicMock(return_value=count)
+ )
+ )
+ ),
+ ):
assert us.has_quota_max_monthly_urls("group1") == expected
test_cases = [
(-1, 1000, True),
@@ -360,8 +671,23 @@ def test_has_quota_max_monthly_urls(db_session):
]
for max_urls, count, expected in test_cases:
- with patch.object(UserState, 'max_monthly_urls', new_callable=PropertyMock, return_value=max_urls):
- with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(count=MagicMock(return_value=count))))):
+ with patch.object(
+ UserState,
+ "max_monthly_urls",
+ new_callable=PropertyMock,
+ return_value=max_urls,
+ ):
+ with patch.object(
+ us.db,
+ "query",
+ return_value=MagicMock(
+ filter=MagicMock(
+ return_value=MagicMock(
+ count=MagicMock(return_value=count)
+ )
+ )
+ ),
+ ):
assert us.has_quota_max_monthly_urls("") == expected
@@ -376,8 +702,29 @@ def test_has_quota_max_monthly_mbs(db_session):
]
for permissions, mbs, expected in test_cases:
- with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions):
- with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(with_entities=MagicMock(return_value=MagicMock(scalar=MagicMock(return_value=mbs * 1024 * 1024))))))):
+ with patch.object(
+ UserState,
+ "permissions",
+ new_callable=PropertyMock,
+ return_value=permissions,
+ ):
+ with patch.object(
+ us.db,
+ "query",
+ return_value=MagicMock(
+ filter=MagicMock(
+ return_value=MagicMock(
+ with_entities=MagicMock(
+ return_value=MagicMock(
+ scalar=MagicMock(
+ return_value=mbs * 1024 * 1024
+ )
+ )
+ )
+ )
+ )
+ ),
+ ):
assert us.has_quota_max_monthly_mbs("group1") == expected
test_cases = [
@@ -388,8 +735,29 @@ def test_has_quota_max_monthly_mbs(db_session):
]
for max_mbs, mbs, expected in test_cases:
- with patch.object(UserState, 'max_monthly_mbs', new_callable=PropertyMock, return_value=max_mbs):
- with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(with_entities=MagicMock(return_value=MagicMock(scalar=MagicMock(return_value=mbs * 1024 * 1024))))))):
+ with patch.object(
+ UserState,
+ "max_monthly_mbs",
+ new_callable=PropertyMock,
+ return_value=max_mbs,
+ ):
+ with patch.object(
+ us.db,
+ "query",
+ return_value=MagicMock(
+ filter=MagicMock(
+ return_value=MagicMock(
+ with_entities=MagicMock(
+ return_value=MagicMock(
+ scalar=MagicMock(
+ return_value=mbs * 1024 * 1024
+ )
+ )
+ )
+ )
+ )
+ ),
+ ):
assert us.has_quota_max_monthly_mbs("") == expected
@@ -399,10 +767,15 @@ def test_can_manually_trigger(user_state):
"group2": GroupInfo(manually_trigger_sheet=False),
}
- with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions):
- assert user_state.can_manually_trigger("group1") == True
- assert user_state.can_manually_trigger("group2") == False
- assert user_state.can_manually_trigger("group3") == False
+ with patch.object(
+ UserState,
+ "permissions",
+ new_callable=PropertyMock,
+ return_value=permissions,
+ ):
+ assert user_state.can_manually_trigger("group1") is True
+ assert user_state.can_manually_trigger("group2") is False
+ assert user_state.can_manually_trigger("group3") is False
def test_is_sheet_frequency_allowed(user_state):
@@ -411,23 +784,44 @@ def test_is_sheet_frequency_allowed(user_state):
"group2": GroupInfo(sheet_frequency={"daily"}),
}
- with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions):
- assert user_state.is_sheet_frequency_allowed("group1", "daily") == True
- assert user_state.is_sheet_frequency_allowed("group1", "hourly") == True
- assert user_state.is_sheet_frequency_allowed("group1", "weekly") == False
- assert user_state.is_sheet_frequency_allowed("group2", "hourly") == False
- assert user_state.is_sheet_frequency_allowed("group2", "daily") == True
- assert user_state.is_sheet_frequency_allowed("group3", "daily") == False
+ with patch.object(
+ UserState,
+ "permissions",
+ new_callable=PropertyMock,
+ return_value=permissions,
+ ):
+ assert user_state.is_sheet_frequency_allowed("group1", "daily") is True
+ assert user_state.is_sheet_frequency_allowed("group1", "hourly") is True
+ assert (
+ user_state.is_sheet_frequency_allowed("group1", "weekly") is False
+ )
+ assert (
+ user_state.is_sheet_frequency_allowed("group2", "hourly") is False
+ )
+ assert user_state.is_sheet_frequency_allowed("group2", "daily") is True
+ assert user_state.is_sheet_frequency_allowed("group3", "daily") is False
def test_priority_group(user_state):
- from app.web.utils.misc import convert_priority_to_queue_dict
- with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[
- models.Group(id="group1", permissions={"priority": "high"}),
- models.Group(id="group2", permissions={"priority": "medium"}),
- models.Group(id="group3", permissions={"priority": "low"}),
- ]):
- assert user_state.priority_group("group1") == convert_priority_to_queue_dict("high")
- assert user_state.priority_group("group2") == convert_priority_to_queue_dict("medium")
- assert user_state.priority_group("group3") == convert_priority_to_queue_dict("low")
- assert user_state.priority_group("group4") == convert_priority_to_queue_dict("low")
+ with patch.object(
+ UserState,
+ "user_groups",
+ new_callable=PropertyMock,
+ return_value=[
+ models.Group(id="group1", permissions={"priority": "high"}),
+ models.Group(id="group2", permissions={"priority": "medium"}),
+ models.Group(id="group3", permissions={"priority": "low"}),
+ ],
+ ):
+ assert user_state.priority_group(
+ "group1"
+ ) == convert_priority_to_queue_dict("high")
+ assert user_state.priority_group(
+ "group2"
+ ) == convert_priority_to_queue_dict("medium")
+ assert user_state.priority_group(
+ "group3"
+ ) == convert_priority_to_queue_dict("low")
+ assert user_state.priority_group(
+ "group4"
+ ) == convert_priority_to_queue_dict("low")
diff --git a/app/tests/web/endpoints/test_interoperability.py b/app/tests/web/endpoints/test_interoperability.py
deleted file mode 100644
index 31cf8f0..0000000
--- a/app/tests/web/endpoints/test_interoperability.py
+++ /dev/null
@@ -1,60 +0,0 @@
-from datetime import datetime
-import json
-from unittest.mock import MagicMock, patch
-
-from app.shared.db import models
-from app.web.config import ALLOW_ANY_EMAIL
-from app.web.db import crud
-
-
-def test_submit_manual_archive_unauthenticated(client, test_no_auth):
- test_no_auth(client.post, "/interop/submit-archive")
-
-
-def test_submit_manual_archive_not_user_auth(client_with_auth, test_no_auth):
- test_no_auth(client_with_auth.post, "/interop/submit-archive")
-
-
-@patch("app.web.endpoints.interoperability.business_logic", return_value=MagicMock(get_store_archive_until=MagicMock(return_value=datetime)))
-def test_submit_manual_archive(m1, client_with_token, db_session):
- # normal workflow
- aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.s3.com"]}]})
- r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": True, "author_id": "jerry@gmail.com", "group_id": "spaceship", "tags": ["test"], "url": "http://example.com"})
- assert r.status_code == 201
- assert "id" in r.json()
-
- inserted = db_session.query(models.Archive).filter(models.Archive.id == r.json()["id"]).first()
- assert inserted.url == "http://example.com"
- assert inserted.group_id == "spaceship"
- assert inserted.author_id == "jerry@gmail.com"
- assert sorted([t.id for t in inserted.tags]) == sorted(["test", "manual"])
- assert inserted.public
- assert type(inserted.result) == dict
- assert [u.url for u in inserted.urls] == ["http://example.s3.com"]
- assert type(inserted.store_until) == datetime
-
- # cannot have the same URL twice
- aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.com", "http://example.com"]}]})
- r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": False, "author_id": "jerry@gmail.com", "tags": ["test"], "url": "http://example.com"})
- assert r.status_code == 422
- assert r.json() == {"detail": "Cannot insert into DB due to integrity error, likely duplicate urls."}
-
-
-# test with invalid JSON
-def test_submit_manual_archive_invalid_json(client_with_token):
- r = client_with_token.post("/interop/submit-archive", json={"result": "invalid json", "public": False, "author_id": "jer", "tags": ["test"], "url": "http://example.com"})
- assert r.status_code == 422
- assert r.json() == {"detail": "Invalid JSON in result field."}
-
-
-@patch("app.web.endpoints.interoperability.business_logic.get_store_archive_until", side_effect=AssertionError("AssertionError"))
-def test_submit_manual_archive_no_store_until(m_sau, client_with_token, db_session):
- aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.s3.com"]}]})
- r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": True, "author_id": "jerry@gmail.com", "group_id": "spaceship", "tags": ["test"], "url": "http://example.com"})
- assert r.status_code == 201
- assert len(r.json()["id"]) == 36
- res = db_session.query(models.Archive).filter(models.Archive.id == r.json()["id"]).first()
- assert res.store_until is None
- # testing that store_until = None is not comparable with datetime, and will always return False
- res = db_session.query(models.Archive).filter(models.Archive.id == r.json()["id"], models.Archive.store_until < datetime.now()).first()
- assert res is None
diff --git a/app/tests/web/endpoints/test_sheet.py b/app/tests/web/endpoints/test_sheet.py
deleted file mode 100644
index 1396d85..0000000
--- a/app/tests/web/endpoints/test_sheet.py
+++ /dev/null
@@ -1,193 +0,0 @@
-from datetime import datetime
-import json
-from unittest.mock import MagicMock, patch
-
-from fastapi.testclient import TestClient
-
-from app.shared.schemas import TaskResult
-
-
-def test_endpoints_no_auth(client, test_no_auth):
- test_no_auth(client.post, "/sheet/create")
- test_no_auth(client.get, "/sheet/mine")
- test_no_auth(client.delete, "/sheet/123-sheet-id")
- test_no_auth(client.post, "/sheet/123-sheet-id/archive")
-
-
-def test_create_sheet_endpoint(app_with_auth, db_session):
- client_with_auth = TestClient(app_with_auth)
- good_data = {
- "id": "123-sheet-id",
- "name": "Test Sheet",
- "group_id": "spaceship",
- "frequency": "daily"
- }
-
- # with good data
- response = client_with_auth.post("/sheet/create", json=good_data)
- assert response.status_code == 201
- j = response.json()
- assert datetime.fromisoformat(j.pop("created_at"))
- assert datetime.fromisoformat(j.pop("last_url_archived_at"))
- assert j.pop("author_id") == 'morty@example.com'
- assert j == good_data
-
- # already exists
- response = client_with_auth.post("/sheet/create", json=good_data)
- assert response.status_code == 400
- assert response.json() == {"detail": "Sheet with this ID is already being archived."}
-
- # bad group
- bad_data = good_data.copy()
- bad_data["group_id"] = "not a group"
- response = client_with_auth.post("/sheet/create", json=bad_data)
- assert response.status_code == 403
- assert response.json() == {"detail": "User does not have access to this group."}
-
- # switch to jerry who's got less quota/permissions
- from app.web.security import get_user_state
- from app.web.db.user_state import UserState
- app_with_auth.dependency_overrides[get_user_state] = lambda: UserState(db_session, "jerry@example.com")
- client_jerry = TestClient(app_with_auth)
-
- # frequency not allowed
- jerry_data = good_data.copy()
- jerry_data["group_id"] = "animated-characters"
- jerry_data["frequency"] = "hourly"
- jerry_data["id"] = "jerry-sheet-id"
- response = client_jerry.post("/sheet/create", json=jerry_data)
- assert response.status_code == 422
- assert response.json() == {"detail": "Invalid frequency selected for this group."}
-
- jerry_data["frequency"] = "daily"
- # success for the first sheet, bad quota on second
- response = client_jerry.post("/sheet/create", json=jerry_data)
- assert response.status_code == 201
-
- response = client_jerry.post("/sheet/create", json=jerry_data)
- assert response.status_code == 429
- assert response.json() == {"detail": "User has reached their sheet quota for this group."}
-
-
-def test_get_user_sheets_endpoint(client_with_auth, db_session):
- # no data
- response = client_with_auth.get("/sheet/mine")
- assert response.status_code == 200
- assert response.json() == []
-
- # with data
- from app.shared.db import models
- db_session.add(
- models.Sheet(id="123", name="Test Sheet 1", author_id="morty@example.com", group_id="spaceship", frequency="hourly")
- )
- db_session.commit()
- db_session.add_all([
- models.Sheet(id="456", name="Test Sheet 2", author_id="morty@example.com", group_id="interdimensional", frequency="daily"),
- models.Sheet(id="789", name="Test Sheet 3", author_id="rick@example.com", group_id="interdimensional", frequency="hourly"),
- ])
- db_session.commit()
-
- response = client_with_auth.get("/sheet/mine")
- assert response.status_code == 200
- r = response.json()
- assert isinstance(r, list)
- assert len(r) == 2
- assert datetime.fromisoformat(r[0].pop("created_at"))
- assert datetime.fromisoformat(r[0].pop("last_url_archived_at"))
- assert datetime.fromisoformat(r[1].pop("created_at"))
- assert datetime.fromisoformat(r[1].pop("last_url_archived_at"))
- assert r[0] == {
- 'id': '123',
- 'author_id': 'morty@example.com',
- 'frequency': 'hourly',
- 'group_id': 'spaceship',
- 'name': 'Test Sheet 1',
- }
- assert r[1] == {
- 'id': '456',
- 'author_id': 'morty@example.com',
- 'frequency': 'daily',
- 'group_id': 'interdimensional',
- 'name': 'Test Sheet 2',
- }
-
-
-def test_delete_sheet_endpoint(client_with_auth, db_session):
- # missing sheet
- response = client_with_auth.delete("/sheet/123-sheet-id")
- assert response.status_code == 200
- assert response.json() == {
- "id": "123-sheet-id",
- "deleted": False
- }
-
- # add sheets for deletion
- from app.shared.db import models
- db_session.add_all([
- models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="interdimensional", frequency="daily"),
- models.Sheet(id="456-sheet-id", name="Test Sheet 2", author_id="rick@example.com", group_id="spaceship", frequency="hourly"),
- ])
- db_session.commit()
-
- # morty can delete his
- response = client_with_auth.delete("/sheet/123-sheet-id")
- assert response.status_code == 200
- assert response.json() == {"id": "123-sheet-id", "deleted": True}
- # but only once
- response = client_with_auth.delete("/sheet/123-sheet-id")
- assert response.status_code == 200
- assert response.json() == {"id": "123-sheet-id", "deleted": False}
- # and not rick's
- response = client_with_auth.delete("/sheet/456-sheet-id")
- assert response.status_code == 200
- assert response.json() == {"id": "456-sheet-id", "deleted": False}
-
-
-class TestArchiveUserSheetEndpoint:
- @patch("app.web.endpoints.sheet.celery", return_value=MagicMock())
- def test_normal_flow(self, m_celery, client_with_auth, db_session):
- from app.shared.db import models
- db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="spaceship", frequency="hourly"))
- db_session.commit()
-
- m_signature = MagicMock()
- m_signature.apply_async.return_value = TaskResult(id="123-taskid", status="PENDING", result="")
- m_celery.signature.return_value = m_signature
-
- r = client_with_auth.post("/sheet/123-sheet-id/archive")
- assert r.status_code == 201
- assert r.json() == {"id": "123-taskid"}
- m_celery.signature.assert_called_once()
- m_signature.apply_async.assert_called_once()
-
- def test_token_auth(self, client_with_token, test_no_auth):
- test_no_auth(client_with_token.post, "/sheet/123-sheet-id/archive")
-
- def test_missing_data(self, client_with_auth):
- r = client_with_auth.post("/sheet/123-sheet-id/archive")
- assert r.status_code == 403
- assert r.json() == {"detail": "No access to this sheet."}
-
- def test_no_access(self, client_with_auth, db_session):
- from app.shared.db import models
- db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="rick@example.com", group_id="spaceship", frequency="hourly"))
- db_session.commit()
- r = client_with_auth.post("/sheet/123-sheet-id/archive")
- assert r.status_code == 403
- assert r.json() == {"detail": "No access to this sheet."}
-
- def test_user_not_in_group(self, client_with_auth, db_session):
- from app.shared.db import models
- db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="interdimensional", frequency="hourly"))
- db_session.commit()
- r = client_with_auth.post("/sheet/123-sheet-id/archive")
- assert r.status_code == 403
- assert r.json() == {"detail": "User does not have access to this group."}
-
- def test_user_cannot_manually_trigger(self, client_with_auth, db_session):
- from app.shared.db import models
- db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="default", frequency="hourly"))
- db_session.commit()
- r = client_with_auth.post("/sheet/123-sheet-id/archive")
- assert r.status_code == 429
- assert r.json() == {"detail": "User cannot manually trigger sheet archiving in this group."}
diff --git a/app/tests/web/endpoints/test_url.py b/app/tests/web/endpoints/test_url.py
deleted file mode 100644
index 1b6ee85..0000000
--- a/app/tests/web/endpoints/test_url.py
+++ /dev/null
@@ -1,202 +0,0 @@
-import json
-from unittest.mock import MagicMock, patch
-
-from app.shared.schemas import ArchiveCreate, TaskResult
-from app.web.config import ALLOW_ANY_EMAIL
-
-
-def test_archive_url_unauthenticated(client, test_no_auth):
- test_no_auth(client.post, "/url/archive")
-
-
-@patch("app.web.endpoints.url.UserState")
-@patch("app.web.endpoints.url.celery", return_value=MagicMock())
-def test_archive_url(m_celery, m2, client_with_auth):
- m_signature = MagicMock()
- m_signature.apply_async.return_value = TaskResult(id="123-456-789", status="PENDING", result="")
- m_celery.signature.return_value = m_signature
-
- m_user_state = MagicMock()
- m2.return_value = m_user_state
-
- # url is too short
- response = client_with_auth.post("/url/archive", json={"url": "bad"})
- assert response.status_code == 422
- assert response.json()["detail"][0]["msg"] == 'String should have at least 5 characters'
- m_celery.signature.assert_not_called()
-
- # url is invalid
- response = client_with_auth.post("/url/archive", json={"url": "example.com"})
- assert response.status_code == 400
- assert response.json()["detail"] == "Invalid URL received."
-
- # valid request
- m_user_state.has_quota_max_monthly_urls.return_value = True
- m_user_state.has_quota_max_monthly_mbs.return_value = True
- response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
- assert response.status_code == 201
- assert response.json() == {'id': '123-456-789'}
- m_celery.signature.assert_called_once()
- m_signature.apply_async.assert_called_once()
- called_val = m_celery.signature.call_args
- assert called_val[0][0] == "create_archive_task"
- assert json.loads(called_val[1]['args'][0]) == {"id": None, "url": "https://example.com", "result": None, "public": False, "author_id": "rick@example.com", "group_id": "default", "tags": None, "sheet_id": None, "store_until": None, "urls": None}
- m_user_state.has_quota_max_monthly_urls.assert_called_once()
- m_user_state.has_quota_max_monthly_mbs.assert_called_once()
- m_user_state.in_group.assert_called_once_with("default")
-
- # user is not in group
- m_user_state.in_group.return_value = False
- response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "new-group"})
- assert response.status_code == 403
- assert response.json()["detail"] == "User does not have access to this group."
- m_user_state.in_group.assert_called_with("new-group")
-
- # user is in group
- m_user_state.in_group.return_value = True
- response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spaceship"})
- assert response.status_code == 201
- assert response.json() == {'id': '123-456-789'}
- assert m_celery.signature.call_count == 2
- assert m_signature.apply_async.call_count == 2
- called_val = m_celery.signature.call_args
- assert json.loads(called_val[1]['args'][0])["group_id"] == "spaceship"
- m_user_state.in_group.assert_called_with("spaceship")
-
- # user is over monthly URL quota
- m_user_state.has_quota_max_monthly_urls.return_value = False
- m_user_state.has_quota_max_monthly_mbs.return_value = True
- response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spaceship"})
- assert response.status_code == 429
- assert response.json()["detail"] == "User has reached their monthly URL quota."
- m_user_state.has_quota_max_monthly_urls.assert_called_with("spaceship")
-
- # user is over monthly MB quota
- m_user_state.has_quota_max_monthly_urls.return_value = True
- m_user_state.has_quota_max_monthly_mbs.return_value = False
- response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spacesuit"})
- assert response.status_code == 429
- assert response.json()["detail"] == "User has reached their monthly MB quota."
- m_user_state.has_quota_max_monthly_mbs.assert_called_with("spacesuit")
- assert m_celery.signature.call_count == 2
- assert m_signature.apply_async.call_count == 2
-
-
-@patch("app.web.endpoints.url.UserState")
-def test_archive_url_quotas(m1, client_with_auth):
- m_user_state = MagicMock()
- m1.return_value = m_user_state
-
- # misses on monthly URLs quota
- m_user_state.has_quota_max_monthly_urls.return_value = False
- response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
- assert response.status_code == 429
- assert response.json()["detail"] == "User has reached their monthly URL quota."
- m_user_state.has_quota_max_monthly_urls.assert_called_once()
-
- # misses on monthly MBs quota
- m_user_state.has_quota_max_monthly_urls.return_value = True
- m_user_state.has_quota_max_monthly_mbs.return_value = False
- response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
- assert response.status_code == 429
- assert response.json()["detail"] == "User has reached their monthly MB quota."
- m_user_state.has_quota_max_monthly_mbs.assert_called_once()
-
-
-@patch("app.web.endpoints.url.celery", return_value=MagicMock())
-def test_archive_url_with_api_token(m_celery, client_with_token):
- m_signature = MagicMock()
- m_signature.apply_async.return_value = TaskResult(id="123-456-789", status="PENDING", result="")
- m_celery.signature.return_value = m_signature
- response = client_with_token.post("/url/archive", json={"url": "https://example.com", "author_id": "someone@example.com"})
- assert response.status_code == 201
- assert response.json() == {'id': '123-456-789'}
- m_celery.signature.assert_called_once()
- m_signature.apply_async.assert_called_once()
- called_val = m_celery.signature.call_args
- assert called_val[0][0] == "create_archive_task"
- assert json.loads(called_val[1]['args'][0]) == {"id": None, "url": "https://example.com", "result": None, "public": False, "author_id": "someone@example.com", "group_id": "default", "tags": None, "sheet_id": None, "store_until": None, "urls": None}
-
- # missing id should use ALLOW_ANY_EMAIL
- response = client_with_token.post("/url/archive", json={"url": "https://example.com", "author_id": None})
- assert response.status_code == 201
- called_val = m_celery.signature.call_args
- assert called_val[0][0] == "create_archive_task"
- assert json.loads(called_val[1]['args'][0]) == {"id": None, "url": "https://example.com", "result": None, "public": False, "author_id": ALLOW_ANY_EMAIL, "group_id": "default", "tags": None, "sheet_id": None, "store_until": None, "urls": None}
-
-
-def test_search_by_url_unauthenticated(client, test_no_auth):
- test_no_auth(client.get, "/url/search")
-
-
-def test_search_by_url(client_with_auth, client_with_token, db_session):
- # tests the search endpoint, including through some db data for the endpoint params
- response = client_with_auth.get("/url/search")
- assert response.status_code == 422
- assert response.json()["detail"][0]["msg"] == "Field required"
-
- response = client_with_auth.get("/url/search?url=https://example.com")
- assert response.status_code == 200
- assert response.json() == []
-
- from app.shared import schemas
- from app.shared.db import worker_crud
- for i in range(11):
- worker_crud.create_archive(db_session, ArchiveCreate(id=f"url-456-{i}", url="https://example.com" if i < 10 else "https://something-else.com", result={}, public=True, author_id="rick@example.com"), [], [])
- # NB: this insertion is too fast for the ordering to be correct as they are within the same second
-
- response = client_with_auth.get("/url/search?url=https://example.com")
- assert response.status_code == 200
- assert len(j := response.json()) == 10
- assert "url-456-0" in [i["id"] for i in j]
- assert "url-456-9" in [i["id"] for i in j]
- assert "url-456-10" not in [i["id"] for i in j]
- assert j[0].keys() == schemas.ArchiveResult.model_fields.keys()
-
- response = client_with_auth.get("/url/search?url=https://example.com&limit=5")
- assert response.status_code == 200
- assert len(response.json()) == 5
-
- response = client_with_auth.get("/url/search?url=https://example.com&skip=5&limit=2")
- assert response.status_code == 200
- assert len(response.json()) == 2
-
- response = client_with_auth.get("/url/search?url=https://example.com&archived_before=2010-01-01")
- assert response.status_code == 200
- assert len(response.json()) == 0
-
- response = client_with_auth.get("/url/search?url=https://example.com&archived_after=2010-01-01")
- assert response.status_code == 200
- assert len(response.json()) == 10
-
- # API token will also work
- response = client_with_token.get("/url/search?url=https://example.com&archived_after=2010-01-01")
- assert response.status_code == 200
- assert len(response.json()) == 10
-
-
-@patch("app.web.endpoints.url.UserState")
-def test_search_no_read_access(mock_user_state, client_with_auth):
- mock_user_state.return_value.read = False
- mock_user_state.return_value.read_public = False
-
- response = client_with_auth.get("/url/search?url=https://example.com")
- assert response.status_code == 403
- assert response.json() == {"detail": "User does not have read access."}
-
-
-def test_delete_task_unauthenticated(client, test_no_auth):
- test_no_auth(client.delete, "/url/123-456-789")
-
-
-def test_delete_task(client_with_auth, db_session):
- response = client_with_auth.delete("/url/delete-123-456-789")
- assert response.status_code == 200
- assert response.json() == {"id": "delete-123-456-789", "deleted": False}
-
- from app.shared.db import worker_crud
- worker_crud.create_archive(db_session, ArchiveCreate(id="delete-123-456-789", url="https://example.com", result={}, public=True, author_id="morty@example.com"), [], [])
-
- response = client_with_auth.delete("/url/delete-123-456-789")
- assert response.status_code == 200
- assert response.json() == {"id": "delete-123-456-789", "deleted": True}
diff --git a/app/tests/web/endpoints/test_default.py b/app/tests/web/routers/test_default.py
similarity index 73%
rename from app/tests/web/endpoints/test_default.py
rename to app/tests/web/routers/test_default.py
index 401a164..970e705 100644
--- a/app/tests/web/endpoints/test_default.py
+++ b/app/tests/web/routers/test_default.py
@@ -1,15 +1,20 @@
+from http import HTTPStatus
from unittest.mock import MagicMock
-from fastapi.testclient import TestClient
+
import pytest
+from fastapi.testclient import TestClient
+from loguru import logger
+
from app.shared.schemas import Usage, UsageResponse
from app.shared.user_groups import GroupInfo
from app.web.config import VERSION
-from app.tests.web.db.test_crud import test_data
+from app.web.security import get_user_state
+from app.web.utils.metrics import measure_regular_metrics
def test_endpoint_home(client_with_auth):
r = client_with_auth.get("/")
- assert r.status_code == 200
+ assert r.status_code == HTTPStatus.OK
j = r.json()
assert "version" in j and j["version"] == VERSION
assert "breakingChanges" in j
@@ -18,7 +23,7 @@ def test_endpoint_home(client_with_auth):
def test_endpoint_health(client_with_auth):
r = client_with_auth.get("/health")
- assert r.status_code == 200
+ assert r.status_code == HTTPStatus.OK
assert r.json() == {"status": "ok"}
@@ -29,32 +34,31 @@ def test_endpoint_active_no_auth(client, test_no_auth):
def test_endpoint_active(app):
m_user_state = MagicMock()
- from app.web.security import get_user_state
app.dependency_overrides[get_user_state] = lambda: m_user_state
# inactive user
m_user_state.active = False
client = TestClient(app)
r = client.get("/user/active")
- assert r.status_code == 200
+ assert r.status_code == HTTPStatus.OK
assert r.json() == {"active": False}
# active user
m_user_state.active = True
client = TestClient(app)
r = client.get("/user/active")
- assert r.status_code == 200
+ assert r.status_code == HTTPStatus.OK
assert r.json() == {"active": True}
def test_no_serve_local_archive_by_default(client_with_auth):
r = client_with_auth.get("/app/local_archive_test/temp.txt")
- assert r.status_code == 404
+ assert r.status_code == HTTPStatus.NOT_FOUND
def test_favicon(client_with_auth):
r = client_with_auth.get("/favicon.ico")
- assert r.status_code == 200
+ assert r.status_code == HTTPStatus.OK
assert r.headers["content-type"] == "image/vnd.microsoft.icon"
@@ -70,8 +74,10 @@ def test_endpoint_test_prometheus_no_user_auth(client_with_auth, test_no_auth):
async def test_prometheus_metrics(test_data, client_with_token, get_settings):
# before metrics calculation
r = client_with_token.get("/metrics")
- assert r.status_code == 200
- assert r.headers["content-type"] == "text/plain; version=0.0.4; charset=utf-8"
+ assert r.status_code == HTTPStatus.OK
+ assert (
+ r.headers["content-type"] == "text/plain; version=0.0.4; charset=utf-8"
+ )
assert "disk_utilization" in r.text
assert "database_metrics" in r.text
assert "exceptions" in r.text
@@ -79,8 +85,9 @@ async def test_prometheus_metrics(test_data, client_with_token, get_settings):
assert 'disk_utilization{type="used"}' not in r.text
# after metrics calculation
- from app.web.utils.metrics import measure_regular_metrics
- await measure_regular_metrics(get_settings.DATABASE_PATH, 60 * 60 * 24 * 31 * 12 * 100)
+ await measure_regular_metrics(
+ get_settings.DATABASE_PATH, 60 * 60 * 24 * 31 * 12 * 100
+ )
r2 = client_with_token.get("/metrics")
assert 'disk_utilization{type="used"}' in r2.text
assert 'disk_utilization{type="free"}' in r2.text
@@ -88,20 +95,37 @@ async def test_prometheus_metrics(test_data, client_with_token, get_settings):
assert 'database_metrics{query="count_archives"} 100.0' in r2.text
assert 'database_metrics{query="count_archive_urls"} 1000.0' in r2.text
assert 'database_metrics{query="count_users"} 3.0' in r2.text
- assert 'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0' in r2.text
- assert 'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0' in r2.text
- assert 'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0' in r2.text
+ assert (
+ 'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0'
+ in r2.text
+ )
+ assert (
+ 'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0'
+ in r2.text
+ )
+ assert (
+ 'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0'
+ in r2.text
+ )
# 30s window, should not change the gauges nor the total in the counters
- from app.web.utils.metrics import measure_regular_metrics
await measure_regular_metrics(get_settings.DATABASE_PATH, 30)
r3 = client_with_token.get("/metrics")
assert 'database_metrics{query="count_archives"} 100.0' in r3.text
assert 'database_metrics{query="count_archive_urls"} 1000.0' in r3.text
assert 'database_metrics{query="count_users"} 3.0' in r3.text
- assert 'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0' in r3.text
- assert 'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0' in r3.text
- assert 'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0' in r3.text
+ assert (
+ 'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0'
+ in r3.text
+ )
+ assert (
+ 'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0'
+ in r3.text
+ )
+ assert (
+ 'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0'
+ in r3.text
+ )
def test_endpoint_get_user_permissions_no_user_auth(client, test_no_auth):
@@ -109,14 +133,12 @@ def test_endpoint_get_user_permissions_no_user_auth(client, test_no_auth):
def test_endpoint_get_user_permissions(app):
- from app.web.security import get_user_state
-
m_user_state = MagicMock()
rv = {
"all": GroupInfo(read=True),
"group1": GroupInfo(archive_url=True),
}
- from loguru import logger
+
logger.info(rv)
m_user_state.permissions = rv
@@ -124,13 +146,13 @@ def test_endpoint_get_user_permissions(app):
client = TestClient(app)
r = client.get("/user/permissions")
- assert r.status_code == 200
+ assert r.status_code == HTTPStatus.OK
response = r.json()
assert response.keys() == {"all", "group1"}
assert response["all"]["read"]
assert response["group1"]["read"] == []
assert response["group1"]["archive_url"]
- assert response["all"]["archive_url"] == False
+ assert response["all"]["archive_url"] is False
def test_endpoint_get_user_usage_no_user_auth(client, test_no_auth):
@@ -138,8 +160,6 @@ def test_endpoint_get_user_usage_no_user_auth(client, test_no_auth):
def test_endpoint_get_user_usage_inactive(app):
- from app.web.security import get_user_state
-
m_user_state = MagicMock()
m_user_state.active = False
@@ -147,13 +167,11 @@ def test_endpoint_get_user_usage_inactive(app):
client = TestClient(app)
r = client.get("/user/usage")
- assert r.status_code == 403
+ assert r.status_code == HTTPStatus.FORBIDDEN
assert r.json() == {"detail": "User is not active."}
def test_endpoint_get_user_usage_active(app):
- from app.web.security import get_user_state
-
m_user_state = MagicMock()
m_user_state.active = True
mock_usage = UsageResponse(
@@ -162,8 +180,8 @@ def test_endpoint_get_user_usage_active(app):
total_sheets=3,
groups={
"group1": Usage(monthly_urls=4, monthly_mbs=5, total_sheets=6),
- "group2": Usage(monthly_urls=7, monthly_mbs=8, total_sheets=9)
- }
+ "group2": Usage(monthly_urls=7, monthly_mbs=8, total_sheets=9),
+ },
)
m_user_state.usage.return_value = mock_usage
@@ -171,5 +189,5 @@ def test_endpoint_get_user_usage_active(app):
client = TestClient(app)
r = client.get("/user/usage")
- assert r.status_code == 200
+ assert r.status_code == HTTPStatus.OK
assert UsageResponse(**r.json()) == mock_usage
diff --git a/app/tests/web/routers/test_interoperability.py b/app/tests/web/routers/test_interoperability.py
new file mode 100644
index 0000000..d53c289
--- /dev/null
+++ b/app/tests/web/routers/test_interoperability.py
@@ -0,0 +1,147 @@
+import json
+from datetime import datetime
+from http import HTTPStatus
+from unittest.mock import MagicMock, patch
+
+from app.shared.db import models
+
+
+def test_submit_manual_archive_unauthenticated(client, test_no_auth):
+ test_no_auth(client.post, "/interop/submit-archive")
+
+
+def test_submit_manual_archive_not_user_auth(client_with_auth, test_no_auth):
+ test_no_auth(client_with_auth.post, "/interop/submit-archive")
+
+
+@patch(
+ "app.web.routers.interoperability.business_logic",
+ return_value=MagicMock(
+ get_store_archive_until=MagicMock(return_value=datetime)
+ ),
+)
+def test_submit_manual_archive(m1, client_with_token, db_session):
+ # normal workflow
+ aa_metadata = json.dumps(
+ {
+ "status": "test: success",
+ "metadata": {"url": "http://example.com"},
+ "media": [{"filename": "fn1", "urls": ["http://example.s3.com"]}],
+ }
+ )
+ r = client_with_token.post(
+ "/interop/submit-archive",
+ json={
+ "result": aa_metadata,
+ "public": True,
+ "author_id": "jerry@gmail.com",
+ "group_id": "spaceship",
+ "tags": ["test"],
+ "url": "http://example.com",
+ },
+ )
+ assert r.status_code == HTTPStatus.CREATED
+ assert "id" in r.json()
+
+ inserted = (
+ db_session.query(models.Archive)
+ .filter(models.Archive.id == r.json()["id"])
+ .first()
+ )
+ assert inserted.url == "http://example.com"
+ assert inserted.group_id == "spaceship"
+ assert inserted.author_id == "jerry@gmail.com"
+ assert sorted([t.id for t in inserted.tags]) == sorted(["test", "manual"])
+ assert inserted.public
+ assert isinstance(inserted.result, dict)
+ assert [u.url for u in inserted.urls] == ["http://example.s3.com"]
+ assert isinstance(inserted.store_until, datetime)
+
+ # cannot have the same URL twice
+ aa_metadata = json.dumps(
+ {
+ "status": "test: success",
+ "metadata": {"url": "http://example.com"},
+ "media": [
+ {
+ "filename": "fn1",
+ "urls": ["http://example.com", "http://example.com"],
+ }
+ ],
+ }
+ )
+ r = client_with_token.post(
+ "/interop/submit-archive",
+ json={
+ "result": aa_metadata,
+ "public": False,
+ "author_id": "jerry@gmail.com",
+ "tags": ["test"],
+ "url": "http://example.com",
+ },
+ )
+ assert r.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
+ assert r.json() == {
+ "detail": "Cannot insert into DB due to integrity error, likely duplicate urls."
+ }
+
+
+# test with invalid JSON
+def test_submit_manual_archive_invalid_json(client_with_token):
+ r = client_with_token.post(
+ "/interop/submit-archive",
+ json={
+ "result": "invalid json",
+ "public": False,
+ "author_id": "jer",
+ "tags": ["test"],
+ "url": "http://example.com",
+ },
+ )
+ assert r.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
+ assert r.json() == {"detail": "Invalid JSON in result field."}
+
+
+@patch(
+ "app.web.routers.interoperability.business_logic.get_store_archive_until",
+ side_effect=AssertionError("AssertionError"),
+)
+def test_submit_manual_archive_no_store_until(
+ m_sau, client_with_token, db_session
+):
+ aa_metadata = json.dumps(
+ {
+ "status": "test: success",
+ "metadata": {"url": "http://example.com"},
+ "media": [{"filename": "fn1", "urls": ["http://example.s3.com"]}],
+ }
+ )
+ r = client_with_token.post(
+ "/interop/submit-archive",
+ json={
+ "result": aa_metadata,
+ "public": True,
+ "author_id": "jerry@gmail.com",
+ "group_id": "spaceship",
+ "tags": ["test"],
+ "url": "http://example.com",
+ },
+ )
+ assert r.status_code == HTTPStatus.CREATED
+ assert len(r.json()["id"]) == 36
+ res = (
+ db_session.query(models.Archive)
+ .filter(models.Archive.id == r.json()["id"])
+ .first()
+ )
+ assert res.store_until is None
+ # testing that store_until = None is not comparable with datetime, and will always return False
+ res = (
+ db_session.query(models.Archive)
+ .filter(
+ models.Archive.id == r.json()["id"],
+ models.Archive.store_until < datetime.now(),
+ )
+ .first()
+ )
+ assert res is None
diff --git a/app/tests/web/routers/test_sheet.py b/app/tests/web/routers/test_sheet.py
new file mode 100644
index 0000000..41c6fd6
--- /dev/null
+++ b/app/tests/web/routers/test_sheet.py
@@ -0,0 +1,268 @@
+from datetime import datetime
+from http import HTTPStatus
+from unittest.mock import MagicMock, patch
+
+from fastapi.testclient import TestClient
+
+from app.shared.constants import STATUS_PENDING
+from app.shared.db import models
+from app.shared.schemas import TaskResult
+from app.web.db.user_state import UserState
+from app.web.security import get_user_state
+
+
+def test_endpoints_no_auth(client, test_no_auth):
+ test_no_auth(client.post, "/sheet/create")
+ test_no_auth(client.get, "/sheet/mine")
+ test_no_auth(client.delete, "/sheet/123-sheet-id")
+ test_no_auth(client.post, "/sheet/123-sheet-id/archive")
+
+
+def test_create_sheet_endpoint(app_with_auth, db_session):
+ client_with_auth = TestClient(app_with_auth)
+ good_data = {
+ "id": "123-sheet-id",
+ "name": "Test Sheet",
+ "group_id": "spaceship",
+ "frequency": "daily",
+ }
+
+ # with good data
+ response = client_with_auth.post("/sheet/create", json=good_data)
+ assert response.status_code == HTTPStatus.CREATED
+ j = response.json()
+ assert datetime.fromisoformat(j.pop("created_at"))
+ assert datetime.fromisoformat(j.pop("last_url_archived_at"))
+ assert j.pop("author_id") == "morty@example.com"
+ assert j == good_data
+
+ # already exists
+ response = client_with_auth.post("/sheet/create", json=good_data)
+ assert response.status_code == HTTPStatus.BAD_REQUEST
+ assert response.json() == {
+ "detail": "Sheet with this ID is already being archived."
+ }
+
+ # bad group
+ bad_data = good_data.copy()
+ bad_data["group_id"] = "not a group"
+ response = client_with_auth.post("/sheet/create", json=bad_data)
+ assert response.status_code == HTTPStatus.FORBIDDEN
+ assert response.json() == {
+ "detail": "User does not have access to this group."
+ }
+
+ # switch to jerry who's got less quota/permissions
+ app_with_auth.dependency_overrides[get_user_state] = lambda: UserState(
+ db_session, "jerry@example.com"
+ )
+ client_jerry = TestClient(app_with_auth)
+
+ # frequency not allowed
+ jerry_data = good_data.copy()
+ jerry_data["group_id"] = "animated-characters"
+ jerry_data["frequency"] = "hourly"
+ jerry_data["id"] = "jerry-sheet-id"
+ response = client_jerry.post("/sheet/create", json=jerry_data)
+ assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
+ assert response.json() == {
+ "detail": "Invalid frequency selected for this group."
+ }
+
+ jerry_data["frequency"] = "daily"
+ # success for the first sheet, bad quota on second
+ response = client_jerry.post("/sheet/create", json=jerry_data)
+ assert response.status_code == HTTPStatus.CREATED
+
+ response = client_jerry.post("/sheet/create", json=jerry_data)
+ assert response.status_code == HTTPStatus.TOO_MANY_REQUESTS
+ assert response.json() == {
+ "detail": "User has reached their sheet quota for this group."
+ }
+
+
+def test_get_user_sheets_endpoint(client_with_auth, db_session):
+ # no data
+ response = client_with_auth.get("/sheet/mine")
+ assert response.status_code == HTTPStatus.OK
+ assert response.json() == []
+
+ # with data
+ db_session.add(
+ models.Sheet(
+ id="123",
+ name="Test Sheet 1",
+ author_id="morty@example.com",
+ group_id="spaceship",
+ frequency="hourly",
+ )
+ )
+ db_session.commit()
+ db_session.add_all(
+ [
+ models.Sheet(
+ id="456",
+ name="Test Sheet 2",
+ author_id="morty@example.com",
+ group_id="interdimensional",
+ frequency="daily",
+ ),
+ models.Sheet(
+ id="789",
+ name="Test Sheet 3",
+ author_id="rick@example.com",
+ group_id="interdimensional",
+ frequency="hourly",
+ ),
+ ]
+ )
+ db_session.commit()
+
+ response = client_with_auth.get("/sheet/mine")
+ assert response.status_code == HTTPStatus.OK
+ r = response.json()
+ assert isinstance(r, list)
+ assert len(r) == 2
+ assert datetime.fromisoformat(r[0].pop("created_at"))
+ assert datetime.fromisoformat(r[0].pop("last_url_archived_at"))
+ assert datetime.fromisoformat(r[1].pop("created_at"))
+ assert datetime.fromisoformat(r[1].pop("last_url_archived_at"))
+ assert r[0] == {
+ "id": "123",
+ "author_id": "morty@example.com",
+ "frequency": "hourly",
+ "group_id": "spaceship",
+ "name": "Test Sheet 1",
+ }
+ assert r[1] == {
+ "id": "456",
+ "author_id": "morty@example.com",
+ "frequency": "daily",
+ "group_id": "interdimensional",
+ "name": "Test Sheet 2",
+ }
+
+
+def test_delete_sheet_endpoint(client_with_auth, db_session):
+ # missing sheet
+ response = client_with_auth.delete("/sheet/123-sheet-id")
+ assert response.status_code == HTTPStatus.OK
+ assert response.json() == {"id": "123-sheet-id", "deleted": False}
+
+ # add sheets for deletion
+ db_session.add_all(
+ [
+ models.Sheet(
+ id="123-sheet-id",
+ name="Test Sheet 1",
+ author_id="morty@example.com",
+ group_id="interdimensional",
+ frequency="daily",
+ ),
+ models.Sheet(
+ id="456-sheet-id",
+ name="Test Sheet 2",
+ author_id="rick@example.com",
+ group_id="spaceship",
+ frequency="hourly",
+ ),
+ ]
+ )
+ db_session.commit()
+
+ # morty can delete his
+ response = client_with_auth.delete("/sheet/123-sheet-id")
+ assert response.status_code == HTTPStatus.OK
+ assert response.json() == {"id": "123-sheet-id", "deleted": True}
+ # but only once
+ response = client_with_auth.delete("/sheet/123-sheet-id")
+ assert response.status_code == HTTPStatus.OK
+ assert response.json() == {"id": "123-sheet-id", "deleted": False}
+ # and not Rick's
+ response = client_with_auth.delete("/sheet/456-sheet-id")
+ assert response.status_code == HTTPStatus.OK
+ assert response.json() == {"id": "456-sheet-id", "deleted": False}
+
+
+class TestArchiveUserSheetEndpoint:
+ @patch("app.web.routers.sheet.celery", return_value=MagicMock())
+ def test_normal_flow(self, m_celery, client_with_auth, db_session):
+ db_session.add(
+ models.Sheet(
+ id="123-sheet-id",
+ name="Test Sheet 1",
+ author_id="morty@example.com",
+ group_id="spaceship",
+ frequency="hourly",
+ )
+ )
+ db_session.commit()
+
+ m_signature = MagicMock()
+ m_signature.apply_async.return_value = TaskResult(
+ id="123-taskid", status=STATUS_PENDING, result=""
+ )
+ m_celery.signature.return_value = m_signature
+
+ r = client_with_auth.post("/sheet/123-sheet-id/archive")
+ assert r.status_code == HTTPStatus.CREATED
+ assert r.json() == {"id": "123-taskid"}
+ m_celery.signature.assert_called_once()
+ m_signature.apply_async.assert_called_once()
+
+ def test_token_auth(self, client_with_token, test_no_auth):
+ test_no_auth(client_with_token.post, "/sheet/123-sheet-id/archive")
+
+ def test_missing_data(self, client_with_auth):
+ r = client_with_auth.post("/sheet/123-sheet-id/archive")
+ assert r.status_code == HTTPStatus.FORBIDDEN
+ assert r.json() == {"detail": "No access to this sheet."}
+
+ def test_no_access(self, client_with_auth, db_session):
+ db_session.add(
+ models.Sheet(
+ id="123-sheet-id",
+ name="Test Sheet 1",
+ author_id="rick@example.com",
+ group_id="spaceship",
+ frequency="hourly",
+ )
+ )
+ db_session.commit()
+ r = client_with_auth.post("/sheet/123-sheet-id/archive")
+ assert r.status_code == HTTPStatus.FORBIDDEN
+ assert r.json() == {"detail": "No access to this sheet."}
+
+ def test_user_not_in_group(self, client_with_auth, db_session):
+ db_session.add(
+ models.Sheet(
+ id="123-sheet-id",
+ name="Test Sheet 1",
+ author_id="morty@example.com",
+ group_id="interdimensional",
+ frequency="hourly",
+ )
+ )
+ db_session.commit()
+ r = client_with_auth.post("/sheet/123-sheet-id/archive")
+ assert r.status_code == HTTPStatus.FORBIDDEN
+ assert r.json() == {
+ "detail": "User does not have access to this group."
+ }
+
+ def test_user_cannot_manually_trigger(self, client_with_auth, db_session):
+ db_session.add(
+ models.Sheet(
+ id="123-sheet-id",
+ name="Test Sheet 1",
+ author_id="morty@example.com",
+ group_id="default",
+ frequency="hourly",
+ )
+ )
+ db_session.commit()
+ r = client_with_auth.post("/sheet/123-sheet-id/archive")
+ assert r.status_code == HTTPStatus.TOO_MANY_REQUESTS
+ assert r.json() == {
+ "detail": "User cannot manually trigger sheet archiving in this group."
+ }
diff --git a/app/tests/web/endpoints/test_task.py b/app/tests/web/routers/test_task.py
similarity index 54%
rename from app/tests/web/endpoints/test_task.py
rename to app/tests/web/routers/test_task.py
index 937ad46..8165d58 100644
--- a/app/tests/web/endpoints/test_task.py
+++ b/app/tests/web/routers/test_task.py
@@ -1,51 +1,53 @@
+from http import HTTPStatus
from unittest.mock import patch
+from app.shared.constants import STATUS_FAILURE, STATUS_PENDING, STATUS_SUCCESS
+
def test_endpoint_task_status_no_auth(client, test_no_auth):
test_no_auth(client.get, "/task/test-task-id")
-@patch("app.web.endpoints.task.AsyncResult")
+@patch("app.web.routers.task.AsyncResult")
def test_get_status_success(mock_async_result, client_with_auth):
- mock_async_result.return_value.status = "SUCCESS"
+ mock_async_result.return_value.status = STATUS_SUCCESS
mock_async_result.return_value.result = {"data": "some result"}
response = client_with_auth.get("/task/test-task-id")
- assert response.status_code == 200
+ assert response.status_code == HTTPStatus.OK
assert response.json() == {
"id": "test-task-id",
- "status": "SUCCESS",
- "result": {"data": "some result"}
+ "status": STATUS_SUCCESS,
+ "result": {"data": "some result"},
}
-@patch("app.web.endpoints.task.AsyncResult")
+@patch("app.web.routers.task.AsyncResult")
def test_get_status_failure(mock_async_result, client_with_auth):
-
- mock_async_result.return_value.status = "FAILURE"
+ mock_async_result.return_value.status = STATUS_FAILURE
mock_async_result.return_value.result = Exception("Some error")
response = client_with_auth.get("/task/test-task-id")
- assert response.status_code == 200
+ assert response.status_code == HTTPStatus.OK
assert response.json() == {
"id": "test-task-id",
- "status": "FAILURE",
- "result": {"error": "Some error"}
+ "status": STATUS_FAILURE,
+ "result": {"error": "Some error"},
}
-@patch("app.web.endpoints.task.AsyncResult")
+@patch("app.web.routers.task.AsyncResult")
def test_get_status_pending(mock_async_result, client_with_auth):
- mock_async_result.return_value.status = "PENDING"
+ mock_async_result.return_value.status = STATUS_PENDING
mock_async_result.return_value.result = None
response = client_with_auth.get("/task/test-task-id")
- assert response.status_code == 200
+ assert response.status_code == HTTPStatus.OK
assert response.json() == {
"id": "test-task-id",
- "status": "PENDING",
- "result": None
+ "status": STATUS_PENDING,
+ "result": None,
}
diff --git a/app/tests/web/routers/test_url.py b/app/tests/web/routers/test_url.py
new file mode 100644
index 0000000..0d1d452
--- /dev/null
+++ b/app/tests/web/routers/test_url.py
@@ -0,0 +1,312 @@
+import json
+from http import HTTPStatus
+from unittest.mock import MagicMock, patch
+
+from app.shared import schemas
+from app.shared.constants import STATUS_PENDING
+from app.shared.db import worker_crud
+from app.shared.schemas import ArchiveCreate, TaskResult
+from app.web.config import ALLOW_ANY_EMAIL
+
+
+def test_archive_url_unauthenticated(client, test_no_auth):
+ test_no_auth(client.post, "/url/archive")
+
+
+@patch("app.web.routers.url.UserState")
+@patch("app.web.routers.url.celery", return_value=MagicMock())
+def test_archive_url(m_celery, m2, client_with_auth):
+ m_signature = MagicMock()
+ m_signature.apply_async.return_value = TaskResult(
+ id="123-456-789", status=STATUS_PENDING, result=""
+ )
+ m_celery.signature.return_value = m_signature
+
+ m_user_state = MagicMock()
+ m2.return_value = m_user_state
+
+ # url is too short
+ response = client_with_auth.post("/url/archive", json={"url": "bad"})
+ assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
+ assert (
+ response.json()["detail"][0]["msg"]
+ == "String should have at least 5 characters"
+ )
+ m_celery.signature.assert_not_called()
+
+ # url is invalid
+ response = client_with_auth.post(
+ "/url/archive", json={"url": "example.com"}
+ )
+ assert response.status_code == HTTPStatus.BAD_REQUEST
+ assert response.json()["detail"] == "Invalid URL received."
+
+ # valid request
+ m_user_state.has_quota_max_monthly_urls.return_value = True
+ m_user_state.has_quota_max_monthly_mbs.return_value = True
+ response = client_with_auth.post(
+ "/url/archive", json={"url": "https://example.com"}
+ )
+ assert response.status_code == HTTPStatus.CREATED
+ assert response.json() == {"id": "123-456-789"}
+ m_celery.signature.assert_called_once()
+ m_signature.apply_async.assert_called_once()
+ called_val = m_celery.signature.call_args
+ assert called_val[0][0] == "create_archive_task"
+ assert json.loads(called_val[1]["args"][0]) == {
+ "id": None,
+ "url": "https://example.com",
+ "result": None,
+ "public": False,
+ "author_id": "rick@example.com",
+ "group_id": "default",
+ "tags": None,
+ "sheet_id": None,
+ "store_until": None,
+ "urls": None,
+ }
+ m_user_state.has_quota_max_monthly_urls.assert_called_once()
+ m_user_state.has_quota_max_monthly_mbs.assert_called_once()
+ m_user_state.in_group.assert_called_once_with("default")
+
+ # user is not in group
+ m_user_state.in_group.return_value = False
+ response = client_with_auth.post(
+ "/url/archive",
+ json={"url": "https://example.com", "group_id": "new-group"},
+ )
+ assert response.status_code == HTTPStatus.FORBIDDEN
+ assert (
+ response.json()["detail"] == "User does not have access to this group."
+ )
+ m_user_state.in_group.assert_called_with("new-group")
+
+ # user is in group
+ m_user_state.in_group.return_value = True
+ response = client_with_auth.post(
+ "/url/archive",
+ json={"url": "https://example.com", "group_id": "spaceship"},
+ )
+ assert response.status_code == HTTPStatus.CREATED
+ assert response.json() == {"id": "123-456-789"}
+ assert m_celery.signature.call_count == 2
+ assert m_signature.apply_async.call_count == 2
+ called_val = m_celery.signature.call_args
+ assert json.loads(called_val[1]["args"][0])["group_id"] == "spaceship"
+ m_user_state.in_group.assert_called_with("spaceship")
+
+ # user is over monthly URL quota
+ m_user_state.has_quota_max_monthly_urls.return_value = False
+ m_user_state.has_quota_max_monthly_mbs.return_value = True
+ response = client_with_auth.post(
+ "/url/archive",
+ json={"url": "https://example.com", "group_id": "spaceship"},
+ )
+ assert response.status_code == HTTPStatus.TOO_MANY_REQUESTS
+ assert (
+ response.json()["detail"] == "User has reached their monthly URL quota."
+ )
+ m_user_state.has_quota_max_monthly_urls.assert_called_with("spaceship")
+
+ # user is over monthly MB quota
+ m_user_state.has_quota_max_monthly_urls.return_value = True
+ m_user_state.has_quota_max_monthly_mbs.return_value = False
+ response = client_with_auth.post(
+ "/url/archive",
+ json={"url": "https://example.com", "group_id": "spacesuit"},
+ )
+ assert response.status_code == HTTPStatus.TOO_MANY_REQUESTS
+ assert (
+ response.json()["detail"] == "User has reached their monthly MB quota."
+ )
+ m_user_state.has_quota_max_monthly_mbs.assert_called_with("spacesuit")
+ assert m_celery.signature.call_count == 2
+ assert m_signature.apply_async.call_count == 2
+
+
+@patch("app.web.routers.url.UserState")
+def test_archive_url_quotas(m1, client_with_auth):
+ m_user_state = MagicMock()
+ m1.return_value = m_user_state
+
+ # misses on monthly URLs quota
+ m_user_state.has_quota_max_monthly_urls.return_value = False
+ response = client_with_auth.post(
+ "/url/archive", json={"url": "https://example.com"}
+ )
+ assert response.status_code == HTTPStatus.TOO_MANY_REQUESTS
+ assert (
+ response.json()["detail"] == "User has reached their monthly URL quota."
+ )
+ m_user_state.has_quota_max_monthly_urls.assert_called_once()
+
+ # misses on monthly MBs quota
+ m_user_state.has_quota_max_monthly_urls.return_value = True
+ m_user_state.has_quota_max_monthly_mbs.return_value = False
+ response = client_with_auth.post(
+ "/url/archive", json={"url": "https://example.com"}
+ )
+ assert response.status_code == HTTPStatus.TOO_MANY_REQUESTS
+ assert (
+ response.json()["detail"] == "User has reached their monthly MB quota."
+ )
+ m_user_state.has_quota_max_monthly_mbs.assert_called_once()
+
+
+@patch("app.web.routers.url.celery", return_value=MagicMock())
+def test_archive_url_with_api_token(m_celery, client_with_token):
+ m_signature = MagicMock()
+ m_signature.apply_async.return_value = TaskResult(
+ id="123-456-789", status=STATUS_PENDING, result=""
+ )
+ m_celery.signature.return_value = m_signature
+ response = client_with_token.post(
+ "/url/archive",
+ json={"url": "https://example.com", "author_id": "someone@example.com"},
+ )
+ assert response.status_code == HTTPStatus.CREATED
+ assert response.json() == {"id": "123-456-789"}
+ m_celery.signature.assert_called_once()
+ m_signature.apply_async.assert_called_once()
+ called_val = m_celery.signature.call_args
+ assert called_val[0][0] == "create_archive_task"
+ assert json.loads(called_val[1]["args"][0]) == {
+ "id": None,
+ "url": "https://example.com",
+ "result": None,
+ "public": False,
+ "author_id": "someone@example.com",
+ "group_id": "default",
+ "tags": None,
+ "sheet_id": None,
+ "store_until": None,
+ "urls": None,
+ }
+
+ # missing id should use ALLOW_ANY_EMAIL
+ response = client_with_token.post(
+ "/url/archive", json={"url": "https://example.com", "author_id": None}
+ )
+ assert response.status_code == HTTPStatus.CREATED
+ called_val = m_celery.signature.call_args
+ assert called_val[0][0] == "create_archive_task"
+ assert json.loads(called_val[1]["args"][0]) == {
+ "id": None,
+ "url": "https://example.com",
+ "result": None,
+ "public": False,
+ "author_id": ALLOW_ANY_EMAIL,
+ "group_id": "default",
+ "tags": None,
+ "sheet_id": None,
+ "store_until": None,
+ "urls": None,
+ }
+
+
+def test_search_by_url_unauthenticated(client, test_no_auth):
+ test_no_auth(client.get, "/url/search")
+
+
+def test_search_by_url(client_with_auth, client_with_token, db_session):
+ # tests the search endpoint, including through some db data for the endpoint params
+ response = client_with_auth.get("/url/search")
+ assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
+ assert response.json()["detail"][0]["msg"] == "Field required"
+
+ response = client_with_auth.get("/url/search?url=https://example.com")
+ assert response.status_code == HTTPStatus.OK
+ assert response.json() == []
+
+ for i in range(11):
+ worker_crud.create_archive(
+ db_session,
+ ArchiveCreate(
+ id=f"url-456-{i}",
+ url="https://example.com"
+ if i < 10
+ else "https://something-else.com",
+ result={},
+ public=True,
+ author_id="rick@example.com",
+ ),
+ [],
+ [],
+ )
+ # NB: this insertion is too fast for the ordering to be correct as they are within the same second
+
+ response = client_with_auth.get("/url/search?url=https://example.com")
+ assert response.status_code == HTTPStatus.OK
+ assert len(j := response.json()) == 10
+ assert "url-456-0" in [i["id"] for i in j]
+ assert "url-456-9" in [i["id"] for i in j]
+ assert "url-456-10" not in [i["id"] for i in j]
+ assert j[0].keys() == schemas.ArchiveResult.model_fields.keys()
+
+ response = client_with_auth.get(
+ "/url/search?url=https://example.com&limit=5"
+ )
+ assert response.status_code == HTTPStatus.OK
+ assert len(response.json()) == 5
+
+ response = client_with_auth.get(
+ "/url/search?url=https://example.com&skip=5&limit=2"
+ )
+ assert response.status_code == HTTPStatus.OK
+ assert len(response.json()) == 2
+
+ response = client_with_auth.get(
+ "/url/search?url=https://example.com&archived_before=2010-01-01"
+ )
+ assert response.status_code == HTTPStatus.OK
+ assert len(response.json()) == 0
+
+ response = client_with_auth.get(
+ "/url/search?url=https://example.com&archived_after=2010-01-01"
+ )
+ assert response.status_code == HTTPStatus.OK
+ assert len(response.json()) == 10
+
+ # API token will also work
+ response = client_with_token.get(
+ "/url/search?url=https://example.com&archived_after=2010-01-01"
+ )
+ assert response.status_code == HTTPStatus.OK
+ assert len(response.json()) == 10
+
+
+@patch("app.web.routers.url.UserState")
+def test_search_no_read_access(mock_user_state, client_with_auth):
+ mock_user_state.return_value.read = False
+ mock_user_state.return_value.read_public = False
+
+ response = client_with_auth.get("/url/search?url=https://example.com")
+ assert response.status_code == HTTPStatus.FORBIDDEN
+ assert response.json() == {"detail": "User does not have read access."}
+
+
+def test_delete_task_unauthenticated(client, test_no_auth):
+ test_no_auth(client.delete, "/url/123-456-789")
+
+
+def test_delete_task(client_with_auth, db_session):
+ response = client_with_auth.delete("/url/delete-123-456-789")
+ assert response.status_code == HTTPStatus.OK
+ assert response.json() == {"id": "delete-123-456-789", "deleted": False}
+
+ worker_crud.create_archive(
+ db_session,
+ ArchiveCreate(
+ id="delete-123-456-789",
+ url="https://example.com",
+ result={},
+ public=True,
+ author_id="morty@example.com",
+ ),
+ [],
+ [],
+ )
+
+ response = client_with_auth.delete("/url/delete-123-456-789")
+ assert response.status_code == HTTPStatus.OK
+ assert response.json() == {"id": "delete-123-456-789", "deleted": True}
diff --git a/app/tests/web/test_main.py b/app/tests/web/test_main.py
index f77d368..a8a2c4f 100644
--- a/app/tests/web/test_main.py
+++ b/app/tests/web/test_main.py
@@ -1,31 +1,55 @@
import os
+import shutil
+from http import HTTPStatus
from unittest.mock import patch
+
+import alembic.config
+import pytest
from fastapi.testclient import TestClient
-import shutil
+from app.web.main import app_factory
+from app.web.utils.metrics import EXCEPTION_COUNTER
-import pytest
def test_lifespan(app):
with TestClient(app) as client:
r = client.get("/health")
- assert r.status_code == 200
+ assert r.status_code == HTTPStatus.OK
assert r.json() == {"status": "ok"}
-def test_alembic(db_session):
- import alembic.config
- alembic.config.main(argv=['--raiseerr', 'upgrade', 'head'])
- alembic.config.main(argv=['--raiseerr', 'downgrade', 'base'])
-@patch("app.web.endpoints.url.crud.soft_delete_archive", side_effect=Exception('mocked error'))
+def test_alembic(db_session):
+ alembic.config.main(
+ argv=[
+ "-c",
+ "./app/migrations/alembic.ini",
+ "--raiseerr",
+ "upgrade",
+ "head",
+ ]
+ )
+ alembic.config.main(
+ argv=[
+ "-c",
+ "./app/migrations/alembic.ini",
+ "--raiseerr",
+ "downgrade",
+ "base",
+ ]
+ )
+
+
+@patch(
+ "app.web.routers.url.crud.soft_delete_archive",
+ side_effect=Exception("mocked error"),
+)
def test_logging_middleware(m1, client_with_auth):
- from app.web.utils.metrics import EXCEPTION_COUNTER
assert len(EXCEPTION_COUNTER.collect()[0].samples) == 0
with pytest.raises(Exception, match="mocked error"):
client_with_auth.delete("/url/123")
# creates one empty and one from above
assert len(EXCEPTION_COUNTER.collect()[0].samples) == 2
-
+
def test_serve_local_archive_logic(get_settings):
# create a test file first
@@ -36,13 +60,13 @@ def test_serve_local_archive_logic(get_settings):
try:
# modify the settings
get_settings.SERVE_LOCAL_ARCHIVE = "/app/local_archive_test"
- from app.web.main import app_factory
+
app = app_factory(get_settings)
-
+
# test
client = TestClient(app)
r = client.get("/app/local_archive_test/temp.txt")
- assert r.status_code == 200
+ assert r.status_code == HTTPStatus.OK
assert r.text == "test"
finally:
# cleanup
diff --git a/app/tests/web/test_security.py b/app/tests/web/test_security.py
index d723e9f..2b2d028 100644
--- a/app/tests/web/test_security.py
+++ b/app/tests/web/test_security.py
@@ -1,101 +1,168 @@
+from http import HTTPStatus
from unittest import mock
from unittest.mock import Mock, patch
+import pytest
from fastapi import HTTPException
from fastapi.security import HTTPAuthorizationCredentials
-import pytest
from app.web.config import ALLOW_ANY_EMAIL
+from app.web.db.user_state import UserState
+from app.web.security import (
+ authenticate_user,
+ get_token_or_user_auth,
+ get_user_auth,
+ get_user_state,
+ secure_compare,
+ token_api_key_auth,
+)
def test_secure_compare():
- from app.web.security import secure_compare
-
assert secure_compare("test", "test")
assert not secure_compare("test", "test2")
@pytest.mark.asyncio
async def test_get_token_or_user_auth_with_api():
- from app.web.security import get_token_or_user_auth
- mock_api = HTTPAuthorizationCredentials(scheme="lorem", credentials="this_is_the_test_api_token")
+ mock_api = HTTPAuthorizationCredentials(
+ scheme="lorem", credentials="this_is_the_test_api_token"
+ )
assert await get_token_or_user_auth(mock_api) == ALLOW_ANY_EMAIL
@pytest.mark.asyncio
async def test_get_token_or_user_auth_with_user():
- from app.web.security import get_token_or_user_auth
- bad_user = HTTPAuthorizationCredentials(scheme="ipsum", credentials="invalid")
- e: pytest.ExceptionInfo = None
+ bad_user = HTTPAuthorizationCredentials(
+ scheme="ipsum", credentials="invalid"
+ )
with pytest.raises(HTTPException) as e:
await get_token_or_user_auth(bad_user)
- assert e.value.status_code == 401
+ assert e.value.status_code == HTTPStatus.UNAUTHORIZED
assert e.value.detail == "invalid access_token"
-@patch("app.web.security.authenticate_user", return_value=(True, "summer@example.com"))
+@patch(
+ "app.web.security.authenticate_user",
+ return_value=(True, "summer@example.com"),
+)
@pytest.mark.asyncio
async def test_get_user_auth(m1):
- from app.web.security import get_user_auth
- good_user = HTTPAuthorizationCredentials(scheme="ipsum", credentials="valid-and-good")
+ good_user = HTTPAuthorizationCredentials(
+ scheme="ipsum", credentials="valid-and-good"
+ )
assert await get_user_auth(good_user) == "summer@example.com"
@patch("app.web.security.secure_compare", return_value=False)
@pytest.mark.asyncio
async def test_token_api_key_auth_exception(m1):
- from app.web.security import token_api_key_auth
-
- e: pytest.ExceptionInfo = None
with pytest.raises(HTTPException) as e:
- await token_api_key_auth(HTTPAuthorizationCredentials(scheme="ipsum", credentials="does-not-matter"), auto_error=True)
- assert e.value.status_code == 401
+ await token_api_key_auth(
+ HTTPAuthorizationCredentials(
+ scheme="ipsum", credentials="does-not-matter"
+ ),
+ auto_error=True,
+ )
+ assert e.value.status_code == HTTPStatus.UNAUTHORIZED
assert e.value.detail == "Wrong auth credentials"
@pytest.mark.asyncio
async def test_authenticate_user():
- from app.web.security import authenticate_user
-
assert authenticate_user("test") == (False, "invalid access_token")
assert authenticate_user(123) == (False, "invalid access_token")
with patch("app.web.security.requests.get") as mock_get:
# bad response from oauth2
- mock_get.return_value.status_code = 403
- assert authenticate_user("this-will-call-requests") == (False, "invalid token")
+ mock_get.return_value.status_code = HTTPStatus.FORBIDDEN
+ assert authenticate_user("this-will-call-requests") == (
+ False,
+ "invalid token",
+ )
assert mock_get.call_count == 1
# 200 but invalid json
- mock_get.return_value.status_code = 200
- assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID")
+ mock_get.return_value.status_code = HTTPStatus.OK
+ assert authenticate_user("this-will-call-requests") == (
+ False,
+ "token does not belong to valid APP_ID",
+ )
assert mock_get.call_count == 2
# 200 but invalid azp and aud
- mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "not_an_app"}
- assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID")
+ mock_get.return_value.json.return_value = {
+ "email": "summer@example.com",
+ "azp": "not_an_app",
+ }
+ assert authenticate_user("this-will-call-requests") == (
+ False,
+ "token does not belong to valid APP_ID",
+ )
- mock_get.return_value.json.return_value = {"email": "summer@example.com", "aud": "not_an_app"}
- assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID")
+ mock_get.return_value.json.return_value = {
+ "email": "summer@example.com",
+ "aud": "not_an_app",
+ }
+ assert authenticate_user("this-will-call-requests") == (
+ False,
+ "token does not belong to valid APP_ID",
+ )
- mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "not_an_app", "aud": "not_an_app"}
- assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID")
+ mock_get.return_value.json.return_value = {
+ "email": "summer@example.com",
+ "azp": "not_an_app",
+ "aud": "not_an_app",
+ }
+ assert authenticate_user("this-will-call-requests") == (
+ False,
+ "token does not belong to valid APP_ID",
+ )
# blocked email
- mock_get.return_value.json.return_value = {"email": "blocked@example.com", "azp": "test_app_id_1", "aud": "not_an_app"}
- assert authenticate_user("this-will-call-requests") == (False, "email 'blocked@example.com' not allowed")
+ mock_get.return_value.json.return_value = {
+ "email": "blocked@example.com",
+ "azp": "test_app_id_1",
+ "aud": "not_an_app",
+ }
+ assert authenticate_user("this-will-call-requests") == (
+ False,
+ "email 'blocked@example.com' not allowed",
+ )
# not verified
- mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "not_an_app", "aud": "test_app_id_1"}
- assert authenticate_user("this-will-call-requests") == (False, "email 'summer@example.com' not verified")
+ mock_get.return_value.json.return_value = {
+ "email": "summer@example.com",
+ "azp": "not_an_app",
+ "aud": "test_app_id_1",
+ }
+ assert authenticate_user("this-will-call-requests") == (
+ False,
+ "email 'summer@example.com' not verified",
+ )
# token expired
- mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "test_app_id_2", "email_verified": "true"}
- assert authenticate_user("this-will-call-requests") == (False, "Token expired")
+ mock_get.return_value.json.return_value = {
+ "email": "summer@example.com",
+ "azp": "test_app_id_2",
+ "email_verified": "true",
+ }
+ assert authenticate_user("this-will-call-requests") == (
+ False,
+ "Token expired",
+ )
# 200 and valid azp and aup and verified
- mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "test_app_id_2", "email_verified": "true", "expires_in": 100}
- assert authenticate_user("this-will-call-requests") == (True, "summer@example.com")
+ mock_get.return_value.json.return_value = {
+ "email": "summer@example.com",
+ "azp": "test_app_id_2",
+ "email_verified": "true",
+ "expires_in": 100,
+ }
+ assert authenticate_user("this-will-call-requests") == (
+ True,
+ "summer@example.com",
+ )
assert mock_get.call_count == 9
@@ -104,6 +171,7 @@ async def test_authenticate_user():
@pytest.mark.asyncio
async def test_authenticate_user_with_id_token(m_init):
from firebase_admin import exceptions
+
from app.web.security import authenticate_user
with pytest.raises(ValueError) as e:
@@ -113,12 +181,20 @@ async def test_authenticate_user_with_id_token(m_init):
with patch("app.web.security.auth.verify_id_token") as mock_verify:
# missing email
mock_verify.return_value = {"email": None}
- assert authenticate_user("fake_token") == (False, "email not found in token")
+ assert authenticate_user("fake_token") == (
+ False,
+ "email not found in token",
+ )
assert mock_verify.call_count == 1
# blocked email
- mock_verify.return_value = {"email": "blocked@example.com", }
- assert authenticate_user("fake_token") == (False, "email 'blocked@example.com' not allowed")
+ mock_verify.return_value = {
+ "email": "blocked@example.com",
+ }
+ assert authenticate_user("fake_token") == (
+ False,
+ "email 'blocked@example.com' not allowed",
+ )
assert mock_verify.call_count == 2
# valid email
@@ -132,17 +208,16 @@ async def test_authenticate_user_with_id_token(m_init):
@pytest.mark.asyncio
async def test_authenticate_user_exception():
- from app.web.security import authenticate_user
with patch("app.web.security.requests.get") as mock_get:
- mock_get.return_value.status_code = 200
+ mock_get.return_value.status_code = HTTPStatus.OK
mock_get.return_value.json.side_effect = Exception("mocked error")
- assert authenticate_user("this-will-call-requests") == (False, "exception occurred")
+ assert authenticate_user("this-will-call-requests") == (
+ False,
+ "exception occurred",
+ )
def test_get_user_state():
- from app.web.security import get_user_state
- from app.web.db.user_state import UserState
-
mock_session = Mock()
test_email = "test@example.com"
diff --git a/app/tests/worker/test_worker_main.py b/app/tests/worker/test_worker_main.py
index d40c457..39b2b17 100644
--- a/app/tests/worker/test_worker_main.py
+++ b/app/tests/worker/test_worker_main.py
@@ -1,29 +1,47 @@
from datetime import datetime
-
from unittest.mock import patch
import pytest
-
-from app.shared.db import models
-from app.shared import schemas
from auto_archiver.core import Media, Metadata
+from app.shared import constants, schemas
+from app.shared.db import models
+from app.web.utils.misc import get_all_urls
+from app.worker.main import create_archive_task, create_sheet_task
-class Test_create_archive_task():
+
+class TestCreateArchiveTask:
URL = "https://example-live.com"
- archive = schemas.ArchiveCreate(url=URL, tags=["tag-celery"], public=True, author_id="rick@example.com", group_id="interstellar")
+ archive = schemas.ArchiveCreate(
+ url=URL,
+ tags=["tag-celery"],
+ public=True,
+ author_id="rick@example.com",
+ group_id="interstellar",
+ )
@patch("app.worker.main.ArchivingOrchestrator")
@patch("app.worker.main.get_all_urls", return_value=[])
@patch("app.worker.main.insert_result_into_db")
@patch("app.worker.main.get_store_until", return_value=datetime.now())
- @patch("app.worker.main.get_orchestrator_args", return_value=["arg1", "arg2"])
+ @patch(
+ "app.worker.main.get_orchestrator_args", return_value=["arg1", "arg2"]
+ )
@patch("celery.app.task.Task.request")
- def test_success(self, m_req, m_args, m_store, m_insert, m_urls, m_orchestrator, db_session):
- from app.worker.main import create_archive_task
-
+ def test_success(
+ self,
+ m_req,
+ m_args,
+ m_store,
+ m_insert,
+ m_urls,
+ m_orchestrator,
+ db_session,
+ ):
m_req.id = "this-just-in"
- m_orchestrator.return_value.feed.return_value = iter([Metadata().set_url(self.URL).success()])
+ m_orchestrator.return_value.feed.return_value = iter(
+ [Metadata().set_url(self.URL).success()]
+ )
task = create_archive_task(self.archive.model_dump_json())
@@ -39,15 +57,15 @@ class Test_create_archive_task():
assert len(task["media"]) == 0
def test_raise_invalid(self):
- from app.worker.main import create_archive_task
- with pytest.raises(Exception):
+ with pytest.raises(Exception) as _:
create_archive_task(self.archive.model_dump_json())
@patch("app.worker.main.ArchivingOrchestrator")
@patch("app.worker.main.get_orchestrator_args")
def test_raise_db_error(self, m_args, m_orchestrator):
- from app.worker.main import create_archive_task
- m_orchestrator.return_value.feed.side_effect = Exception("Orchestrator failed")
+ m_orchestrator.return_value.feed.side_effect = Exception(
+ "Orchestrator failed"
+ )
with pytest.raises(Exception) as e:
create_archive_task(self.archive.model_dump_json())
@@ -59,7 +77,6 @@ class Test_create_archive_task():
@patch("app.worker.main.insert_result_into_db", return_value=None)
@patch("app.worker.main.get_orchestrator_args")
def test_raise_empty_result(self, m_args, m_insert, m_orchestrator):
- from app.worker.main import create_archive_task
m_orchestrator.return_value.feed.return_value = iter([None])
with pytest.raises(Exception) as e:
@@ -68,61 +85,83 @@ class Test_create_archive_task():
m_orchestrator.return_value.feed.assert_called_once()
-class Test_create_sheet_task():
+class TestCreateSheetTask:
URL = "https://example-live.com"
- sheet = schemas.SubmitSheet(sheet_id="123", author_id="rick@example.com", group_id="interstellar", tags=["spaceship"])
+ sheet = schemas.SubmitSheet(
+ sheet_id="123",
+ author_id="rick@example.com",
+ group_id="interstellar",
+ tags=["spaceship"],
+ )
@patch("app.worker.main.get_all_urls", return_value=[])
@patch("app.worker.main.ArchivingOrchestrator")
@patch("app.worker.main.models.generate_uuid", return_value="constant-uuid")
@patch("app.worker.main.get_store_until", return_value=datetime.now())
@patch("app.worker.main.get_orchestrator_args")
- def test_success(self, m_args, m_store, m_uuid, m_orchestrator, m_urls, db_session):
- from app.worker.main import create_sheet_task
-
- assert db_session.query(models.Archive).filter(models.Archive.url == self.URL).count() == 0
+ def test_success(
+ self, m_args, m_store, m_uuid, m_orchestrator, m_urls, db_session
+ ):
+ assert (
+ db_session.query(models.Archive)
+ .filter(models.Archive.url == self.URL)
+ .count()
+ == 0
+ )
mock_metadata = Metadata().set_url(self.URL).success()
mock_metadata.add_media(Media("fn1.txt", urls=["outcome1.com"]))
- m_orchestrator.return_value.feed.return_value = iter([False, mock_metadata, mock_metadata])
+ m_orchestrator.return_value.feed.return_value = iter(
+ [False, mock_metadata, mock_metadata]
+ )
res = create_sheet_task(self.sheet.model_dump_json())
- m_args.assert_called_once_with("interstellar", True, ["--gsheet_feeder.sheet_id", "123"])
+ m_args.assert_called_once_with(
+ "interstellar", True, [constants.SHEET_ID, "123"]
+ )
m_orchestrator.return_value.setup.assert_called_once()
m_orchestrator.return_value.feed.assert_called_once()
m_store.assert_called_with("interstellar")
- m_store.call_count == 2
- m_uuid.call_count == 2
- assert type(res) == dict
+ assert m_store.call_count == 2
+ assert m_uuid.call_count == 2
+ assert isinstance(res, dict)
assert res["stats"]["archived"] == 1
assert res["stats"]["failed"] == 1
assert len(res["stats"]["errors"]) == 1
assert res["sheet_id"] == "123"
assert res["success"]
- assert type(res["time"]) == datetime
+ assert isinstance(res["time"], datetime)
# query created archive entry
- inserted = db_session.query(models.Archive).filter(models.Archive.url == self.URL).one()
+ inserted = (
+ db_session.query(models.Archive)
+ .filter(models.Archive.url == self.URL)
+ .one()
+ )
assert inserted is not None
assert inserted.url == self.URL
assert len(inserted.tags) == 1
assert inserted.tags[0].id == "spaceship"
assert inserted.group_id == "interstellar"
assert inserted.author_id == "rick@example.com"
- assert inserted.public == False
+ assert inserted.public is False
def test_get_all_urls(db_session):
- from app.worker.main import get_all_urls
-
meta = Metadata().set_url("https://example.com")
m1 = meta.add_media(Media("fn1.txt", urls=["outcome1.com"]))
m2 = meta.add_media(Media("fn2.txt", urls=["outcome2.com"]))
m3 = meta.add_media(Media("fn3.txt", urls=["outcome3.com"]))
m1.set("screenshot", Media("screenshot.png", urls=["screenshot.com"]))
- m2.set("thumbnails", [Media("thumb1.png", urls=["thumb1.com"]), Media("thumb2.png", urls=["thumb2.com"])])
+ m2.set(
+ "thumbnails",
+ [
+ Media("thumb1.png", urls=["thumb1.com"]),
+ Media("thumb2.png", urls=["thumb2.com"]),
+ ],
+ )
m3.set("ssl_data", Media("ssl_data.txt", urls=["ssl_data.com"]).to_dict())
m3.set("bad_data", {"bad": "dict is ignored"})
diff --git a/app/web/__init__.py b/app/web/__init__.py
index a817e9e..98a139a 100644
--- a/app/web/__init__.py
+++ b/app/web/__init__.py
@@ -1,3 +1,4 @@
from app.web.main import app_factory
-app = app_factory
\ No newline at end of file
+
+app = app_factory
diff --git a/app/web/config.py b/app/web/config.py
index a6c668c..9f2e0fd 100644
--- a/app/web/config.py
+++ b/app/web/config.py
@@ -1,14 +1,17 @@
-VERSION = "0.9.4"
+VERSION = "0.10.0"
API_DESCRIPTION = """
#### API for the Auto-Archiver project, a tool to archive web pages and Google Sheets.
**Usage notes:**
- The API requires a Bearer token for most operations, which you can obtain by logging in with your Google account.
-- You can use this API to archive single URLs or entire Google Sheets.
+- You can use this API to archive single URLs or entire Google Sheets.
- Once you submit a URL or Sheet for archiving, the API will return a task_id that you can use to check the status of the archiving process. It works asynchronously.
"""
-BREAKING_CHANGES = {"minVersion": "0.4.0", "message": "The latest update has breaking changes, please update the extension to the most recent version."}
+BREAKING_CHANGES = {
+ "minVersion": "0.4.0",
+ "message": "The latest update has breaking changes, please update the extension to the most recent version.",
+}
# changing this will corrupt the database logic
ALLOW_ANY_EMAIL = "*"
diff --git a/app/web/db/crud.py b/app/web/db/crud.py
index c16b09a..308c526 100644
--- a/app/web/db/crud.py
+++ b/app/web/db/crud.py
@@ -1,18 +1,30 @@
from collections import defaultdict
-from functools import lru_cache
-from sqlalchemy.orm import Session, load_only
-from sqlalchemy import Column, or_, func, select
-from loguru import logger
from datetime import datetime, timedelta
-from sqlalchemy.ext.asyncio import AsyncSession
+from typing import Any, Type
+
from cachetools import LRUCache, cached
from cachetools.keys import hashkey
+from loguru import logger
+from sqlalchemy import (
+ Column,
+ ColumnElement,
+ ScalarResult,
+ false,
+ func,
+ not_,
+ or_,
+ select,
+ true,
+)
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.orm import Session, load_only
-from app.web.config import ALLOW_ANY_EMAIL
from app.shared.db import models
+from app.shared.db.models import Archive, Group
from app.shared.settings import get_settings
from app.shared.user_groups import UserGroups
from app.shared.utils.misc import fnv1a_hash_mod
+from app.web.config import ALLOW_ANY_EMAIL
from app.web.utils.misc import convert_priority_to_queue_dict
@@ -22,24 +34,48 @@ DATABASE_QUERY_LIMIT = get_settings().DATABASE_QUERY_LIMIT
def get_limit(user_limit: int):
return max(1, min(user_limit, DATABASE_QUERY_LIMIT))
+
# --------------- TASK = Archive
def base_query(db: Session):
- # NOTE: load_only is for optimization and not obfuscation, use .with_entities() if needed
- return db.query(models.Archive)\
- .filter(models.Archive.deleted == False)\
- .options(load_only(models.Archive.id, models.Archive.created_at, models.Archive.url, models.Archive.result, models.Archive.store_until))
+ # NOTE: load_only is for optimization and not obfuscation, use
+ # .with_entities() if needed
+ return (
+ db.query(models.Archive)
+ .filter(not_(models.Archive.deleted))
+ .options(
+ load_only(
+ models.Archive.id,
+ models.Archive.created_at,
+ models.Archive.url,
+ models.Archive.result,
+ models.Archive.store_until,
+ )
+ )
+ )
-def search_archives_by_url(db: Session, url: str, email: str, read_groups: bool | set[str], read_public: bool, skip: int = 0, limit: int = 100, archived_after: datetime = None, archived_before: datetime = None, absolute_search: bool = False) -> list[models.Archive]:
- # searches for partial URLs, if email is * no ownership (or read/read_public) filtering happens
+def search_archives_by_url(
+ db: Session,
+ url: str,
+ email: str,
+ read_groups: bool | set[str],
+ read_public: bool,
+ skip: int = 0,
+ limit: int = 100,
+ archived_after: datetime = None,
+ archived_before: datetime = None,
+ absolute_search: bool = False,
+) -> list[Type[Archive]]:
+ # searches for partial URLs, if email is * no ownership
+ # (or read/read_public) filtering happens
query = base_query(db)
if email != ALLOW_ANY_EMAIL:
or_filters = [models.Archive.author_id == email]
if read_public:
- or_filters.append(models.Archive.public == True)
- if read_groups == True:
+ or_filters.append(models.Archive.public.is_(true()))
+ if read_groups is True:
or_filters.append(models.Archive.group_id.isnot(None))
else:
or_filters.append(models.Archive.group_id.in_(read_groups))
@@ -47,21 +83,43 @@ def search_archives_by_url(db: Session, url: str, email: str, read_groups: bool
if absolute_search:
query = query.filter(models.Archive.url == url)
else:
- query = query.filter(models.Archive.url.like(f'%{url}%'))
+ query = query.filter(models.Archive.url.like(f"%{url}%"))
if archived_after:
query = query.filter(models.Archive.created_at > archived_after)
if archived_before:
query = query.filter(models.Archive.created_at < archived_before)
- return query.order_by(models.Archive.created_at.desc()).offset(skip).limit(get_limit(limit)).all()
+ return (
+ query.order_by(models.Archive.created_at.desc())
+ .offset(skip)
+ .limit(get_limit(limit))
+ .all()
+ )
-def search_archives_by_email(db: Session, email: str, skip: int = 0, limit: int = 100):
- return base_query(db).filter(models.Archive.author_id == email).order_by(models.Archive.created_at.desc()).offset(skip).limit(get_limit(limit)).all()
+def search_archives_by_email(
+ db: Session, email: str, skip: int = 0, limit: int = 100
+):
+ return (
+ base_query(db)
+ .filter(models.Archive.author_id == email)
+ .order_by(models.Archive.created_at.desc())
+ .offset(skip)
+ .limit(get_limit(limit))
+ .all()
+ )
def soft_delete_archive(db: Session, id: str, email: str) -> bool:
# TODO: implement hard-delete with cronjob that deletes from S3
- db_archive = db.query(models.Archive).filter(models.Archive.id == id, models.Archive.author_id == email, models.Archive.deleted == False).first()
+ db_archive = (
+ db.query(models.Archive)
+ .filter(
+ models.Archive.id == id,
+ models.Archive.author_id == email,
+ models.Archive.deleted.is_(false()),
+ )
+ .first()
+ )
if db_archive:
db_archive.deleted = True
db.commit()
@@ -82,22 +140,29 @@ def count_users(db: Session):
def count_by_user_since(db: Session, seconds_delta: int = 15):
time_threshold = datetime.now() - timedelta(seconds=seconds_delta)
- return db.query(models.Archive.author_id, func.count().label('total'))\
- .filter(models.Archive.created_at >= time_threshold)\
- .group_by(models.Archive.author_id)\
- .order_by(func.count().desc())\
- .limit(500).all()
+ return (
+ db.query(models.Archive.author_id, func.count().label("total"))
+ .filter(models.Archive.created_at >= time_threshold)
+ .group_by(models.Archive.author_id)
+ .order_by(func.count().desc())
+ .limit(500)
+ .all()
+ )
-async def find_by_store_until(db: AsyncSession, store_until_is_before: datetime) -> list[models.Archive]:
+async def find_by_store_until(
+ db: AsyncSession, store_until_is_before: datetime
+) -> ScalarResult[Archive]:
res = await db.execute(
- select(models.Archive)
- .filter(models.Archive.deleted == False, models.Archive.store_until < store_until_is_before)
+ select(models.Archive).filter(
+ models.Archive.deleted.is_(false()),
+ models.Archive.store_until < store_until_is_before,
+ )
)
return res.scalars()
-async def soft_delete_expired_archives(db: AsyncSession) -> dict:
+async def soft_delete_expired_archives(db: AsyncSession) -> int:
to_delete = await find_by_store_until(db, datetime.now())
counter = 0
for archive in to_delete:
@@ -105,47 +170,86 @@ async def soft_delete_expired_archives(db: AsyncSession) -> dict:
counter += 1
await db.commit()
return counter
+
+
# --------------- TAG
async def get_group_priority_async(db: AsyncSession, group_id: str) -> dict:
db_group = await db.get(models.Group, group_id)
- priority = db_group.permissions.get("priority", "low") if db_group else "low"
+ priority = (
+ db_group.permissions.get("priority", "low") if db_group else "low"
+ )
return convert_priority_to_queue_dict(priority)
@cached(cache=LRUCache(maxsize=128), key=lambda db, email: hashkey(email))
-def get_user_group_names(db: Session, email: str) -> list[str]:
+def get_user_group_names(
+ db: Session, email: str
+) -> list[Any] | list[ColumnElement[Any]]:
"""
- given an email retrieves the user groups from the DB and then the email-domain groups from a global variable, the email does not need to belong to an existing user.
+ given an email retrieves the user groups from the DB and then the
+ email-domain groups from a global variable, the email does not need to
+ belong to an existing user.
"""
# TODO: the read: [group1, group2] permissions don't currently work
- if not email or not len(email) or "@" not in email: return []
+ if not email or not len(email) or "@" not in email:
+ return []
# get user groups
- user_groups = db.query(models.association_table_user_groups).filter_by(user_id=email).with_entities(Column("group_id")).all()
+ user_groups = (
+ db.query(models.association_table_user_groups)
+ .filter_by(user_id=email)
+ .with_entities(Column("group_id"))
+ .all()
+ )
user_level_groups_names = [g[0] for g in user_groups]
# get domain groups
- domain = email.split('@')[1]
- domain_level_groups = db.query(models.Group.id).filter(models.Group.domains.contains(domain)).with_entities(Column("id")).all()
+ domain = email.split("@")[1]
+ domain_level_groups = (
+ db.query(models.Group.id)
+ .filter(models.Group.domains.contains(domain))
+ .with_entities(Column("id"))
+ .all()
+ )
domain_level_groups_names = [g[0] for g in domain_level_groups]
return list(set(user_level_groups_names + domain_level_groups_names))
-def get_user_groups_by_name(db: Session, groups: list[str]) -> list[models.Group]:
- return db.query(models.Group).filter(
- models.Group.id.in_(groups)
- ).all()
+def get_user_groups_by_name(
+ db: Session, groups: list[str]
+) -> list[Type[Group]]:
+ return db.query(models.Group).filter(models.Group.id.in_(groups)).all()
+
# --------------- INIT User-Groups
-def upsert_group(db: Session, group_name: str, description: str, orchestrator: str, orchestrator_sheet: str, service_account_email: str, permissions: dict, domains: list) -> models.Group:
- db_group = db.query(models.Group).filter(models.Group.id == group_name).first()
+def upsert_group(
+ db: Session,
+ group_name: str,
+ description: str,
+ orchestrator: str,
+ orchestrator_sheet: str,
+ service_account_email: str,
+ permissions: dict,
+ domains: list,
+) -> models.Group:
+ db_group = (
+ db.query(models.Group).filter(models.Group.id == group_name).first()
+ )
if db_group is None:
- db_group = models.Group(id=group_name, description=description, orchestrator=orchestrator, orchestrator_sheet=orchestrator_sheet, service_account_email=service_account_email, permissions=permissions, domains=domains)
+ db_group = models.Group(
+ id=group_name,
+ description=description,
+ orchestrator=orchestrator,
+ orchestrator_sheet=orchestrator_sheet,
+ service_account_email=service_account_email,
+ permissions=permissions,
+ domains=domains,
+ )
db.add(db_group)
else:
db_group.description = description
@@ -172,8 +276,9 @@ def upsert_user(db: Session, email: str):
def upsert_user_groups(db: Session):
def display_email_pii(email: str):
return f"'{email[0:3]}...@{email.split('@')[1]}'"
+
"""
- reads the user_groups yaml file and inserts any new users, groups,
+ reads the user_groups yaml file and inserts any new users, groups,
along with new participation of users in groups
"""
filename = get_settings().USER_GROUPS_FILENAME
@@ -192,18 +297,33 @@ def upsert_user_groups(db: Session):
for group in explicit_groups:
group_domains[group].add(domain)
import json
+
# upsert groups and save a map of groupid -> dbobject
for group_id, g in ug.groups.items():
- upsert_group(db, group_id, g.description, g.orchestrator, g.orchestrator_sheet, g.service_account_email, json.loads(g.permissions.model_dump_json()), list(group_domains.get(group_id, [])))
- db_groups: dict[str, models.Group] = {g.id: g for g in db.query(models.Group).all()}
+ upsert_group(
+ db,
+ group_id,
+ g.description,
+ g.orchestrator,
+ g.orchestrator_sheet,
+ g.service_account_email,
+ json.loads(g.permissions.model_dump_json()),
+ list(group_domains.get(group_id, [])),
+ )
+ db_groups: dict[str, models.Group] = {
+ g.id: g for g in db.query(models.Group).all()
+ }
# integrity checks
for group_in_domains in group_domains:
if group_in_domains not in db_groups:
- logger.warning(f"[CONFIG] Group '{group_in_domains}' does not exist in the database: domains setting will not work.")
+ logger.warning(
+ f"[CONFIG] Group '{group_in_domains}' does not exist in the database: domains setting will not work."
+ )
# reinsert users in their EXPLICITLY DEFINED groups
- # domain groups are check live, as there may be new users that are not explicitly registered but belong to a domain
+ # domain groups are check live, as there may be new users that are not
+ # explicitly registered but belong to a domain
for email, explicit_groups in ug.users.items():
explicit_groups = explicit_groups or []
logger.info(f"EXPLICIT {display_email_pii(email)} => {explicit_groups}")
@@ -213,7 +333,9 @@ def upsert_user_groups(db: Session):
# connect users to groups
for group_id in explicit_groups:
if group_id not in db_groups:
- logger.warning(f"[CONFIG] Group {group_id} does not exist in config file, skipping for email={display_email_pii(email)}.")
+ logger.warning(
+ f"[CONFIG] Group {group_id} does not exist in config file, skipping for email={display_email_pii(email)}."
+ )
continue
db_groups[group_id].users.append(db_user)
@@ -221,12 +343,27 @@ def upsert_user_groups(db: Session):
count_user_groups = db.query(models.association_table_user_groups).count()
count_groups = db.query(func.count(models.Group.id)).scalar()
- logger.success(f"[CONFIG] DONE: [users={count_users(db)}, groups={count_groups}, explicit user groups={count_user_groups}].")
+ logger.success(
+ f"[CONFIG] DONE: [users={count_users(db)}, groups={count_groups}, explicit user groups={count_user_groups}]."
+ )
# --------------- SHEET
-def create_sheet(db: Session, sheet_id: str, name: str, email: str, group_id: str, frequency: str):
- db_sheet = models.Sheet(id=sheet_id, name=name, author_id=email, group_id=group_id, frequency=frequency)
+def create_sheet(
+ db: Session,
+ sheet_id: str,
+ name: str,
+ email: str,
+ group_id: str,
+ frequency: str,
+):
+ db_sheet = models.Sheet(
+ id=sheet_id,
+ name=name,
+ author_id=email,
+ group_id=group_id,
+ frequency=frequency,
+ )
db.add(db_sheet)
db.commit()
db.refresh(db_sheet)
@@ -234,20 +371,31 @@ def create_sheet(db: Session, sheet_id: str, name: str, email: str, group_id: st
def get_user_sheet(db: Session, email: str, sheet_id: str) -> models.Sheet:
- return db.query(models.Sheet).filter(models.Sheet.author_id == email, models.Sheet.id == sheet_id).first()
+ return (
+ db.query(models.Sheet)
+ .filter(models.Sheet.author_id == email, models.Sheet.id == sheet_id)
+ .first()
+ )
def get_user_sheets(db: Session, email: str) -> list[models.Sheet]:
- return db.query(models.Sheet).filter(models.Sheet.author_id == email).order_by(models.Sheet.last_url_archived_at.desc()).all()
+ return (
+ db.query(models.Sheet)
+ .filter(models.Sheet.author_id == email)
+ .order_by(models.Sheet.last_url_archived_at.desc())
+ .all()
+ )
-async def get_sheets_by_id_hash(db: AsyncSession, frequency: str, modulo: str, id_hash: int) -> list[models.Sheet]:
+async def get_sheets_by_id_hash(
+ db: AsyncSession, frequency: str, modulo: str, id_hash: int
+) -> list[models.Sheet]:
result = await db.execute(
select(models.Sheet).filter(models.Sheet.frequency == frequency)
)
filtered = []
for sheet in result.scalars():
- if fnv1a_hash_mod(sheet.id, modulo) == id_hash:
+ if fnv1a_hash_mod(sheet.id, int(modulo)) == id_hash:
filtered.append(sheet)
return filtered
@@ -255,7 +403,9 @@ async def get_sheets_by_id_hash(db: AsyncSession, frequency: str, modulo: str, i
async def delete_stale_sheets(db: AsyncSession, inactivity_days: int) -> dict:
time_threshold = datetime.now() - timedelta(days=inactivity_days)
result = await db.execute(
- select(models.Sheet).filter(models.Sheet.last_url_archived_at < time_threshold)
+ select(models.Sheet).filter(
+ models.Sheet.last_url_archived_at < time_threshold
+ )
)
deleted = defaultdict(list)
for sheet in result.scalars():
@@ -266,7 +416,11 @@ async def delete_stale_sheets(db: AsyncSession, inactivity_days: int) -> dict:
def delete_sheet(db: Session, sheet_id: str, email: str) -> bool:
- db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id, models.Sheet.author_id == email).first()
+ db_sheet = (
+ db.query(models.Sheet)
+ .filter(models.Sheet.id == sheet_id, models.Sheet.author_id == email)
+ .first()
+ )
if db_sheet:
db.delete(db_sheet)
db.commit()
diff --git a/app/web/db/user_state.py b/app/web/db/user_state.py
index 968e1bd..67160db 100644
--- a/app/web/db/user_state.py
+++ b/app/web/db/user_state.py
@@ -1,13 +1,13 @@
-
-from typing import Dict, Set
-import sqlalchemy
-from sqlalchemy.orm import Session
-from sqlalchemy import func
from datetime import datetime
+from typing import Dict, Set
+
+import sqlalchemy
+from sqlalchemy import func
+from sqlalchemy.orm import Session
from app.shared.db import models
-from app.shared.user_groups import GroupInfo, GroupPermissions
from app.shared.schemas import Usage, UsageResponse
+from app.shared.user_groups import GroupInfo, GroupPermissions
from app.web.db import crud
from app.web.utils.misc import convert_priority_to_queue_dict
@@ -20,14 +20,15 @@ class UserState:
def __init__(self, db: Session, email: str):
self.db = db
self.email = email.lower()
+ self._permissions = {}
@property
def permissions(self) -> Dict[str, GroupInfo]:
"""
- Returns a dict of all group permissions and a special {"all": read/archive_url/archive_sheet} key
+ Returns a dict of all group permissions and a special
+ {"all": read/archive_url/archive_sheet} key
"""
- if not hasattr(self, '_permissions'):
- self._permissions = {}
+ if not self._permissions:
self._permissions["all"] = GroupInfo(
read=self.read,
read_public=self.read_public,
@@ -37,23 +38,33 @@ class UserState:
max_archive_lifespan_months=self.max_archive_lifespan_months,
max_monthly_urls=self.max_monthly_urls,
max_monthly_mbs=self.max_monthly_mbs,
- priority=self.priority
+ priority=self.priority,
)
for group in self.user_groups:
- if not group.permissions: continue
- self._permissions[group.id] = GroupInfo(**group.permissions, description=group.description, service_account_email=group.service_account_email)
+ if not group.permissions:
+ continue
+ self._permissions[group.id] = GroupInfo(
+ **group.permissions,
+ description=group.description,
+ service_account_email=group.service_account_email,
+ )
return self._permissions
@property
def user_groups_names(self):
- if not hasattr(self, '_user_groups_names'):
- self._user_groups_names = crud.get_user_group_names(self.db, self.email) + ["default"]
+ if not hasattr(self, "_user_groups_names"):
+ # TODO: Define hidden properties in __init__ method
+ self._user_groups_names = crud.get_user_group_names(
+ self.db, self.email
+ ) + ["default"]
return self._user_groups_names
@property
def user_groups(self):
- if not hasattr(self, '_user_groups'):
- self._user_groups = crud.get_user_groups_by_name(self.db, self.user_groups_names)
+ if not hasattr(self, "_user_groups"):
+ self._user_groups = crud.get_user_groups_by_name(
+ self.db, self.user_groups_names
+ )
return self._user_groups
@property
@@ -61,10 +72,11 @@ class UserState:
"""
Read can be a list of group names or True, if all can be read.
"""
- if not hasattr(self, '_read'):
+ if not hasattr(self, "_read"):
self._read = set()
for group in self.user_groups:
- if not group.permissions: continue
+ if not group.permissions:
+ continue
group_read_permissions = group.permissions.get("read", [])
if "all" in group_read_permissions:
self._read = True
@@ -78,10 +90,11 @@ class UserState:
"""
Read public permission
"""
- if not hasattr(self, '_read_public'):
+ if not hasattr(self, "_read_public"):
self._read_public = False
for group in self.user_groups:
- if not group.permissions: continue
+ if not group.permissions:
+ continue
if group.permissions.get("read_public", False):
self._read_public = True
return self._read_public
@@ -92,10 +105,11 @@ class UserState:
"""
Archive URL permission
"""
- if not hasattr(self, '_archive_url'):
+ if not hasattr(self, "_archive_url"):
self._archive_url = False
for group in self.user_groups:
- if not group.permissions: continue
+ if not group.permissions:
+ continue
if group.permissions.get("archive_url", False):
self._archive_url = True
return self._archive_url
@@ -106,10 +120,11 @@ class UserState:
"""
Archive sheet permission
"""
- if not hasattr(self, '_archive_sheet'):
+ if not hasattr(self, "_archive_sheet"):
self._archive_sheet = False
for group in self.user_groups:
- if not group.permissions: continue
+ if not group.permissions:
+ continue
if group.permissions.get("archive_sheet", False):
self._archive_sheet = True
return self._archive_sheet
@@ -117,37 +132,53 @@ class UserState:
@property
def sheet_frequency(self):
- if not hasattr(self, '_sheet_frequency'):
+ if not hasattr(self, "_sheet_frequency"):
self._sheet_frequency = set()
for group in self.user_groups:
- if not group.permissions: continue
- self._sheet_frequency.update(group.permissions.get("sheet_frequency", None))
+ if not group.permissions:
+ continue
+ self._sheet_frequency.update(
+ group.permissions.get("sheet_frequency", None)
+ )
return self._sheet_frequency
@property
def max_archive_lifespan_months(self) -> int:
- if not hasattr(self, '_max_archive_lifespan_months'):
- self._max_archive_lifespan_months = self._helper_for_grouping_max_numerical_permissions("max_archive_lifespan_months")
+ if not hasattr(self, "_max_archive_lifespan_months"):
+ self._max_archive_lifespan_months = (
+ self._helper_for_grouping_max_numerical_permissions(
+ "max_archive_lifespan_months"
+ )
+ )
return self._max_archive_lifespan_months
@property
def max_monthly_urls(self) -> int:
- if not hasattr(self, '_max_monthly_urls'):
- self._max_monthly_urls = self._helper_for_grouping_max_numerical_permissions("max_monthly_urls")
+ if not hasattr(self, "_max_monthly_urls"):
+ self._max_monthly_urls = (
+ self._helper_for_grouping_max_numerical_permissions(
+ "max_monthly_urls"
+ )
+ )
return self._max_monthly_urls
@property
def max_monthly_mbs(self) -> int:
- if not hasattr(self, '_max_monthly_mbs'):
- self._max_monthly_mbs = self._helper_for_grouping_max_numerical_permissions("max_monthly_mbs")
+ if not hasattr(self, "_max_monthly_mbs"):
+ self._max_monthly_mbs = (
+ self._helper_for_grouping_max_numerical_permissions(
+ "max_monthly_mbs"
+ )
+ )
return self._max_monthly_mbs
@property
def priority(self) -> str:
- if not hasattr(self, '_priority'):
+ if not hasattr(self, "_priority"):
self._priority = "low"
for group in self.user_groups:
- if not group.permissions: continue
+ if not group.permissions:
+ continue
if group.permissions.get("priority", self._priority) == "high":
self._priority = "high"
break
@@ -158,18 +189,28 @@ class UserState:
"""
A user is active if they can read/archive anything
"""
- if not hasattr(self, '_active'):
- self._active = bool(self.read or self.read_public or self.archive_url or self.archive_sheet)
+ if not hasattr(self, "_active"):
+ self._active = bool(
+ self.read
+ or self.read_public
+ or self.archive_url
+ or self.archive_sheet
+ )
return self._active
- def _helper_for_grouping_max_numerical_permissions(self, permission_name: str) -> int:
+ def _helper_for_grouping_max_numerical_permissions(
+ self, permission_name: str
+ ) -> int:
"""
- Iterates one of the numerical permissions where -1 means no restrictions and returns either -1 or the maximum value, defaults according to GroupPermissions
+ Iterates one of the numerical permissions where -1 means no restrictions
+ and returns either -1 or the maximum value, defaults according to
+ GroupPermissions
"""
default = GroupPermissions.model_fields[permission_name].default
max_value = default
for group in self.user_groups:
- if not group.permissions: continue
+ if not group.permissions:
+ continue
group_value = group.permissions.get(permission_name, default)
if group_value == -1:
max_value = -1
@@ -180,43 +221,65 @@ class UserState:
def in_group(self, group_id: str) -> bool:
return group_id in self.user_groups_names
- def usage(self) -> Dict:
+ def usage(self) -> UsageResponse:
"""
- returns the monthly quotas for the URLs/MBs and the totals for Sheets
+ Returns the monthly quotas for the URLs/MBs and the totals for Sheets
"""
current_month = datetime.now().month
current_year = datetime.now().year
# find and sum all user sheets over this month
- user_sheets = self.db.query(
- models.Sheet.group_id,
- func.count(models.Sheet.id).label('sheet_count')
- ).filter(models.Sheet.author_id == self.email).group_by(models.Sheet.group_id).all()
+ user_sheets = (
+ self.db.query(
+ models.Sheet.group_id,
+ func.count(models.Sheet.id).label("sheet_count"),
+ )
+ .filter(models.Sheet.author_id == self.email)
+ .group_by(models.Sheet.group_id)
+ .all()
+ )
- sheets_by_group = {sheet.group_id: sheet.sheet_count for sheet in user_sheets}
+ sheets_by_group = {
+ sheet.group_id: sheet.sheet_count for sheet in user_sheets
+ }
# find and sum all user urls over this month
- urls_by_group = self.db.query(
- models.Archive.group_id,
- func.count(models.Archive.id).label('url_count'),
- func.coalesce(func.sum(
+ urls_by_group = (
+ self.db.query(
+ models.Archive.group_id,
+ func.count(models.Archive.id).label("url_count"),
func.coalesce(
- func.cast(
- func.json_extract(models.Archive.result, '$.metadata.total_bytes'),
- sqlalchemy.Integer
- ), 0
- )
- ), 0).label('total_bytes')
- ).filter(
- models.Archive.author_id == self.email,
- func.extract('month', models.Archive.created_at) == current_month,
- func.extract('year', models.Archive.created_at) == current_year
- ).group_by(models.Archive.group_id).all()
+ func.sum(
+ func.coalesce(
+ func.cast(
+ func.json_extract(
+ models.Archive.result,
+ "$.metadata.total_bytes",
+ ),
+ sqlalchemy.Integer,
+ ),
+ 0,
+ )
+ ),
+ 0,
+ ).label("total_bytes"),
+ )
+ .filter(
+ models.Archive.author_id == self.email,
+ func.extract("month", models.Archive.created_at)
+ == current_month,
+ func.extract("year", models.Archive.created_at) == current_year,
+ )
+ .group_by(models.Archive.group_id)
+ .all()
+ )
# merge the two queries
usage_by_group: Dict[str, Usage] = {
- (url.group_id or ""):
- Usage(monthly_urls=url.url_count, monthly_mbs=int(url.total_bytes / 1024 / 1024))
+ (url.group_id or ""): Usage(
+ monthly_urls=url.url_count,
+ monthly_mbs=int(url.total_bytes / 1024 / 1024),
+ )
for url in urls_by_group
}
for group_id, sheet_count in sheets_by_group.items():
@@ -235,7 +298,7 @@ class UserState:
monthly_urls=total_urls,
monthly_mbs=int(total_bytes / 1024 / 1024),
total_sheets=total_sheets,
- groups=usage_by_group
+ groups=usage_by_group,
)
def has_quota_monthly_sheets(self, group_id: str) -> bool:
@@ -245,7 +308,14 @@ class UserState:
if group_id not in self.permissions:
return False
- user_sheets = self.db.query(models.Sheet).filter(models.Sheet.author_id == self.email, models.Sheet.group_id == group_id).count()
+ user_sheets = (
+ self.db.query(models.Sheet)
+ .filter(
+ models.Sheet.author_id == self.email,
+ models.Sheet.group_id == group_id,
+ )
+ .count()
+ )
sheet_quota = self.permissions[group_id].max_sheets
if sheet_quota == -1:
@@ -254,13 +324,15 @@ class UserState:
def has_quota_max_monthly_urls(self, group_id: str) -> bool:
"""
- checks if a user has reached their monthly url quota for a group, if global then group should be empty string
+ Checks if a user has reached their monthly url quota for a group, if
+ global then group should be empty string
"""
quota = 0
if not group_id:
quota = self.max_monthly_urls
else:
- if group_id not in self.permissions: return False
+ if group_id not in self.permissions:
+ return False
quota = self.permissions[group_id].max_monthly_urls
if quota == -1:
@@ -268,24 +340,31 @@ class UserState:
current_month = datetime.now().month
current_year = datetime.now().year
- user_urls = self.db.query(models.Archive).filter(
- models.Archive.author_id == self.email,
- models.Archive.group_id == group_id,
- func.extract('month', models.Archive.created_at) == current_month,
- func.extract('year', models.Archive.created_at) == current_year
- ).count()
+ user_urls = (
+ self.db.query(models.Archive)
+ .filter(
+ models.Archive.author_id == self.email,
+ models.Archive.group_id == group_id,
+ func.extract("month", models.Archive.created_at)
+ == current_month,
+ func.extract("year", models.Archive.created_at) == current_year,
+ )
+ .count()
+ )
return user_urls < quota
def has_quota_max_monthly_mbs(self, group_id: str) -> bool:
"""
- checks if a user has reached their monthly MBs quota for a group, if global then group should be empty string
+ Checks if a user has reached their monthly MBs quota for a group, if
+ global then group should be empty string
"""
quota = 0
if not group_id:
quota = self.max_monthly_mbs
else:
- if group_id not in self.permissions: return False
+ if group_id not in self.permissions:
+ return False
quota = self.permissions[group_id].max_monthly_mbs
if quota == -1:
@@ -295,19 +374,34 @@ class UserState:
current_year = datetime.now().year
# find and sum all user bytes over this month
- user_bytes = self.db.query(models.Archive).filter(
- models.Archive.author_id == self.email,
- models.Archive.group_id == group_id,
- func.extract('month', models.Archive.created_at) == current_month,
- func.extract('year', models.Archive.created_at) == current_year
- ).with_entities(func.coalesce(func.sum(
- func.coalesce(
- func.cast(
- func.json_extract(models.Archive.result, '$.metadata.total_bytes'),
- sqlalchemy.Integer
- ), 0
+ user_bytes = (
+ self.db.query(models.Archive)
+ .filter(
+ models.Archive.author_id == self.email,
+ models.Archive.group_id == group_id,
+ func.extract("month", models.Archive.created_at)
+ == current_month,
+ func.extract("year", models.Archive.created_at) == current_year,
)
- ), 0).label('total')).scalar()
+ .with_entities(
+ func.coalesce(
+ func.sum(
+ func.coalesce(
+ func.cast(
+ func.json_extract(
+ models.Archive.result,
+ "$.metadata.total_bytes",
+ ),
+ sqlalchemy.Integer,
+ ),
+ 0,
+ )
+ ),
+ 0,
+ ).label("total")
+ )
+ .scalar()
+ )
# convert bytes to mb
user_mbs = int(user_bytes / 1024 / 1024)
@@ -315,7 +409,7 @@ class UserState:
def can_manually_trigger(self, group_id: str) -> bool:
"""
- checks if a user is allowed to manually trigger a sheet
+ Checks if a user is allowed to manually trigger a sheet
"""
if group_id not in self.permissions:
return False
@@ -324,18 +418,21 @@ class UserState:
def is_sheet_frequency_allowed(self, group_id: str, frequency: str) -> bool:
"""
- checks if a user is allowed to create a sheet with this frequency for this group
+ Checks if a user is allowed to create a sheet with this frequency for
+ this group
"""
if group_id not in self.permissions:
return False
return frequency in self.permissions[group_id].sheet_frequency
- def priority_group(self, group_id: str) -> str:
+ def priority_group(self, group_id: str) -> dict:
priority = "low"
for group in self.user_groups:
- if group.id != group_id: continue
- if not group.permissions: continue
+ if group.id != group_id:
+ continue
+ if not group.permissions:
+ continue
priority = group.permissions.get("priority", priority)
break
return convert_priority_to_queue_dict(priority)
diff --git a/app/web/endpoints/default.py b/app/web/endpoints/default.py
deleted file mode 100644
index 9271992..0000000
--- a/app/web/endpoints/default.py
+++ /dev/null
@@ -1,50 +0,0 @@
-
-from typing import Dict
-from fastapi import APIRouter, Depends, HTTPException
-from fastapi.responses import FileResponse, JSONResponse
-
-from app.web.config import VERSION, BREAKING_CHANGES
-from app.shared.schemas import ActiveUser, UsageResponse
-from app.web.db.user_state import UserState
-from app.web.security import get_user_state
-from app.shared.user_groups import GroupInfo
-
-default_router = APIRouter()
-
-
-@default_router.get("/")
-async def home():
- return JSONResponse({"version": VERSION, "breakingChanges": BREAKING_CHANGES})
-
-
-@default_router.get("/health")
-async def health():
- return JSONResponse({"status": "ok"})
-
-
-@default_router.get("/user/active", summary="Check if the user is active and can use the tool.")
-async def active(
- user: UserState = Depends(get_user_state),
-) -> ActiveUser:
- return {"active": user.active}
-
-
-@default_router.get("/user/permissions", summary="Get the user's global 'all' permissions and the permissions for each group they belong to.")
-def get_user_permissions(
- user: UserState = Depends(get_user_state),
-) -> Dict[str, GroupInfo]:
- return user.permissions
-
-@default_router.get("/user/usage", summary="Get the user's monthly URLs/MBs usage along with the total active sheets, breakdown by group.")
-def get_user_usage(
- user: UserState = Depends(get_user_state),
-) -> UsageResponse:
- if not user.active:
- raise HTTPException(status_code=403, detail="User is not active.")
- return user.usage()
-
-
-
-@default_router.get('/favicon.ico', include_in_schema=False)
-async def favicon() -> FileResponse:
- return FileResponse("app/web/static/favicon.ico")
diff --git a/app/web/endpoints/sheet.py b/app/web/endpoints/sheet.py
deleted file mode 100644
index 7848b5e..0000000
--- a/app/web/endpoints/sheet.py
+++ /dev/null
@@ -1,81 +0,0 @@
-
-from fastapi import APIRouter, Depends, HTTPException
-from fastapi.responses import JSONResponse
-
-from sqlalchemy import exc
-from sqlalchemy.orm import Session
-
-from app.web.db.user_state import UserState
-from app.shared import schemas
-from app.shared.task_messaging import get_celery
-from app.web.security import get_user_state
-from app.web.db import crud
-from app.shared.db.database import get_db_dependency
-
-sheet_router = APIRouter(prefix="/sheet", tags=["Google Spreadsheet operations"])
-
-celery = get_celery()
-
-@sheet_router.post("/create", status_code=201, summary="Store a new Google Sheet for regular archiving.")
-def create_sheet(
- sheet: schemas.SheetAdd,
- user: UserState = Depends(get_user_state),
- db: Session = Depends(get_db_dependency),
-) -> schemas.SheetResponse:
-
- if not user.in_group(sheet.group_id):
- raise HTTPException(status_code=403, detail="User does not have access to this group.")
-
- if not user.has_quota_monthly_sheets(sheet.group_id):
- raise HTTPException(status_code=429, detail="User has reached their sheet quota for this group.")
-
- if not user.is_sheet_frequency_allowed(sheet.group_id, sheet.frequency):
- raise HTTPException(status_code=422, detail="Invalid frequency selected for this group.")
-
- try:
- return crud.create_sheet(db, sheet.id, sheet.name, user.email, sheet.group_id, sheet.frequency)
- except exc.IntegrityError as e:
- raise HTTPException(status_code=400, detail="Sheet with this ID is already being archived.") from e
-
-
-@sheet_router.get("/mine", status_code=200, summary="Get the authenticated user's Google Sheets.")
-def get_user_sheets(
- user: UserState = Depends(get_user_state),
- db: Session = Depends(get_db_dependency)
-) -> list[schemas.SheetResponse]:
- return crud.get_user_sheets(db, user.email)
-
-
-@sheet_router.delete("/{id}", summary="Delete a Google Sheet by ID.")
-def delete_sheet(
- id: str,
- user: UserState = Depends(get_user_state),
- db: Session = Depends(get_db_dependency),
-) -> schemas.DeleteResponse:
- return JSONResponse({
- "id": id,
- "deleted": crud.delete_sheet(db, id, user.email)
- })
-
-
-@sheet_router.post("/{id}/archive", status_code=201, summary="Trigger an archiving task for a GSheet you own.", response_description="task_id for the archiving task.")
-def archive_user_sheet(
- id: str,
- user: UserState = Depends(get_user_state),
- db: Session = Depends(get_db_dependency),
-) -> schemas.Task:
-
- sheet = crud.get_user_sheet(db, user.email, sheet_id=id)
- if not sheet:
- raise HTTPException(status_code=403, detail="No access to this sheet.")
-
- if not user.in_group(sheet.group_id):
- raise HTTPException(status_code=403, detail="User does not have access to this group.")
-
- if not user.can_manually_trigger(sheet.group_id):
- raise HTTPException(status_code=429, detail="User cannot manually trigger sheet archiving in this group.")
-
- group_queue = user.priority_group(sheet.group_id)
- task = celery.signature("create_sheet_task", args=[schemas.SubmitSheet(sheet_id=id, author_id=user.email, group_id=sheet.group_id).model_dump_json()]).apply_async(**group_queue)
-
- return JSONResponse({"id": task.id}, status_code=201)
\ No newline at end of file
diff --git a/app/web/endpoints/task.py b/app/web/endpoints/task.py
deleted file mode 100644
index 610c579..0000000
--- a/app/web/endpoints/task.py
+++ /dev/null
@@ -1,40 +0,0 @@
-from celery.result import AsyncResult
-from fastapi import APIRouter, Depends
-from fastapi.encoders import jsonable_encoder
-from fastapi.responses import JSONResponse
-
-from app.shared.task_messaging import get_celery
-from app.web.security import get_token_or_user_auth
-from app.shared import schemas
-from app.shared.log import log_error
-from app.web.utils.misc import custom_jsonable_encoder
-
-
-task_router = APIRouter(prefix="/task", tags=["Async task operations"])
-
-celery = get_celery()
-
-@task_router.get("/{task_id}", summary="Check the status of an async task by its id, works for URLs and Sheet tasks.")
-def get_status(task_id, email=Depends(get_token_or_user_auth)) -> schemas.TaskResult:
- task = AsyncResult(task_id, app=celery)
- try:
- if task.status == "FAILURE":
- # *FAILURE* The task raised an exception, or has exceeded the retry limit.
- # The :attr:`result` attribute then contains the exception raised by the task.
- # https://docs.celeryq.dev/en/stable/_modules/celery/result.html#AsyncResult
- raise task.result
-
- response = {
- "id": task_id,
- "status": task.status,
- "result": task.result
- }
- return JSONResponse(jsonable_encoder(response, exclude_unset=True, custom_encoder={bytes: custom_jsonable_encoder}))
-
- except Exception as e:
- log_error(e)
- return JSONResponse({
- "id": task_id,
- "status": "FAILURE",
- "result": {"error": str(e)}
- })
diff --git a/app/web/endpoints/url.py b/app/web/endpoints/url.py
deleted file mode 100644
index a7ac4b4..0000000
--- a/app/web/endpoints/url.py
+++ /dev/null
@@ -1,85 +0,0 @@
-
-from fastapi import APIRouter, Depends, HTTPException
-from fastapi.responses import JSONResponse
-from datetime import datetime
-from loguru import logger
-from sqlalchemy.orm import Session
-
-from app.web.config import ALLOW_ANY_EMAIL
-from app.shared import schemas
-from app.shared.task_messaging import get_celery
-from app.web.security import get_token_or_user_auth, get_user_state
-from app.web.db import crud
-from app.web.db.user_state import UserState
-from app.shared.db.database import get_db_dependency
-
-from urllib.parse import urlparse
-
-from app.web.utils.misc import convert_priority_to_queue_dict
-
-url_router = APIRouter(prefix="/url", tags=["Single URL operations"])
-
-celery = get_celery()
-
-@url_router.post("/archive", status_code=201, summary="Submit a single URL archive request, starts an archiving task.", response_description="task_id for the archiving task, will match the archive id.")
-def archive_url(
- archive: schemas.ArchiveTrigger,
- email=Depends(get_token_or_user_auth),
- db: Session = Depends(get_db_dependency)
-) -> schemas.Task:
- logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {archive.url}")
-
- parsed_url = urlparse(archive.url)
- if not all([parsed_url.scheme, parsed_url.netloc]):
- raise HTTPException(status_code=400, detail="Invalid URL received.")
-
- archive_create = schemas.ArchiveCreate(**archive.model_dump())
- if email != ALLOW_ANY_EMAIL:
- archive_create.author_id = email
- user = UserState(db, email)
- if archive.group_id and not user.in_group(archive.group_id):
- raise HTTPException(status_code=403, detail="User does not have access to this group.")
- if not user.has_quota_max_monthly_urls(archive.group_id):
- raise HTTPException(status_code=429, detail="User has reached their monthly URL quota.")
- if not user.has_quota_max_monthly_mbs(archive.group_id):
- raise HTTPException(status_code=429, detail="User has reached their monthly MB quota.")
- group_queue = user.priority_group(archive_create.group_id)
- else:
- archive_create.author_id = archive.author_id or email
- group_queue = convert_priority_to_queue_dict("high")
-
-
- task = celery.signature("create_archive_task", args=[archive_create.model_dump_json()]).apply_async(**group_queue)
- task_response = schemas.Task(id=task.id)
- return JSONResponse(task_response.model_dump(), status_code=201)
-
-
-@url_router.get("/search", summary="Search for archive entries by URL.")
-def search_by_url(
- url: str, skip: int = 0, limit: int = 25,
- archived_after: datetime = None, archived_before: datetime = None,
- db: Session = Depends(get_db_dependency),
- email: str = Depends(get_token_or_user_auth)
-) -> list[schemas.ArchiveResult]:
-
- read_groups, read_public = False, False
- if email != ALLOW_ANY_EMAIL:
- user = UserState(db, email)
- if not user.read and not user.read_public:
- raise HTTPException(status_code=403, detail="User does not have read access.")
- read_groups = user.read
- read_public = user.read_public
- return crud.search_archives_by_url(db, url.strip(), email, read_groups, read_public, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before)
-
-
-@url_router.delete("/{id}", summary="Delete a single URL archive by id.")
-def delete_archive(
- id:str,
- user: UserState = Depends(get_user_state),
- db: Session = Depends(get_db_dependency)
-) -> schemas.DeleteResponse:
- logger.info(f"deleting url archive task {id} request by {user.email}")
- return JSONResponse({
- "id": id,
- "deleted": crud.soft_delete_archive(db, id, user.email)
- })
diff --git a/app/web/events.py b/app/web/events.py
index 625731a..e9af845 100644
--- a/app/web/events.py
+++ b/app/web/events.py
@@ -1,22 +1,32 @@
import asyncio
-from collections import defaultdict
import datetime
import logging
+from collections import defaultdict
+from contextlib import asynccontextmanager
+
import alembic.config
from fastapi import FastAPI
-from contextlib import asynccontextmanager
+from fastapi_mail import FastMail, MessageSchema, MessageType
from fastapi_utils.tasks import repeat_every
from loguru import logger
-from fastapi_mail import FastMail, MessageSchema, MessageType
-from app.shared.db import models
-from app.shared.db.database import get_db, get_db_async, make_engine, wal_checkpoint
from app.shared import schemas
+from app.shared.db import models
+from app.shared.db.database import (
+ get_db,
+ get_db_async,
+ make_engine,
+ wal_checkpoint,
+)
from app.shared.settings import get_settings
from app.shared.task_messaging import get_celery
from app.web.db import crud
from app.web.middleware import increase_exceptions_counter
-from app.web.utils.metrics import measure_regular_metrics, redis_subscribe_worker_exceptions
+from app.web.utils.metrics import (
+ measure_regular_metrics,
+ redis_subscribe_worker_exceptions,
+)
+
celery = get_celery()
@@ -28,9 +38,22 @@ async def lifespan(app: FastAPI):
# STARTUP
engine = make_engine(get_settings().DATABASE_PATH)
models.Base.metadata.create_all(bind=engine)
- alembic.config.main(prog="alembic", argv=['--raiseerr', 'upgrade', 'head'])
+ alembic.config.main(
+ prog="alembic",
+ argv=[
+ "-c",
+ "./app/migrations/alembic.ini",
+ "--raiseerr",
+ "upgrade",
+ "head",
+ ],
+ )
logging.getLogger("uvicorn.access").disabled = True # loguru
- asyncio.create_task(redis_subscribe_worker_exceptions(get_settings().REDIS_EXCEPTIONS_CHANNEL))
+ asyncio.create_task(
+ redis_subscribe_worker_exceptions(
+ get_settings().REDIS_EXCEPTIONS_CHANNEL
+ )
+ )
asyncio.create_task(repeat_measure_regular_metrics())
with get_db() as db:
crud.upsert_user_groups(db)
@@ -61,41 +84,74 @@ async def lifespan(app: FastAPI):
# CRON JOBS
-@repeat_every(seconds=get_settings().REPEAT_COUNT_METRICS_SECONDS, on_exception=increase_exceptions_counter)
+@repeat_every(
+ seconds=get_settings().REPEAT_COUNT_METRICS_SECONDS,
+ on_exception=increase_exceptions_counter,
+)
async def repeat_measure_regular_metrics():
- await measure_regular_metrics(get_settings().DATABASE_PATH, get_settings().REPEAT_COUNT_METRICS_SECONDS)
+ await measure_regular_metrics(
+ get_settings().DATABASE_PATH,
+ get_settings().REPEAT_COUNT_METRICS_SECONDS,
+ )
-@repeat_every(seconds=60, wait_first=120, on_exception=increase_exceptions_counter)
+@repeat_every(
+ seconds=60, wait_first=120, on_exception=increase_exceptions_counter
+)
async def archive_hourly_sheets_cronjob():
await archive_sheets_cronjob("hourly", 60, datetime.datetime.now().minute)
-@repeat_every(seconds=3600, wait_first=120, on_exception=increase_exceptions_counter)
+@repeat_every(
+ seconds=3600, wait_first=120, on_exception=increase_exceptions_counter
+)
async def archive_daily_sheets_cronjob():
await archive_sheets_cronjob("daily", 24, datetime.datetime.now().hour)
-async def archive_sheets_cronjob(frequency: str, interval: int, current_time_unit: int):
+async def archive_sheets_cronjob(
+ frequency: str, interval: int, current_time_unit: int
+):
triggered_jobs = []
async with get_db_async() as db:
- sheets = await crud.get_sheets_by_id_hash(db, frequency, interval, current_time_unit)
+ sheets = await crud.get_sheets_by_id_hash(
+ db, frequency, str(interval), current_time_unit
+ )
for s in sheets:
group_queue = await crud.get_group_priority_async(db, s.group_id)
- task = celery.signature("create_sheet_task", args=[schemas.SubmitSheet(sheet_id=s.id, author_id=s.author_id, group_id=s.group_id).model_dump_json()]).apply_async(**group_queue)
+ task = celery.signature(
+ "create_sheet_task",
+ args=[
+ schemas.SubmitSheet(
+ sheet_id=s.id,
+ author_id=s.author_id,
+ group_id=s.group_id,
+ ).model_dump_json()
+ ],
+ ).apply_async(**group_queue)
triggered_jobs.append({"sheet_id": s.id, "task_id": task.id})
- logger.debug(f"[CRON {frequency.upper()}:{current_time_unit}] Triggered {len(triggered_jobs)} sheet tasks: {triggered_jobs}")
+ logger.debug(
+ f"[CRON {frequency.upper()}:{current_time_unit}] Triggered {len(triggered_jobs)} sheet tasks: {triggered_jobs}"
+ )
# TODO: on exception should logerror but also prometheus counter
-DELETE_WINDOW = get_settings().DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS * 24 * 60 * 60
+DELETE_WINDOW = (
+ get_settings().DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS * 24 * 60 * 60
+)
-@repeat_every(seconds=DELETE_WINDOW, wait_first=180, on_exception=increase_exceptions_counter)
+@repeat_every(
+ seconds=DELETE_WINDOW,
+ wait_first=180,
+ on_exception=increase_exceptions_counter,
+)
async def notify_about_expired_archives():
- notify_from = datetime.datetime.now() + datetime.timedelta(days=get_settings().DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS)
+ notify_from = datetime.datetime.now() + datetime.timedelta(
+ days=get_settings().DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS
+ )
async with get_db_async() as db:
scheduled_deletions = await crud.find_by_store_until(db, notify_from)
@@ -104,10 +160,15 @@ async def notify_about_expired_archives():
user_archives[archive.author_id].append(archive)
if user_archives:
- fastmail = FastMail(get_settings().MAIL_CONFIG)
+ fastmail = FastMail(get_settings().mail_config)
# notify users
for email in user_archives:
- list_of_archives = "\n".join([f'{a.url}, {a.id}, {a.store_until.isoformat()}
' for a in user_archives[email]])
+ list_of_archives = "\n".join(
+ [
+ f"{a.url}, {a.id}, {a.store_until.isoformat()}
"
+ for a in user_archives[email]
+ ]
+ )
# TODO: how can users download them in bulk?
message = MessageSchema(
subject="Auto Archiver: Archives Scheduled for Deletion",
@@ -127,16 +188,23 @@ async def notify_about_expired_archives():