mirror of
https://github.com/bellingcat/auto-archiver-api.git
synced 2026-06-07 19:18:34 +03:00
Merge pull request #72 from bellingcat/dev
Community contributions, code standardization, and AA v1.0.0
This commit is contained in:
@@ -1,3 +0,0 @@
|
||||
[run]
|
||||
omit =
|
||||
app/migrations/*
|
||||
@@ -2,4 +2,4 @@ CHROME_APP_IDS='["1234567890"]'
|
||||
ALLOWED_ORIGINS='["allowed"]'
|
||||
BLOCKED_EMAILS='[]'
|
||||
DATABASE_PATH="sqlite:///./database/auto-archiver.db"
|
||||
API_BEARER_TOKEN=THIS_API_TOKEN_SHOULD_NEVER_BE_USED
|
||||
API_BEARER_TOKEN=THIS_API_TOKEN_SHOULD_NEVER_BE_USED
|
||||
|
||||
@@ -35,4 +35,4 @@ MAIL_SSL_TLS=True
|
||||
|
||||
|
||||
# celery workers config
|
||||
CONCURRENCY=2
|
||||
CONCURRENCY=2
|
||||
|
||||
@@ -5,4 +5,4 @@ BLOCKED_EMAILS='["blocked@example.com"]'
|
||||
|
||||
DATABASE_PATH="sqlite:///auto-archiver.test.db"
|
||||
API_BEARER_TOKEN=this_is_the_test_api_token
|
||||
USER_GROUPS_FILENAME=app/tests/user-groups.test.yaml
|
||||
USER_GROUPS_FILENAME=app/tests/user-groups.test.yaml
|
||||
|
||||
23
.github/pull_request_template.md
vendored
Normal file
23
.github/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
<!---
|
||||
Please write your PR name in the present imperative tense. Examples of that tense are:
|
||||
"Fix issue in the dispatcher where…", "Improve our handling of…", etc.
|
||||
|
||||
For more information on Pull Requests, you can reference here:
|
||||
https://success.vanillaforums.com/kb/articles/228-using-pull-requests-to-contribute
|
||||
-->
|
||||
## Describe your changes
|
||||
|
||||
|
||||
## Non-obvious technical information
|
||||
|
||||
|
||||
## Checklist before requesting a review
|
||||
<!---
|
||||
These are suggested things you could add, but what you add will be dependent on
|
||||
your repository's standards.
|
||||
--->
|
||||
- [ ] The code runs successfully.
|
||||
|
||||
```commandline
|
||||
HERE IS SOME COMMAND LINE OUTPUT
|
||||
```
|
||||
12
.github/workflows/ci.yml
vendored
12
.github/workflows/ci.yml
vendored
@@ -1,16 +1,12 @@
|
||||
name: CI
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
branches: [ main, dev ]
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
branches: [ main, dev ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
test-with-coverage:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
services:
|
||||
@@ -41,4 +37,4 @@ jobs:
|
||||
run: poetry run coverage run -m pytest -v -ra --color=yes app/tests/
|
||||
|
||||
- name: Report coverage
|
||||
run: poetry run coverage report
|
||||
run: poetry run coverage report
|
||||
|
||||
16
.github/workflows/format-and-fail.yml
vendored
Normal file
16
.github/workflows/format-and-fail.yml
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
name: Format and Fail
|
||||
on:
|
||||
push:
|
||||
branches: [ main, dev ]
|
||||
pull_request:
|
||||
branches: [ main, dev ]
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
40
.github/workflows/test.yml
vendored
Normal file
40
.github/workflows/test.yml
vendored
Normal file
@@ -0,0 +1,40 @@
|
||||
name: Run Tests
|
||||
on:
|
||||
push:
|
||||
branches: [ main, dev ]
|
||||
pull_request:
|
||||
branches: [ main, dev ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.10', '3.11', '3.12']
|
||||
|
||||
services:
|
||||
redis:
|
||||
image: redis:6-alpine
|
||||
ports:
|
||||
- 6379:6379
|
||||
|
||||
steps:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Poetry
|
||||
run: pipx install poetry
|
||||
|
||||
- name: Install dependencies
|
||||
run: poetry install --no-interaction --with dev
|
||||
|
||||
- name: Set dev environment variable
|
||||
run: echo "ENVIRONMENT_FILE=.env.test" >> $GITHUB_ENV
|
||||
|
||||
- name: Run tests
|
||||
run: poetry run pytest app/tests
|
||||
141
.gitignore
vendored
141
.gitignore
vendored
@@ -1,16 +1,10 @@
|
||||
# Misc.
|
||||
user-groups.dev.yaml
|
||||
user-groups.yaml
|
||||
orchestration.yaml
|
||||
my-archives
|
||||
*.pyc
|
||||
.DS_Store
|
||||
secrets/*
|
||||
*.log
|
||||
__pycache__
|
||||
.pytest_cache
|
||||
.env
|
||||
.env.dev
|
||||
.env.prod
|
||||
*.db
|
||||
redis/data/*
|
||||
.ipynb_checkpoints*
|
||||
@@ -18,8 +12,6 @@ app/user-groups.yaml
|
||||
app/user-groups.dev.yaml
|
||||
wit*
|
||||
app/crawls
|
||||
.coverage
|
||||
.pytest_cache/
|
||||
htmlcov
|
||||
local_archive
|
||||
local_archive_test
|
||||
@@ -27,6 +19,133 @@ local_archive_test
|
||||
*db-shm
|
||||
copy-files.sh
|
||||
temp/
|
||||
.python-version
|
||||
orchestration2.yaml
|
||||
database
|
||||
database
|
||||
|
||||
# IDE files
|
||||
.idea
|
||||
.vscode
|
||||
**/.DS_Store
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env*
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
78
.pre-commit-config.yaml
Normal file
78
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,78 @@
|
||||
repos:
|
||||
- repo: https://github.com/nbQA-dev/nbQA
|
||||
rev: 1.9.1
|
||||
hooks:
|
||||
- id: nbqa-ruff
|
||||
args:
|
||||
- --fix
|
||||
- --target-version=py310
|
||||
- --ignore=E721,E722
|
||||
- --line-length=80
|
||||
- id: nbqa-black
|
||||
args:
|
||||
- --line-length=80
|
||||
- id: nbqa-isort
|
||||
args:
|
||||
- --float-to-top
|
||||
- --profile=black
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: check-docstring-first
|
||||
- id: check-executables-have-shebangs
|
||||
- id: check-json
|
||||
- id: check-case-conflict
|
||||
- id: check-toml
|
||||
- id: check-merge-conflict
|
||||
- id: check-xml
|
||||
- id: check-yaml
|
||||
exclude: app/tests/user-groups.test.broken.yaml
|
||||
- id: end-of-file-fixer
|
||||
- id: check-symlinks
|
||||
- id: mixed-line-ending
|
||||
- id: sort-simple-yaml
|
||||
- id: fix-encoding-pragma
|
||||
args:
|
||||
- --remove
|
||||
- id: pretty-format-json
|
||||
args:
|
||||
- --autofix
|
||||
|
||||
- repo: https://github.com/pre-commit/pygrep-hooks
|
||||
rev: v1.10.0
|
||||
hooks:
|
||||
- id: python-check-blanket-noqa
|
||||
- id: python-check-mock-methods
|
||||
- id: python-no-eval
|
||||
- id: python-no-log-warn
|
||||
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 6.0.1
|
||||
hooks:
|
||||
- id: isort
|
||||
name: Run isort to sort imports
|
||||
files: \.py$
|
||||
# To keep consistent with the global isort skip config defined in setup.cfg
|
||||
exclude: ^build/.*$|^.tox/.*$|^venv/.*$
|
||||
args:
|
||||
- --lines-after-imports=2
|
||||
- --profile=black
|
||||
- --line-length=80
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.9.7
|
||||
hooks:
|
||||
- id: ruff
|
||||
types_or: [python,pyi]
|
||||
args:
|
||||
- --fix
|
||||
- --select=B,C,E,F,W,B9
|
||||
- --line-length=80
|
||||
- --ignore=E203,E402,E501,E261
|
||||
- id: ruff-format
|
||||
types_or: [ python,pyi]
|
||||
args:
|
||||
- --target-version=py310
|
||||
- --line-length=80
|
||||
1
CODEOWNERS
Normal file
1
CODEOWNERS
Normal file
@@ -0,0 +1 @@
|
||||
* @msramalho
|
||||
2
LICENSE
2
LICENSE
@@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
SOFTWARE.
|
||||
|
||||
19
Makefile
19
Makefile
@@ -1,19 +1,33 @@
|
||||
.PHONY: lint
|
||||
lint:
|
||||
poetry run pre-commit run --all-files
|
||||
|
||||
.PHONY: test
|
||||
test:
|
||||
export ENVIRONMENT_FILE=.env.test
|
||||
poetry run coverage run -m pytest -v --disable-warnings --color=yes app/tests/
|
||||
poetry run coverage report
|
||||
|
||||
.PHONY: clean-dev
|
||||
clean-dev:
|
||||
@echo -n "Are you sure? [yes/N] (this will delete volumes) " && read ans && [ $${ans:-N} = yes ]
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml down --volumes --remove-orphans
|
||||
|
||||
.PHONY: dev
|
||||
dev:
|
||||
docker compose --env-file .env.dev -f docker-compose.yml -f docker-compose.dev.yml build
|
||||
docker compose --env-file .env.dev -f docker-compose.yml -f docker-compose.dev.yml up --remove-orphans
|
||||
|
||||
|
||||
.PHONY: dev-redis-only
|
||||
dev-redis-only:
|
||||
docker compose --env-file .env.dev -f docker-compose.yml -f docker-compose.dev.yml build redis
|
||||
docker compose --env-file .env.dev -f docker-compose.yml -f docker-compose.dev.yml up --remove-orphans redis
|
||||
|
||||
.PHONY: stop-dev
|
||||
stop-dev:
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml down --volumes
|
||||
|
||||
.PHONY: prod
|
||||
prod:
|
||||
docker compose --env-file .env.prod build
|
||||
docker compose --env-file .env.prod up -d --remove-orphans
|
||||
@@ -21,5 +35,6 @@ prod:
|
||||
docker image prune -f
|
||||
docker system df
|
||||
|
||||
.PHONY: stop-prod
|
||||
stop-prod:
|
||||
docker compose down
|
||||
docker compose down
|
||||
|
||||
25
README.md
25
README.md
@@ -12,9 +12,9 @@ To properly set up the API you need to install `docker` and to have these files,
|
||||
2. a `user-groups.yaml` to manage user permissions
|
||||
1. note that all local files referenced in `user-groups.yaml` and any orchestration.yaml files should be relative to the home directory so if your service account is in `secrets/orchestration.yaml` use that path and not just `orchestration.yaml`.
|
||||
2. go through the example file and configure it according to your needs.
|
||||
3. you will need to create and reference at least one `secrets/orchestration.yaml` file, you can do so by following the instructions in the [auto-archiver](https://github.com/bellingcat/auto-archiver#installation) that automatically generates one for you. If you use the archive sheets feature you will need to create a `orchestrationsheets-sheets.yaml` file as well that should have the `gsheet_feeder` and `gsheet_db` enabled and configured, the auto-archiver has [extensive documentation](https://auto-archiver.readthedocs.io/en/latest/) on how to set this up.
|
||||
3. you will need to create and reference at least one `secrets/orchestration.yaml` file, you can do so by following the instructions in the [auto-archiver](https://github.com/bellingcat/auto-archiver#installation) that automatically generates one for you. If you use the archive sheets feature you will need to create a `orchestrationsheets-sheets.yaml` file as well that should have the `gsheet_feeder_db` feeder and database enabled and configured, the auto-archiver has [extensive documentation](https://auto-archiver.readthedocs.io/en/latest/) on how to set this up.
|
||||
|
||||
Do not commit those files, they are .gitignored by default.
|
||||
Do not commit those files, they are .gitignored by default.
|
||||
We also advise you to keep any sensitive files in the `secrets/` folder which is pinned and gitignored.
|
||||
|
||||
We have examples for both of those files (`.env.example` and `user-groups.example.yaml`), and here's how to set them up whether you're in development or production:
|
||||
@@ -108,6 +108,27 @@ Make sure environment and user-groups files are up to date.
|
||||
Then `make prod`.
|
||||
|
||||
|
||||
## Development
|
||||
```bash
|
||||
# make sure all development dependencies are installed
|
||||
poetry install --with dev
|
||||
|
||||
# this project uses pre-commit to enforce code style and formatting, set that up locally
|
||||
poetry run pre-commit install
|
||||
|
||||
# you can test pre-commit with
|
||||
poetry run pre-commit run --all-files
|
||||
|
||||
# this means pre-commit will always run with git commit, to skip it use
|
||||
git commit --no-verify
|
||||
|
||||
# see the Makefile for more commands, but linting and formatting can be done with
|
||||
make lint
|
||||
|
||||
# run all tests
|
||||
make test
|
||||
```
|
||||
|
||||
### Testing
|
||||
```bash
|
||||
# set the testing environment variables
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
script_location = app/migrations
|
||||
script_location = ./app/migrations
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
@@ -10,10 +10,6 @@ script_location = app/migrations
|
||||
# for all available tokens
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the python-dateutil library that can be
|
||||
@@ -1,19 +1,20 @@
|
||||
from logging.config import fileConfig
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
|
||||
from app.shared.settings import get_settings
|
||||
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
config.set_main_option('sqlalchemy.url', get_settings().DATABASE_PATH)
|
||||
config.set_main_option("sqlalchemy.url", get_settings().DATABASE_PATH)
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name, disable_existing_loggers=False) # disable_existing_loggers prevents loguru disabling
|
||||
# disable_existing_loggers prevents loguru disabling
|
||||
fileConfig(config.config_file_name, disable_existing_loggers=False)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
|
||||
@@ -5,13 +5,14 @@ Revises: 1636724ec4b1
|
||||
Create Date: 2025-02-08 15:22:20.392522
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '02b2f6d17ed0'
|
||||
down_revision = '1636724ec4b1'
|
||||
revision = "02b2f6d17ed0"
|
||||
down_revision = "1636724ec4b1"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
STORE_UNTIL_COL = "store_until"
|
||||
@@ -20,15 +21,20 @@ STORE_UNTIL_COL = "store_until"
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('archives')]
|
||||
columns = [col["name"] for col in inspector.get_columns("archives")]
|
||||
|
||||
if STORE_UNTIL_COL not in columns:
|
||||
op.add_column('archives', sa.Column(STORE_UNTIL_COL, sa.DateTime(), nullable=True, default=None))
|
||||
op.add_column(
|
||||
"archives",
|
||||
sa.Column(
|
||||
STORE_UNTIL_COL, sa.DateTime(), nullable=True, default=None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('archives')]
|
||||
columns = [col["name"] for col in inspector.get_columns("archives")]
|
||||
if STORE_UNTIL_COL in columns:
|
||||
op.drop_column('archives', STORE_UNTIL_COL)
|
||||
op.drop_column("archives", STORE_UNTIL_COL)
|
||||
|
||||
@@ -5,13 +5,14 @@ Revises: a23aaf3ae930
|
||||
Create Date: 2025-02-05 19:19:01.984396
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '1636724ec4b1'
|
||||
down_revision = 'a23aaf3ae930'
|
||||
revision = "1636724ec4b1"
|
||||
down_revision = "a23aaf3ae930"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
@@ -19,14 +20,18 @@ depends_on = None
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('sheets')]
|
||||
if 'last_archived_at' in columns:
|
||||
op.alter_column('sheets', 'last_archived_at', new_column_name='last_url_archived_at')
|
||||
columns = [col["name"] for col in inspector.get_columns("sheets")]
|
||||
if "last_archived_at" in columns:
|
||||
op.alter_column(
|
||||
"sheets", "last_archived_at", new_column_name="last_url_archived_at"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('sheets')]
|
||||
if 'last_url_archived_at' in columns:
|
||||
op.alter_column('sheets', 'last_url_archived_at', new_column_name='last_archived_at')
|
||||
columns = [col["name"] for col in inspector.get_columns("sheets")]
|
||||
if "last_url_archived_at" in columns:
|
||||
op.alter_column(
|
||||
"sheets", "last_url_archived_at", new_column_name="last_archived_at"
|
||||
)
|
||||
|
||||
@@ -5,13 +5,14 @@ Revises: 02b2f6d17ed0
|
||||
Create Date: 2025-02-11 21:53:23.293274
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '63ac79df4ad0'
|
||||
down_revision = '02b2f6d17ed0'
|
||||
revision = "63ac79df4ad0"
|
||||
down_revision = "02b2f6d17ed0"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
@@ -22,15 +23,17 @@ TABLE = "groups"
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns(TABLE)]
|
||||
columns = [col["name"] for col in inspector.get_columns(TABLE)]
|
||||
|
||||
if NEW_COL not in columns:
|
||||
op.add_column(TABLE, sa.Column(NEW_COL, sa.String, nullable=True, default=None))
|
||||
op.add_column(
|
||||
TABLE, sa.Column(NEW_COL, sa.String, nullable=True, default=None)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns(TABLE)]
|
||||
columns = [col["name"] for col in inspector.get_columns(TABLE)]
|
||||
if NEW_COL in columns:
|
||||
op.drop_column(TABLE, NEW_COL)
|
||||
|
||||
@@ -5,14 +5,14 @@ Revises: fa012ec405b8
|
||||
Create Date: 2024-11-04 11:12:30.237299
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '89121d2c96d8'
|
||||
down_revision = 'fa012ec405b8'
|
||||
revision = "89121d2c96d8"
|
||||
down_revision = "fa012ec405b8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
@@ -20,23 +20,27 @@ depends_on = None
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('archives')]
|
||||
columns = [col["name"] for col in inspector.get_columns("archives")]
|
||||
|
||||
if 'sheet_id' not in columns:
|
||||
with op.batch_alter_table('archives') as batch_op:
|
||||
batch_op.add_column(sa.Column('sheet_id', sa.String(), nullable=True, default=None))
|
||||
batch_op.create_foreign_key('fk_sheet_id', 'sheets', ['sheet_id'], ['id'])
|
||||
if "sheet_id" not in columns:
|
||||
with op.batch_alter_table("archives") as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column("sheet_id", sa.String(), nullable=True, default=None)
|
||||
)
|
||||
batch_op.create_foreign_key(
|
||||
"fk_sheet_id", "sheets", ["sheet_id"], ["id"]
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
foreign_keys = [fk['name'] for fk in inspector.get_foreign_keys('archives')]
|
||||
columns = [col['name'] for col in inspector.get_columns('archives')]
|
||||
foreign_keys = [fk["name"] for fk in inspector.get_foreign_keys("archives")]
|
||||
columns = [col["name"] for col in inspector.get_columns("archives")]
|
||||
|
||||
with op.batch_alter_table('archives') as batch_op:
|
||||
if 'fk_sheet_id' in foreign_keys:
|
||||
batch_op.drop_constraint('fk_sheet_id', type_='foreignkey')
|
||||
with op.batch_alter_table("archives") as batch_op:
|
||||
if "fk_sheet_id" in foreign_keys:
|
||||
batch_op.drop_constraint("fk_sheet_id", type_="foreignkey")
|
||||
|
||||
if 'sheet_id' in columns:
|
||||
batch_op.drop_column('sheet_id')
|
||||
if "sheet_id" in columns:
|
||||
batch_op.drop_column("sheet_id")
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
"""modify archive url to have uuid id instead of url unique constraint
|
||||
|
||||
Revision ID: 9369a264945b
|
||||
Revises:
|
||||
Revises:
|
||||
Create Date: 2023-12-20 17:24:59.320691
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '9369a264945b'
|
||||
revision = "9369a264945b"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
@@ -18,10 +20,11 @@ def upgrade() -> None:
|
||||
# since the primary key constraint is not named, we have to recreate it first
|
||||
with op.batch_alter_table("archive_urls") as batch_op:
|
||||
batch_op.create_primary_key("pk_url", ["url"])
|
||||
batch_op.drop_constraint("pk_url", type_='primary')
|
||||
batch_op.drop_constraint("pk_url", type_="primary")
|
||||
batch_op.create_primary_key("pk_url_archive_id", ["url", "archive_id"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
with op.batch_alter_table("archive_urls") as batch_op:
|
||||
batch_op.drop_constraint("pk_url_archive_id", type_='primary')
|
||||
batch_op.drop_constraint("pk_url_archive_id", type_="primary")
|
||||
batch_op.create_primary_key("url", ["url"])
|
||||
|
||||
@@ -5,12 +5,13 @@ Revises: 9369a264945b
|
||||
Create Date: 2023-12-20 18:33:27.132566
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '93a611e4c066'
|
||||
down_revision = '9369a264945b'
|
||||
revision = "93a611e4c066"
|
||||
down_revision = "9369a264945b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
@@ -20,7 +21,9 @@ def upgrade() -> None:
|
||||
with op.get_context().autocommit_block():
|
||||
op.execute("VACUUM")
|
||||
except Exception as e:
|
||||
print("Unable to run vacuum, maybe there's not enough disk space. it should be 2x the size of the database")
|
||||
print(
|
||||
"Unable to run vacuum, maybe there's not enough disk space. it should be 2x the size of the database"
|
||||
)
|
||||
print(e)
|
||||
|
||||
|
||||
|
||||
@@ -5,13 +5,14 @@ Revises: 89121d2c96d8
|
||||
Create Date: 2025-02-04 12:19:20.753570
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'a23aaf3ae930'
|
||||
down_revision = '89121d2c96d8'
|
||||
revision = "a23aaf3ae930"
|
||||
down_revision = "89121d2c96d8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
@@ -19,16 +20,24 @@ depends_on = None
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('users')]
|
||||
columns = [col["name"] for col in inspector.get_columns("users")]
|
||||
|
||||
if 'is_active' in columns:
|
||||
op.drop_column('users', 'is_active')
|
||||
if "is_active" in columns:
|
||||
op.drop_column("users", "is_active")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('users')]
|
||||
columns = [col["name"] for col in inspector.get_columns("users")]
|
||||
|
||||
if 'is_active' not in columns:
|
||||
op.add_column('users', sa.Column('is_active', sa.Boolean(), nullable=False, server_default=sa.false()))
|
||||
if "is_active" not in columns:
|
||||
op.add_column(
|
||||
"users",
|
||||
sa.Column(
|
||||
"is_active",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -5,14 +5,14 @@ Revises: 93a611e4c066
|
||||
Create Date: 2024-10-31 09:36:50.360710
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'fa012ec405b8'
|
||||
down_revision = '93a611e4c066'
|
||||
revision = "fa012ec405b8"
|
||||
down_revision = "93a611e4c066"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
@@ -20,26 +20,41 @@ depends_on = None
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('groups')]
|
||||
columns = [col["name"] for col in inspector.get_columns("groups")]
|
||||
|
||||
if 'description' not in columns:
|
||||
op.add_column('groups', sa.Column('description', sa.String(), nullable=True))
|
||||
if 'orchestrator' not in columns:
|
||||
op.add_column('groups', sa.Column('orchestrator', sa.String(), nullable=True))
|
||||
if 'orchestrator_sheet' not in columns:
|
||||
op.add_column('groups', sa.Column('orchestrator_sheet', sa.String(), nullable=True))
|
||||
if 'permissions' not in columns:
|
||||
op.add_column('groups', sa.Column('permissions', sa.JSON(), nullable=True))
|
||||
if 'domains' not in columns:
|
||||
op.add_column('groups', sa.Column('domains', sa.JSON(), nullable=True))
|
||||
if "description" not in columns:
|
||||
op.add_column(
|
||||
"groups", sa.Column("description", sa.String(), nullable=True)
|
||||
)
|
||||
if "orchestrator" not in columns:
|
||||
op.add_column(
|
||||
"groups", sa.Column("orchestrator", sa.String(), nullable=True)
|
||||
)
|
||||
if "orchestrator_sheet" not in columns:
|
||||
op.add_column(
|
||||
"groups",
|
||||
sa.Column("orchestrator_sheet", sa.String(), nullable=True),
|
||||
)
|
||||
if "permissions" not in columns:
|
||||
op.add_column(
|
||||
"groups", sa.Column("permissions", sa.JSON(), nullable=True)
|
||||
)
|
||||
if "domains" not in columns:
|
||||
op.add_column("groups", sa.Column("domains", sa.JSON(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('groups')]
|
||||
columns = [col["name"] for col in inspector.get_columns("groups")]
|
||||
|
||||
column_names = ['description', 'orchestrator', 'orchestrator_sheet', 'permissions', 'domains']
|
||||
column_names = [
|
||||
"description",
|
||||
"orchestrator",
|
||||
"orchestrator_sheet",
|
||||
"permissions",
|
||||
"domains",
|
||||
]
|
||||
for column_name in column_names:
|
||||
if column_name in columns:
|
||||
op.drop_column('groups', column_name)
|
||||
op.drop_column("groups", column_name)
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
# TODO: code in this file should eventually be moved to the auto-archiver code base
|
||||
|
||||
from typing import List
|
||||
from loguru import logger
|
||||
from auto_archiver.core import Media, Metadata
|
||||
|
||||
from app.shared.db import models
|
||||
|
||||
def get_all_urls(result: Metadata) -> List[models.ArchiveUrl]:
|
||||
db_urls = []
|
||||
for m in result.media:
|
||||
for i, url in enumerate(m.urls): db_urls.append(models.ArchiveUrl(url=url, key=m.get("id", f"media_{i}")))
|
||||
for k, prop in m.properties.items():
|
||||
if prop_converted := convert_if_media(prop):
|
||||
for i, url in enumerate(prop_converted.urls): db_urls.append(models.ArchiveUrl(url=url, key=prop_converted.get("id", f"{k}_{i}")))
|
||||
if isinstance(prop, list):
|
||||
for i, prop_media in enumerate(prop):
|
||||
if prop_media := convert_if_media(prop_media):
|
||||
for j, url in enumerate(prop_media.urls):
|
||||
db_urls.append(models.ArchiveUrl(url=url, key=prop_media.get("id", f"{k}{prop_media.key}_{i}.{j}")))
|
||||
return db_urls
|
||||
|
||||
|
||||
|
||||
def convert_if_media(media):
|
||||
if isinstance(media, Media): return media
|
||||
elif isinstance(media, dict):
|
||||
try: return Media.from_dict(media)
|
||||
except Exception as e:
|
||||
logger.debug(f"error parsing {media} : {e}")
|
||||
return False
|
||||
|
||||
@@ -1,25 +1,35 @@
|
||||
# TODO: temporary file for this code, maybe other code belongs here, maybe not. do decide
|
||||
|
||||
|
||||
import datetime
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.shared.db import worker_crud
|
||||
|
||||
|
||||
def get_store_archive_until(db: Session, group_id: str) -> datetime.datetime:
|
||||
# TODO: temporary file for this code, maybe other code belongs here, maybe not. do
|
||||
# decide
|
||||
|
||||
|
||||
def get_store_archive_until(
|
||||
db: Session, group_id: str
|
||||
) -> Union[datetime.datetime, None]:
|
||||
group = worker_crud.get_group(db, group_id)
|
||||
assert group, f"Group {group_id} not found."
|
||||
assert group.permissions and type(group.permissions) == dict, f"Group {group_id} has no permissions."
|
||||
assert group.permissions and isinstance(group.permissions, dict), (
|
||||
f"Group {group_id} has no permissions."
|
||||
)
|
||||
|
||||
max_lifespan = group.permissions.get("max_archive_lifespan_months", -1)
|
||||
if max_lifespan == -1: return None
|
||||
if max_lifespan == -1:
|
||||
return None
|
||||
|
||||
return datetime.datetime.now() + datetime.timedelta(days=30 * max_lifespan)
|
||||
|
||||
|
||||
def get_store_archive_until_or_never(db: Session, group_id: str) -> datetime.datetime:
|
||||
def get_store_archive_until_or_never(
|
||||
db: Session, group_id: str
|
||||
) -> Union[datetime.datetime, None]:
|
||||
try:
|
||||
return get_store_archive_until(db, group_id)
|
||||
except AssertionError as e:
|
||||
except AssertionError:
|
||||
return None
|
||||
|
||||
7
app/shared/constants.py
Normal file
7
app/shared/constants.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Statuses
|
||||
STATUS_FAILURE = "FAILURE"
|
||||
STATUS_PENDING = "PENDING"
|
||||
STATUS_SUCCESS = "SUCCESS"
|
||||
|
||||
# AA CLI CONFIGS
|
||||
SHEET_ID = "--gsheet_feeder_db.sheet_id"
|
||||
@@ -1,8 +1,14 @@
|
||||
from functools import lru_cache
|
||||
from sqlalchemy import Engine, create_engine, event, text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, AsyncEngine, async_sessionmaker
|
||||
from functools import lru_cache
|
||||
|
||||
from sqlalchemy import Engine, create_engine, event, text
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from app.shared.settings import get_settings
|
||||
|
||||
@@ -12,9 +18,9 @@ def make_engine(database_url: str):
|
||||
engine = create_engine(
|
||||
database_url,
|
||||
connect_args={"check_same_thread": False},
|
||||
pool_size=15, # Increase pool size
|
||||
max_overflow=20, # Allow more temporary connections
|
||||
pool_recycle=1800 # Recycle connections every 30 minutes
|
||||
pool_size=15, # Increase pool size
|
||||
max_overflow=20, # Allow more temporary connections
|
||||
pool_recycle=1800, # Recycle connections every 30 minutes
|
||||
)
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
@@ -34,8 +40,10 @@ def make_session_local(engine: Engine):
|
||||
@contextmanager
|
||||
def get_db():
|
||||
session = make_session_local(make_engine(get_settings().DATABASE_PATH))()
|
||||
try: yield session
|
||||
finally: session.close()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
def get_db_dependency():
|
||||
@@ -53,22 +61,32 @@ def wal_checkpoint():
|
||||
|
||||
# ASYNC connections
|
||||
async def make_async_engine(database_url: str) -> AsyncEngine:
|
||||
engine = create_async_engine(database_url, connect_args={"check_same_thread": False})
|
||||
engine = create_async_engine(
|
||||
database_url, connect_args={"check_same_thread": False}
|
||||
)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(lambda sync_conn: sync_conn.execute(text("PRAGMA journal_mode=WAL;")))
|
||||
await conn.run_sync(
|
||||
lambda sync_conn: sync_conn.execute(
|
||||
text("PRAGMA journal_mode=WAL;")
|
||||
)
|
||||
)
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
async def make_async_session_local(engine: AsyncEngine) -> AsyncSession:
|
||||
return async_sessionmaker(engine, expire_on_commit=False, autoflush=False, autocommit=False)
|
||||
return async_sessionmaker(
|
||||
engine, expire_on_commit=False, autoflush=False, autocommit=False
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db_async():
|
||||
engine = await make_async_engine(get_settings().ASYNC_DATABASE_PATH)
|
||||
engine = await make_async_engine(get_settings().async_database_path)
|
||||
async_session = await make_async_session_local(engine)
|
||||
async with async_session() as session:
|
||||
try: yield session
|
||||
finally: await engine.dispose()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
@@ -1,8 +1,17 @@
|
||||
from sqlalchemy import Column, String, JSON, DateTime, Boolean, Table, ForeignKey
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship, declarative_base
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
String,
|
||||
Table,
|
||||
)
|
||||
from sqlalchemy.orm import declarative_base, relationship
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
@@ -11,7 +20,7 @@ def generate_uuid():
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
# many to many association tables
|
||||
# many-to-many association tables
|
||||
association_table_archive_tags = Table(
|
||||
"mtm_archives_tags",
|
||||
Base.metadata,
|
||||
@@ -33,7 +42,9 @@ class Archive(Base):
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
url = Column(String, index=True)
|
||||
result = Column(JSON, default=None)
|
||||
public = Column(Boolean, default=True) # if public=false, access by group and author
|
||||
public = Column(
|
||||
Boolean, default=True
|
||||
) # if public=false, access by group and author
|
||||
deleted = Column(Boolean, default=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
@@ -43,7 +54,11 @@ class Archive(Base):
|
||||
author_id = Column(String, ForeignKey("users.email"))
|
||||
sheet_id = Column(String, ForeignKey("sheets.id"), default=None)
|
||||
|
||||
tags = relationship("Tag", back_populates="archives", secondary=association_table_archive_tags)
|
||||
tags = relationship(
|
||||
"Tag",
|
||||
back_populates="archives",
|
||||
secondary=association_table_archive_tags,
|
||||
)
|
||||
group = relationship("Group", back_populates="archives")
|
||||
author = relationship("User", back_populates="archives")
|
||||
urls = relationship("ArchiveUrl", back_populates="archive")
|
||||
@@ -66,7 +81,11 @@ class Tag(Base):
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
archives = relationship("Archive", back_populates="tags", secondary=association_table_archive_tags)
|
||||
archives = relationship(
|
||||
"Archive",
|
||||
back_populates="tags",
|
||||
secondary=association_table_archive_tags,
|
||||
)
|
||||
|
||||
|
||||
class User(Base):
|
||||
@@ -76,7 +95,9 @@ class User(Base):
|
||||
|
||||
archives = relationship("Archive", back_populates="author")
|
||||
sheets = relationship("Sheet", back_populates="author")
|
||||
groups = relationship("Group", back_populates="users", secondary=association_table_user_groups)
|
||||
groups = relationship(
|
||||
"Group", back_populates="users", secondary=association_table_user_groups
|
||||
)
|
||||
|
||||
|
||||
class Group(Base):
|
||||
@@ -92,7 +113,9 @@ class Group(Base):
|
||||
|
||||
archives = relationship("Archive", back_populates="group")
|
||||
sheets = relationship("Sheet", back_populates="group")
|
||||
users = relationship("User", back_populates="groups", secondary=association_table_user_groups)
|
||||
users = relationship(
|
||||
"User", back_populates="groups", secondary=association_table_user_groups
|
||||
)
|
||||
|
||||
|
||||
class Sheet(Base):
|
||||
@@ -101,11 +124,27 @@ class Sheet(Base):
|
||||
id = Column(String, primary_key=True, index=True, doc="Google Sheet ID")
|
||||
name = Column(String, default=None)
|
||||
author_id = Column(String, ForeignKey("users.email"))
|
||||
group_id = Column(String, ForeignKey("groups.id"), doc="Group ID, user must be in a group to create a sheet.")
|
||||
frequency = Column(String, default="daily", doc="Frequency of archiving: hourly, daily, weekly.")
|
||||
group_id = Column(
|
||||
String,
|
||||
ForeignKey("groups.id"),
|
||||
doc="Group ID, user must be in a group to create a sheet.",
|
||||
)
|
||||
frequency = Column(
|
||||
String,
|
||||
default="daily",
|
||||
doc="Frequency of archiving: hourly, daily, weekly.",
|
||||
)
|
||||
# TODO: stats is not being used, consider removing
|
||||
stats = Column(JSON, default={}, doc="Sheet statistics like total links, total rows, ...")
|
||||
last_url_archived_at = Column(DateTime(timezone=True), server_default=func.now(), doc="Last time a new link was archived.")
|
||||
stats = Column(
|
||||
JSON,
|
||||
default={},
|
||||
doc="Sheet statistics like total links, total rows, ...",
|
||||
)
|
||||
last_url_archived_at = Column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
doc="Last time a new link was archived.",
|
||||
)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from datetime import datetime
|
||||
|
||||
from app.shared.db import models
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.shared import schemas
|
||||
from app.shared.db import models
|
||||
|
||||
|
||||
# TODO: isolate database operations away from worker and into WEB
|
||||
# ONLY WORKER
|
||||
def update_sheet_last_url_archived_at(db: Session, sheet_id: str):
|
||||
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id).first()
|
||||
db_sheet = (
|
||||
db.query(models.Sheet).filter(models.Sheet.id == sheet_id).first()
|
||||
)
|
||||
if db_sheet:
|
||||
db_sheet.last_url_archived_at = datetime.now()
|
||||
db.commit()
|
||||
@@ -17,12 +21,17 @@ def update_sheet_last_url_archived_at(db: Session, sheet_id: str):
|
||||
|
||||
# ONLY WORKER and INTEROP
|
||||
|
||||
|
||||
def get_group(db: Session, group_name: str) -> models.Group:
|
||||
return db.query(models.Group).filter(models.Group.id == group_name).first()
|
||||
|
||||
|
||||
def create_or_get_user(db: Session, author_id: str) -> models.User:
|
||||
if type(author_id) == str: author_id = author_id.lower()
|
||||
db_user = db.query(models.User).filter(models.User.email == author_id).first()
|
||||
if isinstance(author_id, str):
|
||||
author_id = author_id.lower()
|
||||
db_user = (
|
||||
db.query(models.User).filter(models.User.email == author_id).first()
|
||||
)
|
||||
if not db_user:
|
||||
db_user = models.User(email=author_id)
|
||||
db.add(db_user)
|
||||
@@ -41,8 +50,22 @@ def create_tag(db: Session, tag: str) -> models.Tag:
|
||||
return db_tag
|
||||
|
||||
|
||||
def create_archive(db: Session, archive: schemas.ArchiveCreate, tags: list[models.Tag], urls: list[models.ArchiveUrl]) -> models.Archive:
|
||||
db_archive = models.Archive(id=archive.id, url=archive.url, result=archive.result, public=archive.public, author_id=archive.author_id, group_id=archive.group_id, sheet_id=archive.sheet_id, store_until=archive.store_until)
|
||||
def create_archive(
|
||||
db: Session,
|
||||
archive: schemas.ArchiveCreate,
|
||||
tags: list[models.Tag],
|
||||
urls: list[models.ArchiveUrl],
|
||||
) -> models.Archive:
|
||||
db_archive = models.Archive(
|
||||
id=archive.id,
|
||||
url=archive.url,
|
||||
result=archive.result,
|
||||
public=archive.public,
|
||||
author_id=archive.author_id,
|
||||
group_id=archive.group_id,
|
||||
sheet_id=archive.sheet_id,
|
||||
store_until=archive.store_until,
|
||||
)
|
||||
db_archive.tags = tags
|
||||
db_archive.urls = urls
|
||||
db.add(db_archive)
|
||||
@@ -51,10 +74,14 @@ def create_archive(db: Session, archive: schemas.ArchiveCreate, tags: list[model
|
||||
return db_archive
|
||||
|
||||
|
||||
def store_archived_url(db: Session, archive: schemas.ArchiveCreate) -> models.Archive:
|
||||
def store_archived_url(
|
||||
db: Session, archive: schemas.ArchiveCreate
|
||||
) -> models.Archive:
|
||||
# create and load user, tags, if needed
|
||||
create_or_get_user(db, archive.author_id)
|
||||
db_tags = [create_tag(db, tag) for tag in (archive.tags or [])]
|
||||
# insert everything
|
||||
db_archive = create_archive(db, archive=archive, tags=db_tags, urls=archive.urls)
|
||||
db_archive = create_archive(
|
||||
db, archive=archive, tags=db_tags, urls=archive.urls
|
||||
)
|
||||
return db_archive
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import traceback
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@@ -6,8 +7,10 @@ from loguru import logger
|
||||
logger.add("logs/api_logs.log", retention="30 days")
|
||||
logger.add("logs/error_logs.log", retention="30 days", level="ERROR")
|
||||
|
||||
|
||||
def log_error(e: Exception, traceback_str: str = None, extra:str = ""):
|
||||
if not traceback_str: traceback_str = traceback.format_exc()
|
||||
if extra: extra = f"{extra}\n"
|
||||
|
||||
def log_error(e: Exception, traceback_str: str = None, extra: str = ""):
|
||||
if not traceback_str:
|
||||
traceback_str = traceback.format_exc()
|
||||
if extra:
|
||||
extra = f"{extra}\n"
|
||||
logger.error(f"{extra}{e.__class__.__name__}: {e}\n{traceback_str}")
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
|
||||
from annotated_types import Len
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class SubmitSheet(BaseModel):
|
||||
@@ -10,6 +11,7 @@ class SubmitSheet(BaseModel):
|
||||
group_id: str = "default"
|
||||
tags: set[str] | None = set()
|
||||
|
||||
|
||||
class ArchiveUrl(BaseModel):
|
||||
url: str
|
||||
public: bool = False
|
||||
@@ -17,6 +19,7 @@ class ArchiveUrl(BaseModel):
|
||||
group_id: str | None
|
||||
tags: set[str] | None = set()
|
||||
|
||||
|
||||
class ArchiveResult(BaseModel):
|
||||
id: str
|
||||
url: str
|
||||
|
||||
@@ -1,31 +1,38 @@
|
||||
|
||||
from functools import lru_cache
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Annotated, Set
|
||||
|
||||
from annotated_types import Len
|
||||
from fastapi_mail import ConnectionConfig
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from typing import Annotated, Set
|
||||
from annotated_types import Len
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
|
||||
model_config = SettingsConfigDict(env_file=os.environ.get("ENVIRONMENT_FILE") , env_file_encoding='utf-8', extra='ignore', str_strip_whitespace=True)
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=os.environ.get("ENVIRONMENT_FILE"),
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
str_strip_whitespace=True,
|
||||
)
|
||||
|
||||
# general
|
||||
# general
|
||||
SERVE_LOCAL_ARCHIVE: str | None = None
|
||||
USER_GROUPS_FILENAME: str = "app/user-groups.yaml"
|
||||
|
||||
# database
|
||||
# database
|
||||
DATABASE_PATH: str
|
||||
DATABASE_QUERY_LIMIT: int = 100
|
||||
|
||||
@property
|
||||
def ASYNC_DATABASE_PATH(self) -> str:
|
||||
def async_database_path(self) -> str:
|
||||
return self.DATABASE_PATH.replace("sqlite://", "sqlite+aiosqlite://")
|
||||
|
||||
# security
|
||||
# security
|
||||
API_BEARER_TOKEN: Annotated[str, Len(min_length=20)]
|
||||
ALLOWED_ORIGINS: Annotated[Set[str], Len(min_length=1)]
|
||||
CHROME_APP_IDS: Annotated[Set[Annotated[str, Len(min_length=10)]], Len(min_length=1)]
|
||||
CHROME_APP_IDS: Annotated[
|
||||
Set[Annotated[str, Len(min_length=10)]], Len(min_length=1)
|
||||
]
|
||||
BLOCKED_EMAILS: Annotated[Set[str], Len(min_length=0)] = set()
|
||||
# if not provided only OAUTH access_tokens are allowed
|
||||
FIREBASE_SERVICE_ACCOUNT_JSON: str = ""
|
||||
@@ -34,20 +41,21 @@ class Settings(BaseSettings):
|
||||
REDIS_PASSWORD: str = ""
|
||||
REDIS_HOSTNAME: str = "localhost"
|
||||
REDIS_EXCEPTIONS_CHANNEL: str = "exceptions-channel"
|
||||
|
||||
@property
|
||||
def CELERY_BROKER_URL(self)-> str:
|
||||
def celery_broker_url(self) -> str:
|
||||
if self.REDIS_PASSWORD:
|
||||
return f"redis://:{self.REDIS_PASSWORD}@{self.REDIS_HOSTNAME}:6379"
|
||||
return f"redis://{self.REDIS_HOSTNAME}:6379"
|
||||
|
||||
|
||||
# cronjobs
|
||||
CRON_ARCHIVE_SHEETS: bool = False
|
||||
CRON_DELETE_STALE_SHEETS: bool = False
|
||||
DELETE_STALE_SHEETS_DAYS: int = 14
|
||||
CRON_DELETE_SCHEDULED_ARCHIVES: bool = False
|
||||
DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS: int = 7
|
||||
|
||||
# observability
|
||||
|
||||
# observability
|
||||
REPEAT_COUNT_METRICS_SECONDS: int = 30
|
||||
|
||||
# email configuration, if needed
|
||||
@@ -59,8 +67,9 @@ class Settings(BaseSettings):
|
||||
MAIL_PORT: int = 587
|
||||
MAIL_STARTTLS: bool = False
|
||||
MAIL_SSL_TLS: bool = True
|
||||
|
||||
@property
|
||||
def MAIL_CONFIG(self) -> str:
|
||||
def mail_config(self) -> ConnectionConfig:
|
||||
return ConnectionConfig(
|
||||
MAIL_FROM=self.MAIL_FROM,
|
||||
MAIL_FROM_NAME=self.MAIL_FROM_NAME,
|
||||
@@ -75,4 +84,4 @@ class Settings(BaseSettings):
|
||||
|
||||
@lru_cache
|
||||
def get_settings():
|
||||
return Settings()
|
||||
return Settings()
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
|
||||
from functools import lru_cache
|
||||
from celery import Celery
|
||||
import redis
|
||||
|
||||
from celery import Celery
|
||||
|
||||
import redis
|
||||
from app.shared.settings import get_settings
|
||||
|
||||
|
||||
@@ -10,14 +10,14 @@ from app.shared.settings import get_settings
|
||||
def get_celery(name: str = "") -> Celery:
|
||||
return Celery(
|
||||
name,
|
||||
broker_url=get_settings().CELERY_BROKER_URL,
|
||||
result_backend=get_settings().CELERY_BROKER_URL,
|
||||
broker_url=get_settings().celery_broker_url,
|
||||
result_backend=get_settings().celery_broker_url,
|
||||
broker_connection_retry_on_startup=False,
|
||||
broker_transport_options={
|
||||
'queue_order_strategy': 'priority',
|
||||
}
|
||||
"queue_order_strategy": "priority",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def get_redis() -> redis.Redis:
|
||||
return redis.Redis.from_url(get_settings().CELERY_BROKER_URL)
|
||||
return redis.Redis.from_url(get_settings().celery_broker_url)
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Set
|
||||
|
||||
import yaml
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, computed_field, field_validator, Field, model_validator
|
||||
from typing import Dict, List, Set
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
computed_field,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
@@ -12,13 +19,16 @@ class UserGroups:
|
||||
user_groups = self.read_yaml(filename)
|
||||
self.validate_and_load(user_groups)
|
||||
|
||||
def read_yaml(self, user_groups_filename):
|
||||
@staticmethod
|
||||
def read_yaml(user_groups_filename):
|
||||
# read yaml safely
|
||||
with open(user_groups_filename) as inf:
|
||||
try:
|
||||
return yaml.safe_load(inf)
|
||||
except yaml.YAMLError as e:
|
||||
logger.error(f"could not open user groups filename {user_groups_filename}: {e}")
|
||||
logger.error(
|
||||
f"could not open user groups filename {user_groups_filename}: {e}"
|
||||
)
|
||||
raise e
|
||||
|
||||
def validate_and_load(self, user_groups):
|
||||
@@ -45,22 +55,36 @@ class GroupPermissions(BaseModel):
|
||||
max_monthly_mbs: int = 0
|
||||
priority: str = "low"
|
||||
|
||||
@field_validator('max_sheets', 'max_archive_lifespan_months', 'max_monthly_urls', 'max_monthly_mbs', mode='before')
|
||||
@classmethod
|
||||
@field_validator(
|
||||
"max_sheets",
|
||||
"max_archive_lifespan_months",
|
||||
"max_monthly_urls",
|
||||
"max_monthly_mbs",
|
||||
mode="before",
|
||||
)
|
||||
def validate_max_values(cls, v):
|
||||
if v < -1:
|
||||
raise ValueError("max_* values should be positive integers or -1 (for no limit).")
|
||||
raise ValueError(
|
||||
"max_* values should be positive integers or -1 (for no limit)."
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator('sheet_frequency', mode='before')
|
||||
@classmethod
|
||||
@field_validator("sheet_frequency", mode="before")
|
||||
def validate_sheet_frequency(cls, v):
|
||||
if not v: return []
|
||||
if not v:
|
||||
return []
|
||||
allowed = ["daily", "hourly"]
|
||||
for k in v:
|
||||
if k not in allowed:
|
||||
raise ValueError(f"Invalid sheet frequency: '{k}', expected one of {allowed}")
|
||||
raise ValueError(
|
||||
f"Invalid sheet frequency: '{k}', expected one of {allowed}"
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator('priority', mode='before')
|
||||
@classmethod
|
||||
@field_validator("priority", mode="before")
|
||||
def validate_priority(cls, v):
|
||||
v = v.lower()
|
||||
if v not in ["low", "high"]:
|
||||
@@ -70,19 +94,31 @@ class GroupPermissions(BaseModel):
|
||||
|
||||
class GroupModel(BaseModel):
|
||||
description: str
|
||||
orchestrator: str
|
||||
orchestrator_sheet: str
|
||||
orchestrator: str | None = None
|
||||
orchestrator_sheet: str | None = None
|
||||
permissions: GroupPermissions
|
||||
|
||||
@field_validator('orchestrator', 'orchestrator_sheet', mode='before')
|
||||
@classmethod
|
||||
@field_validator("orchestrator", mode="before")
|
||||
def validate_orchestrator(cls, v):
|
||||
if not os.path.exists(v):
|
||||
# orchestrator is only needed if the group has archive_url permission
|
||||
if cls.permissions.archive_url and not os.path.exists(v):
|
||||
raise ValueError(f"Orchestrator file not found with this path: {v}")
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
@field_validator("orchestrator_sheet", mode="before")
|
||||
def validate_orchestrator_sheet(cls, v):
|
||||
# orchestrator_sheet is only needed if the group has archive_sheet permission
|
||||
if cls.permissions.archive_sheet and not os.path.exists(v):
|
||||
raise ValueError(f"Orchestrator file not found with this path: {v}")
|
||||
return v
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def service_account_email(self) -> str:
|
||||
if self.orchestrator_sheet is None:
|
||||
return ""
|
||||
if hasattr(self, "_service_account_email"):
|
||||
return self._service_account_email
|
||||
orch = yaml.safe_load(open(self.orchestrator_sheet))
|
||||
@@ -98,13 +134,17 @@ class GroupModel(BaseModel):
|
||||
|
||||
service_account_json = find_service_account_email(orch)
|
||||
if not service_account_json:
|
||||
raise ValueError(f"service_account key not found in orchestrator sheet file: {self.orchestrator_sheet}.")
|
||||
raise ValueError(
|
||||
f"service_account key not found in orchestrator sheet file: {self.orchestrator_sheet}."
|
||||
)
|
||||
|
||||
with open(service_account_json) as f:
|
||||
self._service_account_email = json.load(f).get("client_email")
|
||||
|
||||
if not self._service_account_email:
|
||||
raise ValueError(f"Service account email not found in {service_account_json}.")
|
||||
raise ValueError(
|
||||
f"Service account email not found in {service_account_json}."
|
||||
)
|
||||
|
||||
return self._service_account_email
|
||||
|
||||
@@ -114,29 +154,45 @@ class UserGroupModel(BaseModel):
|
||||
domains: Dict[str, List[str]] = Field(default_factory=dict)
|
||||
groups: Dict[str, GroupModel] = Field(default_factory=dict)
|
||||
|
||||
@field_validator('users', mode='before')
|
||||
@classmethod
|
||||
@field_validator("users", mode="before")
|
||||
def validate_emails(cls, v):
|
||||
for email in v.keys():
|
||||
if '@' not in email:
|
||||
raise ValueError(f"Invalid user, it should be an address: {email}")
|
||||
if "@" not in email:
|
||||
raise ValueError(
|
||||
f"Invalid user, it should be an address: {email}"
|
||||
)
|
||||
if not v[email]:
|
||||
raise ValueError(f"User {email} has no explicitly listed groups, only include them here if they should be in a group.")
|
||||
raise ValueError(
|
||||
f"User {email} has no explicitly listed groups, only include them here if they should be in a group."
|
||||
)
|
||||
# all users belong to the default group
|
||||
return {k.lower().strip(): list(set(["default"] + [g.lower().strip() for g in v])) for k, v in v.items()}
|
||||
return {
|
||||
k.lower().strip(): list(
|
||||
set(["default"] + [g.lower().strip() for g in v])
|
||||
)
|
||||
for k, v in v.items()
|
||||
}
|
||||
|
||||
@field_validator('domains', mode='before')
|
||||
@classmethod
|
||||
@field_validator("domains", mode="before")
|
||||
def validate_domains(cls, v):
|
||||
for domain, members in v.items():
|
||||
if '.' not in domain:
|
||||
raise ValueError(f"Invalid domain, it should contain a dot: {domain}")
|
||||
if "." not in domain:
|
||||
raise ValueError(
|
||||
f"Invalid domain, it should contain a dot: {domain}"
|
||||
)
|
||||
if not members:
|
||||
raise ValueError(f"Domain {domain} should have at least one member.")
|
||||
return {k.lower().strip(): list(set([g.lower().strip() for g in v])) for k, v in v.items()}
|
||||
raise ValueError(
|
||||
f"Domain {domain} should have at least one member."
|
||||
)
|
||||
return {
|
||||
k.lower().strip(): list({[g.lower().strip() for g in v]})
|
||||
for k, v in v.items()
|
||||
}
|
||||
|
||||
@field_validator('groups', mode='before')
|
||||
@classmethod
|
||||
@field_validator("groups", mode="before")
|
||||
def validate_groups(cls, v):
|
||||
if "default" not in v.keys():
|
||||
raise ValueError("Please include a 'default' group.")
|
||||
@@ -147,20 +203,28 @@ class UserGroupModel(BaseModel):
|
||||
raise ValueError(f"Group names should be lowercase: {group}")
|
||||
return v
|
||||
|
||||
@model_validator(mode='after')
|
||||
@model_validator(mode="after")
|
||||
def check_groups_consistency(self) -> Self:
|
||||
groups_in_domains = set([g for domain in self.domains for g in self.domains[domain]])
|
||||
groups_in_users = set([g for user in self.users for g in self.users[user]])
|
||||
groups_in_domains = {
|
||||
g for domain in self.domains for g in self.domains[domain]
|
||||
}
|
||||
groups_in_users = {g for user in self.users for g in self.users[user]}
|
||||
configured_groups = set(self.groups.keys())
|
||||
|
||||
# groups mentioned in domains and users should be defined, but this is not a ValueError since historical DB data may require it
|
||||
# groups mentioned in domains and users should be defined, but this is
|
||||
# not a ValueError since historical DB data may require it
|
||||
if groups_in_domains - configured_groups:
|
||||
logger.warning(f"These groups are associated to DOMAINS but not defined in the GROUPS section, the domains settings may not work as expected: {groups_in_domains - configured_groups}")
|
||||
logger.warning(
|
||||
f"These groups are associated to DOMAINS but not defined in the GROUPS section, the domains settings may not work as expected: {groups_in_domains - configured_groups}"
|
||||
)
|
||||
if groups_in_users - configured_groups:
|
||||
logger.warning(f"These groups are associated to USERS but not defined in the GROUPS section, the users settings may not work as expected: {groups_in_users - configured_groups}")
|
||||
logger.warning(
|
||||
f"These groups are associated to USERS but not defined in the GROUPS section, the users settings may not work as expected: {groups_in_users - configured_groups}"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
# for the API return values
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
|
||||
def fnv1a_hash_mod(s: str, modulo:int) -> int:
|
||||
# receives a string and returns a number in [0:modulo-1], ensures an even distribution over the modulo range
|
||||
hash = 0x811c9dc5 # FNV offset basis
|
||||
fnv_prime = 0x01000193 # FNV prime
|
||||
def fnv1a_hash_mod(s: str, modulo: int) -> int:
|
||||
# receives a string and returns a number in [0:modulo-1], ensures an even
|
||||
# distribution over the modulo range
|
||||
offset_basis_hash = 0x811C9DC5 # FNV offset basis
|
||||
fnv_prime = 0x01000193 # FNV prime
|
||||
for char in s:
|
||||
hash ^= ord(char)
|
||||
hash *= fnv_prime
|
||||
hash &= 0xFFFFFFFF # Keep it 32-bit
|
||||
return (hash if hash < 0x80000000 else hash - 0x100000000) % modulo
|
||||
offset_basis_hash ^= ord(char)
|
||||
offset_basis_hash *= fnv_prime
|
||||
offset_basis_hash &= 0xFFFFFFFF # Keep it 32-bit
|
||||
return (
|
||||
offset_basis_hash
|
||||
if offset_basis_hash < 0x80000000
|
||||
else offset_basis_hash - 0x100000000
|
||||
) % modulo
|
||||
|
||||
@@ -1,19 +1,39 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
from http import HTTPStatus
|
||||
from typing import AsyncGenerator
|
||||
from fastapi.testclient import TestClient
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
|
||||
|
||||
from app.shared.db import models
|
||||
from app.shared.db.database import (
|
||||
make_async_engine,
|
||||
make_async_session_local,
|
||||
make_engine,
|
||||
make_session_local,
|
||||
)
|
||||
from app.shared.settings import Settings
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
from app.web.db import crud
|
||||
from app.web.db.crud import get_user_group_names
|
||||
from app.web.db.user_state import UserState
|
||||
from app.web.main import app_factory
|
||||
from app.web.security import (
|
||||
get_token_or_user_auth,
|
||||
get_user_auth,
|
||||
get_user_state,
|
||||
token_api_key_auth,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_logger_add():
|
||||
"""Fixture to mock loguru.logger.add for all tests."""
|
||||
with patch('loguru.logger.add') as mock_add:
|
||||
with patch("loguru.logger.add") as mock_add:
|
||||
yield mock_add # This makes the mock available to tests
|
||||
|
||||
|
||||
@@ -24,23 +44,22 @@ def get_settings():
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_settings():
|
||||
with patch('app.shared.settings.Settings', return_value=Settings(_env_file=".env.test")) as mock_settings:
|
||||
with patch(
|
||||
"app.shared.settings.Settings",
|
||||
return_value=Settings(_env_file=".env.test"),
|
||||
) as mock_settings:
|
||||
yield mock_settings
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_db(get_settings: Settings):
|
||||
from app.shared.db import models
|
||||
from app.shared.db.database import make_engine
|
||||
from app.web.db.crud import get_user_group_names
|
||||
|
||||
get_user_group_names.cache_clear()
|
||||
make_engine.cache_clear()
|
||||
engine = make_engine(get_settings.DATABASE_PATH)
|
||||
|
||||
fs = get_settings.DATABASE_PATH.replace("sqlite:///", "")
|
||||
if not os.path.exists(fs):
|
||||
open(fs, 'w').close()
|
||||
open(fs, "w").close()
|
||||
|
||||
models.Base.metadata.create_all(engine)
|
||||
|
||||
@@ -57,7 +76,6 @@ def test_db(get_settings: Settings):
|
||||
|
||||
@pytest.fixture()
|
||||
def db_session(test_db):
|
||||
from app.shared.db.database import make_session_local
|
||||
session_local = make_session_local(test_db)
|
||||
with session_local() as session:
|
||||
yield session
|
||||
@@ -65,17 +83,12 @@ def db_session(test_db):
|
||||
|
||||
@pytest_asyncio.fixture()
|
||||
async def async_test_db(get_settings: Settings):
|
||||
from app.shared.db import models
|
||||
from app.shared.db.database import make_async_engine
|
||||
from app.web.db.crud import get_user_group_names
|
||||
import asyncio
|
||||
|
||||
get_user_group_names.cache_clear()
|
||||
engine = await make_async_engine(get_settings.ASYNC_DATABASE_PATH)
|
||||
engine = await make_async_engine(get_settings.async_database_path)
|
||||
|
||||
fs = get_settings.ASYNC_DATABASE_PATH.replace("sqlite+aiosqlite:///", "")
|
||||
fs = get_settings.async_database_path.replace("sqlite+aiosqlite:///", "")
|
||||
if not os.path.exists(fs):
|
||||
open(fs, 'w').close()
|
||||
open(fs, "w").close()
|
||||
|
||||
async def create_all():
|
||||
async with engine.begin() as conn:
|
||||
@@ -99,8 +112,9 @@ async def async_test_db(get_settings: Settings):
|
||||
|
||||
|
||||
@pytest_asyncio.fixture()
|
||||
async def async_db_session(async_test_db: AsyncEngine) -> AsyncGenerator[AsyncSession, None]:
|
||||
from app.shared.db.database import make_async_session_local
|
||||
async def async_db_session(
|
||||
async_test_db: AsyncEngine,
|
||||
) -> AsyncGenerator[AsyncSession, None]:
|
||||
session_local = await make_async_session_local(async_test_db)
|
||||
async with session_local() as session:
|
||||
yield session
|
||||
@@ -108,8 +122,6 @@ async def async_db_session(async_test_db: AsyncEngine) -> AsyncGenerator[AsyncSe
|
||||
|
||||
@pytest.fixture()
|
||||
def app(db_session):
|
||||
from app.web.main import app_factory
|
||||
from app.web.db import crud
|
||||
app = app_factory()
|
||||
crud.upsert_user_groups(db_session)
|
||||
return app
|
||||
@@ -123,10 +135,13 @@ def client(app):
|
||||
|
||||
@pytest.fixture()
|
||||
def app_with_auth(app, db_session):
|
||||
from app.web.security import get_token_or_user_auth, get_user_auth, get_user_state
|
||||
app.dependency_overrides[get_token_or_user_auth] = lambda: "rick@example.com"
|
||||
app.dependency_overrides[get_token_or_user_auth] = (
|
||||
lambda: "rick@example.com"
|
||||
)
|
||||
app.dependency_overrides[get_user_auth] = lambda: "morty@example.com"
|
||||
app.dependency_overrides[get_user_state] = lambda: UserState(db_session, "MORTY@example.com")
|
||||
app.dependency_overrides[get_user_state] = lambda: UserState(
|
||||
db_session, "MORTY@example.com"
|
||||
)
|
||||
return app
|
||||
|
||||
|
||||
@@ -138,7 +153,6 @@ def client_with_auth(app_with_auth):
|
||||
|
||||
@pytest.fixture()
|
||||
def app_with_token(app):
|
||||
from app.web.security import token_api_key_auth, get_token_or_user_auth
|
||||
app.dependency_overrides[token_api_key_auth] = lambda: ALLOW_ANY_EMAIL
|
||||
app.dependency_overrides[get_token_or_user_auth] = lambda: ALLOW_ANY_EMAIL
|
||||
return app
|
||||
@@ -155,6 +169,93 @@ def test_no_auth():
|
||||
# reusable code to ensure a method/endpoint combination is unauthorized
|
||||
def no_auth(http_method, endpoint):
|
||||
response = http_method(endpoint)
|
||||
assert response.status_code == 403
|
||||
assert response.status_code == HTTPStatus.FORBIDDEN
|
||||
assert response.json() == {"detail": "Not authenticated"}
|
||||
|
||||
return no_auth
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_data(db_session):
|
||||
author_emails = [
|
||||
"rick@example.com",
|
||||
"morty@example.com",
|
||||
"jerry@example.com",
|
||||
]
|
||||
|
||||
# creates 3 users
|
||||
for email in author_emails:
|
||||
db_session.add(models.User(email=email))
|
||||
db_session.commit()
|
||||
assert db_session.query(models.User).count() == 3
|
||||
|
||||
# creates 100 archives for 3 users over 2 months with repeating URLs
|
||||
for i in range(100):
|
||||
author = author_emails[i % 3]
|
||||
archive = models.Archive(
|
||||
id=f"archive-id-456-{i}",
|
||||
url=f"https://example-{i % 3}.com",
|
||||
result={},
|
||||
public=author == "jerry@example.com",
|
||||
author_id=author,
|
||||
group_id="spaceship"
|
||||
if author == "morty@example.com" and i % 2 == 0
|
||||
else None,
|
||||
created_at=datetime(2021, (i % 2) + 1, (i % 25) + 1),
|
||||
)
|
||||
if i % 5 == 0:
|
||||
archive.tags.append(models.Tag(id=f"tag-{i}"))
|
||||
if i % 10 == 0:
|
||||
archive.tags.append(models.Tag(id=f"tag-second-{i}"))
|
||||
if i % 4 == 0:
|
||||
archive.tags.append(models.Tag(id=f"tag-third-{i}"))
|
||||
for j in range(10):
|
||||
archive.urls.append(
|
||||
models.ArchiveUrl(
|
||||
url=f"https://example-{i}.com/{j}", key=f"media_{j}"
|
||||
)
|
||||
)
|
||||
db_session.add(archive)
|
||||
|
||||
# creates a sheet for each user
|
||||
for i, email in enumerate(
|
||||
["rick@example.com", "morty@example.com", "jerry@example.com"]
|
||||
):
|
||||
db_session.add(
|
||||
models.Sheet(
|
||||
id=f"sheet-{i}",
|
||||
name=f"sheet-{i}",
|
||||
author_id=email,
|
||||
group_id=None,
|
||||
frequency="daily",
|
||||
)
|
||||
)
|
||||
if email == "rick@example.com":
|
||||
db_session.add(
|
||||
models.Sheet(
|
||||
id=f"sheet-{i}-2",
|
||||
name=f"sheet-{i}-2",
|
||||
author_id=email,
|
||||
group_id="spaceship",
|
||||
frequency="hourly",
|
||||
)
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
assert db_session.query(models.Archive).count() == 100
|
||||
assert db_session.query(models.Tag).count() == 20 + 10 + 25
|
||||
assert db_session.query(models.ArchiveUrl).count() == 1000
|
||||
assert (
|
||||
db_session.query(models.ArchiveUrl)
|
||||
.filter(models.ArchiveUrl.archive_id == "archive-id-456-0")
|
||||
.count()
|
||||
== 10
|
||||
)
|
||||
|
||||
# setup groups
|
||||
assert db_session.query(models.Group).count() == 0
|
||||
|
||||
crud.upsert_user_groups(db_session)
|
||||
assert db_session.query(models.Group).count() == 4
|
||||
assert db_session.query(models.User).count() == 3
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
{
|
||||
"client_email": "fake_service_account@fake_service_account.iam.gserviceaccount.com"
|
||||
}
|
||||
"client_email": "fake_service_account@fake_service_account.iam.gserviceaccount.com"
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
steps:
|
||||
feeder: cli_feeder
|
||||
feeders:
|
||||
- cli_feeder
|
||||
archivers: # order matters
|
||||
- youtubedl_archiver
|
||||
- generic_extractor
|
||||
enrichers:
|
||||
- hash_enricher
|
||||
|
||||
@@ -12,10 +13,10 @@ steps:
|
||||
- console_db
|
||||
|
||||
configurations:
|
||||
gsheet_feeder:
|
||||
gsheet_feeder_db:
|
||||
service_account: "app/tests/fake_service_account.json"
|
||||
cli_feeder:
|
||||
urls:
|
||||
urls:
|
||||
- "url1"
|
||||
hash_enricher:
|
||||
algorithm: "SHA-256"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
def test_generate_uuid():
|
||||
from app.shared.db.models import generate_uuid
|
||||
from app.shared.db.models import generate_uuid
|
||||
|
||||
assert generate_uuid() != generate_uuid()
|
||||
assert len(generate_uuid()) == 36
|
||||
assert generate_uuid().count("-") == 4
|
||||
|
||||
def test_generate_uuid():
|
||||
assert generate_uuid() != generate_uuid()
|
||||
assert len(generate_uuid()) == 36
|
||||
assert generate_uuid().count("-") == 4
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
from app.shared.db import models
|
||||
from app.shared.db import worker_crud, models
|
||||
from datetime import datetime
|
||||
|
||||
from app.shared import schemas
|
||||
from app.shared.db import models, worker_crud
|
||||
|
||||
from app.tests.web.db.test_crud import test_data
|
||||
|
||||
def test_update_sheet_last_url_archived_at(db_session):
|
||||
|
||||
# Create test sheet
|
||||
test_sheet = models.Sheet(id="sheet-123")
|
||||
db_session.add(test_sheet)
|
||||
@@ -15,17 +13,24 @@ def test_update_sheet_last_url_archived_at(db_session):
|
||||
# Test updating existing sheet
|
||||
assert isinstance(test_sheet.last_url_archived_at, datetime)
|
||||
before = test_sheet.last_url_archived_at
|
||||
assert worker_crud.update_sheet_last_url_archived_at(db_session, "sheet-123") is True
|
||||
assert (
|
||||
worker_crud.update_sheet_last_url_archived_at(db_session, "sheet-123")
|
||||
is True
|
||||
)
|
||||
db_session.refresh(test_sheet)
|
||||
assert isinstance(test_sheet.last_url_archived_at, datetime)
|
||||
assert test_sheet.last_url_archived_at > before
|
||||
|
||||
|
||||
# Test non-existent sheet
|
||||
assert worker_crud.update_sheet_last_url_archived_at(db_session, "non-existent-sheet") is False
|
||||
assert (
|
||||
worker_crud.update_sheet_last_url_archived_at(
|
||||
db_session, "non-existent-sheet"
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_get_group(test_data, db_session):
|
||||
from app.shared.db import worker_crud
|
||||
|
||||
assert worker_crud.get_group(db_session, "spaceship") is not None
|
||||
assert worker_crud.get_group(db_session, "interdimensional") is not None
|
||||
assert worker_crud.get_group(db_session, "animated-characters") is not None
|
||||
@@ -33,24 +38,24 @@ def test_get_group(test_data, db_session):
|
||||
|
||||
|
||||
def test_create_or_get_user(test_data, db_session):
|
||||
from app.shared.db import worker_crud
|
||||
|
||||
assert db_session.query(models.User).count() == 3
|
||||
|
||||
# already exists
|
||||
assert (u1 := worker_crud.create_or_get_user(db_session, "rick@example.com")) is not None
|
||||
assert (
|
||||
u1 := worker_crud.create_or_get_user(db_session, "rick@example.com")
|
||||
) is not None
|
||||
assert u1.email == "rick@example.com"
|
||||
|
||||
# new user
|
||||
assert (u2 := worker_crud.create_or_get_user(db_session, "beth@example.com")) is not None
|
||||
assert (
|
||||
u2 := worker_crud.create_or_get_user(db_session, "beth@example.com")
|
||||
) is not None
|
||||
assert u2.email == "beth@example.com"
|
||||
|
||||
assert db_session.query(models.User).count() == 4
|
||||
|
||||
|
||||
def test_create_tag(db_session):
|
||||
from app.shared.db import worker_crud
|
||||
|
||||
assert db_session.query(models.Tag).count() == 0
|
||||
|
||||
# create first
|
||||
@@ -58,7 +63,10 @@ def test_create_tag(db_session):
|
||||
assert create_tag is not None
|
||||
assert create_tag.id == "tag-101"
|
||||
assert db_session.query(models.Tag).count() == 1
|
||||
assert db_session.query(models.Tag).filter(models.Tag.id == "tag-101").first() == create_tag
|
||||
assert (
|
||||
db_session.query(models.Tag).filter(models.Tag.id == "tag-101").first()
|
||||
== create_tag
|
||||
)
|
||||
|
||||
# same id does not add new db entry
|
||||
existing_tag = worker_crud.create_tag(db_session, "tag-101")
|
||||
@@ -73,9 +81,6 @@ def test_create_tag(db_session):
|
||||
|
||||
|
||||
def test_create_task(db_session):
|
||||
from app.shared.db import worker_crud
|
||||
from app.shared import schemas
|
||||
|
||||
task = schemas.ArchiveCreate(
|
||||
id="archive-id-456-101",
|
||||
url="https://example-0.com",
|
||||
@@ -84,17 +89,22 @@ def test_create_task(db_session):
|
||||
author_id="rick@example.com",
|
||||
group_id="spaceship",
|
||||
tags=[],
|
||||
urls=[]
|
||||
urls=[],
|
||||
)
|
||||
|
||||
# with tags and urls
|
||||
nt = worker_crud.create_archive(db_session, task, [models.Tag(id="tag-101")], [models.ArchiveUrl(url="https://example-0.com/0", key="media_0")])
|
||||
nt = worker_crud.create_archive(
|
||||
db_session,
|
||||
task,
|
||||
[models.Tag(id="tag-101")],
|
||||
[models.ArchiveUrl(url="https://example-0.com/0", key="media_0")],
|
||||
)
|
||||
|
||||
assert nt is not None
|
||||
assert nt.id == "archive-id-456-101"
|
||||
assert nt.url == "https://example-0.com"
|
||||
assert nt.author_id == "rick@example.com"
|
||||
assert nt.public == False
|
||||
assert nt.public is False
|
||||
assert nt.group_id == "spaceship"
|
||||
assert len(nt.tags) == 1
|
||||
assert nt.tags[0].id == "tag-101"
|
||||
@@ -110,8 +120,8 @@ def test_create_task(db_session):
|
||||
assert nt.id == "archive-id-456-102"
|
||||
assert nt.url == "https://example-0.com"
|
||||
assert nt.author_id == "rick@example.com"
|
||||
assert nt.public == False
|
||||
assert nt.public is False
|
||||
assert nt.group_id == "spaceship"
|
||||
assert len(nt.tags) == 0
|
||||
assert len(nt.urls) == 0
|
||||
assert nt.created_at is not None
|
||||
assert nt.created_at is not None
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from app.shared.business_logic import get_store_archive_until, get_store_archive_until_or_never
|
||||
|
||||
from app.shared.business_logic import (
|
||||
get_store_archive_until,
|
||||
get_store_archive_until_or_never,
|
||||
)
|
||||
|
||||
|
||||
class Test_get_store_archive_until:
|
||||
class TestGetStoreArchiveUntil:
|
||||
GROUP_ID = "test-group"
|
||||
|
||||
def test_group_not_found(self, db_session):
|
||||
@@ -12,7 +17,10 @@ class Test_get_store_archive_until:
|
||||
get_store_archive_until(db_session, self.GROUP_ID)
|
||||
assert str(exc.value) == f"Group {self.GROUP_ID} not found."
|
||||
|
||||
@patch("app.shared.db.worker_crud.get_group", return_value=MagicMock(permissions=None))
|
||||
@patch(
|
||||
"app.shared.db.worker_crud.get_group",
|
||||
return_value=MagicMock(permissions=None),
|
||||
)
|
||||
def test_group_no_permissions(self, db_session):
|
||||
with pytest.raises(AssertionError) as exc:
|
||||
get_store_archive_until(db_session, self.GROUP_ID)
|
||||
@@ -43,14 +51,17 @@ class Test_get_store_archive_until:
|
||||
mock_get_group.assert_called_once_with(db_session, self.GROUP_ID)
|
||||
|
||||
|
||||
class Test_get_store_archive_until_or_never:
|
||||
class TestGetStoreArchiveUntilOrNever:
|
||||
GROUP_ID = "test-group"
|
||||
|
||||
def test_group_not_found(self, db_session):
|
||||
result = get_store_archive_until_or_never(db_session, self.GROUP_ID)
|
||||
assert result is None
|
||||
|
||||
@patch("app.shared.db.worker_crud.get_group", return_value=MagicMock(permissions=None))
|
||||
@patch(
|
||||
"app.shared.db.worker_crud.get_group",
|
||||
return_value=MagicMock(permissions=None),
|
||||
)
|
||||
def test_group_no_permissions(self, db_session):
|
||||
result = get_store_archive_until_or_never(db_session, self.GROUP_ID)
|
||||
assert result is None
|
||||
|
||||
@@ -11,7 +11,7 @@ def test_fnv1a_hash_mod():
|
||||
|
||||
# Test different modulos
|
||||
hash1 = fnv1a_hash_mod("test", 5)
|
||||
hash2 = fnv1a_hash_mod("test", 10)
|
||||
hash2 = fnv1a_hash_mod("test", 10)
|
||||
assert 0 <= hash1 < 5
|
||||
assert 0 <= hash2 < 10
|
||||
|
||||
@@ -28,4 +28,4 @@ def test_fnv1a_hash_mod():
|
||||
assert 0 <= fnv1a_hash_mod("测试", 10) < 10
|
||||
|
||||
# Test modulo = 1 edge case
|
||||
assert fnv1a_hash_mod("test", 1) == 0
|
||||
assert fnv1a_hash_mod("test", 1) == 0
|
||||
|
||||
@@ -3,4 +3,4 @@ This is just an invalid yaml for testing
|
||||
|
||||
still broken: True
|
||||
- one
|
||||
- two
|
||||
- two
|
||||
|
||||
@@ -84,4 +84,4 @@ groups:
|
||||
# max_archive_lifespan_months: 12
|
||||
max_monthly_urls: 1
|
||||
# max_monthly_mbs: 50
|
||||
priority: "low"
|
||||
priority: "low"
|
||||
|
||||
@@ -2,123 +2,379 @@ from datetime import datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import sqlalchemy
|
||||
import yaml
|
||||
from sqlalchemy import false, true
|
||||
from sqlalchemy.sql import select
|
||||
|
||||
from app.shared.db import models
|
||||
from app.shared.settings import Settings
|
||||
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
from app.web.db import crud
|
||||
authors = ["rick@example.com", "morty@example.com", "jerry@example.com"]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_data(db_session):
|
||||
|
||||
# creates 3 users
|
||||
for email in authors:
|
||||
db_session.add(models.User(email=email))
|
||||
db_session.commit()
|
||||
assert db_session.query(models.User).count() == 3
|
||||
|
||||
# creates 100 archives for 3 users over 2 months with repeating URLs
|
||||
for i in range(100):
|
||||
author = authors[i % 3]
|
||||
archive = models.Archive(
|
||||
id=f"archive-id-456-{i}",
|
||||
url=f"https://example-{i%3}.com",
|
||||
result={},
|
||||
public=author == "jerry@example.com",
|
||||
author_id=author,
|
||||
group_id="spaceship" if author == "morty@example.com" and i % 2 == 0 else None,
|
||||
created_at=datetime(2021, (i % 2) + 1, (i % 25) + 1)
|
||||
)
|
||||
if i % 5 == 0:
|
||||
archive.tags.append(models.Tag(id=f"tag-{i}"))
|
||||
if i % 10 == 0:
|
||||
archive.tags.append(models.Tag(id=f"tag-second-{i}"))
|
||||
if i % 4 == 0:
|
||||
archive.tags.append(models.Tag(id=f"tag-third-{i}"))
|
||||
for j in range(10):
|
||||
archive.urls.append(models.ArchiveUrl(url=f"https://example-{i}.com/{j}", key=f"media_{j}"))
|
||||
db_session.add(archive)
|
||||
|
||||
# creates a sheet for each user
|
||||
for i, email in enumerate(authors):
|
||||
db_session.add(models.Sheet(id=f"sheet-{i}", name=f"sheet-{i}", author_id=email, group_id=None, frequency="daily"))
|
||||
if email == "rick@example.com":
|
||||
db_session.add(models.Sheet(id=f"sheet-{i}-2", name=f"sheet-{i}-2", author_id=email, group_id="spaceship", frequency="hourly"))
|
||||
|
||||
db_session.commit()
|
||||
|
||||
assert db_session.query(models.Archive).count() == 100
|
||||
assert db_session.query(models.Tag).count() == 20 + 10 + 25
|
||||
assert db_session.query(models.ArchiveUrl).count() == 1000
|
||||
assert db_session.query(models.ArchiveUrl).filter(models.ArchiveUrl.archive_id == "archive-id-456-0").count() == 10
|
||||
|
||||
# setup groups
|
||||
assert db_session.query(models.Group).count() == 0
|
||||
from app.web.db import crud
|
||||
crud.upsert_user_groups(db_session)
|
||||
assert db_session.query(models.Group).count() == 4
|
||||
assert db_session.query(models.User).count() == 3
|
||||
|
||||
|
||||
def test_search_archives_by_url(test_data, db_session):
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
|
||||
# rick's archives are private
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "rick@example.com", True, False)) == 34
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "rick@example.com", [], False)) == 34
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "rick@example.com", [], True)) == 34
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", ALLOW_ANY_EMAIL, [], False)) == 34
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", ALLOW_ANY_EMAIL, True, False)) == 34
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "morty@example.com", [], False)) == 0
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "morty@example.com", [], True)) == 0
|
||||
# Rick's archives are private
|
||||
assert (
|
||||
len(
|
||||
crud.search_archives_by_url(
|
||||
db_session,
|
||||
"https://example-0.com",
|
||||
"rick@example.com",
|
||||
True,
|
||||
False,
|
||||
)
|
||||
)
|
||||
== 34
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
crud.search_archives_by_url(
|
||||
db_session,
|
||||
"https://example-0.com",
|
||||
"rick@example.com",
|
||||
[],
|
||||
False,
|
||||
)
|
||||
)
|
||||
== 34
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
crud.search_archives_by_url(
|
||||
db_session,
|
||||
"https://example-0.com",
|
||||
"rick@example.com",
|
||||
[],
|
||||
True,
|
||||
)
|
||||
)
|
||||
== 34
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
crud.search_archives_by_url(
|
||||
db_session, "https://example-0.com", ALLOW_ANY_EMAIL, [], False
|
||||
)
|
||||
)
|
||||
== 34
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
crud.search_archives_by_url(
|
||||
db_session,
|
||||
"https://example-0.com",
|
||||
ALLOW_ANY_EMAIL,
|
||||
True,
|
||||
False,
|
||||
)
|
||||
)
|
||||
== 34
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
crud.search_archives_by_url(
|
||||
db_session,
|
||||
"https://example-0.com",
|
||||
"morty@example.com",
|
||||
[],
|
||||
False,
|
||||
)
|
||||
)
|
||||
== 0
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
crud.search_archives_by_url(
|
||||
db_session,
|
||||
"https://example-0.com",
|
||||
"morty@example.com",
|
||||
[],
|
||||
True,
|
||||
)
|
||||
)
|
||||
== 0
|
||||
)
|
||||
|
||||
# morty's archives are public but half are in spaceship group
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-1.com", "rick@example.com", ["spaceship"], False)) == 16
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-1.com", "rick@example.com", True, False)) == 16
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-1.com", "jerry@example.com", True, True)) == 16
|
||||
assert (
|
||||
len(
|
||||
crud.search_archives_by_url(
|
||||
db_session,
|
||||
"https://example-1.com",
|
||||
"rick@example.com",
|
||||
["spaceship"],
|
||||
False,
|
||||
)
|
||||
)
|
||||
== 16
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
crud.search_archives_by_url(
|
||||
db_session,
|
||||
"https://example-1.com",
|
||||
"rick@example.com",
|
||||
True,
|
||||
False,
|
||||
)
|
||||
)
|
||||
== 16
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
crud.search_archives_by_url(
|
||||
db_session,
|
||||
"https://example-1.com",
|
||||
"jerry@example.com",
|
||||
True,
|
||||
True,
|
||||
)
|
||||
)
|
||||
== 16
|
||||
)
|
||||
|
||||
# jerry's archives are public
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-2.com", "jerry@example.com", [], True)) == 33
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-2.com", "rick@example.com", [], True)) == 33
|
||||
# Jerry's archives are public
|
||||
assert (
|
||||
len(
|
||||
crud.search_archives_by_url(
|
||||
db_session,
|
||||
"https://example-2.com",
|
||||
"jerry@example.com",
|
||||
[],
|
||||
True,
|
||||
)
|
||||
)
|
||||
== 33
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
crud.search_archives_by_url(
|
||||
db_session,
|
||||
"https://example-2.com",
|
||||
"rick@example.com",
|
||||
[],
|
||||
True,
|
||||
)
|
||||
)
|
||||
== 33
|
||||
)
|
||||
|
||||
# fuzzy search
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False)) == 100
|
||||
assert len(crud.search_archives_by_url(db_session, "https://EXAMPLE", ALLOW_ANY_EMAIL, False, False)) == 100
|
||||
assert len(crud.search_archives_by_url(db_session, "2.com", ALLOW_ANY_EMAIL, False, False)) == 33
|
||||
assert (
|
||||
len(
|
||||
crud.search_archives_by_url(
|
||||
db_session, "https://example", ALLOW_ANY_EMAIL, False, False
|
||||
)
|
||||
)
|
||||
== 100
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
crud.search_archives_by_url(
|
||||
db_session, "https://EXAMPLE", ALLOW_ANY_EMAIL, False, False
|
||||
)
|
||||
)
|
||||
== 100
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
crud.search_archives_by_url(
|
||||
db_session, "2.com", ALLOW_ANY_EMAIL, False, False
|
||||
)
|
||||
)
|
||||
== 33
|
||||
)
|
||||
|
||||
# absolute search
|
||||
assert len(crud.search_archives_by_url(db_session, "example-2.com", ALLOW_ANY_EMAIL, [], False, absolute_search=True)) == 0
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-2.com", ALLOW_ANY_EMAIL, [], False, absolute_search=True)) == 33
|
||||
assert (
|
||||
len(
|
||||
crud.search_archives_by_url(
|
||||
db_session,
|
||||
"example-2.com",
|
||||
ALLOW_ANY_EMAIL,
|
||||
[],
|
||||
False,
|
||||
absolute_search=True,
|
||||
)
|
||||
)
|
||||
== 0
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
crud.search_archives_by_url(
|
||||
db_session,
|
||||
"https://example-2.com",
|
||||
ALLOW_ANY_EMAIL,
|
||||
[],
|
||||
False,
|
||||
absolute_search=True,
|
||||
)
|
||||
)
|
||||
== 33
|
||||
)
|
||||
|
||||
# archived_after
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, True, True, archived_after=datetime(2010, 1, 1))) == 100
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_after=datetime(2021, 1, 15))) == 70
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_after=datetime(2031, 1, 1))) == 0
|
||||
assert (
|
||||
len(
|
||||
crud.search_archives_by_url(
|
||||
db_session,
|
||||
"https://example",
|
||||
ALLOW_ANY_EMAIL,
|
||||
True,
|
||||
True,
|
||||
archived_after=datetime(2010, 1, 1),
|
||||
)
|
||||
)
|
||||
== 100
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
crud.search_archives_by_url(
|
||||
db_session,
|
||||
"https://example",
|
||||
ALLOW_ANY_EMAIL,
|
||||
False,
|
||||
False,
|
||||
archived_after=datetime(2021, 1, 15),
|
||||
)
|
||||
)
|
||||
== 70
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
crud.search_archives_by_url(
|
||||
db_session,
|
||||
"https://example",
|
||||
ALLOW_ANY_EMAIL,
|
||||
False,
|
||||
False,
|
||||
archived_after=datetime(2031, 1, 1),
|
||||
)
|
||||
)
|
||||
== 0
|
||||
)
|
||||
|
||||
# archived before
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_before=datetime(2010, 1, 1))) == 0
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_before=datetime(2021, 1, 15))) == 28
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_before=datetime(2031, 1, 1))) == 100
|
||||
assert (
|
||||
len(
|
||||
crud.search_archives_by_url(
|
||||
db_session,
|
||||
"https://example",
|
||||
ALLOW_ANY_EMAIL,
|
||||
False,
|
||||
False,
|
||||
archived_before=datetime(2010, 1, 1),
|
||||
)
|
||||
)
|
||||
== 0
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
crud.search_archives_by_url(
|
||||
db_session,
|
||||
"https://example",
|
||||
ALLOW_ANY_EMAIL,
|
||||
False,
|
||||
False,
|
||||
archived_before=datetime(2021, 1, 15),
|
||||
)
|
||||
)
|
||||
== 28
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
crud.search_archives_by_url(
|
||||
db_session,
|
||||
"https://example",
|
||||
ALLOW_ANY_EMAIL,
|
||||
False,
|
||||
False,
|
||||
archived_before=datetime(2031, 1, 1),
|
||||
)
|
||||
)
|
||||
== 100
|
||||
)
|
||||
|
||||
# archived before and after
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_after=datetime(2001, 1, 1), archived_before=datetime(2031, 1, 11))) == 100
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, archived_after=datetime(2021, 1, 14), archived_before=datetime(2021, 1, 16))) == 2
|
||||
assert (
|
||||
len(
|
||||
crud.search_archives_by_url(
|
||||
db_session,
|
||||
"https://example",
|
||||
ALLOW_ANY_EMAIL,
|
||||
False,
|
||||
False,
|
||||
archived_after=datetime(2001, 1, 1),
|
||||
archived_before=datetime(2031, 1, 11),
|
||||
)
|
||||
)
|
||||
== 100
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
crud.search_archives_by_url(
|
||||
db_session,
|
||||
"https://example",
|
||||
ALLOW_ANY_EMAIL,
|
||||
False,
|
||||
False,
|
||||
archived_after=datetime(2021, 1, 14),
|
||||
archived_before=datetime(2021, 1, 16),
|
||||
)
|
||||
)
|
||||
== 2
|
||||
)
|
||||
|
||||
# limit
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, limit=10)) == 10
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, limit=-1)) == 1
|
||||
assert (
|
||||
len(
|
||||
crud.search_archives_by_url(
|
||||
db_session,
|
||||
"https://example",
|
||||
ALLOW_ANY_EMAIL,
|
||||
False,
|
||||
False,
|
||||
limit=10,
|
||||
)
|
||||
)
|
||||
== 10
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
crud.search_archives_by_url(
|
||||
db_session,
|
||||
"https://example",
|
||||
ALLOW_ANY_EMAIL,
|
||||
False,
|
||||
False,
|
||||
limit=-1,
|
||||
)
|
||||
)
|
||||
== 1
|
||||
)
|
||||
|
||||
# skip
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, False, False, skip=10)) == 90
|
||||
assert (
|
||||
len(
|
||||
crud.search_archives_by_url(
|
||||
db_session,
|
||||
"https://example",
|
||||
ALLOW_ANY_EMAIL,
|
||||
False,
|
||||
False,
|
||||
skip=10,
|
||||
)
|
||||
)
|
||||
== 90
|
||||
)
|
||||
|
||||
|
||||
def test_search_archives_by_email(test_data, db_session):
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
|
||||
# lower/upper case
|
||||
assert len(crud.search_archives_by_email(db_session, "rick@example.com")) == 34
|
||||
assert (
|
||||
len(crud.search_archives_by_email(db_session, "rick@example.com")) == 34
|
||||
)
|
||||
|
||||
# ALLOW_ANY_EMAIL is not a user
|
||||
assert len(crud.search_archives_by_email(db_session, ALLOW_ANY_EMAIL)) == 0
|
||||
@@ -136,45 +392,108 @@ def test_search_archives_by_email(test_data, db_session):
|
||||
|
||||
@patch("app.web.db.crud.DATABASE_QUERY_LIMIT", new=25)
|
||||
def test_max_query_limit(test_data, db_session):
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
assert (
|
||||
len(
|
||||
crud.search_archives_by_url(
|
||||
db_session, "https://example", ALLOW_ANY_EMAIL, [], False
|
||||
)
|
||||
)
|
||||
== 25
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
crud.search_archives_by_url(
|
||||
db_session,
|
||||
"https://example",
|
||||
ALLOW_ANY_EMAIL,
|
||||
True,
|
||||
True,
|
||||
limit=1000,
|
||||
)
|
||||
)
|
||||
== 25
|
||||
)
|
||||
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, [], False)) == 25
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, True, True, limit=1000)) == 25
|
||||
|
||||
assert len(crud.search_archives_by_email(db_session, "rick@example.com")) == 25
|
||||
assert len(crud.search_archives_by_email(db_session, "rick@example.com", limit=1000)) == 25
|
||||
assert (
|
||||
len(crud.search_archives_by_email(db_session, "rick@example.com")) == 25
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
crud.search_archives_by_email(
|
||||
db_session, "rick@example.com", limit=1000
|
||||
)
|
||||
)
|
||||
== 25
|
||||
)
|
||||
|
||||
|
||||
def test_soft_delete(test_data, db_session):
|
||||
# none deleted yet
|
||||
db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").first() is not None
|
||||
assert db_session.query(models.Archive).filter(models.Archive.deleted == True).count() == 0
|
||||
assert (
|
||||
db_session.query(models.Archive)
|
||||
.filter(models.Archive.id == "archive-id-456-0")
|
||||
.first()
|
||||
is not None
|
||||
)
|
||||
assert (
|
||||
db_session.query(models.Archive)
|
||||
.filter(models.Archive.deleted.is_(true()))
|
||||
.count()
|
||||
== 0
|
||||
)
|
||||
|
||||
# delete
|
||||
assert crud.soft_delete_archive(db_session, "archive-id-456-0", "rick@example.com") == True
|
||||
assert (
|
||||
crud.soft_delete_archive(
|
||||
db_session, "archive-id-456-0", "rick@example.com"
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
# ensure soft delete
|
||||
assert db_session.query(models.Archive).filter(models.Archive.deleted == True).count() == 1
|
||||
db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").first() is None
|
||||
assert (
|
||||
db_session.query(models.Archive)
|
||||
.filter(models.Archive.deleted.is_(true()))
|
||||
.count()
|
||||
== 1
|
||||
)
|
||||
assert (
|
||||
db_session.query(models.Archive)
|
||||
.filter(models.Archive.id == "archive-id-456-0")
|
||||
.filter(models.Archive.deleted.is_(false()))
|
||||
.first()
|
||||
is None
|
||||
)
|
||||
|
||||
# already deleted
|
||||
assert crud.soft_delete_archive(db_session, "archive-id-456-0", "rick@example.com") == False
|
||||
assert (
|
||||
crud.soft_delete_archive(
|
||||
db_session, "archive-id-456-0", "rick@example.com"
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_count_archives(test_data, db_session):
|
||||
assert crud.count_archives(db_session) == 100
|
||||
db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").delete()
|
||||
db_session.query(models.Archive).filter(
|
||||
models.Archive.id == "archive-id-456-0"
|
||||
).delete()
|
||||
db_session.commit()
|
||||
assert crud.count_archives(db_session) == 99
|
||||
|
||||
|
||||
def test_count_archive_urls(test_data, db_session):
|
||||
assert crud.count_archive_urls(db_session) == 1000
|
||||
db_session.query(models.ArchiveUrl).filter(models.ArchiveUrl.url == "https://example-0.com/0").delete()
|
||||
db_session.query(models.ArchiveUrl).filter(
|
||||
models.ArchiveUrl.url == "https://example-0.com/0"
|
||||
).delete()
|
||||
db_session.commit()
|
||||
assert crud.count_archive_urls(db_session) == 999
|
||||
|
||||
db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").delete()
|
||||
db_session.query(models.Archive).filter(
|
||||
models.Archive.id == "archive-id-456-0"
|
||||
).delete()
|
||||
db_session.commit()
|
||||
# no Cascade is enabled
|
||||
assert crud.count_archives(db_session) == 99
|
||||
@@ -183,16 +502,23 @@ def test_count_archive_urls(test_data, db_session):
|
||||
|
||||
def test_count_users(test_data, db_session):
|
||||
assert crud.count_users(db_session) == 3
|
||||
db_session.query(models.User).filter(models.User.email == "rick@example.com").delete()
|
||||
db_session.query(models.User).filter(
|
||||
models.User.email == "rick@example.com"
|
||||
).delete()
|
||||
db_session.commit()
|
||||
assert crud.count_users(db_session) == 2
|
||||
|
||||
|
||||
def test_count_by_users_since(test_data, db_session):
|
||||
from app.web.db import crud
|
||||
|
||||
# 100y window
|
||||
assert len(cu := crud.count_by_user_since(db_session, 60 * 60 * 24 * 31 * 12 * 100)) == 3
|
||||
assert (
|
||||
len(
|
||||
cu := crud.count_by_user_since(
|
||||
db_session, 60 * 60 * 24 * 31 * 12 * 100
|
||||
)
|
||||
)
|
||||
== 3
|
||||
)
|
||||
assert cu[0].total == 34
|
||||
assert cu[1].total == 33
|
||||
assert cu[2].total == 33
|
||||
@@ -201,9 +527,18 @@ def test_count_by_users_since(test_data, db_session):
|
||||
def test_upsert_group(test_data, db_session):
|
||||
assert db_session.query(models.Group).count() == 4
|
||||
|
||||
repeatable_params = ["desc 1", "orch.yaml", "sheet.yaml", "service_account_email@example.com", {"read": ["all"]}, ["example.com"]]
|
||||
repeatable_params = [
|
||||
"desc 1",
|
||||
"orch.yaml",
|
||||
"sheet.yaml",
|
||||
"service_account_email@example.com",
|
||||
{"read": ["all"]},
|
||||
["example.com"],
|
||||
]
|
||||
|
||||
assert (g1 := crud.upsert_group(db_session, "spaceship", *repeatable_params)) is not None
|
||||
assert (
|
||||
g1 := crud.upsert_group(db_session, "spaceship", *repeatable_params)
|
||||
) is not None
|
||||
assert g1.id == "spaceship"
|
||||
assert g1.description == "desc 1"
|
||||
assert g1.orchestrator == "orch.yaml"
|
||||
@@ -212,14 +547,25 @@ def test_upsert_group(test_data, db_session):
|
||||
assert g1.permissions == {"read": ["all"]}
|
||||
assert g1.domains == ["example.com"]
|
||||
assert len(g1.users) == 2
|
||||
assert [u.email for u in g1.users] == ["rick@example.com", "morty@example.com"]
|
||||
assert [u.email for u in g1.users] == [
|
||||
"rick@example.com",
|
||||
"morty@example.com",
|
||||
]
|
||||
|
||||
assert (g2 := crud.upsert_group(db_session, "interdimensional", *repeatable_params)) is not None
|
||||
assert (
|
||||
g2 := crud.upsert_group(
|
||||
db_session, "interdimensional", *repeatable_params
|
||||
)
|
||||
) is not None
|
||||
assert g2.id == "interdimensional"
|
||||
assert len(g2.users) == 1
|
||||
assert [u.email for u in g2.users] == ["rick@example.com"]
|
||||
|
||||
assert (g3 := crud.upsert_group(db_session, "this-is-a-new-group", *repeatable_params)) is not None
|
||||
assert (
|
||||
g3 := crud.upsert_group(
|
||||
db_session, "this-is-a-new-group", *repeatable_params
|
||||
)
|
||||
) is not None
|
||||
assert g3.id == "this-is-a-new-group"
|
||||
assert len(g3.users) == 0
|
||||
|
||||
@@ -227,29 +573,38 @@ def test_upsert_group(test_data, db_session):
|
||||
|
||||
|
||||
def test_upsert_user_groups(db_session):
|
||||
@patch('app.web.db.crud.get_settings', new=lambda: bad_setings)
|
||||
@patch("app.web.db.crud.get_settings", new=lambda: bad_settings)
|
||||
def test_missing_yaml(db_session):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
crud.upsert_user_groups(db_session)
|
||||
|
||||
@patch('app.web.db.crud.get_settings', new=lambda: bad_setings)
|
||||
@patch("app.web.db.crud.get_settings", new=lambda: bad_settings)
|
||||
def test_broken_yaml(db_session):
|
||||
with pytest.raises(yaml.YAMLError):
|
||||
crud.upsert_user_groups(db_session)
|
||||
|
||||
bad_setings = Settings(_env_file=".env.test")
|
||||
bad_settings = Settings(_env_file=".env.test")
|
||||
|
||||
bad_setings.USER_GROUPS_FILENAME = "app/tests/user-groups.test.missing.yaml"
|
||||
bad_settings.USER_GROUPS_FILENAME = (
|
||||
"app/tests/user-groups.test.missing.yaml"
|
||||
)
|
||||
test_missing_yaml(db_session)
|
||||
|
||||
bad_setings.USER_GROUPS_FILENAME = "app/tests/user-groups.test.broken.yaml"
|
||||
bad_settings.USER_GROUPS_FILENAME = "app/tests/user-groups.test.broken.yaml"
|
||||
test_broken_yaml(db_session)
|
||||
|
||||
|
||||
def test_create_sheet(db_session):
|
||||
assert db_session.query(models.Sheet).count() == 0
|
||||
|
||||
s = crud.create_sheet(db_session, "sheet-id-123", "sheet name", "email@example.com", "group-id", "hourly")
|
||||
s = crud.create_sheet(
|
||||
db_session,
|
||||
"sheet-id-123",
|
||||
"sheet name",
|
||||
"email@example.com",
|
||||
"group-id",
|
||||
"hourly",
|
||||
)
|
||||
assert s is not None
|
||||
assert s.id == "sheet-id-123"
|
||||
assert s.name == "sheet name"
|
||||
@@ -259,19 +614,35 @@ def test_create_sheet(db_session):
|
||||
|
||||
assert db_session.query(models.Sheet).count() == 1
|
||||
|
||||
# duplicate id
|
||||
import sqlalchemy
|
||||
with pytest.raises(sqlalchemy.exc.IntegrityError):
|
||||
crud.create_sheet(db_session, "sheet-id-123", "I thought this was another sheet", "email", "group-id", "hourly")
|
||||
crud.create_sheet(
|
||||
db_session,
|
||||
"sheet-id-123",
|
||||
"I thought this was another sheet",
|
||||
"email",
|
||||
"group-id",
|
||||
"hourly",
|
||||
)
|
||||
|
||||
|
||||
def test_get_user_sheet(test_data, db_session):
|
||||
assert crud.get_user_sheet(db_session, "", "sheet-0") is None
|
||||
assert crud.get_user_sheet(db_session, "morty@example.com", "sheet-0") is None
|
||||
assert (
|
||||
crud.get_user_sheet(db_session, "morty@example.com", "sheet-0") is None
|
||||
)
|
||||
|
||||
assert crud.get_user_sheet(db_session, "rick@example.com", "sheet-0") is not None
|
||||
assert crud.get_user_sheet(db_session, "rick@example.com", "sheet-0-2") is not None
|
||||
assert crud.get_user_sheet(db_session, "morty@example.com", "sheet-1") is not None
|
||||
assert (
|
||||
crud.get_user_sheet(db_session, "rick@example.com", "sheet-0")
|
||||
is not None
|
||||
)
|
||||
assert (
|
||||
crud.get_user_sheet(db_session, "rick@example.com", "sheet-0-2")
|
||||
is not None
|
||||
)
|
||||
assert (
|
||||
crud.get_user_sheet(db_session, "morty@example.com", "sheet-1")
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
def test_get_user_sheets(test_data, db_session):
|
||||
@@ -283,9 +654,9 @@ def test_get_user_sheets(test_data, db_session):
|
||||
|
||||
|
||||
def test_delete_sheet(test_data, db_session):
|
||||
assert crud.delete_sheet(db_session, "sheet-0", "") == False
|
||||
assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") == True
|
||||
assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") == False
|
||||
assert crud.delete_sheet(db_session, "sheet-0", "") is False
|
||||
assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") is True
|
||||
assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -297,21 +668,21 @@ async def test_find_by_store_until(async_db_session):
|
||||
url="https://example-expired-1.com",
|
||||
result={},
|
||||
author_id="rick@example.com",
|
||||
store_until=now - timedelta(days=1)
|
||||
store_until=now - timedelta(days=1),
|
||||
)
|
||||
archive2 = models.Archive(
|
||||
id="archive-expired-2",
|
||||
url="https://example-expired-2.com",
|
||||
result={},
|
||||
author_id="rick@example.com",
|
||||
store_until=now - timedelta(hours=1)
|
||||
store_until=now - timedelta(hours=1),
|
||||
)
|
||||
archive3 = models.Archive(
|
||||
id="archive-active",
|
||||
url="https://example-active.com",
|
||||
result={},
|
||||
author_id="rick@example.com",
|
||||
store_until=now + timedelta(days=1)
|
||||
store_until=now + timedelta(days=1),
|
||||
)
|
||||
async_db_session.add_all([archive1, archive2, archive3])
|
||||
await async_db_session.commit()
|
||||
@@ -321,11 +692,15 @@ async def test_find_by_store_until(async_db_session):
|
||||
assert len(list(expired)) == 2
|
||||
|
||||
# Should find 1 archive expired before 2 hours ago
|
||||
expired = await crud.find_by_store_until(async_db_session, now - timedelta(hours=2))
|
||||
expired = await crud.find_by_store_until(
|
||||
async_db_session, now - timedelta(hours=2)
|
||||
)
|
||||
assert len(list(expired)) == 1
|
||||
|
||||
# Should find no archives expired before 2 days ago
|
||||
expired = await crud.find_by_store_until(async_db_session, now - timedelta(days=2))
|
||||
expired = await crud.find_by_store_until(
|
||||
async_db_session, now - timedelta(days=2)
|
||||
)
|
||||
assert len(list(expired)) == 0
|
||||
|
||||
# Should not find deleted archives
|
||||
@@ -337,44 +712,82 @@ async def test_find_by_store_until(async_db_session):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_sheets_by_id_hash(async_db_session):
|
||||
author_emails = [
|
||||
"rick@example.com",
|
||||
"morty@example.com",
|
||||
"jerry@example.com",
|
||||
]
|
||||
|
||||
# Add test data
|
||||
authors = ["rick@example.com", "morty@example.com", "jerry@example.com"]
|
||||
sheets = [
|
||||
models.Sheet(id="sheet-0", name="sheet-0", author_id=authors[0], group_id=None, frequency="daily"),
|
||||
models.Sheet(id="sheet-0-2", name="sheet-0-2", author_id=authors[0], group_id="spaceship", frequency="hourly"),
|
||||
models.Sheet(id="sheet-1", name="sheet-1", author_id=authors[1], group_id=None, frequency="daily"),
|
||||
models.Sheet(id="sheet-2", name="sheet-2", author_id=authors[2], group_id=None, frequency="daily")
|
||||
models.Sheet(
|
||||
id="sheet-0",
|
||||
name="sheet-0",
|
||||
author_id=author_emails[0],
|
||||
group_id=None,
|
||||
frequency="daily",
|
||||
),
|
||||
models.Sheet(
|
||||
id="sheet-0-2",
|
||||
name="sheet-0-2",
|
||||
author_id=author_emails[0],
|
||||
group_id="spaceship",
|
||||
frequency="hourly",
|
||||
),
|
||||
models.Sheet(
|
||||
id="sheet-1",
|
||||
name="sheet-1",
|
||||
author_id=author_emails[1],
|
||||
group_id=None,
|
||||
frequency="daily",
|
||||
),
|
||||
models.Sheet(
|
||||
id="sheet-2",
|
||||
name="sheet-2",
|
||||
author_id=author_emails[2],
|
||||
group_id=None,
|
||||
frequency="daily",
|
||||
),
|
||||
]
|
||||
async_db_session.add_all(sheets)
|
||||
await async_db_session.commit()
|
||||
|
||||
with patch("app.web.db.crud.fnv1a_hash_mod", return_value=1):
|
||||
# Test retrieving hourly sheets
|
||||
hourly_sheets = await crud.get_sheets_by_id_hash(async_db_session, "hourly", 4, 1)
|
||||
hourly_sheets = await crud.get_sheets_by_id_hash(
|
||||
async_db_session, "hourly", 4, 1
|
||||
)
|
||||
assert len(hourly_sheets) == 1
|
||||
assert hourly_sheets[0].id == "sheet-0-2"
|
||||
assert hourly_sheets[0].frequency == "hourly"
|
||||
|
||||
# Test retrieving daily sheets
|
||||
daily_sheets = await crud.get_sheets_by_id_hash(async_db_session, "daily", 4, 1)
|
||||
daily_sheets = await crud.get_sheets_by_id_hash(
|
||||
async_db_session, "daily", 4, 1
|
||||
)
|
||||
assert len(daily_sheets) == 3
|
||||
assert all(sheet.frequency == "daily" for sheet in daily_sheets)
|
||||
assert {sheet.id for sheet in daily_sheets} == {"sheet-0", "sheet-1", "sheet-2"}
|
||||
assert {sheet.id for sheet in daily_sheets} == {
|
||||
"sheet-0",
|
||||
"sheet-1",
|
||||
"sheet-2",
|
||||
}
|
||||
|
||||
# Test with non-matching hash
|
||||
no_sheets = await crud.get_sheets_by_id_hash(async_db_session, "daily", 4, 3)
|
||||
no_sheets = await crud.get_sheets_by_id_hash(
|
||||
async_db_session, "daily", 4, 3
|
||||
)
|
||||
assert len(no_sheets) == 0
|
||||
|
||||
# Test with non-existent frequency
|
||||
weekly_sheets = await crud.get_sheets_by_id_hash(async_db_session, "weekly", 4, 1)
|
||||
weekly_sheets = await crud.get_sheets_by_id_hash(
|
||||
async_db_session, "weekly", 4, 1
|
||||
)
|
||||
assert len(weekly_sheets) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_stale_sheets(async_db_session):
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.sql import select
|
||||
|
||||
now = datetime.now()
|
||||
active_date = now - timedelta(days=5)
|
||||
stale_date = now - timedelta(days=15)
|
||||
@@ -386,29 +799,29 @@ async def test_delete_stale_sheets(async_db_session):
|
||||
name="Active Sheet 1",
|
||||
author_id="rick@example.com",
|
||||
frequency="daily",
|
||||
last_url_archived_at=active_date
|
||||
last_url_archived_at=active_date,
|
||||
),
|
||||
models.Sheet(
|
||||
id="sheet-active-2",
|
||||
name="Active Sheet 2",
|
||||
author_id="morty@example.com",
|
||||
frequency="hourly",
|
||||
last_url_archived_at=active_date
|
||||
last_url_archived_at=active_date,
|
||||
),
|
||||
models.Sheet(
|
||||
id="sheet-stale-1",
|
||||
name="Stale Sheet 1",
|
||||
author_id="rick@example.com",
|
||||
frequency="daily",
|
||||
last_url_archived_at=stale_date
|
||||
last_url_archived_at=stale_date,
|
||||
),
|
||||
models.Sheet(
|
||||
id="sheet-stale-2",
|
||||
name="Stale Sheet 2",
|
||||
author_id="morty@example.com",
|
||||
frequency="daily",
|
||||
last_url_archived_at=stale_date
|
||||
)
|
||||
last_url_archived_at=stale_date,
|
||||
),
|
||||
]
|
||||
async_db_session.add_all(sheets)
|
||||
await async_db_session.commit()
|
||||
@@ -435,4 +848,4 @@ async def test_delete_stale_sheets(async_db_session):
|
||||
|
||||
# Running again should not delete anything
|
||||
deleted = await crud.delete_stale_sheets(async_db_session, 7)
|
||||
assert len(deleted) == 0
|
||||
assert len(deleted) == 0
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.shared.db import models
|
||||
from app.shared.user_groups import GroupInfo, GroupPermissions
|
||||
from app.web.db.user_state import UserState
|
||||
from app.web.utils.misc import convert_priority_to_queue_dict
|
||||
|
||||
|
||||
def fresh_user_state():
|
||||
@@ -20,39 +21,73 @@ def user_state():
|
||||
def user_state_with_groups(user_state):
|
||||
user_groups = [
|
||||
models.Group(id="no-permissions", permissions={}),
|
||||
models.Group(id="group1", description="this is g1", service_account_email="sa1@example.com", permissions={"read": ["group1", "no-permissions"], "read_public": True, "archive_url": True, "archive_sheet": True, "max_archive_lifespan_months": 24, "max_monthly_urls": 100, "max_monthly_mbs": 1000, "priority": "high"}),
|
||||
models.Group(id="group2", description="this is g2", service_account_email="sa2@example.com", permissions={"read": ["all"], "read_public": True, "archive_url": False, "archive_sheet": False, "max_archive_lifespan_months": -1, "max_monthly_urls": -1, "max_monthly_mbs": -1, "priority": "low", "sheet_frequency": {"daily"}}),
|
||||
models.Group(
|
||||
id="group1",
|
||||
description="this is g1",
|
||||
service_account_email="sa1@example.com",
|
||||
permissions={
|
||||
"read": ["group1", "no-permissions"],
|
||||
"read_public": True,
|
||||
"archive_url": True,
|
||||
"archive_sheet": True,
|
||||
"max_archive_lifespan_months": 24,
|
||||
"max_monthly_urls": 100,
|
||||
"max_monthly_mbs": 1000,
|
||||
"priority": "high",
|
||||
},
|
||||
),
|
||||
models.Group(
|
||||
id="group2",
|
||||
description="this is g2",
|
||||
service_account_email="sa2@example.com",
|
||||
permissions={
|
||||
"read": ["all"],
|
||||
"read_public": True,
|
||||
"archive_url": False,
|
||||
"archive_sheet": False,
|
||||
"max_archive_lifespan_months": -1,
|
||||
"max_monthly_urls": -1,
|
||||
"max_monthly_mbs": -1,
|
||||
"priority": "low",
|
||||
"sheet_frequency": {"daily"},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=user_groups):
|
||||
with patch.object(
|
||||
UserState,
|
||||
"user_groups",
|
||||
new_callable=PropertyMock,
|
||||
return_value=user_groups,
|
||||
):
|
||||
yield user_state
|
||||
|
||||
|
||||
def test_permissions(user_state_with_groups):
|
||||
permissions = user_state_with_groups.permissions
|
||||
|
||||
assert permissions["all"].read == True
|
||||
assert permissions["all"].read_public == True
|
||||
assert permissions["all"].archive_url == True
|
||||
assert permissions["all"].archive_sheet == True
|
||||
assert permissions["all"].read is True
|
||||
assert permissions["all"].read_public is True
|
||||
assert permissions["all"].archive_url is True
|
||||
assert permissions["all"].archive_sheet is True
|
||||
assert permissions["all"].max_archive_lifespan_months == -1
|
||||
assert permissions["all"].max_monthly_urls == -1
|
||||
assert permissions["all"].max_monthly_mbs == -1
|
||||
assert permissions["all"].priority == "high"
|
||||
|
||||
assert permissions["group1"].read == set(["group1", "no-permissions"])
|
||||
assert permissions["group1"].read_public == True
|
||||
assert permissions["group1"].archive_url == True
|
||||
assert permissions["group1"].archive_sheet == True
|
||||
assert permissions["group1"].read == {"group1", "no-permissions"}
|
||||
assert permissions["group1"].read_public is True
|
||||
assert permissions["group1"].archive_url is True
|
||||
assert permissions["group1"].archive_sheet is True
|
||||
assert permissions["group1"].max_archive_lifespan_months == 24
|
||||
assert permissions["group1"].max_monthly_urls == 100
|
||||
assert permissions["group1"].max_monthly_mbs == 1000
|
||||
assert permissions["group1"].priority == "high"
|
||||
|
||||
assert permissions["group2"].read == set(["all"])
|
||||
assert permissions["group2"].read_public == True
|
||||
assert permissions["group2"].archive_url == False
|
||||
assert permissions["group2"].archive_sheet == False
|
||||
assert permissions["group2"].read == {"all"}
|
||||
assert permissions["group2"].read_public is True
|
||||
assert permissions["group2"].archive_url is False
|
||||
assert permissions["group2"].archive_sheet is False
|
||||
assert permissions["group2"].max_archive_lifespan_months == -1
|
||||
assert permissions["group2"].max_monthly_urls == -1
|
||||
assert permissions["group2"].max_monthly_mbs == -1
|
||||
@@ -62,13 +97,19 @@ def test_permissions(user_state_with_groups):
|
||||
|
||||
|
||||
def test_user_groups_names(user_state):
|
||||
with patch('app.web.db.crud.get_user_group_names', return_value=["group1", "group2"]) as mock:
|
||||
with patch(
|
||||
"app.web.db.crud.get_user_group_names",
|
||||
return_value=["group1", "group2"],
|
||||
) as mock:
|
||||
assert user_state.user_groups_names == ["group1", "group2", "default"]
|
||||
mock.assert_called_once_with(None, "test@example.com")
|
||||
|
||||
|
||||
def test_user_groups(user_state):
|
||||
with patch('app.web.db.crud.get_user_groups_by_name', return_value=[MagicMock(), MagicMock()]) as mock:
|
||||
with patch(
|
||||
"app.web.db.crud.get_user_groups_by_name",
|
||||
return_value=[MagicMock(), MagicMock()],
|
||||
) as mock:
|
||||
user_state._user_groups_names = ["group1", "group2"]
|
||||
assert len(user_state.user_groups) == 2
|
||||
mock.assert_called_once_with(None, ["group1", "group2"])
|
||||
@@ -77,85 +118,166 @@ def test_user_groups(user_state):
|
||||
def test_read():
|
||||
us = fresh_user_state()
|
||||
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
|
||||
with patch.object(
|
||||
UserState,
|
||||
"user_groups",
|
||||
new_callable=PropertyMock,
|
||||
return_value=[models.Group(id="no-permissions", permissions={})],
|
||||
) as mock:
|
||||
assert not hasattr(us, "_read")
|
||||
assert us.read == set()
|
||||
assert us._read == set()
|
||||
mock.assert_called_once()
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"read": ["group1", "no-permissions"]})]):
|
||||
assert us.read == set(["group1", "no-permissions"])
|
||||
with patch.object(
|
||||
UserState,
|
||||
"user_groups",
|
||||
new_callable=PropertyMock,
|
||||
return_value=[
|
||||
models.Group(
|
||||
id="group1", permissions={"read": ["group1", "no-permissions"]}
|
||||
)
|
||||
],
|
||||
):
|
||||
assert us.read == {"group1", "no-permissions"}
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"read": ["all"]})]):
|
||||
assert us.read == True
|
||||
with patch.object(
|
||||
UserState,
|
||||
"user_groups",
|
||||
new_callable=PropertyMock,
|
||||
return_value=[models.Group(id="group1", permissions={"read": ["all"]})],
|
||||
):
|
||||
assert us.read is True
|
||||
|
||||
|
||||
def test_read_public():
|
||||
us = fresh_user_state()
|
||||
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
|
||||
with patch.object(
|
||||
UserState,
|
||||
"user_groups",
|
||||
new_callable=PropertyMock,
|
||||
return_value=[models.Group(id="no-permissions", permissions={})],
|
||||
) as mock:
|
||||
assert not hasattr(us, "_read_public")
|
||||
assert us.read_public == False
|
||||
assert us._read_public == False
|
||||
assert us.read_public is False
|
||||
assert us._read_public is False
|
||||
mock.assert_called_once()
|
||||
# no new calls
|
||||
assert us.read_public == False
|
||||
assert us.read_public is False
|
||||
mock.assert_called_once()
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"read_public": True})]):
|
||||
assert us.read_public == True
|
||||
with patch.object(
|
||||
UserState,
|
||||
"user_groups",
|
||||
new_callable=PropertyMock,
|
||||
return_value=[
|
||||
models.Group(id="group1", permissions={"read_public": True})
|
||||
],
|
||||
):
|
||||
assert us.read_public is True
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"read_public": False})]):
|
||||
assert us.read_public == False
|
||||
with patch.object(
|
||||
UserState,
|
||||
"user_groups",
|
||||
new_callable=PropertyMock,
|
||||
return_value=[
|
||||
models.Group(id="group1", permissions={"read_public": False})
|
||||
],
|
||||
):
|
||||
assert us.read_public is False
|
||||
|
||||
|
||||
def test_archive_url():
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
|
||||
with patch.object(
|
||||
UserState,
|
||||
"user_groups",
|
||||
new_callable=PropertyMock,
|
||||
return_value=[models.Group(id="no-permissions", permissions={})],
|
||||
) as mock:
|
||||
assert not hasattr(us, "_archive_url")
|
||||
assert us.archive_url == False
|
||||
assert us._archive_url == False
|
||||
assert us.archive_url is False
|
||||
assert us._archive_url is False
|
||||
mock.assert_called_once()
|
||||
# no new calls
|
||||
assert us.archive_url == False
|
||||
assert us.archive_url is False
|
||||
mock.assert_called_once()
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"archive_url": False})]):
|
||||
assert us.archive_url == False
|
||||
with patch.object(
|
||||
UserState,
|
||||
"user_groups",
|
||||
new_callable=PropertyMock,
|
||||
return_value=[
|
||||
models.Group(id="group1", permissions={"archive_url": False})
|
||||
],
|
||||
):
|
||||
assert us.archive_url is False
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"archive_url": True})]):
|
||||
assert us.archive_url == True
|
||||
with patch.object(
|
||||
UserState,
|
||||
"user_groups",
|
||||
new_callable=PropertyMock,
|
||||
return_value=[
|
||||
models.Group(id="group1", permissions={"archive_url": True})
|
||||
],
|
||||
):
|
||||
assert us.archive_url is True
|
||||
|
||||
|
||||
def test_archive_sheet():
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
|
||||
with patch.object(
|
||||
UserState,
|
||||
"user_groups",
|
||||
new_callable=PropertyMock,
|
||||
return_value=[models.Group(id="no-permissions", permissions={})],
|
||||
) as mock:
|
||||
assert not hasattr(us, "_archive_sheet")
|
||||
assert us.archive_sheet == False
|
||||
assert us._archive_sheet == False
|
||||
assert us.archive_sheet is False
|
||||
assert us._archive_sheet is False
|
||||
mock.assert_called_once()
|
||||
# no new calls
|
||||
assert us.archive_sheet == False
|
||||
assert us.archive_sheet is False
|
||||
mock.assert_called_once()
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"archive_sheet": False})]):
|
||||
assert us.archive_sheet == False
|
||||
with patch.object(
|
||||
UserState,
|
||||
"user_groups",
|
||||
new_callable=PropertyMock,
|
||||
return_value=[
|
||||
models.Group(id="group1", permissions={"archive_sheet": False})
|
||||
],
|
||||
):
|
||||
assert us.archive_sheet is False
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"archive_sheet": True})]):
|
||||
assert us.archive_sheet == True
|
||||
with patch.object(
|
||||
UserState,
|
||||
"user_groups",
|
||||
new_callable=PropertyMock,
|
||||
return_value=[
|
||||
models.Group(id="group1", permissions={"archive_sheet": True})
|
||||
],
|
||||
):
|
||||
assert us.archive_sheet is True
|
||||
|
||||
|
||||
def test_sheet_frequency():
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
|
||||
with patch.object(
|
||||
UserState,
|
||||
"user_groups",
|
||||
new_callable=PropertyMock,
|
||||
return_value=[models.Group(id="no-permissions", permissions={})],
|
||||
) as mock:
|
||||
assert not hasattr(us, "_sheet_frequency")
|
||||
assert us.sheet_frequency == set()
|
||||
assert us._sheet_frequency == set()
|
||||
@@ -165,18 +287,42 @@ def test_sheet_frequency():
|
||||
mock.assert_called_once()
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"sheet_frequency": ["daily", "hourly"]})]):
|
||||
with patch.object(
|
||||
UserState,
|
||||
"user_groups",
|
||||
new_callable=PropertyMock,
|
||||
return_value=[
|
||||
models.Group(
|
||||
id="group1",
|
||||
permissions={"sheet_frequency": ["daily", "hourly"]},
|
||||
)
|
||||
],
|
||||
):
|
||||
assert us.sheet_frequency == {"daily", "hourly"}
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"sheet_frequency": []})]):
|
||||
with patch.object(
|
||||
UserState,
|
||||
"user_groups",
|
||||
new_callable=PropertyMock,
|
||||
return_value=[
|
||||
models.Group(id="group1", permissions={"sheet_frequency": []})
|
||||
],
|
||||
):
|
||||
assert us.sheet_frequency == set()
|
||||
|
||||
|
||||
def test_max_archive_lifespan_months():
|
||||
us = fresh_user_state()
|
||||
default = GroupPermissions.model_fields["max_archive_lifespan_months"].default
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
|
||||
default = GroupPermissions.model_fields[
|
||||
"max_archive_lifespan_months"
|
||||
].default
|
||||
with patch.object(
|
||||
UserState,
|
||||
"user_groups",
|
||||
new_callable=PropertyMock,
|
||||
return_value=[models.Group(id="no-permissions", permissions={})],
|
||||
) as mock:
|
||||
assert not hasattr(us, "_max_archive_lifespan_months")
|
||||
assert us.max_archive_lifespan_months == default
|
||||
assert us._max_archive_lifespan_months == default
|
||||
@@ -186,18 +332,44 @@ def test_max_archive_lifespan_months():
|
||||
mock.assert_called_once()
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_archive_lifespan_months": 24})]):
|
||||
with patch.object(
|
||||
UserState,
|
||||
"user_groups",
|
||||
new_callable=PropertyMock,
|
||||
return_value=[
|
||||
models.Group(
|
||||
id="group1", permissions={"max_archive_lifespan_months": 24}
|
||||
)
|
||||
],
|
||||
):
|
||||
assert us.max_archive_lifespan_months == 24
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_archive_lifespan_months": 150}), models.Group(id="group2", permissions={"max_archive_lifespan_months": -1})]):
|
||||
with patch.object(
|
||||
UserState,
|
||||
"user_groups",
|
||||
new_callable=PropertyMock,
|
||||
return_value=[
|
||||
models.Group(
|
||||
id="group1", permissions={"max_archive_lifespan_months": 150}
|
||||
),
|
||||
models.Group(
|
||||
id="group2", permissions={"max_archive_lifespan_months": -1}
|
||||
),
|
||||
],
|
||||
):
|
||||
assert us.max_archive_lifespan_months == -1
|
||||
|
||||
|
||||
def test_max_monthly_urls():
|
||||
us = fresh_user_state()
|
||||
default = GroupPermissions.model_fields["max_monthly_urls"].default
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
|
||||
with patch.object(
|
||||
UserState,
|
||||
"user_groups",
|
||||
new_callable=PropertyMock,
|
||||
return_value=[models.Group(id="no-permissions", permissions={})],
|
||||
) as mock:
|
||||
assert not hasattr(us, "_max_monthly_urls")
|
||||
assert us.max_monthly_urls == default
|
||||
assert us._max_monthly_urls == default
|
||||
@@ -207,18 +379,38 @@ def test_max_monthly_urls():
|
||||
mock.assert_called_once()
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_monthly_urls": 100})]):
|
||||
with patch.object(
|
||||
UserState,
|
||||
"user_groups",
|
||||
new_callable=PropertyMock,
|
||||
return_value=[
|
||||
models.Group(id="group1", permissions={"max_monthly_urls": 100})
|
||||
],
|
||||
):
|
||||
assert us.max_monthly_urls == 100
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_monthly_urls": 150}), models.Group(id="group2", permissions={"max_monthly_urls": -1})]):
|
||||
with patch.object(
|
||||
UserState,
|
||||
"user_groups",
|
||||
new_callable=PropertyMock,
|
||||
return_value=[
|
||||
models.Group(id="group1", permissions={"max_monthly_urls": 150}),
|
||||
models.Group(id="group2", permissions={"max_monthly_urls": -1}),
|
||||
],
|
||||
):
|
||||
assert us.max_monthly_urls == -1
|
||||
|
||||
|
||||
def test_max_monthly_mbs():
|
||||
us = fresh_user_state()
|
||||
default = GroupPermissions.model_fields["max_monthly_mbs"].default
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
|
||||
with patch.object(
|
||||
UserState,
|
||||
"user_groups",
|
||||
new_callable=PropertyMock,
|
||||
return_value=[models.Group(id="no-permissions", permissions={})],
|
||||
) as mock:
|
||||
assert not hasattr(us, "_max_monthly_mbs")
|
||||
assert us.max_monthly_mbs == default
|
||||
assert us._max_monthly_mbs == default
|
||||
@@ -228,17 +420,37 @@ def test_max_monthly_mbs():
|
||||
mock.assert_called_once()
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_monthly_mbs": 1000})]):
|
||||
with patch.object(
|
||||
UserState,
|
||||
"user_groups",
|
||||
new_callable=PropertyMock,
|
||||
return_value=[
|
||||
models.Group(id="group1", permissions={"max_monthly_mbs": 1000})
|
||||
],
|
||||
):
|
||||
assert us.max_monthly_mbs == 1000
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"max_monthly_mbs": 1500}), models.Group(id="group2", permissions={"max_monthly_mbs": -1})]):
|
||||
with patch.object(
|
||||
UserState,
|
||||
"user_groups",
|
||||
new_callable=PropertyMock,
|
||||
return_value=[
|
||||
models.Group(id="group1", permissions={"max_monthly_mbs": 1500}),
|
||||
models.Group(id="group2", permissions={"max_monthly_mbs": -1}),
|
||||
],
|
||||
):
|
||||
assert us.max_monthly_mbs == -1
|
||||
|
||||
|
||||
def test_priority(user_state):
|
||||
default = GroupPermissions.model_fields["priority"].default
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="no-permissions", permissions={})]) as mock:
|
||||
with patch.object(
|
||||
UserState,
|
||||
"user_groups",
|
||||
new_callable=PropertyMock,
|
||||
return_value=[models.Group(id="no-permissions", permissions={})],
|
||||
) as mock:
|
||||
assert not hasattr(user_state, "_priority")
|
||||
assert user_state.priority == default
|
||||
assert user_state._priority == default
|
||||
@@ -248,11 +460,26 @@ def test_priority(user_state):
|
||||
mock.assert_called_once()
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"priority": "high"})]):
|
||||
with patch.object(
|
||||
UserState,
|
||||
"user_groups",
|
||||
new_callable=PropertyMock,
|
||||
return_value=[
|
||||
models.Group(id="group1", permissions={"priority": "high"})
|
||||
],
|
||||
):
|
||||
assert us.priority == "high"
|
||||
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[models.Group(id="group1", permissions={"priority": "low"}), models.Group(id="group2", permissions={"priority": "medium"})]):
|
||||
with patch.object(
|
||||
UserState,
|
||||
"user_groups",
|
||||
new_callable=PropertyMock,
|
||||
return_value=[
|
||||
models.Group(id="group1", permissions={"priority": "low"}),
|
||||
models.Group(id="group2", permissions={"priority": "medium"}),
|
||||
],
|
||||
):
|
||||
assert us.priority == "low"
|
||||
|
||||
|
||||
@@ -262,21 +489,45 @@ def test_active():
|
||||
(True, False, False, False, True),
|
||||
(False, True, False, False, True),
|
||||
(False, False, True, False, True),
|
||||
(False, False, False, True, True)
|
||||
(False, False, False, True, True),
|
||||
]:
|
||||
us = fresh_user_state()
|
||||
with patch.object(UserState, 'read', new_callable=PropertyMock, return_value=read), \
|
||||
patch.object(UserState, 'read_public', new_callable=PropertyMock, return_value=read_public), \
|
||||
patch.object(UserState, 'archive_url', new_callable=PropertyMock, return_value=archive_url), \
|
||||
patch.object(UserState, 'archive_sheet', new_callable=PropertyMock, return_value=archive_sheet):
|
||||
with (
|
||||
patch.object(
|
||||
UserState, "read", new_callable=PropertyMock, return_value=read
|
||||
),
|
||||
patch.object(
|
||||
UserState,
|
||||
"read_public",
|
||||
new_callable=PropertyMock,
|
||||
return_value=read_public,
|
||||
),
|
||||
patch.object(
|
||||
UserState,
|
||||
"archive_url",
|
||||
new_callable=PropertyMock,
|
||||
return_value=archive_url,
|
||||
),
|
||||
patch.object(
|
||||
UserState,
|
||||
"archive_sheet",
|
||||
new_callable=PropertyMock,
|
||||
return_value=archive_sheet,
|
||||
),
|
||||
):
|
||||
assert us.active == is_active
|
||||
|
||||
|
||||
def test_in_group(user_state):
|
||||
with patch.object(UserState, 'user_groups_names', new_callable=PropertyMock, return_value=["group1", "group2"]):
|
||||
assert user_state.in_group("group1") == True
|
||||
assert user_state.in_group("group2") == True
|
||||
assert user_state.in_group("group3") == False
|
||||
with patch.object(
|
||||
UserState,
|
||||
"user_groups_names",
|
||||
new_callable=PropertyMock,
|
||||
return_value=["group1", "group2"],
|
||||
):
|
||||
assert user_state.in_group("group1") is True
|
||||
assert user_state.in_group("group2") is True
|
||||
assert user_state.in_group("group3") is False
|
||||
|
||||
|
||||
def test_usage(db_session):
|
||||
@@ -294,10 +545,34 @@ def test_usage(db_session):
|
||||
]
|
||||
megabytes = int(sum(bytes) / 1024 / 1024)
|
||||
|
||||
with patch.object(db_session, 'query', side_effect=[
|
||||
MagicMock(filter=MagicMock(return_value=MagicMock(group_by=MagicMock(return_value=MagicMock(all=MagicMock(return_value=user_sheets)))))),
|
||||
MagicMock(filter=MagicMock(return_value=MagicMock(group_by=MagicMock(return_value=MagicMock(all=MagicMock(return_value=urls_by_group))))))
|
||||
]):
|
||||
with patch.object(
|
||||
db_session,
|
||||
"query",
|
||||
side_effect=[
|
||||
MagicMock(
|
||||
filter=MagicMock(
|
||||
return_value=MagicMock(
|
||||
group_by=MagicMock(
|
||||
return_value=MagicMock(
|
||||
all=MagicMock(return_value=user_sheets)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
),
|
||||
MagicMock(
|
||||
filter=MagicMock(
|
||||
return_value=MagicMock(
|
||||
group_by=MagicMock(
|
||||
return_value=MagicMock(
|
||||
all=MagicMock(return_value=urls_by_group)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
),
|
||||
],
|
||||
):
|
||||
usage_response = user_state.usage()
|
||||
|
||||
assert usage_response.monthly_urls == 155
|
||||
@@ -305,11 +580,15 @@ def test_usage(db_session):
|
||||
assert usage_response.total_sheets == 115
|
||||
|
||||
assert usage_response.groups["group1"].monthly_urls == 50
|
||||
assert usage_response.groups["group1"].monthly_mbs == int(bytes[0] / 1024 / 1024)
|
||||
assert usage_response.groups["group1"].monthly_mbs == int(
|
||||
bytes[0] / 1024 / 1024
|
||||
)
|
||||
assert usage_response.groups["group1"].total_sheets == 5
|
||||
|
||||
assert usage_response.groups["group2"].monthly_urls == 100
|
||||
assert usage_response.groups["group2"].monthly_mbs == int(bytes[1] / 1024 / 1024)
|
||||
assert usage_response.groups["group2"].monthly_mbs == int(
|
||||
bytes[1] / 1024 / 1024
|
||||
)
|
||||
assert usage_response.groups["group2"].total_sheets == 10
|
||||
|
||||
assert usage_response.groups["group3"].monthly_urls == 0
|
||||
@@ -317,7 +596,9 @@ def test_usage(db_session):
|
||||
assert usage_response.groups["group3"].total_sheets == 100
|
||||
|
||||
assert usage_response.groups["group4"].monthly_urls == 5
|
||||
assert usage_response.groups["group4"].monthly_mbs == int(bytes[2] / 1024 / 1024)
|
||||
assert usage_response.groups["group4"].monthly_mbs == int(
|
||||
bytes[2] / 1024 / 1024
|
||||
)
|
||||
assert usage_response.groups["group4"].total_sheets == 0
|
||||
|
||||
|
||||
@@ -333,8 +614,23 @@ def test_has_quota_monthly_sheets(db_session):
|
||||
]
|
||||
|
||||
for permissions, count, expected in test_cases:
|
||||
with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions):
|
||||
with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(count=MagicMock(return_value=count))))):
|
||||
with patch.object(
|
||||
UserState,
|
||||
"permissions",
|
||||
new_callable=PropertyMock,
|
||||
return_value=permissions,
|
||||
):
|
||||
with patch.object(
|
||||
us.db,
|
||||
"query",
|
||||
return_value=MagicMock(
|
||||
filter=MagicMock(
|
||||
return_value=MagicMock(
|
||||
count=MagicMock(return_value=count)
|
||||
)
|
||||
)
|
||||
),
|
||||
):
|
||||
assert us.has_quota_monthly_sheets("group1") == expected
|
||||
|
||||
|
||||
@@ -349,8 +645,23 @@ def test_has_quota_max_monthly_urls(db_session):
|
||||
]
|
||||
|
||||
for permissions, count, expected in test_cases:
|
||||
with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions):
|
||||
with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(count=MagicMock(return_value=count))))):
|
||||
with patch.object(
|
||||
UserState,
|
||||
"permissions",
|
||||
new_callable=PropertyMock,
|
||||
return_value=permissions,
|
||||
):
|
||||
with patch.object(
|
||||
us.db,
|
||||
"query",
|
||||
return_value=MagicMock(
|
||||
filter=MagicMock(
|
||||
return_value=MagicMock(
|
||||
count=MagicMock(return_value=count)
|
||||
)
|
||||
)
|
||||
),
|
||||
):
|
||||
assert us.has_quota_max_monthly_urls("group1") == expected
|
||||
test_cases = [
|
||||
(-1, 1000, True),
|
||||
@@ -360,8 +671,23 @@ def test_has_quota_max_monthly_urls(db_session):
|
||||
]
|
||||
|
||||
for max_urls, count, expected in test_cases:
|
||||
with patch.object(UserState, 'max_monthly_urls', new_callable=PropertyMock, return_value=max_urls):
|
||||
with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(count=MagicMock(return_value=count))))):
|
||||
with patch.object(
|
||||
UserState,
|
||||
"max_monthly_urls",
|
||||
new_callable=PropertyMock,
|
||||
return_value=max_urls,
|
||||
):
|
||||
with patch.object(
|
||||
us.db,
|
||||
"query",
|
||||
return_value=MagicMock(
|
||||
filter=MagicMock(
|
||||
return_value=MagicMock(
|
||||
count=MagicMock(return_value=count)
|
||||
)
|
||||
)
|
||||
),
|
||||
):
|
||||
assert us.has_quota_max_monthly_urls("") == expected
|
||||
|
||||
|
||||
@@ -376,8 +702,29 @@ def test_has_quota_max_monthly_mbs(db_session):
|
||||
]
|
||||
|
||||
for permissions, mbs, expected in test_cases:
|
||||
with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions):
|
||||
with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(with_entities=MagicMock(return_value=MagicMock(scalar=MagicMock(return_value=mbs * 1024 * 1024))))))):
|
||||
with patch.object(
|
||||
UserState,
|
||||
"permissions",
|
||||
new_callable=PropertyMock,
|
||||
return_value=permissions,
|
||||
):
|
||||
with patch.object(
|
||||
us.db,
|
||||
"query",
|
||||
return_value=MagicMock(
|
||||
filter=MagicMock(
|
||||
return_value=MagicMock(
|
||||
with_entities=MagicMock(
|
||||
return_value=MagicMock(
|
||||
scalar=MagicMock(
|
||||
return_value=mbs * 1024 * 1024
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
),
|
||||
):
|
||||
assert us.has_quota_max_monthly_mbs("group1") == expected
|
||||
|
||||
test_cases = [
|
||||
@@ -388,8 +735,29 @@ def test_has_quota_max_monthly_mbs(db_session):
|
||||
]
|
||||
|
||||
for max_mbs, mbs, expected in test_cases:
|
||||
with patch.object(UserState, 'max_monthly_mbs', new_callable=PropertyMock, return_value=max_mbs):
|
||||
with patch.object(us.db, 'query', return_value=MagicMock(filter=MagicMock(return_value=MagicMock(with_entities=MagicMock(return_value=MagicMock(scalar=MagicMock(return_value=mbs * 1024 * 1024))))))):
|
||||
with patch.object(
|
||||
UserState,
|
||||
"max_monthly_mbs",
|
||||
new_callable=PropertyMock,
|
||||
return_value=max_mbs,
|
||||
):
|
||||
with patch.object(
|
||||
us.db,
|
||||
"query",
|
||||
return_value=MagicMock(
|
||||
filter=MagicMock(
|
||||
return_value=MagicMock(
|
||||
with_entities=MagicMock(
|
||||
return_value=MagicMock(
|
||||
scalar=MagicMock(
|
||||
return_value=mbs * 1024 * 1024
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
),
|
||||
):
|
||||
assert us.has_quota_max_monthly_mbs("") == expected
|
||||
|
||||
|
||||
@@ -399,10 +767,15 @@ def test_can_manually_trigger(user_state):
|
||||
"group2": GroupInfo(manually_trigger_sheet=False),
|
||||
}
|
||||
|
||||
with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions):
|
||||
assert user_state.can_manually_trigger("group1") == True
|
||||
assert user_state.can_manually_trigger("group2") == False
|
||||
assert user_state.can_manually_trigger("group3") == False
|
||||
with patch.object(
|
||||
UserState,
|
||||
"permissions",
|
||||
new_callable=PropertyMock,
|
||||
return_value=permissions,
|
||||
):
|
||||
assert user_state.can_manually_trigger("group1") is True
|
||||
assert user_state.can_manually_trigger("group2") is False
|
||||
assert user_state.can_manually_trigger("group3") is False
|
||||
|
||||
|
||||
def test_is_sheet_frequency_allowed(user_state):
|
||||
@@ -411,23 +784,44 @@ def test_is_sheet_frequency_allowed(user_state):
|
||||
"group2": GroupInfo(sheet_frequency={"daily"}),
|
||||
}
|
||||
|
||||
with patch.object(UserState, 'permissions', new_callable=PropertyMock, return_value=permissions):
|
||||
assert user_state.is_sheet_frequency_allowed("group1", "daily") == True
|
||||
assert user_state.is_sheet_frequency_allowed("group1", "hourly") == True
|
||||
assert user_state.is_sheet_frequency_allowed("group1", "weekly") == False
|
||||
assert user_state.is_sheet_frequency_allowed("group2", "hourly") == False
|
||||
assert user_state.is_sheet_frequency_allowed("group2", "daily") == True
|
||||
assert user_state.is_sheet_frequency_allowed("group3", "daily") == False
|
||||
with patch.object(
|
||||
UserState,
|
||||
"permissions",
|
||||
new_callable=PropertyMock,
|
||||
return_value=permissions,
|
||||
):
|
||||
assert user_state.is_sheet_frequency_allowed("group1", "daily") is True
|
||||
assert user_state.is_sheet_frequency_allowed("group1", "hourly") is True
|
||||
assert (
|
||||
user_state.is_sheet_frequency_allowed("group1", "weekly") is False
|
||||
)
|
||||
assert (
|
||||
user_state.is_sheet_frequency_allowed("group2", "hourly") is False
|
||||
)
|
||||
assert user_state.is_sheet_frequency_allowed("group2", "daily") is True
|
||||
assert user_state.is_sheet_frequency_allowed("group3", "daily") is False
|
||||
|
||||
|
||||
def test_priority_group(user_state):
|
||||
from app.web.utils.misc import convert_priority_to_queue_dict
|
||||
with patch.object(UserState, 'user_groups', new_callable=PropertyMock, return_value=[
|
||||
models.Group(id="group1", permissions={"priority": "high"}),
|
||||
models.Group(id="group2", permissions={"priority": "medium"}),
|
||||
models.Group(id="group3", permissions={"priority": "low"}),
|
||||
]):
|
||||
assert user_state.priority_group("group1") == convert_priority_to_queue_dict("high")
|
||||
assert user_state.priority_group("group2") == convert_priority_to_queue_dict("medium")
|
||||
assert user_state.priority_group("group3") == convert_priority_to_queue_dict("low")
|
||||
assert user_state.priority_group("group4") == convert_priority_to_queue_dict("low")
|
||||
with patch.object(
|
||||
UserState,
|
||||
"user_groups",
|
||||
new_callable=PropertyMock,
|
||||
return_value=[
|
||||
models.Group(id="group1", permissions={"priority": "high"}),
|
||||
models.Group(id="group2", permissions={"priority": "medium"}),
|
||||
models.Group(id="group3", permissions={"priority": "low"}),
|
||||
],
|
||||
):
|
||||
assert user_state.priority_group(
|
||||
"group1"
|
||||
) == convert_priority_to_queue_dict("high")
|
||||
assert user_state.priority_group(
|
||||
"group2"
|
||||
) == convert_priority_to_queue_dict("medium")
|
||||
assert user_state.priority_group(
|
||||
"group3"
|
||||
) == convert_priority_to_queue_dict("low")
|
||||
assert user_state.priority_group(
|
||||
"group4"
|
||||
) == convert_priority_to_queue_dict("low")
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
from datetime import datetime
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from app.shared.db import models
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
from app.web.db import crud
|
||||
|
||||
|
||||
def test_submit_manual_archive_unauthenticated(client, test_no_auth):
|
||||
test_no_auth(client.post, "/interop/submit-archive")
|
||||
|
||||
|
||||
def test_submit_manual_archive_not_user_auth(client_with_auth, test_no_auth):
|
||||
test_no_auth(client_with_auth.post, "/interop/submit-archive")
|
||||
|
||||
|
||||
@patch("app.web.endpoints.interoperability.business_logic", return_value=MagicMock(get_store_archive_until=MagicMock(return_value=datetime)))
|
||||
def test_submit_manual_archive(m1, client_with_token, db_session):
|
||||
# normal workflow
|
||||
aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.s3.com"]}]})
|
||||
r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": True, "author_id": "jerry@gmail.com", "group_id": "spaceship", "tags": ["test"], "url": "http://example.com"})
|
||||
assert r.status_code == 201
|
||||
assert "id" in r.json()
|
||||
|
||||
inserted = db_session.query(models.Archive).filter(models.Archive.id == r.json()["id"]).first()
|
||||
assert inserted.url == "http://example.com"
|
||||
assert inserted.group_id == "spaceship"
|
||||
assert inserted.author_id == "jerry@gmail.com"
|
||||
assert sorted([t.id for t in inserted.tags]) == sorted(["test", "manual"])
|
||||
assert inserted.public
|
||||
assert type(inserted.result) == dict
|
||||
assert [u.url for u in inserted.urls] == ["http://example.s3.com"]
|
||||
assert type(inserted.store_until) == datetime
|
||||
|
||||
# cannot have the same URL twice
|
||||
aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.com", "http://example.com"]}]})
|
||||
r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": False, "author_id": "jerry@gmail.com", "tags": ["test"], "url": "http://example.com"})
|
||||
assert r.status_code == 422
|
||||
assert r.json() == {"detail": "Cannot insert into DB due to integrity error, likely duplicate urls."}
|
||||
|
||||
|
||||
# test with invalid JSON
|
||||
def test_submit_manual_archive_invalid_json(client_with_token):
|
||||
r = client_with_token.post("/interop/submit-archive", json={"result": "invalid json", "public": False, "author_id": "jer", "tags": ["test"], "url": "http://example.com"})
|
||||
assert r.status_code == 422
|
||||
assert r.json() == {"detail": "Invalid JSON in result field."}
|
||||
|
||||
|
||||
@patch("app.web.endpoints.interoperability.business_logic.get_store_archive_until", side_effect=AssertionError("AssertionError"))
|
||||
def test_submit_manual_archive_no_store_until(m_sau, client_with_token, db_session):
|
||||
aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.s3.com"]}]})
|
||||
r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": True, "author_id": "jerry@gmail.com", "group_id": "spaceship", "tags": ["test"], "url": "http://example.com"})
|
||||
assert r.status_code == 201
|
||||
assert len(r.json()["id"]) == 36
|
||||
res = db_session.query(models.Archive).filter(models.Archive.id == r.json()["id"]).first()
|
||||
assert res.store_until is None
|
||||
# testing that store_until = None is not comparable with datetime, and will always return False
|
||||
res = db_session.query(models.Archive).filter(models.Archive.id == r.json()["id"], models.Archive.store_until < datetime.now()).first()
|
||||
assert res is None
|
||||
@@ -1,193 +0,0 @@
|
||||
from datetime import datetime
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.shared.schemas import TaskResult
|
||||
|
||||
|
||||
def test_endpoints_no_auth(client, test_no_auth):
|
||||
test_no_auth(client.post, "/sheet/create")
|
||||
test_no_auth(client.get, "/sheet/mine")
|
||||
test_no_auth(client.delete, "/sheet/123-sheet-id")
|
||||
test_no_auth(client.post, "/sheet/123-sheet-id/archive")
|
||||
|
||||
|
||||
def test_create_sheet_endpoint(app_with_auth, db_session):
|
||||
client_with_auth = TestClient(app_with_auth)
|
||||
good_data = {
|
||||
"id": "123-sheet-id",
|
||||
"name": "Test Sheet",
|
||||
"group_id": "spaceship",
|
||||
"frequency": "daily"
|
||||
}
|
||||
|
||||
# with good data
|
||||
response = client_with_auth.post("/sheet/create", json=good_data)
|
||||
assert response.status_code == 201
|
||||
j = response.json()
|
||||
assert datetime.fromisoformat(j.pop("created_at"))
|
||||
assert datetime.fromisoformat(j.pop("last_url_archived_at"))
|
||||
assert j.pop("author_id") == 'morty@example.com'
|
||||
assert j == good_data
|
||||
|
||||
# already exists
|
||||
response = client_with_auth.post("/sheet/create", json=good_data)
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {"detail": "Sheet with this ID is already being archived."}
|
||||
|
||||
# bad group
|
||||
bad_data = good_data.copy()
|
||||
bad_data["group_id"] = "not a group"
|
||||
response = client_with_auth.post("/sheet/create", json=bad_data)
|
||||
assert response.status_code == 403
|
||||
assert response.json() == {"detail": "User does not have access to this group."}
|
||||
|
||||
# switch to jerry who's got less quota/permissions
|
||||
from app.web.security import get_user_state
|
||||
from app.web.db.user_state import UserState
|
||||
app_with_auth.dependency_overrides[get_user_state] = lambda: UserState(db_session, "jerry@example.com")
|
||||
client_jerry = TestClient(app_with_auth)
|
||||
|
||||
# frequency not allowed
|
||||
jerry_data = good_data.copy()
|
||||
jerry_data["group_id"] = "animated-characters"
|
||||
jerry_data["frequency"] = "hourly"
|
||||
jerry_data["id"] = "jerry-sheet-id"
|
||||
response = client_jerry.post("/sheet/create", json=jerry_data)
|
||||
assert response.status_code == 422
|
||||
assert response.json() == {"detail": "Invalid frequency selected for this group."}
|
||||
|
||||
jerry_data["frequency"] = "daily"
|
||||
# success for the first sheet, bad quota on second
|
||||
response = client_jerry.post("/sheet/create", json=jerry_data)
|
||||
assert response.status_code == 201
|
||||
|
||||
response = client_jerry.post("/sheet/create", json=jerry_data)
|
||||
assert response.status_code == 429
|
||||
assert response.json() == {"detail": "User has reached their sheet quota for this group."}
|
||||
|
||||
|
||||
def test_get_user_sheets_endpoint(client_with_auth, db_session):
|
||||
# no data
|
||||
response = client_with_auth.get("/sheet/mine")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
# with data
|
||||
from app.shared.db import models
|
||||
db_session.add(
|
||||
models.Sheet(id="123", name="Test Sheet 1", author_id="morty@example.com", group_id="spaceship", frequency="hourly")
|
||||
)
|
||||
db_session.commit()
|
||||
db_session.add_all([
|
||||
models.Sheet(id="456", name="Test Sheet 2", author_id="morty@example.com", group_id="interdimensional", frequency="daily"),
|
||||
models.Sheet(id="789", name="Test Sheet 3", author_id="rick@example.com", group_id="interdimensional", frequency="hourly"),
|
||||
])
|
||||
db_session.commit()
|
||||
|
||||
response = client_with_auth.get("/sheet/mine")
|
||||
assert response.status_code == 200
|
||||
r = response.json()
|
||||
assert isinstance(r, list)
|
||||
assert len(r) == 2
|
||||
assert datetime.fromisoformat(r[0].pop("created_at"))
|
||||
assert datetime.fromisoformat(r[0].pop("last_url_archived_at"))
|
||||
assert datetime.fromisoformat(r[1].pop("created_at"))
|
||||
assert datetime.fromisoformat(r[1].pop("last_url_archived_at"))
|
||||
assert r[0] == {
|
||||
'id': '123',
|
||||
'author_id': 'morty@example.com',
|
||||
'frequency': 'hourly',
|
||||
'group_id': 'spaceship',
|
||||
'name': 'Test Sheet 1',
|
||||
}
|
||||
assert r[1] == {
|
||||
'id': '456',
|
||||
'author_id': 'morty@example.com',
|
||||
'frequency': 'daily',
|
||||
'group_id': 'interdimensional',
|
||||
'name': 'Test Sheet 2',
|
||||
}
|
||||
|
||||
|
||||
def test_delete_sheet_endpoint(client_with_auth, db_session):
|
||||
# missing sheet
|
||||
response = client_with_auth.delete("/sheet/123-sheet-id")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"id": "123-sheet-id",
|
||||
"deleted": False
|
||||
}
|
||||
|
||||
# add sheets for deletion
|
||||
from app.shared.db import models
|
||||
db_session.add_all([
|
||||
models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="interdimensional", frequency="daily"),
|
||||
models.Sheet(id="456-sheet-id", name="Test Sheet 2", author_id="rick@example.com", group_id="spaceship", frequency="hourly"),
|
||||
])
|
||||
db_session.commit()
|
||||
|
||||
# morty can delete his
|
||||
response = client_with_auth.delete("/sheet/123-sheet-id")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"id": "123-sheet-id", "deleted": True}
|
||||
# but only once
|
||||
response = client_with_auth.delete("/sheet/123-sheet-id")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"id": "123-sheet-id", "deleted": False}
|
||||
# and not rick's
|
||||
response = client_with_auth.delete("/sheet/456-sheet-id")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"id": "456-sheet-id", "deleted": False}
|
||||
|
||||
|
||||
class TestArchiveUserSheetEndpoint:
|
||||
@patch("app.web.endpoints.sheet.celery", return_value=MagicMock())
|
||||
def test_normal_flow(self, m_celery, client_with_auth, db_session):
|
||||
from app.shared.db import models
|
||||
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="spaceship", frequency="hourly"))
|
||||
db_session.commit()
|
||||
|
||||
m_signature = MagicMock()
|
||||
m_signature.apply_async.return_value = TaskResult(id="123-taskid", status="PENDING", result="")
|
||||
m_celery.signature.return_value = m_signature
|
||||
|
||||
r = client_with_auth.post("/sheet/123-sheet-id/archive")
|
||||
assert r.status_code == 201
|
||||
assert r.json() == {"id": "123-taskid"}
|
||||
m_celery.signature.assert_called_once()
|
||||
m_signature.apply_async.assert_called_once()
|
||||
|
||||
def test_token_auth(self, client_with_token, test_no_auth):
|
||||
test_no_auth(client_with_token.post, "/sheet/123-sheet-id/archive")
|
||||
|
||||
def test_missing_data(self, client_with_auth):
|
||||
r = client_with_auth.post("/sheet/123-sheet-id/archive")
|
||||
assert r.status_code == 403
|
||||
assert r.json() == {"detail": "No access to this sheet."}
|
||||
|
||||
def test_no_access(self, client_with_auth, db_session):
|
||||
from app.shared.db import models
|
||||
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="rick@example.com", group_id="spaceship", frequency="hourly"))
|
||||
db_session.commit()
|
||||
r = client_with_auth.post("/sheet/123-sheet-id/archive")
|
||||
assert r.status_code == 403
|
||||
assert r.json() == {"detail": "No access to this sheet."}
|
||||
|
||||
def test_user_not_in_group(self, client_with_auth, db_session):
|
||||
from app.shared.db import models
|
||||
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="interdimensional", frequency="hourly"))
|
||||
db_session.commit()
|
||||
r = client_with_auth.post("/sheet/123-sheet-id/archive")
|
||||
assert r.status_code == 403
|
||||
assert r.json() == {"detail": "User does not have access to this group."}
|
||||
|
||||
def test_user_cannot_manually_trigger(self, client_with_auth, db_session):
|
||||
from app.shared.db import models
|
||||
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="default", frequency="hourly"))
|
||||
db_session.commit()
|
||||
r = client_with_auth.post("/sheet/123-sheet-id/archive")
|
||||
assert r.status_code == 429
|
||||
assert r.json() == {"detail": "User cannot manually trigger sheet archiving in this group."}
|
||||
@@ -1,202 +0,0 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from app.shared.schemas import ArchiveCreate, TaskResult
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
|
||||
|
||||
def test_archive_url_unauthenticated(client, test_no_auth):
|
||||
test_no_auth(client.post, "/url/archive")
|
||||
|
||||
|
||||
@patch("app.web.endpoints.url.UserState")
|
||||
@patch("app.web.endpoints.url.celery", return_value=MagicMock())
|
||||
def test_archive_url(m_celery, m2, client_with_auth):
|
||||
m_signature = MagicMock()
|
||||
m_signature.apply_async.return_value = TaskResult(id="123-456-789", status="PENDING", result="")
|
||||
m_celery.signature.return_value = m_signature
|
||||
|
||||
m_user_state = MagicMock()
|
||||
m2.return_value = m_user_state
|
||||
|
||||
# url is too short
|
||||
response = client_with_auth.post("/url/archive", json={"url": "bad"})
|
||||
assert response.status_code == 422
|
||||
assert response.json()["detail"][0]["msg"] == 'String should have at least 5 characters'
|
||||
m_celery.signature.assert_not_called()
|
||||
|
||||
# url is invalid
|
||||
response = client_with_auth.post("/url/archive", json={"url": "example.com"})
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "Invalid URL received."
|
||||
|
||||
# valid request
|
||||
m_user_state.has_quota_max_monthly_urls.return_value = True
|
||||
m_user_state.has_quota_max_monthly_mbs.return_value = True
|
||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {'id': '123-456-789'}
|
||||
m_celery.signature.assert_called_once()
|
||||
m_signature.apply_async.assert_called_once()
|
||||
called_val = m_celery.signature.call_args
|
||||
assert called_val[0][0] == "create_archive_task"
|
||||
assert json.loads(called_val[1]['args'][0]) == {"id": None, "url": "https://example.com", "result": None, "public": False, "author_id": "rick@example.com", "group_id": "default", "tags": None, "sheet_id": None, "store_until": None, "urls": None}
|
||||
m_user_state.has_quota_max_monthly_urls.assert_called_once()
|
||||
m_user_state.has_quota_max_monthly_mbs.assert_called_once()
|
||||
m_user_state.in_group.assert_called_once_with("default")
|
||||
|
||||
# user is not in group
|
||||
m_user_state.in_group.return_value = False
|
||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "new-group"})
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "User does not have access to this group."
|
||||
m_user_state.in_group.assert_called_with("new-group")
|
||||
|
||||
# user is in group
|
||||
m_user_state.in_group.return_value = True
|
||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spaceship"})
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {'id': '123-456-789'}
|
||||
assert m_celery.signature.call_count == 2
|
||||
assert m_signature.apply_async.call_count == 2
|
||||
called_val = m_celery.signature.call_args
|
||||
assert json.loads(called_val[1]['args'][0])["group_id"] == "spaceship"
|
||||
m_user_state.in_group.assert_called_with("spaceship")
|
||||
|
||||
# user is over monthly URL quota
|
||||
m_user_state.has_quota_max_monthly_urls.return_value = False
|
||||
m_user_state.has_quota_max_monthly_mbs.return_value = True
|
||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spaceship"})
|
||||
assert response.status_code == 429
|
||||
assert response.json()["detail"] == "User has reached their monthly URL quota."
|
||||
m_user_state.has_quota_max_monthly_urls.assert_called_with("spaceship")
|
||||
|
||||
# user is over monthly MB quota
|
||||
m_user_state.has_quota_max_monthly_urls.return_value = True
|
||||
m_user_state.has_quota_max_monthly_mbs.return_value = False
|
||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spacesuit"})
|
||||
assert response.status_code == 429
|
||||
assert response.json()["detail"] == "User has reached their monthly MB quota."
|
||||
m_user_state.has_quota_max_monthly_mbs.assert_called_with("spacesuit")
|
||||
assert m_celery.signature.call_count == 2
|
||||
assert m_signature.apply_async.call_count == 2
|
||||
|
||||
|
||||
@patch("app.web.endpoints.url.UserState")
|
||||
def test_archive_url_quotas(m1, client_with_auth):
|
||||
m_user_state = MagicMock()
|
||||
m1.return_value = m_user_state
|
||||
|
||||
# misses on monthly URLs quota
|
||||
m_user_state.has_quota_max_monthly_urls.return_value = False
|
||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
|
||||
assert response.status_code == 429
|
||||
assert response.json()["detail"] == "User has reached their monthly URL quota."
|
||||
m_user_state.has_quota_max_monthly_urls.assert_called_once()
|
||||
|
||||
# misses on monthly MBs quota
|
||||
m_user_state.has_quota_max_monthly_urls.return_value = True
|
||||
m_user_state.has_quota_max_monthly_mbs.return_value = False
|
||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
|
||||
assert response.status_code == 429
|
||||
assert response.json()["detail"] == "User has reached their monthly MB quota."
|
||||
m_user_state.has_quota_max_monthly_mbs.assert_called_once()
|
||||
|
||||
|
||||
@patch("app.web.endpoints.url.celery", return_value=MagicMock())
|
||||
def test_archive_url_with_api_token(m_celery, client_with_token):
|
||||
m_signature = MagicMock()
|
||||
m_signature.apply_async.return_value = TaskResult(id="123-456-789", status="PENDING", result="")
|
||||
m_celery.signature.return_value = m_signature
|
||||
response = client_with_token.post("/url/archive", json={"url": "https://example.com", "author_id": "someone@example.com"})
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {'id': '123-456-789'}
|
||||
m_celery.signature.assert_called_once()
|
||||
m_signature.apply_async.assert_called_once()
|
||||
called_val = m_celery.signature.call_args
|
||||
assert called_val[0][0] == "create_archive_task"
|
||||
assert json.loads(called_val[1]['args'][0]) == {"id": None, "url": "https://example.com", "result": None, "public": False, "author_id": "someone@example.com", "group_id": "default", "tags": None, "sheet_id": None, "store_until": None, "urls": None}
|
||||
|
||||
# missing id should use ALLOW_ANY_EMAIL
|
||||
response = client_with_token.post("/url/archive", json={"url": "https://example.com", "author_id": None})
|
||||
assert response.status_code == 201
|
||||
called_val = m_celery.signature.call_args
|
||||
assert called_val[0][0] == "create_archive_task"
|
||||
assert json.loads(called_val[1]['args'][0]) == {"id": None, "url": "https://example.com", "result": None, "public": False, "author_id": ALLOW_ANY_EMAIL, "group_id": "default", "tags": None, "sheet_id": None, "store_until": None, "urls": None}
|
||||
|
||||
|
||||
def test_search_by_url_unauthenticated(client, test_no_auth):
|
||||
test_no_auth(client.get, "/url/search")
|
||||
|
||||
|
||||
def test_search_by_url(client_with_auth, client_with_token, db_session):
|
||||
# tests the search endpoint, including through some db data for the endpoint params
|
||||
response = client_with_auth.get("/url/search")
|
||||
assert response.status_code == 422
|
||||
assert response.json()["detail"][0]["msg"] == "Field required"
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
from app.shared import schemas
|
||||
from app.shared.db import worker_crud
|
||||
for i in range(11):
|
||||
worker_crud.create_archive(db_session, ArchiveCreate(id=f"url-456-{i}", url="https://example.com" if i < 10 else "https://something-else.com", result={}, public=True, author_id="rick@example.com"), [], [])
|
||||
# NB: this insertion is too fast for the ordering to be correct as they are within the same second
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com")
|
||||
assert response.status_code == 200
|
||||
assert len(j := response.json()) == 10
|
||||
assert "url-456-0" in [i["id"] for i in j]
|
||||
assert "url-456-9" in [i["id"] for i in j]
|
||||
assert "url-456-10" not in [i["id"] for i in j]
|
||||
assert j[0].keys() == schemas.ArchiveResult.model_fields.keys()
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com&limit=5")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 5
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com&skip=5&limit=2")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 2
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com&archived_before=2010-01-01")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 0
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com&archived_after=2010-01-01")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 10
|
||||
|
||||
# API token will also work
|
||||
response = client_with_token.get("/url/search?url=https://example.com&archived_after=2010-01-01")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 10
|
||||
|
||||
|
||||
@patch("app.web.endpoints.url.UserState")
|
||||
def test_search_no_read_access(mock_user_state, client_with_auth):
|
||||
mock_user_state.return_value.read = False
|
||||
mock_user_state.return_value.read_public = False
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com")
|
||||
assert response.status_code == 403
|
||||
assert response.json() == {"detail": "User does not have read access."}
|
||||
|
||||
|
||||
def test_delete_task_unauthenticated(client, test_no_auth):
|
||||
test_no_auth(client.delete, "/url/123-456-789")
|
||||
|
||||
|
||||
def test_delete_task(client_with_auth, db_session):
|
||||
response = client_with_auth.delete("/url/delete-123-456-789")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"id": "delete-123-456-789", "deleted": False}
|
||||
|
||||
from app.shared.db import worker_crud
|
||||
worker_crud.create_archive(db_session, ArchiveCreate(id="delete-123-456-789", url="https://example.com", result={}, public=True, author_id="morty@example.com"), [], [])
|
||||
|
||||
response = client_with_auth.delete("/url/delete-123-456-789")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"id": "delete-123-456-789", "deleted": True}
|
||||
@@ -1,15 +1,20 @@
|
||||
from http import HTTPStatus
|
||||
from unittest.mock import MagicMock
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from loguru import logger
|
||||
|
||||
from app.shared.schemas import Usage, UsageResponse
|
||||
from app.shared.user_groups import GroupInfo
|
||||
from app.web.config import VERSION
|
||||
from app.tests.web.db.test_crud import test_data
|
||||
from app.web.security import get_user_state
|
||||
from app.web.utils.metrics import measure_regular_metrics
|
||||
|
||||
|
||||
def test_endpoint_home(client_with_auth):
|
||||
r = client_with_auth.get("/")
|
||||
assert r.status_code == 200
|
||||
assert r.status_code == HTTPStatus.OK
|
||||
j = r.json()
|
||||
assert "version" in j and j["version"] == VERSION
|
||||
assert "breakingChanges" in j
|
||||
@@ -18,7 +23,7 @@ def test_endpoint_home(client_with_auth):
|
||||
|
||||
def test_endpoint_health(client_with_auth):
|
||||
r = client_with_auth.get("/health")
|
||||
assert r.status_code == 200
|
||||
assert r.status_code == HTTPStatus.OK
|
||||
assert r.json() == {"status": "ok"}
|
||||
|
||||
|
||||
@@ -29,32 +34,31 @@ def test_endpoint_active_no_auth(client, test_no_auth):
|
||||
def test_endpoint_active(app):
|
||||
m_user_state = MagicMock()
|
||||
|
||||
from app.web.security import get_user_state
|
||||
app.dependency_overrides[get_user_state] = lambda: m_user_state
|
||||
|
||||
# inactive user
|
||||
m_user_state.active = False
|
||||
client = TestClient(app)
|
||||
r = client.get("/user/active")
|
||||
assert r.status_code == 200
|
||||
assert r.status_code == HTTPStatus.OK
|
||||
assert r.json() == {"active": False}
|
||||
|
||||
# active user
|
||||
m_user_state.active = True
|
||||
client = TestClient(app)
|
||||
r = client.get("/user/active")
|
||||
assert r.status_code == 200
|
||||
assert r.status_code == HTTPStatus.OK
|
||||
assert r.json() == {"active": True}
|
||||
|
||||
|
||||
def test_no_serve_local_archive_by_default(client_with_auth):
|
||||
r = client_with_auth.get("/app/local_archive_test/temp.txt")
|
||||
assert r.status_code == 404
|
||||
assert r.status_code == HTTPStatus.NOT_FOUND
|
||||
|
||||
|
||||
def test_favicon(client_with_auth):
|
||||
r = client_with_auth.get("/favicon.ico")
|
||||
assert r.status_code == 200
|
||||
assert r.status_code == HTTPStatus.OK
|
||||
assert r.headers["content-type"] == "image/vnd.microsoft.icon"
|
||||
|
||||
|
||||
@@ -70,8 +74,10 @@ def test_endpoint_test_prometheus_no_user_auth(client_with_auth, test_no_auth):
|
||||
async def test_prometheus_metrics(test_data, client_with_token, get_settings):
|
||||
# before metrics calculation
|
||||
r = client_with_token.get("/metrics")
|
||||
assert r.status_code == 200
|
||||
assert r.headers["content-type"] == "text/plain; version=0.0.4; charset=utf-8"
|
||||
assert r.status_code == HTTPStatus.OK
|
||||
assert (
|
||||
r.headers["content-type"] == "text/plain; version=0.0.4; charset=utf-8"
|
||||
)
|
||||
assert "disk_utilization" in r.text
|
||||
assert "database_metrics" in r.text
|
||||
assert "exceptions" in r.text
|
||||
@@ -79,8 +85,9 @@ async def test_prometheus_metrics(test_data, client_with_token, get_settings):
|
||||
assert 'disk_utilization{type="used"}' not in r.text
|
||||
|
||||
# after metrics calculation
|
||||
from app.web.utils.metrics import measure_regular_metrics
|
||||
await measure_regular_metrics(get_settings.DATABASE_PATH, 60 * 60 * 24 * 31 * 12 * 100)
|
||||
await measure_regular_metrics(
|
||||
get_settings.DATABASE_PATH, 60 * 60 * 24 * 31 * 12 * 100
|
||||
)
|
||||
r2 = client_with_token.get("/metrics")
|
||||
assert 'disk_utilization{type="used"}' in r2.text
|
||||
assert 'disk_utilization{type="free"}' in r2.text
|
||||
@@ -88,20 +95,37 @@ async def test_prometheus_metrics(test_data, client_with_token, get_settings):
|
||||
assert 'database_metrics{query="count_archives"} 100.0' in r2.text
|
||||
assert 'database_metrics{query="count_archive_urls"} 1000.0' in r2.text
|
||||
assert 'database_metrics{query="count_users"} 3.0' in r2.text
|
||||
assert 'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0' in r2.text
|
||||
assert 'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0' in r2.text
|
||||
assert 'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0' in r2.text
|
||||
assert (
|
||||
'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0'
|
||||
in r2.text
|
||||
)
|
||||
assert (
|
||||
'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0'
|
||||
in r2.text
|
||||
)
|
||||
assert (
|
||||
'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0'
|
||||
in r2.text
|
||||
)
|
||||
|
||||
# 30s window, should not change the gauges nor the total in the counters
|
||||
from app.web.utils.metrics import measure_regular_metrics
|
||||
await measure_regular_metrics(get_settings.DATABASE_PATH, 30)
|
||||
r3 = client_with_token.get("/metrics")
|
||||
assert 'database_metrics{query="count_archives"} 100.0' in r3.text
|
||||
assert 'database_metrics{query="count_archive_urls"} 1000.0' in r3.text
|
||||
assert 'database_metrics{query="count_users"} 3.0' in r3.text
|
||||
assert 'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0' in r3.text
|
||||
assert 'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0' in r3.text
|
||||
assert 'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0' in r3.text
|
||||
assert (
|
||||
'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0'
|
||||
in r3.text
|
||||
)
|
||||
assert (
|
||||
'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0'
|
||||
in r3.text
|
||||
)
|
||||
assert (
|
||||
'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0'
|
||||
in r3.text
|
||||
)
|
||||
|
||||
|
||||
def test_endpoint_get_user_permissions_no_user_auth(client, test_no_auth):
|
||||
@@ -109,14 +133,12 @@ def test_endpoint_get_user_permissions_no_user_auth(client, test_no_auth):
|
||||
|
||||
|
||||
def test_endpoint_get_user_permissions(app):
|
||||
from app.web.security import get_user_state
|
||||
|
||||
m_user_state = MagicMock()
|
||||
rv = {
|
||||
"all": GroupInfo(read=True),
|
||||
"group1": GroupInfo(archive_url=True),
|
||||
}
|
||||
from loguru import logger
|
||||
|
||||
logger.info(rv)
|
||||
m_user_state.permissions = rv
|
||||
|
||||
@@ -124,13 +146,13 @@ def test_endpoint_get_user_permissions(app):
|
||||
|
||||
client = TestClient(app)
|
||||
r = client.get("/user/permissions")
|
||||
assert r.status_code == 200
|
||||
assert r.status_code == HTTPStatus.OK
|
||||
response = r.json()
|
||||
assert response.keys() == {"all", "group1"}
|
||||
assert response["all"]["read"]
|
||||
assert response["group1"]["read"] == []
|
||||
assert response["group1"]["archive_url"]
|
||||
assert response["all"]["archive_url"] == False
|
||||
assert response["all"]["archive_url"] is False
|
||||
|
||||
|
||||
def test_endpoint_get_user_usage_no_user_auth(client, test_no_auth):
|
||||
@@ -138,8 +160,6 @@ def test_endpoint_get_user_usage_no_user_auth(client, test_no_auth):
|
||||
|
||||
|
||||
def test_endpoint_get_user_usage_inactive(app):
|
||||
from app.web.security import get_user_state
|
||||
|
||||
m_user_state = MagicMock()
|
||||
m_user_state.active = False
|
||||
|
||||
@@ -147,13 +167,11 @@ def test_endpoint_get_user_usage_inactive(app):
|
||||
|
||||
client = TestClient(app)
|
||||
r = client.get("/user/usage")
|
||||
assert r.status_code == 403
|
||||
assert r.status_code == HTTPStatus.FORBIDDEN
|
||||
assert r.json() == {"detail": "User is not active."}
|
||||
|
||||
|
||||
def test_endpoint_get_user_usage_active(app):
|
||||
from app.web.security import get_user_state
|
||||
|
||||
m_user_state = MagicMock()
|
||||
m_user_state.active = True
|
||||
mock_usage = UsageResponse(
|
||||
@@ -162,8 +180,8 @@ def test_endpoint_get_user_usage_active(app):
|
||||
total_sheets=3,
|
||||
groups={
|
||||
"group1": Usage(monthly_urls=4, monthly_mbs=5, total_sheets=6),
|
||||
"group2": Usage(monthly_urls=7, monthly_mbs=8, total_sheets=9)
|
||||
}
|
||||
"group2": Usage(monthly_urls=7, monthly_mbs=8, total_sheets=9),
|
||||
},
|
||||
)
|
||||
m_user_state.usage.return_value = mock_usage
|
||||
|
||||
@@ -171,5 +189,5 @@ def test_endpoint_get_user_usage_active(app):
|
||||
|
||||
client = TestClient(app)
|
||||
r = client.get("/user/usage")
|
||||
assert r.status_code == 200
|
||||
assert r.status_code == HTTPStatus.OK
|
||||
assert UsageResponse(**r.json()) == mock_usage
|
||||
147
app/tests/web/routers/test_interoperability.py
Normal file
147
app/tests/web/routers/test_interoperability.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from http import HTTPStatus
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from app.shared.db import models
|
||||
|
||||
|
||||
def test_submit_manual_archive_unauthenticated(client, test_no_auth):
|
||||
test_no_auth(client.post, "/interop/submit-archive")
|
||||
|
||||
|
||||
def test_submit_manual_archive_not_user_auth(client_with_auth, test_no_auth):
|
||||
test_no_auth(client_with_auth.post, "/interop/submit-archive")
|
||||
|
||||
|
||||
@patch(
|
||||
"app.web.routers.interoperability.business_logic",
|
||||
return_value=MagicMock(
|
||||
get_store_archive_until=MagicMock(return_value=datetime)
|
||||
),
|
||||
)
|
||||
def test_submit_manual_archive(m1, client_with_token, db_session):
|
||||
# normal workflow
|
||||
aa_metadata = json.dumps(
|
||||
{
|
||||
"status": "test: success",
|
||||
"metadata": {"url": "http://example.com"},
|
||||
"media": [{"filename": "fn1", "urls": ["http://example.s3.com"]}],
|
||||
}
|
||||
)
|
||||
r = client_with_token.post(
|
||||
"/interop/submit-archive",
|
||||
json={
|
||||
"result": aa_metadata,
|
||||
"public": True,
|
||||
"author_id": "jerry@gmail.com",
|
||||
"group_id": "spaceship",
|
||||
"tags": ["test"],
|
||||
"url": "http://example.com",
|
||||
},
|
||||
)
|
||||
assert r.status_code == HTTPStatus.CREATED
|
||||
assert "id" in r.json()
|
||||
|
||||
inserted = (
|
||||
db_session.query(models.Archive)
|
||||
.filter(models.Archive.id == r.json()["id"])
|
||||
.first()
|
||||
)
|
||||
assert inserted.url == "http://example.com"
|
||||
assert inserted.group_id == "spaceship"
|
||||
assert inserted.author_id == "jerry@gmail.com"
|
||||
assert sorted([t.id for t in inserted.tags]) == sorted(["test", "manual"])
|
||||
assert inserted.public
|
||||
assert isinstance(inserted.result, dict)
|
||||
assert [u.url for u in inserted.urls] == ["http://example.s3.com"]
|
||||
assert isinstance(inserted.store_until, datetime)
|
||||
|
||||
# cannot have the same URL twice
|
||||
aa_metadata = json.dumps(
|
||||
{
|
||||
"status": "test: success",
|
||||
"metadata": {"url": "http://example.com"},
|
||||
"media": [
|
||||
{
|
||||
"filename": "fn1",
|
||||
"urls": ["http://example.com", "http://example.com"],
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
r = client_with_token.post(
|
||||
"/interop/submit-archive",
|
||||
json={
|
||||
"result": aa_metadata,
|
||||
"public": False,
|
||||
"author_id": "jerry@gmail.com",
|
||||
"tags": ["test"],
|
||||
"url": "http://example.com",
|
||||
},
|
||||
)
|
||||
assert r.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
|
||||
assert r.json() == {
|
||||
"detail": "Cannot insert into DB due to integrity error, likely duplicate urls."
|
||||
}
|
||||
|
||||
|
||||
# test with invalid JSON
|
||||
def test_submit_manual_archive_invalid_json(client_with_token):
|
||||
r = client_with_token.post(
|
||||
"/interop/submit-archive",
|
||||
json={
|
||||
"result": "invalid json",
|
||||
"public": False,
|
||||
"author_id": "jer",
|
||||
"tags": ["test"],
|
||||
"url": "http://example.com",
|
||||
},
|
||||
)
|
||||
assert r.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
|
||||
assert r.json() == {"detail": "Invalid JSON in result field."}
|
||||
|
||||
|
||||
@patch(
|
||||
"app.web.routers.interoperability.business_logic.get_store_archive_until",
|
||||
side_effect=AssertionError("AssertionError"),
|
||||
)
|
||||
def test_submit_manual_archive_no_store_until(
|
||||
m_sau, client_with_token, db_session
|
||||
):
|
||||
aa_metadata = json.dumps(
|
||||
{
|
||||
"status": "test: success",
|
||||
"metadata": {"url": "http://example.com"},
|
||||
"media": [{"filename": "fn1", "urls": ["http://example.s3.com"]}],
|
||||
}
|
||||
)
|
||||
r = client_with_token.post(
|
||||
"/interop/submit-archive",
|
||||
json={
|
||||
"result": aa_metadata,
|
||||
"public": True,
|
||||
"author_id": "jerry@gmail.com",
|
||||
"group_id": "spaceship",
|
||||
"tags": ["test"],
|
||||
"url": "http://example.com",
|
||||
},
|
||||
)
|
||||
assert r.status_code == HTTPStatus.CREATED
|
||||
assert len(r.json()["id"]) == 36
|
||||
res = (
|
||||
db_session.query(models.Archive)
|
||||
.filter(models.Archive.id == r.json()["id"])
|
||||
.first()
|
||||
)
|
||||
assert res.store_until is None
|
||||
# testing that store_until = None is not comparable with datetime, and will always return False
|
||||
res = (
|
||||
db_session.query(models.Archive)
|
||||
.filter(
|
||||
models.Archive.id == r.json()["id"],
|
||||
models.Archive.store_until < datetime.now(),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
assert res is None
|
||||
268
app/tests/web/routers/test_sheet.py
Normal file
268
app/tests/web/routers/test_sheet.py
Normal file
@@ -0,0 +1,268 @@
|
||||
from datetime import datetime
|
||||
from http import HTTPStatus
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.shared.constants import STATUS_PENDING
|
||||
from app.shared.db import models
|
||||
from app.shared.schemas import TaskResult
|
||||
from app.web.db.user_state import UserState
|
||||
from app.web.security import get_user_state
|
||||
|
||||
|
||||
def test_endpoints_no_auth(client, test_no_auth):
|
||||
test_no_auth(client.post, "/sheet/create")
|
||||
test_no_auth(client.get, "/sheet/mine")
|
||||
test_no_auth(client.delete, "/sheet/123-sheet-id")
|
||||
test_no_auth(client.post, "/sheet/123-sheet-id/archive")
|
||||
|
||||
|
||||
def test_create_sheet_endpoint(app_with_auth, db_session):
|
||||
client_with_auth = TestClient(app_with_auth)
|
||||
good_data = {
|
||||
"id": "123-sheet-id",
|
||||
"name": "Test Sheet",
|
||||
"group_id": "spaceship",
|
||||
"frequency": "daily",
|
||||
}
|
||||
|
||||
# with good data
|
||||
response = client_with_auth.post("/sheet/create", json=good_data)
|
||||
assert response.status_code == HTTPStatus.CREATED
|
||||
j = response.json()
|
||||
assert datetime.fromisoformat(j.pop("created_at"))
|
||||
assert datetime.fromisoformat(j.pop("last_url_archived_at"))
|
||||
assert j.pop("author_id") == "morty@example.com"
|
||||
assert j == good_data
|
||||
|
||||
# already exists
|
||||
response = client_with_auth.post("/sheet/create", json=good_data)
|
||||
assert response.status_code == HTTPStatus.BAD_REQUEST
|
||||
assert response.json() == {
|
||||
"detail": "Sheet with this ID is already being archived."
|
||||
}
|
||||
|
||||
# bad group
|
||||
bad_data = good_data.copy()
|
||||
bad_data["group_id"] = "not a group"
|
||||
response = client_with_auth.post("/sheet/create", json=bad_data)
|
||||
assert response.status_code == HTTPStatus.FORBIDDEN
|
||||
assert response.json() == {
|
||||
"detail": "User does not have access to this group."
|
||||
}
|
||||
|
||||
# switch to jerry who's got less quota/permissions
|
||||
app_with_auth.dependency_overrides[get_user_state] = lambda: UserState(
|
||||
db_session, "jerry@example.com"
|
||||
)
|
||||
client_jerry = TestClient(app_with_auth)
|
||||
|
||||
# frequency not allowed
|
||||
jerry_data = good_data.copy()
|
||||
jerry_data["group_id"] = "animated-characters"
|
||||
jerry_data["frequency"] = "hourly"
|
||||
jerry_data["id"] = "jerry-sheet-id"
|
||||
response = client_jerry.post("/sheet/create", json=jerry_data)
|
||||
assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
|
||||
assert response.json() == {
|
||||
"detail": "Invalid frequency selected for this group."
|
||||
}
|
||||
|
||||
jerry_data["frequency"] = "daily"
|
||||
# success for the first sheet, bad quota on second
|
||||
response = client_jerry.post("/sheet/create", json=jerry_data)
|
||||
assert response.status_code == HTTPStatus.CREATED
|
||||
|
||||
response = client_jerry.post("/sheet/create", json=jerry_data)
|
||||
assert response.status_code == HTTPStatus.TOO_MANY_REQUESTS
|
||||
assert response.json() == {
|
||||
"detail": "User has reached their sheet quota for this group."
|
||||
}
|
||||
|
||||
|
||||
def test_get_user_sheets_endpoint(client_with_auth, db_session):
|
||||
# no data
|
||||
response = client_with_auth.get("/sheet/mine")
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
assert response.json() == []
|
||||
|
||||
# with data
|
||||
db_session.add(
|
||||
models.Sheet(
|
||||
id="123",
|
||||
name="Test Sheet 1",
|
||||
author_id="morty@example.com",
|
||||
group_id="spaceship",
|
||||
frequency="hourly",
|
||||
)
|
||||
)
|
||||
db_session.commit()
|
||||
db_session.add_all(
|
||||
[
|
||||
models.Sheet(
|
||||
id="456",
|
||||
name="Test Sheet 2",
|
||||
author_id="morty@example.com",
|
||||
group_id="interdimensional",
|
||||
frequency="daily",
|
||||
),
|
||||
models.Sheet(
|
||||
id="789",
|
||||
name="Test Sheet 3",
|
||||
author_id="rick@example.com",
|
||||
group_id="interdimensional",
|
||||
frequency="hourly",
|
||||
),
|
||||
]
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
response = client_with_auth.get("/sheet/mine")
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
r = response.json()
|
||||
assert isinstance(r, list)
|
||||
assert len(r) == 2
|
||||
assert datetime.fromisoformat(r[0].pop("created_at"))
|
||||
assert datetime.fromisoformat(r[0].pop("last_url_archived_at"))
|
||||
assert datetime.fromisoformat(r[1].pop("created_at"))
|
||||
assert datetime.fromisoformat(r[1].pop("last_url_archived_at"))
|
||||
assert r[0] == {
|
||||
"id": "123",
|
||||
"author_id": "morty@example.com",
|
||||
"frequency": "hourly",
|
||||
"group_id": "spaceship",
|
||||
"name": "Test Sheet 1",
|
||||
}
|
||||
assert r[1] == {
|
||||
"id": "456",
|
||||
"author_id": "morty@example.com",
|
||||
"frequency": "daily",
|
||||
"group_id": "interdimensional",
|
||||
"name": "Test Sheet 2",
|
||||
}
|
||||
|
||||
|
||||
def test_delete_sheet_endpoint(client_with_auth, db_session):
|
||||
# missing sheet
|
||||
response = client_with_auth.delete("/sheet/123-sheet-id")
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
assert response.json() == {"id": "123-sheet-id", "deleted": False}
|
||||
|
||||
# add sheets for deletion
|
||||
db_session.add_all(
|
||||
[
|
||||
models.Sheet(
|
||||
id="123-sheet-id",
|
||||
name="Test Sheet 1",
|
||||
author_id="morty@example.com",
|
||||
group_id="interdimensional",
|
||||
frequency="daily",
|
||||
),
|
||||
models.Sheet(
|
||||
id="456-sheet-id",
|
||||
name="Test Sheet 2",
|
||||
author_id="rick@example.com",
|
||||
group_id="spaceship",
|
||||
frequency="hourly",
|
||||
),
|
||||
]
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
# morty can delete his
|
||||
response = client_with_auth.delete("/sheet/123-sheet-id")
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
assert response.json() == {"id": "123-sheet-id", "deleted": True}
|
||||
# but only once
|
||||
response = client_with_auth.delete("/sheet/123-sheet-id")
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
assert response.json() == {"id": "123-sheet-id", "deleted": False}
|
||||
# and not Rick's
|
||||
response = client_with_auth.delete("/sheet/456-sheet-id")
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
assert response.json() == {"id": "456-sheet-id", "deleted": False}
|
||||
|
||||
|
||||
class TestArchiveUserSheetEndpoint:
|
||||
@patch("app.web.routers.sheet.celery", return_value=MagicMock())
|
||||
def test_normal_flow(self, m_celery, client_with_auth, db_session):
|
||||
db_session.add(
|
||||
models.Sheet(
|
||||
id="123-sheet-id",
|
||||
name="Test Sheet 1",
|
||||
author_id="morty@example.com",
|
||||
group_id="spaceship",
|
||||
frequency="hourly",
|
||||
)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
m_signature = MagicMock()
|
||||
m_signature.apply_async.return_value = TaskResult(
|
||||
id="123-taskid", status=STATUS_PENDING, result=""
|
||||
)
|
||||
m_celery.signature.return_value = m_signature
|
||||
|
||||
r = client_with_auth.post("/sheet/123-sheet-id/archive")
|
||||
assert r.status_code == HTTPStatus.CREATED
|
||||
assert r.json() == {"id": "123-taskid"}
|
||||
m_celery.signature.assert_called_once()
|
||||
m_signature.apply_async.assert_called_once()
|
||||
|
||||
def test_token_auth(self, client_with_token, test_no_auth):
|
||||
test_no_auth(client_with_token.post, "/sheet/123-sheet-id/archive")
|
||||
|
||||
def test_missing_data(self, client_with_auth):
|
||||
r = client_with_auth.post("/sheet/123-sheet-id/archive")
|
||||
assert r.status_code == HTTPStatus.FORBIDDEN
|
||||
assert r.json() == {"detail": "No access to this sheet."}
|
||||
|
||||
def test_no_access(self, client_with_auth, db_session):
|
||||
db_session.add(
|
||||
models.Sheet(
|
||||
id="123-sheet-id",
|
||||
name="Test Sheet 1",
|
||||
author_id="rick@example.com",
|
||||
group_id="spaceship",
|
||||
frequency="hourly",
|
||||
)
|
||||
)
|
||||
db_session.commit()
|
||||
r = client_with_auth.post("/sheet/123-sheet-id/archive")
|
||||
assert r.status_code == HTTPStatus.FORBIDDEN
|
||||
assert r.json() == {"detail": "No access to this sheet."}
|
||||
|
||||
def test_user_not_in_group(self, client_with_auth, db_session):
|
||||
db_session.add(
|
||||
models.Sheet(
|
||||
id="123-sheet-id",
|
||||
name="Test Sheet 1",
|
||||
author_id="morty@example.com",
|
||||
group_id="interdimensional",
|
||||
frequency="hourly",
|
||||
)
|
||||
)
|
||||
db_session.commit()
|
||||
r = client_with_auth.post("/sheet/123-sheet-id/archive")
|
||||
assert r.status_code == HTTPStatus.FORBIDDEN
|
||||
assert r.json() == {
|
||||
"detail": "User does not have access to this group."
|
||||
}
|
||||
|
||||
def test_user_cannot_manually_trigger(self, client_with_auth, db_session):
|
||||
db_session.add(
|
||||
models.Sheet(
|
||||
id="123-sheet-id",
|
||||
name="Test Sheet 1",
|
||||
author_id="morty@example.com",
|
||||
group_id="default",
|
||||
frequency="hourly",
|
||||
)
|
||||
)
|
||||
db_session.commit()
|
||||
r = client_with_auth.post("/sheet/123-sheet-id/archive")
|
||||
assert r.status_code == HTTPStatus.TOO_MANY_REQUESTS
|
||||
assert r.json() == {
|
||||
"detail": "User cannot manually trigger sheet archiving in this group."
|
||||
}
|
||||
@@ -1,51 +1,53 @@
|
||||
from http import HTTPStatus
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.shared.constants import STATUS_FAILURE, STATUS_PENDING, STATUS_SUCCESS
|
||||
|
||||
|
||||
def test_endpoint_task_status_no_auth(client, test_no_auth):
|
||||
test_no_auth(client.get, "/task/test-task-id")
|
||||
|
||||
|
||||
@patch("app.web.endpoints.task.AsyncResult")
|
||||
@patch("app.web.routers.task.AsyncResult")
|
||||
def test_get_status_success(mock_async_result, client_with_auth):
|
||||
mock_async_result.return_value.status = "SUCCESS"
|
||||
mock_async_result.return_value.status = STATUS_SUCCESS
|
||||
mock_async_result.return_value.result = {"data": "some result"}
|
||||
|
||||
response = client_with_auth.get("/task/test-task-id")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
assert response.json() == {
|
||||
"id": "test-task-id",
|
||||
"status": "SUCCESS",
|
||||
"result": {"data": "some result"}
|
||||
"status": STATUS_SUCCESS,
|
||||
"result": {"data": "some result"},
|
||||
}
|
||||
|
||||
|
||||
@patch("app.web.endpoints.task.AsyncResult")
|
||||
@patch("app.web.routers.task.AsyncResult")
|
||||
def test_get_status_failure(mock_async_result, client_with_auth):
|
||||
|
||||
mock_async_result.return_value.status = "FAILURE"
|
||||
mock_async_result.return_value.status = STATUS_FAILURE
|
||||
mock_async_result.return_value.result = Exception("Some error")
|
||||
|
||||
response = client_with_auth.get("/task/test-task-id")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
assert response.json() == {
|
||||
"id": "test-task-id",
|
||||
"status": "FAILURE",
|
||||
"result": {"error": "Some error"}
|
||||
"status": STATUS_FAILURE,
|
||||
"result": {"error": "Some error"},
|
||||
}
|
||||
|
||||
|
||||
@patch("app.web.endpoints.task.AsyncResult")
|
||||
@patch("app.web.routers.task.AsyncResult")
|
||||
def test_get_status_pending(mock_async_result, client_with_auth):
|
||||
mock_async_result.return_value.status = "PENDING"
|
||||
mock_async_result.return_value.status = STATUS_PENDING
|
||||
mock_async_result.return_value.result = None
|
||||
|
||||
response = client_with_auth.get("/task/test-task-id")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
assert response.json() == {
|
||||
"id": "test-task-id",
|
||||
"status": "PENDING",
|
||||
"result": None
|
||||
"status": STATUS_PENDING,
|
||||
"result": None,
|
||||
}
|
||||
312
app/tests/web/routers/test_url.py
Normal file
312
app/tests/web/routers/test_url.py
Normal file
@@ -0,0 +1,312 @@
|
||||
import json
|
||||
from http import HTTPStatus
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from app.shared import schemas
|
||||
from app.shared.constants import STATUS_PENDING
|
||||
from app.shared.db import worker_crud
|
||||
from app.shared.schemas import ArchiveCreate, TaskResult
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
|
||||
|
||||
def test_archive_url_unauthenticated(client, test_no_auth):
|
||||
test_no_auth(client.post, "/url/archive")
|
||||
|
||||
|
||||
@patch("app.web.routers.url.UserState")
|
||||
@patch("app.web.routers.url.celery", return_value=MagicMock())
|
||||
def test_archive_url(m_celery, m2, client_with_auth):
|
||||
m_signature = MagicMock()
|
||||
m_signature.apply_async.return_value = TaskResult(
|
||||
id="123-456-789", status=STATUS_PENDING, result=""
|
||||
)
|
||||
m_celery.signature.return_value = m_signature
|
||||
|
||||
m_user_state = MagicMock()
|
||||
m2.return_value = m_user_state
|
||||
|
||||
# url is too short
|
||||
response = client_with_auth.post("/url/archive", json={"url": "bad"})
|
||||
assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
|
||||
assert (
|
||||
response.json()["detail"][0]["msg"]
|
||||
== "String should have at least 5 characters"
|
||||
)
|
||||
m_celery.signature.assert_not_called()
|
||||
|
||||
# url is invalid
|
||||
response = client_with_auth.post(
|
||||
"/url/archive", json={"url": "example.com"}
|
||||
)
|
||||
assert response.status_code == HTTPStatus.BAD_REQUEST
|
||||
assert response.json()["detail"] == "Invalid URL received."
|
||||
|
||||
# valid request
|
||||
m_user_state.has_quota_max_monthly_urls.return_value = True
|
||||
m_user_state.has_quota_max_monthly_mbs.return_value = True
|
||||
response = client_with_auth.post(
|
||||
"/url/archive", json={"url": "https://example.com"}
|
||||
)
|
||||
assert response.status_code == HTTPStatus.CREATED
|
||||
assert response.json() == {"id": "123-456-789"}
|
||||
m_celery.signature.assert_called_once()
|
||||
m_signature.apply_async.assert_called_once()
|
||||
called_val = m_celery.signature.call_args
|
||||
assert called_val[0][0] == "create_archive_task"
|
||||
assert json.loads(called_val[1]["args"][0]) == {
|
||||
"id": None,
|
||||
"url": "https://example.com",
|
||||
"result": None,
|
||||
"public": False,
|
||||
"author_id": "rick@example.com",
|
||||
"group_id": "default",
|
||||
"tags": None,
|
||||
"sheet_id": None,
|
||||
"store_until": None,
|
||||
"urls": None,
|
||||
}
|
||||
m_user_state.has_quota_max_monthly_urls.assert_called_once()
|
||||
m_user_state.has_quota_max_monthly_mbs.assert_called_once()
|
||||
m_user_state.in_group.assert_called_once_with("default")
|
||||
|
||||
# user is not in group
|
||||
m_user_state.in_group.return_value = False
|
||||
response = client_with_auth.post(
|
||||
"/url/archive",
|
||||
json={"url": "https://example.com", "group_id": "new-group"},
|
||||
)
|
||||
assert response.status_code == HTTPStatus.FORBIDDEN
|
||||
assert (
|
||||
response.json()["detail"] == "User does not have access to this group."
|
||||
)
|
||||
m_user_state.in_group.assert_called_with("new-group")
|
||||
|
||||
# user is in group
|
||||
m_user_state.in_group.return_value = True
|
||||
response = client_with_auth.post(
|
||||
"/url/archive",
|
||||
json={"url": "https://example.com", "group_id": "spaceship"},
|
||||
)
|
||||
assert response.status_code == HTTPStatus.CREATED
|
||||
assert response.json() == {"id": "123-456-789"}
|
||||
assert m_celery.signature.call_count == 2
|
||||
assert m_signature.apply_async.call_count == 2
|
||||
called_val = m_celery.signature.call_args
|
||||
assert json.loads(called_val[1]["args"][0])["group_id"] == "spaceship"
|
||||
m_user_state.in_group.assert_called_with("spaceship")
|
||||
|
||||
# user is over monthly URL quota
|
||||
m_user_state.has_quota_max_monthly_urls.return_value = False
|
||||
m_user_state.has_quota_max_monthly_mbs.return_value = True
|
||||
response = client_with_auth.post(
|
||||
"/url/archive",
|
||||
json={"url": "https://example.com", "group_id": "spaceship"},
|
||||
)
|
||||
assert response.status_code == HTTPStatus.TOO_MANY_REQUESTS
|
||||
assert (
|
||||
response.json()["detail"] == "User has reached their monthly URL quota."
|
||||
)
|
||||
m_user_state.has_quota_max_monthly_urls.assert_called_with("spaceship")
|
||||
|
||||
# user is over monthly MB quota
|
||||
m_user_state.has_quota_max_monthly_urls.return_value = True
|
||||
m_user_state.has_quota_max_monthly_mbs.return_value = False
|
||||
response = client_with_auth.post(
|
||||
"/url/archive",
|
||||
json={"url": "https://example.com", "group_id": "spacesuit"},
|
||||
)
|
||||
assert response.status_code == HTTPStatus.TOO_MANY_REQUESTS
|
||||
assert (
|
||||
response.json()["detail"] == "User has reached their monthly MB quota."
|
||||
)
|
||||
m_user_state.has_quota_max_monthly_mbs.assert_called_with("spacesuit")
|
||||
assert m_celery.signature.call_count == 2
|
||||
assert m_signature.apply_async.call_count == 2
|
||||
|
||||
|
||||
@patch("app.web.routers.url.UserState")
|
||||
def test_archive_url_quotas(m1, client_with_auth):
|
||||
m_user_state = MagicMock()
|
||||
m1.return_value = m_user_state
|
||||
|
||||
# misses on monthly URLs quota
|
||||
m_user_state.has_quota_max_monthly_urls.return_value = False
|
||||
response = client_with_auth.post(
|
||||
"/url/archive", json={"url": "https://example.com"}
|
||||
)
|
||||
assert response.status_code == HTTPStatus.TOO_MANY_REQUESTS
|
||||
assert (
|
||||
response.json()["detail"] == "User has reached their monthly URL quota."
|
||||
)
|
||||
m_user_state.has_quota_max_monthly_urls.assert_called_once()
|
||||
|
||||
# misses on monthly MBs quota
|
||||
m_user_state.has_quota_max_monthly_urls.return_value = True
|
||||
m_user_state.has_quota_max_monthly_mbs.return_value = False
|
||||
response = client_with_auth.post(
|
||||
"/url/archive", json={"url": "https://example.com"}
|
||||
)
|
||||
assert response.status_code == HTTPStatus.TOO_MANY_REQUESTS
|
||||
assert (
|
||||
response.json()["detail"] == "User has reached their monthly MB quota."
|
||||
)
|
||||
m_user_state.has_quota_max_monthly_mbs.assert_called_once()
|
||||
|
||||
|
||||
@patch("app.web.routers.url.celery", return_value=MagicMock())
|
||||
def test_archive_url_with_api_token(m_celery, client_with_token):
|
||||
m_signature = MagicMock()
|
||||
m_signature.apply_async.return_value = TaskResult(
|
||||
id="123-456-789", status=STATUS_PENDING, result=""
|
||||
)
|
||||
m_celery.signature.return_value = m_signature
|
||||
response = client_with_token.post(
|
||||
"/url/archive",
|
||||
json={"url": "https://example.com", "author_id": "someone@example.com"},
|
||||
)
|
||||
assert response.status_code == HTTPStatus.CREATED
|
||||
assert response.json() == {"id": "123-456-789"}
|
||||
m_celery.signature.assert_called_once()
|
||||
m_signature.apply_async.assert_called_once()
|
||||
called_val = m_celery.signature.call_args
|
||||
assert called_val[0][0] == "create_archive_task"
|
||||
assert json.loads(called_val[1]["args"][0]) == {
|
||||
"id": None,
|
||||
"url": "https://example.com",
|
||||
"result": None,
|
||||
"public": False,
|
||||
"author_id": "someone@example.com",
|
||||
"group_id": "default",
|
||||
"tags": None,
|
||||
"sheet_id": None,
|
||||
"store_until": None,
|
||||
"urls": None,
|
||||
}
|
||||
|
||||
# missing id should use ALLOW_ANY_EMAIL
|
||||
response = client_with_token.post(
|
||||
"/url/archive", json={"url": "https://example.com", "author_id": None}
|
||||
)
|
||||
assert response.status_code == HTTPStatus.CREATED
|
||||
called_val = m_celery.signature.call_args
|
||||
assert called_val[0][0] == "create_archive_task"
|
||||
assert json.loads(called_val[1]["args"][0]) == {
|
||||
"id": None,
|
||||
"url": "https://example.com",
|
||||
"result": None,
|
||||
"public": False,
|
||||
"author_id": ALLOW_ANY_EMAIL,
|
||||
"group_id": "default",
|
||||
"tags": None,
|
||||
"sheet_id": None,
|
||||
"store_until": None,
|
||||
"urls": None,
|
||||
}
|
||||
|
||||
|
||||
def test_search_by_url_unauthenticated(client, test_no_auth):
|
||||
test_no_auth(client.get, "/url/search")
|
||||
|
||||
|
||||
def test_search_by_url(client_with_auth, client_with_token, db_session):
|
||||
# tests the search endpoint, including through some db data for the endpoint params
|
||||
response = client_with_auth.get("/url/search")
|
||||
assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
|
||||
assert response.json()["detail"][0]["msg"] == "Field required"
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com")
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
assert response.json() == []
|
||||
|
||||
for i in range(11):
|
||||
worker_crud.create_archive(
|
||||
db_session,
|
||||
ArchiveCreate(
|
||||
id=f"url-456-{i}",
|
||||
url="https://example.com"
|
||||
if i < 10
|
||||
else "https://something-else.com",
|
||||
result={},
|
||||
public=True,
|
||||
author_id="rick@example.com",
|
||||
),
|
||||
[],
|
||||
[],
|
||||
)
|
||||
# NB: this insertion is too fast for the ordering to be correct as they are within the same second
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com")
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
assert len(j := response.json()) == 10
|
||||
assert "url-456-0" in [i["id"] for i in j]
|
||||
assert "url-456-9" in [i["id"] for i in j]
|
||||
assert "url-456-10" not in [i["id"] for i in j]
|
||||
assert j[0].keys() == schemas.ArchiveResult.model_fields.keys()
|
||||
|
||||
response = client_with_auth.get(
|
||||
"/url/search?url=https://example.com&limit=5"
|
||||
)
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
assert len(response.json()) == 5
|
||||
|
||||
response = client_with_auth.get(
|
||||
"/url/search?url=https://example.com&skip=5&limit=2"
|
||||
)
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
assert len(response.json()) == 2
|
||||
|
||||
response = client_with_auth.get(
|
||||
"/url/search?url=https://example.com&archived_before=2010-01-01"
|
||||
)
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
assert len(response.json()) == 0
|
||||
|
||||
response = client_with_auth.get(
|
||||
"/url/search?url=https://example.com&archived_after=2010-01-01"
|
||||
)
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
assert len(response.json()) == 10
|
||||
|
||||
# API token will also work
|
||||
response = client_with_token.get(
|
||||
"/url/search?url=https://example.com&archived_after=2010-01-01"
|
||||
)
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
assert len(response.json()) == 10
|
||||
|
||||
|
||||
@patch("app.web.routers.url.UserState")
|
||||
def test_search_no_read_access(mock_user_state, client_with_auth):
|
||||
mock_user_state.return_value.read = False
|
||||
mock_user_state.return_value.read_public = False
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com")
|
||||
assert response.status_code == HTTPStatus.FORBIDDEN
|
||||
assert response.json() == {"detail": "User does not have read access."}
|
||||
|
||||
|
||||
def test_delete_task_unauthenticated(client, test_no_auth):
|
||||
test_no_auth(client.delete, "/url/123-456-789")
|
||||
|
||||
|
||||
def test_delete_task(client_with_auth, db_session):
|
||||
response = client_with_auth.delete("/url/delete-123-456-789")
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
assert response.json() == {"id": "delete-123-456-789", "deleted": False}
|
||||
|
||||
worker_crud.create_archive(
|
||||
db_session,
|
||||
ArchiveCreate(
|
||||
id="delete-123-456-789",
|
||||
url="https://example.com",
|
||||
result={},
|
||||
public=True,
|
||||
author_id="morty@example.com",
|
||||
),
|
||||
[],
|
||||
[],
|
||||
)
|
||||
|
||||
response = client_with_auth.delete("/url/delete-123-456-789")
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
assert response.json() == {"id": "delete-123-456-789", "deleted": True}
|
||||
@@ -1,31 +1,55 @@
|
||||
import os
|
||||
import shutil
|
||||
from http import HTTPStatus
|
||||
from unittest.mock import patch
|
||||
|
||||
import alembic.config
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import shutil
|
||||
from app.web.main import app_factory
|
||||
from app.web.utils.metrics import EXCEPTION_COUNTER
|
||||
|
||||
import pytest
|
||||
|
||||
def test_lifespan(app):
|
||||
with TestClient(app) as client:
|
||||
r = client.get("/health")
|
||||
assert r.status_code == 200
|
||||
assert r.status_code == HTTPStatus.OK
|
||||
assert r.json() == {"status": "ok"}
|
||||
|
||||
def test_alembic(db_session):
|
||||
import alembic.config
|
||||
alembic.config.main(argv=['--raiseerr', 'upgrade', 'head'])
|
||||
alembic.config.main(argv=['--raiseerr', 'downgrade', 'base'])
|
||||
|
||||
@patch("app.web.endpoints.url.crud.soft_delete_archive", side_effect=Exception('mocked error'))
|
||||
def test_alembic(db_session):
|
||||
alembic.config.main(
|
||||
argv=[
|
||||
"-c",
|
||||
"./app/migrations/alembic.ini",
|
||||
"--raiseerr",
|
||||
"upgrade",
|
||||
"head",
|
||||
]
|
||||
)
|
||||
alembic.config.main(
|
||||
argv=[
|
||||
"-c",
|
||||
"./app/migrations/alembic.ini",
|
||||
"--raiseerr",
|
||||
"downgrade",
|
||||
"base",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@patch(
|
||||
"app.web.routers.url.crud.soft_delete_archive",
|
||||
side_effect=Exception("mocked error"),
|
||||
)
|
||||
def test_logging_middleware(m1, client_with_auth):
|
||||
from app.web.utils.metrics import EXCEPTION_COUNTER
|
||||
assert len(EXCEPTION_COUNTER.collect()[0].samples) == 0
|
||||
with pytest.raises(Exception, match="mocked error"):
|
||||
client_with_auth.delete("/url/123")
|
||||
# creates one empty and one from above
|
||||
assert len(EXCEPTION_COUNTER.collect()[0].samples) == 2
|
||||
|
||||
|
||||
|
||||
def test_serve_local_archive_logic(get_settings):
|
||||
# create a test file first
|
||||
@@ -36,13 +60,13 @@ def test_serve_local_archive_logic(get_settings):
|
||||
try:
|
||||
# modify the settings
|
||||
get_settings.SERVE_LOCAL_ARCHIVE = "/app/local_archive_test"
|
||||
from app.web.main import app_factory
|
||||
|
||||
app = app_factory(get_settings)
|
||||
|
||||
|
||||
# test
|
||||
client = TestClient(app)
|
||||
r = client.get("/app/local_archive_test/temp.txt")
|
||||
assert r.status_code == 200
|
||||
assert r.status_code == HTTPStatus.OK
|
||||
assert r.text == "test"
|
||||
finally:
|
||||
# cleanup
|
||||
|
||||
@@ -1,101 +1,168 @@
|
||||
from http import HTTPStatus
|
||||
from unittest import mock
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
import pytest
|
||||
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
from app.web.db.user_state import UserState
|
||||
from app.web.security import (
|
||||
authenticate_user,
|
||||
get_token_or_user_auth,
|
||||
get_user_auth,
|
||||
get_user_state,
|
||||
secure_compare,
|
||||
token_api_key_auth,
|
||||
)
|
||||
|
||||
|
||||
def test_secure_compare():
|
||||
from app.web.security import secure_compare
|
||||
|
||||
assert secure_compare("test", "test")
|
||||
assert not secure_compare("test", "test2")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_token_or_user_auth_with_api():
|
||||
from app.web.security import get_token_or_user_auth
|
||||
mock_api = HTTPAuthorizationCredentials(scheme="lorem", credentials="this_is_the_test_api_token")
|
||||
mock_api = HTTPAuthorizationCredentials(
|
||||
scheme="lorem", credentials="this_is_the_test_api_token"
|
||||
)
|
||||
assert await get_token_or_user_auth(mock_api) == ALLOW_ANY_EMAIL
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_token_or_user_auth_with_user():
|
||||
from app.web.security import get_token_or_user_auth
|
||||
bad_user = HTTPAuthorizationCredentials(scheme="ipsum", credentials="invalid")
|
||||
e: pytest.ExceptionInfo = None
|
||||
bad_user = HTTPAuthorizationCredentials(
|
||||
scheme="ipsum", credentials="invalid"
|
||||
)
|
||||
with pytest.raises(HTTPException) as e:
|
||||
await get_token_or_user_auth(bad_user)
|
||||
assert e.value.status_code == 401
|
||||
assert e.value.status_code == HTTPStatus.UNAUTHORIZED
|
||||
assert e.value.detail == "invalid access_token"
|
||||
|
||||
|
||||
@patch("app.web.security.authenticate_user", return_value=(True, "summer@example.com"))
|
||||
@patch(
|
||||
"app.web.security.authenticate_user",
|
||||
return_value=(True, "summer@example.com"),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_auth(m1):
|
||||
from app.web.security import get_user_auth
|
||||
good_user = HTTPAuthorizationCredentials(scheme="ipsum", credentials="valid-and-good")
|
||||
good_user = HTTPAuthorizationCredentials(
|
||||
scheme="ipsum", credentials="valid-and-good"
|
||||
)
|
||||
assert await get_user_auth(good_user) == "summer@example.com"
|
||||
|
||||
|
||||
@patch("app.web.security.secure_compare", return_value=False)
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_api_key_auth_exception(m1):
|
||||
from app.web.security import token_api_key_auth
|
||||
|
||||
e: pytest.ExceptionInfo = None
|
||||
with pytest.raises(HTTPException) as e:
|
||||
await token_api_key_auth(HTTPAuthorizationCredentials(scheme="ipsum", credentials="does-not-matter"), auto_error=True)
|
||||
assert e.value.status_code == 401
|
||||
await token_api_key_auth(
|
||||
HTTPAuthorizationCredentials(
|
||||
scheme="ipsum", credentials="does-not-matter"
|
||||
),
|
||||
auto_error=True,
|
||||
)
|
||||
assert e.value.status_code == HTTPStatus.UNAUTHORIZED
|
||||
assert e.value.detail == "Wrong auth credentials"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_user():
|
||||
from app.web.security import authenticate_user
|
||||
|
||||
assert authenticate_user("test") == (False, "invalid access_token")
|
||||
assert authenticate_user(123) == (False, "invalid access_token")
|
||||
|
||||
with patch("app.web.security.requests.get") as mock_get:
|
||||
# bad response from oauth2
|
||||
mock_get.return_value.status_code = 403
|
||||
assert authenticate_user("this-will-call-requests") == (False, "invalid token")
|
||||
mock_get.return_value.status_code = HTTPStatus.FORBIDDEN
|
||||
assert authenticate_user("this-will-call-requests") == (
|
||||
False,
|
||||
"invalid token",
|
||||
)
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
# 200 but invalid json
|
||||
mock_get.return_value.status_code = 200
|
||||
assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID")
|
||||
mock_get.return_value.status_code = HTTPStatus.OK
|
||||
assert authenticate_user("this-will-call-requests") == (
|
||||
False,
|
||||
"token does not belong to valid APP_ID",
|
||||
)
|
||||
assert mock_get.call_count == 2
|
||||
|
||||
# 200 but invalid azp and aud
|
||||
mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "not_an_app"}
|
||||
assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID")
|
||||
mock_get.return_value.json.return_value = {
|
||||
"email": "summer@example.com",
|
||||
"azp": "not_an_app",
|
||||
}
|
||||
assert authenticate_user("this-will-call-requests") == (
|
||||
False,
|
||||
"token does not belong to valid APP_ID",
|
||||
)
|
||||
|
||||
mock_get.return_value.json.return_value = {"email": "summer@example.com", "aud": "not_an_app"}
|
||||
assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID")
|
||||
mock_get.return_value.json.return_value = {
|
||||
"email": "summer@example.com",
|
||||
"aud": "not_an_app",
|
||||
}
|
||||
assert authenticate_user("this-will-call-requests") == (
|
||||
False,
|
||||
"token does not belong to valid APP_ID",
|
||||
)
|
||||
|
||||
mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "not_an_app", "aud": "not_an_app"}
|
||||
assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID")
|
||||
mock_get.return_value.json.return_value = {
|
||||
"email": "summer@example.com",
|
||||
"azp": "not_an_app",
|
||||
"aud": "not_an_app",
|
||||
}
|
||||
assert authenticate_user("this-will-call-requests") == (
|
||||
False,
|
||||
"token does not belong to valid APP_ID",
|
||||
)
|
||||
|
||||
# blocked email
|
||||
mock_get.return_value.json.return_value = {"email": "blocked@example.com", "azp": "test_app_id_1", "aud": "not_an_app"}
|
||||
assert authenticate_user("this-will-call-requests") == (False, "email 'blocked@example.com' not allowed")
|
||||
mock_get.return_value.json.return_value = {
|
||||
"email": "blocked@example.com",
|
||||
"azp": "test_app_id_1",
|
||||
"aud": "not_an_app",
|
||||
}
|
||||
assert authenticate_user("this-will-call-requests") == (
|
||||
False,
|
||||
"email 'blocked@example.com' not allowed",
|
||||
)
|
||||
|
||||
# not verified
|
||||
mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "not_an_app", "aud": "test_app_id_1"}
|
||||
assert authenticate_user("this-will-call-requests") == (False, "email 'summer@example.com' not verified")
|
||||
mock_get.return_value.json.return_value = {
|
||||
"email": "summer@example.com",
|
||||
"azp": "not_an_app",
|
||||
"aud": "test_app_id_1",
|
||||
}
|
||||
assert authenticate_user("this-will-call-requests") == (
|
||||
False,
|
||||
"email 'summer@example.com' not verified",
|
||||
)
|
||||
|
||||
# token expired
|
||||
mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "test_app_id_2", "email_verified": "true"}
|
||||
assert authenticate_user("this-will-call-requests") == (False, "Token expired")
|
||||
mock_get.return_value.json.return_value = {
|
||||
"email": "summer@example.com",
|
||||
"azp": "test_app_id_2",
|
||||
"email_verified": "true",
|
||||
}
|
||||
assert authenticate_user("this-will-call-requests") == (
|
||||
False,
|
||||
"Token expired",
|
||||
)
|
||||
|
||||
# 200 and valid azp and aup and verified
|
||||
mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "test_app_id_2", "email_verified": "true", "expires_in": 100}
|
||||
assert authenticate_user("this-will-call-requests") == (True, "summer@example.com")
|
||||
mock_get.return_value.json.return_value = {
|
||||
"email": "summer@example.com",
|
||||
"azp": "test_app_id_2",
|
||||
"email_verified": "true",
|
||||
"expires_in": 100,
|
||||
}
|
||||
assert authenticate_user("this-will-call-requests") == (
|
||||
True,
|
||||
"summer@example.com",
|
||||
)
|
||||
assert mock_get.call_count == 9
|
||||
|
||||
|
||||
@@ -104,6 +171,7 @@ async def test_authenticate_user():
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_user_with_id_token(m_init):
|
||||
from firebase_admin import exceptions
|
||||
|
||||
from app.web.security import authenticate_user
|
||||
|
||||
with pytest.raises(ValueError) as e:
|
||||
@@ -113,12 +181,20 @@ async def test_authenticate_user_with_id_token(m_init):
|
||||
with patch("app.web.security.auth.verify_id_token") as mock_verify:
|
||||
# missing email
|
||||
mock_verify.return_value = {"email": None}
|
||||
assert authenticate_user("fake_token") == (False, "email not found in token")
|
||||
assert authenticate_user("fake_token") == (
|
||||
False,
|
||||
"email not found in token",
|
||||
)
|
||||
assert mock_verify.call_count == 1
|
||||
|
||||
# blocked email
|
||||
mock_verify.return_value = {"email": "blocked@example.com", }
|
||||
assert authenticate_user("fake_token") == (False, "email 'blocked@example.com' not allowed")
|
||||
mock_verify.return_value = {
|
||||
"email": "blocked@example.com",
|
||||
}
|
||||
assert authenticate_user("fake_token") == (
|
||||
False,
|
||||
"email 'blocked@example.com' not allowed",
|
||||
)
|
||||
assert mock_verify.call_count == 2
|
||||
|
||||
# valid email
|
||||
@@ -132,17 +208,16 @@ async def test_authenticate_user_with_id_token(m_init):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_user_exception():
|
||||
from app.web.security import authenticate_user
|
||||
with patch("app.web.security.requests.get") as mock_get:
|
||||
mock_get.return_value.status_code = 200
|
||||
mock_get.return_value.status_code = HTTPStatus.OK
|
||||
mock_get.return_value.json.side_effect = Exception("mocked error")
|
||||
assert authenticate_user("this-will-call-requests") == (False, "exception occurred")
|
||||
assert authenticate_user("this-will-call-requests") == (
|
||||
False,
|
||||
"exception occurred",
|
||||
)
|
||||
|
||||
|
||||
def test_get_user_state():
|
||||
from app.web.security import get_user_state
|
||||
from app.web.db.user_state import UserState
|
||||
|
||||
mock_session = Mock()
|
||||
test_email = "test@example.com"
|
||||
|
||||
|
||||
@@ -1,29 +1,47 @@
|
||||
from datetime import datetime
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.shared.db import models
|
||||
from app.shared import schemas
|
||||
from auto_archiver.core import Media, Metadata
|
||||
|
||||
from app.shared import constants, schemas
|
||||
from app.shared.db import models
|
||||
from app.web.utils.misc import get_all_urls
|
||||
from app.worker.main import create_archive_task, create_sheet_task
|
||||
|
||||
class Test_create_archive_task():
|
||||
|
||||
class TestCreateArchiveTask:
|
||||
URL = "https://example-live.com"
|
||||
archive = schemas.ArchiveCreate(url=URL, tags=["tag-celery"], public=True, author_id="rick@example.com", group_id="interstellar")
|
||||
archive = schemas.ArchiveCreate(
|
||||
url=URL,
|
||||
tags=["tag-celery"],
|
||||
public=True,
|
||||
author_id="rick@example.com",
|
||||
group_id="interstellar",
|
||||
)
|
||||
|
||||
@patch("app.worker.main.ArchivingOrchestrator")
|
||||
@patch("app.worker.main.get_all_urls", return_value=[])
|
||||
@patch("app.worker.main.insert_result_into_db")
|
||||
@patch("app.worker.main.get_store_until", return_value=datetime.now())
|
||||
@patch("app.worker.main.get_orchestrator_args", return_value=["arg1", "arg2"])
|
||||
@patch(
|
||||
"app.worker.main.get_orchestrator_args", return_value=["arg1", "arg2"]
|
||||
)
|
||||
@patch("celery.app.task.Task.request")
|
||||
def test_success(self, m_req, m_args, m_store, m_insert, m_urls, m_orchestrator, db_session):
|
||||
from app.worker.main import create_archive_task
|
||||
|
||||
def test_success(
|
||||
self,
|
||||
m_req,
|
||||
m_args,
|
||||
m_store,
|
||||
m_insert,
|
||||
m_urls,
|
||||
m_orchestrator,
|
||||
db_session,
|
||||
):
|
||||
m_req.id = "this-just-in"
|
||||
m_orchestrator.return_value.feed.return_value = iter([Metadata().set_url(self.URL).success()])
|
||||
m_orchestrator.return_value.feed.return_value = iter(
|
||||
[Metadata().set_url(self.URL).success()]
|
||||
)
|
||||
|
||||
task = create_archive_task(self.archive.model_dump_json())
|
||||
|
||||
@@ -39,15 +57,15 @@ class Test_create_archive_task():
|
||||
assert len(task["media"]) == 0
|
||||
|
||||
def test_raise_invalid(self):
|
||||
from app.worker.main import create_archive_task
|
||||
with pytest.raises(Exception):
|
||||
with pytest.raises(Exception) as _:
|
||||
create_archive_task(self.archive.model_dump_json())
|
||||
|
||||
@patch("app.worker.main.ArchivingOrchestrator")
|
||||
@patch("app.worker.main.get_orchestrator_args")
|
||||
def test_raise_db_error(self, m_args, m_orchestrator):
|
||||
from app.worker.main import create_archive_task
|
||||
m_orchestrator.return_value.feed.side_effect = Exception("Orchestrator failed")
|
||||
m_orchestrator.return_value.feed.side_effect = Exception(
|
||||
"Orchestrator failed"
|
||||
)
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
create_archive_task(self.archive.model_dump_json())
|
||||
@@ -59,7 +77,6 @@ class Test_create_archive_task():
|
||||
@patch("app.worker.main.insert_result_into_db", return_value=None)
|
||||
@patch("app.worker.main.get_orchestrator_args")
|
||||
def test_raise_empty_result(self, m_args, m_insert, m_orchestrator):
|
||||
from app.worker.main import create_archive_task
|
||||
m_orchestrator.return_value.feed.return_value = iter([None])
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
@@ -68,61 +85,83 @@ class Test_create_archive_task():
|
||||
m_orchestrator.return_value.feed.assert_called_once()
|
||||
|
||||
|
||||
class Test_create_sheet_task():
|
||||
class TestCreateSheetTask:
|
||||
URL = "https://example-live.com"
|
||||
sheet = schemas.SubmitSheet(sheet_id="123", author_id="rick@example.com", group_id="interstellar", tags=["spaceship"])
|
||||
sheet = schemas.SubmitSheet(
|
||||
sheet_id="123",
|
||||
author_id="rick@example.com",
|
||||
group_id="interstellar",
|
||||
tags=["spaceship"],
|
||||
)
|
||||
|
||||
@patch("app.worker.main.get_all_urls", return_value=[])
|
||||
@patch("app.worker.main.ArchivingOrchestrator")
|
||||
@patch("app.worker.main.models.generate_uuid", return_value="constant-uuid")
|
||||
@patch("app.worker.main.get_store_until", return_value=datetime.now())
|
||||
@patch("app.worker.main.get_orchestrator_args")
|
||||
def test_success(self, m_args, m_store, m_uuid, m_orchestrator, m_urls, db_session):
|
||||
from app.worker.main import create_sheet_task
|
||||
|
||||
assert db_session.query(models.Archive).filter(models.Archive.url == self.URL).count() == 0
|
||||
def test_success(
|
||||
self, m_args, m_store, m_uuid, m_orchestrator, m_urls, db_session
|
||||
):
|
||||
assert (
|
||||
db_session.query(models.Archive)
|
||||
.filter(models.Archive.url == self.URL)
|
||||
.count()
|
||||
== 0
|
||||
)
|
||||
|
||||
mock_metadata = Metadata().set_url(self.URL).success()
|
||||
mock_metadata.add_media(Media("fn1.txt", urls=["outcome1.com"]))
|
||||
|
||||
m_orchestrator.return_value.feed.return_value = iter([False, mock_metadata, mock_metadata])
|
||||
m_orchestrator.return_value.feed.return_value = iter(
|
||||
[False, mock_metadata, mock_metadata]
|
||||
)
|
||||
|
||||
res = create_sheet_task(self.sheet.model_dump_json())
|
||||
|
||||
m_args.assert_called_once_with("interstellar", True, ["--gsheet_feeder.sheet_id", "123"])
|
||||
m_args.assert_called_once_with(
|
||||
"interstellar", True, [constants.SHEET_ID, "123"]
|
||||
)
|
||||
m_orchestrator.return_value.setup.assert_called_once()
|
||||
m_orchestrator.return_value.feed.assert_called_once()
|
||||
m_store.assert_called_with("interstellar")
|
||||
m_store.call_count == 2
|
||||
m_uuid.call_count == 2
|
||||
assert type(res) == dict
|
||||
assert m_store.call_count == 2
|
||||
assert m_uuid.call_count == 2
|
||||
assert isinstance(res, dict)
|
||||
assert res["stats"]["archived"] == 1
|
||||
assert res["stats"]["failed"] == 1
|
||||
assert len(res["stats"]["errors"]) == 1
|
||||
assert res["sheet_id"] == "123"
|
||||
assert res["success"]
|
||||
assert type(res["time"]) == datetime
|
||||
assert isinstance(res["time"], datetime)
|
||||
|
||||
# query created archive entry
|
||||
inserted = db_session.query(models.Archive).filter(models.Archive.url == self.URL).one()
|
||||
inserted = (
|
||||
db_session.query(models.Archive)
|
||||
.filter(models.Archive.url == self.URL)
|
||||
.one()
|
||||
)
|
||||
assert inserted is not None
|
||||
assert inserted.url == self.URL
|
||||
assert len(inserted.tags) == 1
|
||||
assert inserted.tags[0].id == "spaceship"
|
||||
assert inserted.group_id == "interstellar"
|
||||
assert inserted.author_id == "rick@example.com"
|
||||
assert inserted.public == False
|
||||
assert inserted.public is False
|
||||
|
||||
|
||||
def test_get_all_urls(db_session):
|
||||
from app.worker.main import get_all_urls
|
||||
|
||||
meta = Metadata().set_url("https://example.com")
|
||||
m1 = meta.add_media(Media("fn1.txt", urls=["outcome1.com"]))
|
||||
m2 = meta.add_media(Media("fn2.txt", urls=["outcome2.com"]))
|
||||
m3 = meta.add_media(Media("fn3.txt", urls=["outcome3.com"]))
|
||||
m1.set("screenshot", Media("screenshot.png", urls=["screenshot.com"]))
|
||||
m2.set("thumbnails", [Media("thumb1.png", urls=["thumb1.com"]), Media("thumb2.png", urls=["thumb2.com"])])
|
||||
m2.set(
|
||||
"thumbnails",
|
||||
[
|
||||
Media("thumb1.png", urls=["thumb1.com"]),
|
||||
Media("thumb2.png", urls=["thumb2.com"]),
|
||||
],
|
||||
)
|
||||
m3.set("ssl_data", Media("ssl_data.txt", urls=["ssl_data.com"]).to_dict())
|
||||
m3.set("bad_data", {"bad": "dict is ignored"})
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from app.web.main import app_factory
|
||||
|
||||
app = app_factory
|
||||
|
||||
app = app_factory
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
VERSION = "0.9.4"
|
||||
VERSION = "0.10.0"
|
||||
|
||||
API_DESCRIPTION = """
|
||||
#### API for the Auto-Archiver project, a tool to archive web pages and Google Sheets.
|
||||
|
||||
**Usage notes:**
|
||||
- The API requires a Bearer token for most operations, which you can obtain by logging in with your Google account.
|
||||
- You can use this API to archive single URLs or entire Google Sheets.
|
||||
- You can use this API to archive single URLs or entire Google Sheets.
|
||||
- Once you submit a URL or Sheet for archiving, the API will return a task_id that you can use to check the status of the archiving process. It works asynchronously.
|
||||
"""
|
||||
BREAKING_CHANGES = {"minVersion": "0.4.0", "message": "The latest update has breaking changes, please update the extension to the most recent version."}
|
||||
BREAKING_CHANGES = {
|
||||
"minVersion": "0.4.0",
|
||||
"message": "The latest update has breaking changes, please update the extension to the most recent version.",
|
||||
}
|
||||
|
||||
# changing this will corrupt the database logic
|
||||
ALLOW_ANY_EMAIL = "*"
|
||||
|
||||
@@ -1,18 +1,30 @@
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache
|
||||
from sqlalchemy.orm import Session, load_only
|
||||
from sqlalchemy import Column, or_, func, select
|
||||
from loguru import logger
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Any, Type
|
||||
|
||||
from cachetools import LRUCache, cached
|
||||
from cachetools.keys import hashkey
|
||||
from loguru import logger
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
ColumnElement,
|
||||
ScalarResult,
|
||||
false,
|
||||
func,
|
||||
not_,
|
||||
or_,
|
||||
select,
|
||||
true,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session, load_only
|
||||
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
from app.shared.db import models
|
||||
from app.shared.db.models import Archive, Group
|
||||
from app.shared.settings import get_settings
|
||||
from app.shared.user_groups import UserGroups
|
||||
from app.shared.utils.misc import fnv1a_hash_mod
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
from app.web.utils.misc import convert_priority_to_queue_dict
|
||||
|
||||
|
||||
@@ -22,24 +34,48 @@ DATABASE_QUERY_LIMIT = get_settings().DATABASE_QUERY_LIMIT
|
||||
def get_limit(user_limit: int):
|
||||
return max(1, min(user_limit, DATABASE_QUERY_LIMIT))
|
||||
|
||||
|
||||
# --------------- TASK = Archive
|
||||
|
||||
|
||||
def base_query(db: Session):
|
||||
# NOTE: load_only is for optimization and not obfuscation, use .with_entities() if needed
|
||||
return db.query(models.Archive)\
|
||||
.filter(models.Archive.deleted == False)\
|
||||
.options(load_only(models.Archive.id, models.Archive.created_at, models.Archive.url, models.Archive.result, models.Archive.store_until))
|
||||
# NOTE: load_only is for optimization and not obfuscation, use
|
||||
# .with_entities() if needed
|
||||
return (
|
||||
db.query(models.Archive)
|
||||
.filter(not_(models.Archive.deleted))
|
||||
.options(
|
||||
load_only(
|
||||
models.Archive.id,
|
||||
models.Archive.created_at,
|
||||
models.Archive.url,
|
||||
models.Archive.result,
|
||||
models.Archive.store_until,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def search_archives_by_url(db: Session, url: str, email: str, read_groups: bool | set[str], read_public: bool, skip: int = 0, limit: int = 100, archived_after: datetime = None, archived_before: datetime = None, absolute_search: bool = False) -> list[models.Archive]:
|
||||
# searches for partial URLs, if email is * no ownership (or read/read_public) filtering happens
|
||||
def search_archives_by_url(
|
||||
db: Session,
|
||||
url: str,
|
||||
email: str,
|
||||
read_groups: bool | set[str],
|
||||
read_public: bool,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
archived_after: datetime = None,
|
||||
archived_before: datetime = None,
|
||||
absolute_search: bool = False,
|
||||
) -> list[Type[Archive]]:
|
||||
# searches for partial URLs, if email is * no ownership
|
||||
# (or read/read_public) filtering happens
|
||||
query = base_query(db)
|
||||
if email != ALLOW_ANY_EMAIL:
|
||||
or_filters = [models.Archive.author_id == email]
|
||||
if read_public:
|
||||
or_filters.append(models.Archive.public == True)
|
||||
if read_groups == True:
|
||||
or_filters.append(models.Archive.public.is_(true()))
|
||||
if read_groups is True:
|
||||
or_filters.append(models.Archive.group_id.isnot(None))
|
||||
else:
|
||||
or_filters.append(models.Archive.group_id.in_(read_groups))
|
||||
@@ -47,21 +83,43 @@ def search_archives_by_url(db: Session, url: str, email: str, read_groups: bool
|
||||
if absolute_search:
|
||||
query = query.filter(models.Archive.url == url)
|
||||
else:
|
||||
query = query.filter(models.Archive.url.like(f'%{url}%'))
|
||||
query = query.filter(models.Archive.url.like(f"%{url}%"))
|
||||
if archived_after:
|
||||
query = query.filter(models.Archive.created_at > archived_after)
|
||||
if archived_before:
|
||||
query = query.filter(models.Archive.created_at < archived_before)
|
||||
return query.order_by(models.Archive.created_at.desc()).offset(skip).limit(get_limit(limit)).all()
|
||||
return (
|
||||
query.order_by(models.Archive.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(get_limit(limit))
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def search_archives_by_email(db: Session, email: str, skip: int = 0, limit: int = 100):
|
||||
return base_query(db).filter(models.Archive.author_id == email).order_by(models.Archive.created_at.desc()).offset(skip).limit(get_limit(limit)).all()
|
||||
def search_archives_by_email(
|
||||
db: Session, email: str, skip: int = 0, limit: int = 100
|
||||
):
|
||||
return (
|
||||
base_query(db)
|
||||
.filter(models.Archive.author_id == email)
|
||||
.order_by(models.Archive.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(get_limit(limit))
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def soft_delete_archive(db: Session, id: str, email: str) -> bool:
|
||||
# TODO: implement hard-delete with cronjob that deletes from S3
|
||||
db_archive = db.query(models.Archive).filter(models.Archive.id == id, models.Archive.author_id == email, models.Archive.deleted == False).first()
|
||||
db_archive = (
|
||||
db.query(models.Archive)
|
||||
.filter(
|
||||
models.Archive.id == id,
|
||||
models.Archive.author_id == email,
|
||||
models.Archive.deleted.is_(false()),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if db_archive:
|
||||
db_archive.deleted = True
|
||||
db.commit()
|
||||
@@ -82,22 +140,29 @@ def count_users(db: Session):
|
||||
|
||||
def count_by_user_since(db: Session, seconds_delta: int = 15):
|
||||
time_threshold = datetime.now() - timedelta(seconds=seconds_delta)
|
||||
return db.query(models.Archive.author_id, func.count().label('total'))\
|
||||
.filter(models.Archive.created_at >= time_threshold)\
|
||||
.group_by(models.Archive.author_id)\
|
||||
.order_by(func.count().desc())\
|
||||
.limit(500).all()
|
||||
return (
|
||||
db.query(models.Archive.author_id, func.count().label("total"))
|
||||
.filter(models.Archive.created_at >= time_threshold)
|
||||
.group_by(models.Archive.author_id)
|
||||
.order_by(func.count().desc())
|
||||
.limit(500)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
async def find_by_store_until(db: AsyncSession, store_until_is_before: datetime) -> list[models.Archive]:
|
||||
async def find_by_store_until(
|
||||
db: AsyncSession, store_until_is_before: datetime
|
||||
) -> ScalarResult[Archive]:
|
||||
res = await db.execute(
|
||||
select(models.Archive)
|
||||
.filter(models.Archive.deleted == False, models.Archive.store_until < store_until_is_before)
|
||||
select(models.Archive).filter(
|
||||
models.Archive.deleted.is_(false()),
|
||||
models.Archive.store_until < store_until_is_before,
|
||||
)
|
||||
)
|
||||
return res.scalars()
|
||||
|
||||
|
||||
async def soft_delete_expired_archives(db: AsyncSession) -> dict:
|
||||
async def soft_delete_expired_archives(db: AsyncSession) -> int:
|
||||
to_delete = await find_by_store_until(db, datetime.now())
|
||||
counter = 0
|
||||
for archive in to_delete:
|
||||
@@ -105,47 +170,86 @@ async def soft_delete_expired_archives(db: AsyncSession) -> dict:
|
||||
counter += 1
|
||||
await db.commit()
|
||||
return counter
|
||||
|
||||
|
||||
# --------------- TAG
|
||||
|
||||
|
||||
async def get_group_priority_async(db: AsyncSession, group_id: str) -> dict:
|
||||
db_group = await db.get(models.Group, group_id)
|
||||
priority = db_group.permissions.get("priority", "low") if db_group else "low"
|
||||
priority = (
|
||||
db_group.permissions.get("priority", "low") if db_group else "low"
|
||||
)
|
||||
return convert_priority_to_queue_dict(priority)
|
||||
|
||||
|
||||
@cached(cache=LRUCache(maxsize=128), key=lambda db, email: hashkey(email))
|
||||
def get_user_group_names(db: Session, email: str) -> list[str]:
|
||||
def get_user_group_names(
|
||||
db: Session, email: str
|
||||
) -> list[Any] | list[ColumnElement[Any]]:
|
||||
"""
|
||||
given an email retrieves the user groups from the DB and then the email-domain groups from a global variable, the email does not need to belong to an existing user.
|
||||
given an email retrieves the user groups from the DB and then the
|
||||
email-domain groups from a global variable, the email does not need to
|
||||
belong to an existing user.
|
||||
"""
|
||||
# TODO: the read: [group1, group2] permissions don't currently work
|
||||
if not email or not len(email) or "@" not in email: return []
|
||||
if not email or not len(email) or "@" not in email:
|
||||
return []
|
||||
|
||||
# get user groups
|
||||
user_groups = db.query(models.association_table_user_groups).filter_by(user_id=email).with_entities(Column("group_id")).all()
|
||||
user_groups = (
|
||||
db.query(models.association_table_user_groups)
|
||||
.filter_by(user_id=email)
|
||||
.with_entities(Column("group_id"))
|
||||
.all()
|
||||
)
|
||||
user_level_groups_names = [g[0] for g in user_groups]
|
||||
|
||||
# get domain groups
|
||||
domain = email.split('@')[1]
|
||||
domain_level_groups = db.query(models.Group.id).filter(models.Group.domains.contains(domain)).with_entities(Column("id")).all()
|
||||
domain = email.split("@")[1]
|
||||
domain_level_groups = (
|
||||
db.query(models.Group.id)
|
||||
.filter(models.Group.domains.contains(domain))
|
||||
.with_entities(Column("id"))
|
||||
.all()
|
||||
)
|
||||
domain_level_groups_names = [g[0] for g in domain_level_groups]
|
||||
|
||||
return list(set(user_level_groups_names + domain_level_groups_names))
|
||||
|
||||
|
||||
def get_user_groups_by_name(db: Session, groups: list[str]) -> list[models.Group]:
|
||||
return db.query(models.Group).filter(
|
||||
models.Group.id.in_(groups)
|
||||
).all()
|
||||
def get_user_groups_by_name(
|
||||
db: Session, groups: list[str]
|
||||
) -> list[Type[Group]]:
|
||||
return db.query(models.Group).filter(models.Group.id.in_(groups)).all()
|
||||
|
||||
|
||||
# --------------- INIT User-Groups
|
||||
|
||||
|
||||
def upsert_group(db: Session, group_name: str, description: str, orchestrator: str, orchestrator_sheet: str, service_account_email: str, permissions: dict, domains: list) -> models.Group:
|
||||
db_group = db.query(models.Group).filter(models.Group.id == group_name).first()
|
||||
def upsert_group(
|
||||
db: Session,
|
||||
group_name: str,
|
||||
description: str,
|
||||
orchestrator: str,
|
||||
orchestrator_sheet: str,
|
||||
service_account_email: str,
|
||||
permissions: dict,
|
||||
domains: list,
|
||||
) -> models.Group:
|
||||
db_group = (
|
||||
db.query(models.Group).filter(models.Group.id == group_name).first()
|
||||
)
|
||||
if db_group is None:
|
||||
db_group = models.Group(id=group_name, description=description, orchestrator=orchestrator, orchestrator_sheet=orchestrator_sheet, service_account_email=service_account_email, permissions=permissions, domains=domains)
|
||||
db_group = models.Group(
|
||||
id=group_name,
|
||||
description=description,
|
||||
orchestrator=orchestrator,
|
||||
orchestrator_sheet=orchestrator_sheet,
|
||||
service_account_email=service_account_email,
|
||||
permissions=permissions,
|
||||
domains=domains,
|
||||
)
|
||||
db.add(db_group)
|
||||
else:
|
||||
db_group.description = description
|
||||
@@ -172,8 +276,9 @@ def upsert_user(db: Session, email: str):
|
||||
def upsert_user_groups(db: Session):
|
||||
def display_email_pii(email: str):
|
||||
return f"'{email[0:3]}...@{email.split('@')[1]}'"
|
||||
|
||||
"""
|
||||
reads the user_groups yaml file and inserts any new users, groups,
|
||||
reads the user_groups yaml file and inserts any new users, groups,
|
||||
along with new participation of users in groups
|
||||
"""
|
||||
filename = get_settings().USER_GROUPS_FILENAME
|
||||
@@ -192,18 +297,33 @@ def upsert_user_groups(db: Session):
|
||||
for group in explicit_groups:
|
||||
group_domains[group].add(domain)
|
||||
import json
|
||||
|
||||
# upsert groups and save a map of groupid -> dbobject
|
||||
for group_id, g in ug.groups.items():
|
||||
upsert_group(db, group_id, g.description, g.orchestrator, g.orchestrator_sheet, g.service_account_email, json.loads(g.permissions.model_dump_json()), list(group_domains.get(group_id, [])))
|
||||
db_groups: dict[str, models.Group] = {g.id: g for g in db.query(models.Group).all()}
|
||||
upsert_group(
|
||||
db,
|
||||
group_id,
|
||||
g.description,
|
||||
g.orchestrator,
|
||||
g.orchestrator_sheet,
|
||||
g.service_account_email,
|
||||
json.loads(g.permissions.model_dump_json()),
|
||||
list(group_domains.get(group_id, [])),
|
||||
)
|
||||
db_groups: dict[str, models.Group] = {
|
||||
g.id: g for g in db.query(models.Group).all()
|
||||
}
|
||||
|
||||
# integrity checks
|
||||
for group_in_domains in group_domains:
|
||||
if group_in_domains not in db_groups:
|
||||
logger.warning(f"[CONFIG] Group '{group_in_domains}' does not exist in the database: domains setting will not work.")
|
||||
logger.warning(
|
||||
f"[CONFIG] Group '{group_in_domains}' does not exist in the database: domains setting will not work."
|
||||
)
|
||||
|
||||
# reinsert users in their EXPLICITLY DEFINED groups
|
||||
# domain groups are check live, as there may be new users that are not explicitly registered but belong to a domain
|
||||
# domain groups are check live, as there may be new users that are not
|
||||
# explicitly registered but belong to a domain
|
||||
for email, explicit_groups in ug.users.items():
|
||||
explicit_groups = explicit_groups or []
|
||||
logger.info(f"EXPLICIT {display_email_pii(email)} => {explicit_groups}")
|
||||
@@ -213,7 +333,9 @@ def upsert_user_groups(db: Session):
|
||||
# connect users to groups
|
||||
for group_id in explicit_groups:
|
||||
if group_id not in db_groups:
|
||||
logger.warning(f"[CONFIG] Group {group_id} does not exist in config file, skipping for email={display_email_pii(email)}.")
|
||||
logger.warning(
|
||||
f"[CONFIG] Group {group_id} does not exist in config file, skipping for email={display_email_pii(email)}."
|
||||
)
|
||||
continue
|
||||
db_groups[group_id].users.append(db_user)
|
||||
|
||||
@@ -221,12 +343,27 @@ def upsert_user_groups(db: Session):
|
||||
count_user_groups = db.query(models.association_table_user_groups).count()
|
||||
count_groups = db.query(func.count(models.Group.id)).scalar()
|
||||
|
||||
logger.success(f"[CONFIG] DONE: [users={count_users(db)}, groups={count_groups}, explicit user groups={count_user_groups}].")
|
||||
logger.success(
|
||||
f"[CONFIG] DONE: [users={count_users(db)}, groups={count_groups}, explicit user groups={count_user_groups}]."
|
||||
)
|
||||
|
||||
|
||||
# --------------- SHEET
|
||||
def create_sheet(db: Session, sheet_id: str, name: str, email: str, group_id: str, frequency: str):
|
||||
db_sheet = models.Sheet(id=sheet_id, name=name, author_id=email, group_id=group_id, frequency=frequency)
|
||||
def create_sheet(
|
||||
db: Session,
|
||||
sheet_id: str,
|
||||
name: str,
|
||||
email: str,
|
||||
group_id: str,
|
||||
frequency: str,
|
||||
):
|
||||
db_sheet = models.Sheet(
|
||||
id=sheet_id,
|
||||
name=name,
|
||||
author_id=email,
|
||||
group_id=group_id,
|
||||
frequency=frequency,
|
||||
)
|
||||
db.add(db_sheet)
|
||||
db.commit()
|
||||
db.refresh(db_sheet)
|
||||
@@ -234,20 +371,31 @@ def create_sheet(db: Session, sheet_id: str, name: str, email: str, group_id: st
|
||||
|
||||
|
||||
def get_user_sheet(db: Session, email: str, sheet_id: str) -> models.Sheet:
|
||||
return db.query(models.Sheet).filter(models.Sheet.author_id == email, models.Sheet.id == sheet_id).first()
|
||||
return (
|
||||
db.query(models.Sheet)
|
||||
.filter(models.Sheet.author_id == email, models.Sheet.id == sheet_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
def get_user_sheets(db: Session, email: str) -> list[models.Sheet]:
|
||||
return db.query(models.Sheet).filter(models.Sheet.author_id == email).order_by(models.Sheet.last_url_archived_at.desc()).all()
|
||||
return (
|
||||
db.query(models.Sheet)
|
||||
.filter(models.Sheet.author_id == email)
|
||||
.order_by(models.Sheet.last_url_archived_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
async def get_sheets_by_id_hash(db: AsyncSession, frequency: str, modulo: str, id_hash: int) -> list[models.Sheet]:
|
||||
async def get_sheets_by_id_hash(
|
||||
db: AsyncSession, frequency: str, modulo: str, id_hash: int
|
||||
) -> list[models.Sheet]:
|
||||
result = await db.execute(
|
||||
select(models.Sheet).filter(models.Sheet.frequency == frequency)
|
||||
)
|
||||
filtered = []
|
||||
for sheet in result.scalars():
|
||||
if fnv1a_hash_mod(sheet.id, modulo) == id_hash:
|
||||
if fnv1a_hash_mod(sheet.id, int(modulo)) == id_hash:
|
||||
filtered.append(sheet)
|
||||
return filtered
|
||||
|
||||
@@ -255,7 +403,9 @@ async def get_sheets_by_id_hash(db: AsyncSession, frequency: str, modulo: str, i
|
||||
async def delete_stale_sheets(db: AsyncSession, inactivity_days: int) -> dict:
|
||||
time_threshold = datetime.now() - timedelta(days=inactivity_days)
|
||||
result = await db.execute(
|
||||
select(models.Sheet).filter(models.Sheet.last_url_archived_at < time_threshold)
|
||||
select(models.Sheet).filter(
|
||||
models.Sheet.last_url_archived_at < time_threshold
|
||||
)
|
||||
)
|
||||
deleted = defaultdict(list)
|
||||
for sheet in result.scalars():
|
||||
@@ -266,7 +416,11 @@ async def delete_stale_sheets(db: AsyncSession, inactivity_days: int) -> dict:
|
||||
|
||||
|
||||
def delete_sheet(db: Session, sheet_id: str, email: str) -> bool:
|
||||
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id, models.Sheet.author_id == email).first()
|
||||
db_sheet = (
|
||||
db.query(models.Sheet)
|
||||
.filter(models.Sheet.id == sheet_id, models.Sheet.author_id == email)
|
||||
.first()
|
||||
)
|
||||
if db_sheet:
|
||||
db.delete(db_sheet)
|
||||
db.commit()
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
|
||||
from typing import Dict, Set
|
||||
import sqlalchemy
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
from datetime import datetime
|
||||
from typing import Dict, Set
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.shared.db import models
|
||||
from app.shared.user_groups import GroupInfo, GroupPermissions
|
||||
from app.shared.schemas import Usage, UsageResponse
|
||||
from app.shared.user_groups import GroupInfo, GroupPermissions
|
||||
from app.web.db import crud
|
||||
from app.web.utils.misc import convert_priority_to_queue_dict
|
||||
|
||||
@@ -20,14 +20,15 @@ class UserState:
|
||||
def __init__(self, db: Session, email: str):
|
||||
self.db = db
|
||||
self.email = email.lower()
|
||||
self._permissions = {}
|
||||
|
||||
@property
|
||||
def permissions(self) -> Dict[str, GroupInfo]:
|
||||
"""
|
||||
Returns a dict of all group permissions and a special {"all": read/archive_url/archive_sheet} key
|
||||
Returns a dict of all group permissions and a special
|
||||
{"all": read/archive_url/archive_sheet} key
|
||||
"""
|
||||
if not hasattr(self, '_permissions'):
|
||||
self._permissions = {}
|
||||
if not self._permissions:
|
||||
self._permissions["all"] = GroupInfo(
|
||||
read=self.read,
|
||||
read_public=self.read_public,
|
||||
@@ -37,23 +38,33 @@ class UserState:
|
||||
max_archive_lifespan_months=self.max_archive_lifespan_months,
|
||||
max_monthly_urls=self.max_monthly_urls,
|
||||
max_monthly_mbs=self.max_monthly_mbs,
|
||||
priority=self.priority
|
||||
priority=self.priority,
|
||||
)
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
self._permissions[group.id] = GroupInfo(**group.permissions, description=group.description, service_account_email=group.service_account_email)
|
||||
if not group.permissions:
|
||||
continue
|
||||
self._permissions[group.id] = GroupInfo(
|
||||
**group.permissions,
|
||||
description=group.description,
|
||||
service_account_email=group.service_account_email,
|
||||
)
|
||||
return self._permissions
|
||||
|
||||
@property
|
||||
def user_groups_names(self):
|
||||
if not hasattr(self, '_user_groups_names'):
|
||||
self._user_groups_names = crud.get_user_group_names(self.db, self.email) + ["default"]
|
||||
if not hasattr(self, "_user_groups_names"):
|
||||
# TODO: Define hidden properties in __init__ method
|
||||
self._user_groups_names = crud.get_user_group_names(
|
||||
self.db, self.email
|
||||
) + ["default"]
|
||||
return self._user_groups_names
|
||||
|
||||
@property
|
||||
def user_groups(self):
|
||||
if not hasattr(self, '_user_groups'):
|
||||
self._user_groups = crud.get_user_groups_by_name(self.db, self.user_groups_names)
|
||||
if not hasattr(self, "_user_groups"):
|
||||
self._user_groups = crud.get_user_groups_by_name(
|
||||
self.db, self.user_groups_names
|
||||
)
|
||||
return self._user_groups
|
||||
|
||||
@property
|
||||
@@ -61,10 +72,11 @@ class UserState:
|
||||
"""
|
||||
Read can be a list of group names or True, if all can be read.
|
||||
"""
|
||||
if not hasattr(self, '_read'):
|
||||
if not hasattr(self, "_read"):
|
||||
self._read = set()
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
if not group.permissions:
|
||||
continue
|
||||
group_read_permissions = group.permissions.get("read", [])
|
||||
if "all" in group_read_permissions:
|
||||
self._read = True
|
||||
@@ -78,10 +90,11 @@ class UserState:
|
||||
"""
|
||||
Read public permission
|
||||
"""
|
||||
if not hasattr(self, '_read_public'):
|
||||
if not hasattr(self, "_read_public"):
|
||||
self._read_public = False
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
if not group.permissions:
|
||||
continue
|
||||
if group.permissions.get("read_public", False):
|
||||
self._read_public = True
|
||||
return self._read_public
|
||||
@@ -92,10 +105,11 @@ class UserState:
|
||||
"""
|
||||
Archive URL permission
|
||||
"""
|
||||
if not hasattr(self, '_archive_url'):
|
||||
if not hasattr(self, "_archive_url"):
|
||||
self._archive_url = False
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
if not group.permissions:
|
||||
continue
|
||||
if group.permissions.get("archive_url", False):
|
||||
self._archive_url = True
|
||||
return self._archive_url
|
||||
@@ -106,10 +120,11 @@ class UserState:
|
||||
"""
|
||||
Archive sheet permission
|
||||
"""
|
||||
if not hasattr(self, '_archive_sheet'):
|
||||
if not hasattr(self, "_archive_sheet"):
|
||||
self._archive_sheet = False
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
if not group.permissions:
|
||||
continue
|
||||
if group.permissions.get("archive_sheet", False):
|
||||
self._archive_sheet = True
|
||||
return self._archive_sheet
|
||||
@@ -117,37 +132,53 @@ class UserState:
|
||||
|
||||
@property
|
||||
def sheet_frequency(self):
|
||||
if not hasattr(self, '_sheet_frequency'):
|
||||
if not hasattr(self, "_sheet_frequency"):
|
||||
self._sheet_frequency = set()
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
self._sheet_frequency.update(group.permissions.get("sheet_frequency", None))
|
||||
if not group.permissions:
|
||||
continue
|
||||
self._sheet_frequency.update(
|
||||
group.permissions.get("sheet_frequency", None)
|
||||
)
|
||||
return self._sheet_frequency
|
||||
|
||||
@property
|
||||
def max_archive_lifespan_months(self) -> int:
|
||||
if not hasattr(self, '_max_archive_lifespan_months'):
|
||||
self._max_archive_lifespan_months = self._helper_for_grouping_max_numerical_permissions("max_archive_lifespan_months")
|
||||
if not hasattr(self, "_max_archive_lifespan_months"):
|
||||
self._max_archive_lifespan_months = (
|
||||
self._helper_for_grouping_max_numerical_permissions(
|
||||
"max_archive_lifespan_months"
|
||||
)
|
||||
)
|
||||
return self._max_archive_lifespan_months
|
||||
|
||||
@property
|
||||
def max_monthly_urls(self) -> int:
|
||||
if not hasattr(self, '_max_monthly_urls'):
|
||||
self._max_monthly_urls = self._helper_for_grouping_max_numerical_permissions("max_monthly_urls")
|
||||
if not hasattr(self, "_max_monthly_urls"):
|
||||
self._max_monthly_urls = (
|
||||
self._helper_for_grouping_max_numerical_permissions(
|
||||
"max_monthly_urls"
|
||||
)
|
||||
)
|
||||
return self._max_monthly_urls
|
||||
|
||||
@property
|
||||
def max_monthly_mbs(self) -> int:
|
||||
if not hasattr(self, '_max_monthly_mbs'):
|
||||
self._max_monthly_mbs = self._helper_for_grouping_max_numerical_permissions("max_monthly_mbs")
|
||||
if not hasattr(self, "_max_monthly_mbs"):
|
||||
self._max_monthly_mbs = (
|
||||
self._helper_for_grouping_max_numerical_permissions(
|
||||
"max_monthly_mbs"
|
||||
)
|
||||
)
|
||||
return self._max_monthly_mbs
|
||||
|
||||
@property
|
||||
def priority(self) -> str:
|
||||
if not hasattr(self, '_priority'):
|
||||
if not hasattr(self, "_priority"):
|
||||
self._priority = "low"
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
if not group.permissions:
|
||||
continue
|
||||
if group.permissions.get("priority", self._priority) == "high":
|
||||
self._priority = "high"
|
||||
break
|
||||
@@ -158,18 +189,28 @@ class UserState:
|
||||
"""
|
||||
A user is active if they can read/archive anything
|
||||
"""
|
||||
if not hasattr(self, '_active'):
|
||||
self._active = bool(self.read or self.read_public or self.archive_url or self.archive_sheet)
|
||||
if not hasattr(self, "_active"):
|
||||
self._active = bool(
|
||||
self.read
|
||||
or self.read_public
|
||||
or self.archive_url
|
||||
or self.archive_sheet
|
||||
)
|
||||
return self._active
|
||||
|
||||
def _helper_for_grouping_max_numerical_permissions(self, permission_name: str) -> int:
|
||||
def _helper_for_grouping_max_numerical_permissions(
|
||||
self, permission_name: str
|
||||
) -> int:
|
||||
"""
|
||||
Iterates one of the numerical permissions where -1 means no restrictions and returns either -1 or the maximum value, defaults according to GroupPermissions
|
||||
Iterates one of the numerical permissions where -1 means no restrictions
|
||||
and returns either -1 or the maximum value, defaults according to
|
||||
GroupPermissions
|
||||
"""
|
||||
default = GroupPermissions.model_fields[permission_name].default
|
||||
max_value = default
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
if not group.permissions:
|
||||
continue
|
||||
group_value = group.permissions.get(permission_name, default)
|
||||
if group_value == -1:
|
||||
max_value = -1
|
||||
@@ -180,43 +221,65 @@ class UserState:
|
||||
def in_group(self, group_id: str) -> bool:
|
||||
return group_id in self.user_groups_names
|
||||
|
||||
def usage(self) -> Dict:
|
||||
def usage(self) -> UsageResponse:
|
||||
"""
|
||||
returns the monthly quotas for the URLs/MBs and the totals for Sheets
|
||||
Returns the monthly quotas for the URLs/MBs and the totals for Sheets
|
||||
"""
|
||||
current_month = datetime.now().month
|
||||
current_year = datetime.now().year
|
||||
|
||||
# find and sum all user sheets over this month
|
||||
user_sheets = self.db.query(
|
||||
models.Sheet.group_id,
|
||||
func.count(models.Sheet.id).label('sheet_count')
|
||||
).filter(models.Sheet.author_id == self.email).group_by(models.Sheet.group_id).all()
|
||||
user_sheets = (
|
||||
self.db.query(
|
||||
models.Sheet.group_id,
|
||||
func.count(models.Sheet.id).label("sheet_count"),
|
||||
)
|
||||
.filter(models.Sheet.author_id == self.email)
|
||||
.group_by(models.Sheet.group_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
sheets_by_group = {sheet.group_id: sheet.sheet_count for sheet in user_sheets}
|
||||
sheets_by_group = {
|
||||
sheet.group_id: sheet.sheet_count for sheet in user_sheets
|
||||
}
|
||||
|
||||
# find and sum all user urls over this month
|
||||
urls_by_group = self.db.query(
|
||||
models.Archive.group_id,
|
||||
func.count(models.Archive.id).label('url_count'),
|
||||
func.coalesce(func.sum(
|
||||
urls_by_group = (
|
||||
self.db.query(
|
||||
models.Archive.group_id,
|
||||
func.count(models.Archive.id).label("url_count"),
|
||||
func.coalesce(
|
||||
func.cast(
|
||||
func.json_extract(models.Archive.result, '$.metadata.total_bytes'),
|
||||
sqlalchemy.Integer
|
||||
), 0
|
||||
)
|
||||
), 0).label('total_bytes')
|
||||
).filter(
|
||||
models.Archive.author_id == self.email,
|
||||
func.extract('month', models.Archive.created_at) == current_month,
|
||||
func.extract('year', models.Archive.created_at) == current_year
|
||||
).group_by(models.Archive.group_id).all()
|
||||
func.sum(
|
||||
func.coalesce(
|
||||
func.cast(
|
||||
func.json_extract(
|
||||
models.Archive.result,
|
||||
"$.metadata.total_bytes",
|
||||
),
|
||||
sqlalchemy.Integer,
|
||||
),
|
||||
0,
|
||||
)
|
||||
),
|
||||
0,
|
||||
).label("total_bytes"),
|
||||
)
|
||||
.filter(
|
||||
models.Archive.author_id == self.email,
|
||||
func.extract("month", models.Archive.created_at)
|
||||
== current_month,
|
||||
func.extract("year", models.Archive.created_at) == current_year,
|
||||
)
|
||||
.group_by(models.Archive.group_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
# merge the two queries
|
||||
usage_by_group: Dict[str, Usage] = {
|
||||
(url.group_id or ""):
|
||||
Usage(monthly_urls=url.url_count, monthly_mbs=int(url.total_bytes / 1024 / 1024))
|
||||
(url.group_id or ""): Usage(
|
||||
monthly_urls=url.url_count,
|
||||
monthly_mbs=int(url.total_bytes / 1024 / 1024),
|
||||
)
|
||||
for url in urls_by_group
|
||||
}
|
||||
for group_id, sheet_count in sheets_by_group.items():
|
||||
@@ -235,7 +298,7 @@ class UserState:
|
||||
monthly_urls=total_urls,
|
||||
monthly_mbs=int(total_bytes / 1024 / 1024),
|
||||
total_sheets=total_sheets,
|
||||
groups=usage_by_group
|
||||
groups=usage_by_group,
|
||||
)
|
||||
|
||||
def has_quota_monthly_sheets(self, group_id: str) -> bool:
|
||||
@@ -245,7 +308,14 @@ class UserState:
|
||||
if group_id not in self.permissions:
|
||||
return False
|
||||
|
||||
user_sheets = self.db.query(models.Sheet).filter(models.Sheet.author_id == self.email, models.Sheet.group_id == group_id).count()
|
||||
user_sheets = (
|
||||
self.db.query(models.Sheet)
|
||||
.filter(
|
||||
models.Sheet.author_id == self.email,
|
||||
models.Sheet.group_id == group_id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
sheet_quota = self.permissions[group_id].max_sheets
|
||||
if sheet_quota == -1:
|
||||
@@ -254,13 +324,15 @@ class UserState:
|
||||
|
||||
def has_quota_max_monthly_urls(self, group_id: str) -> bool:
|
||||
"""
|
||||
checks if a user has reached their monthly url quota for a group, if global then group should be empty string
|
||||
Checks if a user has reached their monthly url quota for a group, if
|
||||
global then group should be empty string
|
||||
"""
|
||||
quota = 0
|
||||
if not group_id:
|
||||
quota = self.max_monthly_urls
|
||||
else:
|
||||
if group_id not in self.permissions: return False
|
||||
if group_id not in self.permissions:
|
||||
return False
|
||||
quota = self.permissions[group_id].max_monthly_urls
|
||||
|
||||
if quota == -1:
|
||||
@@ -268,24 +340,31 @@ class UserState:
|
||||
|
||||
current_month = datetime.now().month
|
||||
current_year = datetime.now().year
|
||||
user_urls = self.db.query(models.Archive).filter(
|
||||
models.Archive.author_id == self.email,
|
||||
models.Archive.group_id == group_id,
|
||||
func.extract('month', models.Archive.created_at) == current_month,
|
||||
func.extract('year', models.Archive.created_at) == current_year
|
||||
).count()
|
||||
user_urls = (
|
||||
self.db.query(models.Archive)
|
||||
.filter(
|
||||
models.Archive.author_id == self.email,
|
||||
models.Archive.group_id == group_id,
|
||||
func.extract("month", models.Archive.created_at)
|
||||
== current_month,
|
||||
func.extract("year", models.Archive.created_at) == current_year,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
return user_urls < quota
|
||||
|
||||
def has_quota_max_monthly_mbs(self, group_id: str) -> bool:
|
||||
"""
|
||||
checks if a user has reached their monthly MBs quota for a group, if global then group should be empty string
|
||||
Checks if a user has reached their monthly MBs quota for a group, if
|
||||
global then group should be empty string
|
||||
"""
|
||||
quota = 0
|
||||
if not group_id:
|
||||
quota = self.max_monthly_mbs
|
||||
else:
|
||||
if group_id not in self.permissions: return False
|
||||
if group_id not in self.permissions:
|
||||
return False
|
||||
quota = self.permissions[group_id].max_monthly_mbs
|
||||
|
||||
if quota == -1:
|
||||
@@ -295,19 +374,34 @@ class UserState:
|
||||
current_year = datetime.now().year
|
||||
|
||||
# find and sum all user bytes over this month
|
||||
user_bytes = self.db.query(models.Archive).filter(
|
||||
models.Archive.author_id == self.email,
|
||||
models.Archive.group_id == group_id,
|
||||
func.extract('month', models.Archive.created_at) == current_month,
|
||||
func.extract('year', models.Archive.created_at) == current_year
|
||||
).with_entities(func.coalesce(func.sum(
|
||||
func.coalesce(
|
||||
func.cast(
|
||||
func.json_extract(models.Archive.result, '$.metadata.total_bytes'),
|
||||
sqlalchemy.Integer
|
||||
), 0
|
||||
user_bytes = (
|
||||
self.db.query(models.Archive)
|
||||
.filter(
|
||||
models.Archive.author_id == self.email,
|
||||
models.Archive.group_id == group_id,
|
||||
func.extract("month", models.Archive.created_at)
|
||||
== current_month,
|
||||
func.extract("year", models.Archive.created_at) == current_year,
|
||||
)
|
||||
), 0).label('total')).scalar()
|
||||
.with_entities(
|
||||
func.coalesce(
|
||||
func.sum(
|
||||
func.coalesce(
|
||||
func.cast(
|
||||
func.json_extract(
|
||||
models.Archive.result,
|
||||
"$.metadata.total_bytes",
|
||||
),
|
||||
sqlalchemy.Integer,
|
||||
),
|
||||
0,
|
||||
)
|
||||
),
|
||||
0,
|
||||
).label("total")
|
||||
)
|
||||
.scalar()
|
||||
)
|
||||
|
||||
# convert bytes to mb
|
||||
user_mbs = int(user_bytes / 1024 / 1024)
|
||||
@@ -315,7 +409,7 @@ class UserState:
|
||||
|
||||
def can_manually_trigger(self, group_id: str) -> bool:
|
||||
"""
|
||||
checks if a user is allowed to manually trigger a sheet
|
||||
Checks if a user is allowed to manually trigger a sheet
|
||||
"""
|
||||
if group_id not in self.permissions:
|
||||
return False
|
||||
@@ -324,18 +418,21 @@ class UserState:
|
||||
|
||||
def is_sheet_frequency_allowed(self, group_id: str, frequency: str) -> bool:
|
||||
"""
|
||||
checks if a user is allowed to create a sheet with this frequency for this group
|
||||
Checks if a user is allowed to create a sheet with this frequency for
|
||||
this group
|
||||
"""
|
||||
if group_id not in self.permissions:
|
||||
return False
|
||||
|
||||
return frequency in self.permissions[group_id].sheet_frequency
|
||||
|
||||
def priority_group(self, group_id: str) -> str:
|
||||
def priority_group(self, group_id: str) -> dict:
|
||||
priority = "low"
|
||||
for group in self.user_groups:
|
||||
if group.id != group_id: continue
|
||||
if not group.permissions: continue
|
||||
if group.id != group_id:
|
||||
continue
|
||||
if not group.permissions:
|
||||
continue
|
||||
priority = group.permissions.get("priority", priority)
|
||||
break
|
||||
return convert_priority_to_queue_dict(priority)
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
|
||||
from typing import Dict
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
|
||||
from app.web.config import VERSION, BREAKING_CHANGES
|
||||
from app.shared.schemas import ActiveUser, UsageResponse
|
||||
from app.web.db.user_state import UserState
|
||||
from app.web.security import get_user_state
|
||||
from app.shared.user_groups import GroupInfo
|
||||
|
||||
default_router = APIRouter()
|
||||
|
||||
|
||||
@default_router.get("/")
|
||||
async def home():
|
||||
return JSONResponse({"version": VERSION, "breakingChanges": BREAKING_CHANGES})
|
||||
|
||||
|
||||
@default_router.get("/health")
|
||||
async def health():
|
||||
return JSONResponse({"status": "ok"})
|
||||
|
||||
|
||||
@default_router.get("/user/active", summary="Check if the user is active and can use the tool.")
|
||||
async def active(
|
||||
user: UserState = Depends(get_user_state),
|
||||
) -> ActiveUser:
|
||||
return {"active": user.active}
|
||||
|
||||
|
||||
@default_router.get("/user/permissions", summary="Get the user's global 'all' permissions and the permissions for each group they belong to.")
|
||||
def get_user_permissions(
|
||||
user: UserState = Depends(get_user_state),
|
||||
) -> Dict[str, GroupInfo]:
|
||||
return user.permissions
|
||||
|
||||
@default_router.get("/user/usage", summary="Get the user's monthly URLs/MBs usage along with the total active sheets, breakdown by group.")
|
||||
def get_user_usage(
|
||||
user: UserState = Depends(get_user_state),
|
||||
) -> UsageResponse:
|
||||
if not user.active:
|
||||
raise HTTPException(status_code=403, detail="User is not active.")
|
||||
return user.usage()
|
||||
|
||||
|
||||
|
||||
@default_router.get('/favicon.ico', include_in_schema=False)
|
||||
async def favicon() -> FileResponse:
|
||||
return FileResponse("app/web/static/favicon.ico")
|
||||
@@ -1,81 +0,0 @@
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from sqlalchemy import exc
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.web.db.user_state import UserState
|
||||
from app.shared import schemas
|
||||
from app.shared.task_messaging import get_celery
|
||||
from app.web.security import get_user_state
|
||||
from app.web.db import crud
|
||||
from app.shared.db.database import get_db_dependency
|
||||
|
||||
sheet_router = APIRouter(prefix="/sheet", tags=["Google Spreadsheet operations"])
|
||||
|
||||
celery = get_celery()
|
||||
|
||||
@sheet_router.post("/create", status_code=201, summary="Store a new Google Sheet for regular archiving.")
|
||||
def create_sheet(
|
||||
sheet: schemas.SheetAdd,
|
||||
user: UserState = Depends(get_user_state),
|
||||
db: Session = Depends(get_db_dependency),
|
||||
) -> schemas.SheetResponse:
|
||||
|
||||
if not user.in_group(sheet.group_id):
|
||||
raise HTTPException(status_code=403, detail="User does not have access to this group.")
|
||||
|
||||
if not user.has_quota_monthly_sheets(sheet.group_id):
|
||||
raise HTTPException(status_code=429, detail="User has reached their sheet quota for this group.")
|
||||
|
||||
if not user.is_sheet_frequency_allowed(sheet.group_id, sheet.frequency):
|
||||
raise HTTPException(status_code=422, detail="Invalid frequency selected for this group.")
|
||||
|
||||
try:
|
||||
return crud.create_sheet(db, sheet.id, sheet.name, user.email, sheet.group_id, sheet.frequency)
|
||||
except exc.IntegrityError as e:
|
||||
raise HTTPException(status_code=400, detail="Sheet with this ID is already being archived.") from e
|
||||
|
||||
|
||||
@sheet_router.get("/mine", status_code=200, summary="Get the authenticated user's Google Sheets.")
|
||||
def get_user_sheets(
|
||||
user: UserState = Depends(get_user_state),
|
||||
db: Session = Depends(get_db_dependency)
|
||||
) -> list[schemas.SheetResponse]:
|
||||
return crud.get_user_sheets(db, user.email)
|
||||
|
||||
|
||||
@sheet_router.delete("/{id}", summary="Delete a Google Sheet by ID.")
|
||||
def delete_sheet(
|
||||
id: str,
|
||||
user: UserState = Depends(get_user_state),
|
||||
db: Session = Depends(get_db_dependency),
|
||||
) -> schemas.DeleteResponse:
|
||||
return JSONResponse({
|
||||
"id": id,
|
||||
"deleted": crud.delete_sheet(db, id, user.email)
|
||||
})
|
||||
|
||||
|
||||
@sheet_router.post("/{id}/archive", status_code=201, summary="Trigger an archiving task for a GSheet you own.", response_description="task_id for the archiving task.")
|
||||
def archive_user_sheet(
|
||||
id: str,
|
||||
user: UserState = Depends(get_user_state),
|
||||
db: Session = Depends(get_db_dependency),
|
||||
) -> schemas.Task:
|
||||
|
||||
sheet = crud.get_user_sheet(db, user.email, sheet_id=id)
|
||||
if not sheet:
|
||||
raise HTTPException(status_code=403, detail="No access to this sheet.")
|
||||
|
||||
if not user.in_group(sheet.group_id):
|
||||
raise HTTPException(status_code=403, detail="User does not have access to this group.")
|
||||
|
||||
if not user.can_manually_trigger(sheet.group_id):
|
||||
raise HTTPException(status_code=429, detail="User cannot manually trigger sheet archiving in this group.")
|
||||
|
||||
group_queue = user.priority_group(sheet.group_id)
|
||||
task = celery.signature("create_sheet_task", args=[schemas.SubmitSheet(sheet_id=id, author_id=user.email, group_id=sheet.group_id).model_dump_json()]).apply_async(**group_queue)
|
||||
|
||||
return JSONResponse({"id": task.id}, status_code=201)
|
||||
@@ -1,40 +0,0 @@
|
||||
from celery.result import AsyncResult
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.shared.task_messaging import get_celery
|
||||
from app.web.security import get_token_or_user_auth
|
||||
from app.shared import schemas
|
||||
from app.shared.log import log_error
|
||||
from app.web.utils.misc import custom_jsonable_encoder
|
||||
|
||||
|
||||
task_router = APIRouter(prefix="/task", tags=["Async task operations"])
|
||||
|
||||
celery = get_celery()
|
||||
|
||||
@task_router.get("/{task_id}", summary="Check the status of an async task by its id, works for URLs and Sheet tasks.")
|
||||
def get_status(task_id, email=Depends(get_token_or_user_auth)) -> schemas.TaskResult:
|
||||
task = AsyncResult(task_id, app=celery)
|
||||
try:
|
||||
if task.status == "FAILURE":
|
||||
# *FAILURE* The task raised an exception, or has exceeded the retry limit.
|
||||
# The :attr:`result` attribute then contains the exception raised by the task.
|
||||
# https://docs.celeryq.dev/en/stable/_modules/celery/result.html#AsyncResult
|
||||
raise task.result
|
||||
|
||||
response = {
|
||||
"id": task_id,
|
||||
"status": task.status,
|
||||
"result": task.result
|
||||
}
|
||||
return JSONResponse(jsonable_encoder(response, exclude_unset=True, custom_encoder={bytes: custom_jsonable_encoder}))
|
||||
|
||||
except Exception as e:
|
||||
log_error(e)
|
||||
return JSONResponse({
|
||||
"id": task_id,
|
||||
"status": "FAILURE",
|
||||
"result": {"error": str(e)}
|
||||
})
|
||||
@@ -1,85 +0,0 @@
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
from app.shared import schemas
|
||||
from app.shared.task_messaging import get_celery
|
||||
from app.web.security import get_token_or_user_auth, get_user_state
|
||||
from app.web.db import crud
|
||||
from app.web.db.user_state import UserState
|
||||
from app.shared.db.database import get_db_dependency
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from app.web.utils.misc import convert_priority_to_queue_dict
|
||||
|
||||
url_router = APIRouter(prefix="/url", tags=["Single URL operations"])
|
||||
|
||||
celery = get_celery()
|
||||
|
||||
@url_router.post("/archive", status_code=201, summary="Submit a single URL archive request, starts an archiving task.", response_description="task_id for the archiving task, will match the archive id.")
|
||||
def archive_url(
|
||||
archive: schemas.ArchiveTrigger,
|
||||
email=Depends(get_token_or_user_auth),
|
||||
db: Session = Depends(get_db_dependency)
|
||||
) -> schemas.Task:
|
||||
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {archive.url}")
|
||||
|
||||
parsed_url = urlparse(archive.url)
|
||||
if not all([parsed_url.scheme, parsed_url.netloc]):
|
||||
raise HTTPException(status_code=400, detail="Invalid URL received.")
|
||||
|
||||
archive_create = schemas.ArchiveCreate(**archive.model_dump())
|
||||
if email != ALLOW_ANY_EMAIL:
|
||||
archive_create.author_id = email
|
||||
user = UserState(db, email)
|
||||
if archive.group_id and not user.in_group(archive.group_id):
|
||||
raise HTTPException(status_code=403, detail="User does not have access to this group.")
|
||||
if not user.has_quota_max_monthly_urls(archive.group_id):
|
||||
raise HTTPException(status_code=429, detail="User has reached their monthly URL quota.")
|
||||
if not user.has_quota_max_monthly_mbs(archive.group_id):
|
||||
raise HTTPException(status_code=429, detail="User has reached their monthly MB quota.")
|
||||
group_queue = user.priority_group(archive_create.group_id)
|
||||
else:
|
||||
archive_create.author_id = archive.author_id or email
|
||||
group_queue = convert_priority_to_queue_dict("high")
|
||||
|
||||
|
||||
task = celery.signature("create_archive_task", args=[archive_create.model_dump_json()]).apply_async(**group_queue)
|
||||
task_response = schemas.Task(id=task.id)
|
||||
return JSONResponse(task_response.model_dump(), status_code=201)
|
||||
|
||||
|
||||
@url_router.get("/search", summary="Search for archive entries by URL.")
|
||||
def search_by_url(
|
||||
url: str, skip: int = 0, limit: int = 25,
|
||||
archived_after: datetime = None, archived_before: datetime = None,
|
||||
db: Session = Depends(get_db_dependency),
|
||||
email: str = Depends(get_token_or_user_auth)
|
||||
) -> list[schemas.ArchiveResult]:
|
||||
|
||||
read_groups, read_public = False, False
|
||||
if email != ALLOW_ANY_EMAIL:
|
||||
user = UserState(db, email)
|
||||
if not user.read and not user.read_public:
|
||||
raise HTTPException(status_code=403, detail="User does not have read access.")
|
||||
read_groups = user.read
|
||||
read_public = user.read_public
|
||||
return crud.search_archives_by_url(db, url.strip(), email, read_groups, read_public, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before)
|
||||
|
||||
|
||||
@url_router.delete("/{id}", summary="Delete a single URL archive by id.")
|
||||
def delete_archive(
|
||||
id:str,
|
||||
user: UserState = Depends(get_user_state),
|
||||
db: Session = Depends(get_db_dependency)
|
||||
) -> schemas.DeleteResponse:
|
||||
logger.info(f"deleting url archive task {id} request by {user.email}")
|
||||
return JSONResponse({
|
||||
"id": id,
|
||||
"deleted": crud.soft_delete_archive(db, id, user.email)
|
||||
})
|
||||
@@ -1,22 +1,32 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
import datetime
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import alembic.config
|
||||
from fastapi import FastAPI
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi_mail import FastMail, MessageSchema, MessageType
|
||||
from fastapi_utils.tasks import repeat_every
|
||||
from loguru import logger
|
||||
from fastapi_mail import FastMail, MessageSchema, MessageType
|
||||
|
||||
from app.shared.db import models
|
||||
from app.shared.db.database import get_db, get_db_async, make_engine, wal_checkpoint
|
||||
from app.shared import schemas
|
||||
from app.shared.db import models
|
||||
from app.shared.db.database import (
|
||||
get_db,
|
||||
get_db_async,
|
||||
make_engine,
|
||||
wal_checkpoint,
|
||||
)
|
||||
from app.shared.settings import get_settings
|
||||
from app.shared.task_messaging import get_celery
|
||||
from app.web.db import crud
|
||||
from app.web.middleware import increase_exceptions_counter
|
||||
from app.web.utils.metrics import measure_regular_metrics, redis_subscribe_worker_exceptions
|
||||
from app.web.utils.metrics import (
|
||||
measure_regular_metrics,
|
||||
redis_subscribe_worker_exceptions,
|
||||
)
|
||||
|
||||
|
||||
celery = get_celery()
|
||||
|
||||
@@ -28,9 +38,22 @@ async def lifespan(app: FastAPI):
|
||||
# STARTUP
|
||||
engine = make_engine(get_settings().DATABASE_PATH)
|
||||
models.Base.metadata.create_all(bind=engine)
|
||||
alembic.config.main(prog="alembic", argv=['--raiseerr', 'upgrade', 'head'])
|
||||
alembic.config.main(
|
||||
prog="alembic",
|
||||
argv=[
|
||||
"-c",
|
||||
"./app/migrations/alembic.ini",
|
||||
"--raiseerr",
|
||||
"upgrade",
|
||||
"head",
|
||||
],
|
||||
)
|
||||
logging.getLogger("uvicorn.access").disabled = True # loguru
|
||||
asyncio.create_task(redis_subscribe_worker_exceptions(get_settings().REDIS_EXCEPTIONS_CHANNEL))
|
||||
asyncio.create_task(
|
||||
redis_subscribe_worker_exceptions(
|
||||
get_settings().REDIS_EXCEPTIONS_CHANNEL
|
||||
)
|
||||
)
|
||||
asyncio.create_task(repeat_measure_regular_metrics())
|
||||
with get_db() as db:
|
||||
crud.upsert_user_groups(db)
|
||||
@@ -61,41 +84,74 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
|
||||
# CRON JOBS
|
||||
@repeat_every(seconds=get_settings().REPEAT_COUNT_METRICS_SECONDS, on_exception=increase_exceptions_counter)
|
||||
@repeat_every(
|
||||
seconds=get_settings().REPEAT_COUNT_METRICS_SECONDS,
|
||||
on_exception=increase_exceptions_counter,
|
||||
)
|
||||
async def repeat_measure_regular_metrics():
|
||||
await measure_regular_metrics(get_settings().DATABASE_PATH, get_settings().REPEAT_COUNT_METRICS_SECONDS)
|
||||
await measure_regular_metrics(
|
||||
get_settings().DATABASE_PATH,
|
||||
get_settings().REPEAT_COUNT_METRICS_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
@repeat_every(seconds=60, wait_first=120, on_exception=increase_exceptions_counter)
|
||||
@repeat_every(
|
||||
seconds=60, wait_first=120, on_exception=increase_exceptions_counter
|
||||
)
|
||||
async def archive_hourly_sheets_cronjob():
|
||||
await archive_sheets_cronjob("hourly", 60, datetime.datetime.now().minute)
|
||||
|
||||
|
||||
@repeat_every(seconds=3600, wait_first=120, on_exception=increase_exceptions_counter)
|
||||
@repeat_every(
|
||||
seconds=3600, wait_first=120, on_exception=increase_exceptions_counter
|
||||
)
|
||||
async def archive_daily_sheets_cronjob():
|
||||
await archive_sheets_cronjob("daily", 24, datetime.datetime.now().hour)
|
||||
|
||||
|
||||
async def archive_sheets_cronjob(frequency: str, interval: int, current_time_unit: int):
|
||||
async def archive_sheets_cronjob(
|
||||
frequency: str, interval: int, current_time_unit: int
|
||||
):
|
||||
triggered_jobs = []
|
||||
|
||||
async with get_db_async() as db:
|
||||
sheets = await crud.get_sheets_by_id_hash(db, frequency, interval, current_time_unit)
|
||||
sheets = await crud.get_sheets_by_id_hash(
|
||||
db, frequency, str(interval), current_time_unit
|
||||
)
|
||||
for s in sheets:
|
||||
group_queue = await crud.get_group_priority_async(db, s.group_id)
|
||||
task = celery.signature("create_sheet_task", args=[schemas.SubmitSheet(sheet_id=s.id, author_id=s.author_id, group_id=s.group_id).model_dump_json()]).apply_async(**group_queue)
|
||||
task = celery.signature(
|
||||
"create_sheet_task",
|
||||
args=[
|
||||
schemas.SubmitSheet(
|
||||
sheet_id=s.id,
|
||||
author_id=s.author_id,
|
||||
group_id=s.group_id,
|
||||
).model_dump_json()
|
||||
],
|
||||
).apply_async(**group_queue)
|
||||
|
||||
triggered_jobs.append({"sheet_id": s.id, "task_id": task.id})
|
||||
logger.debug(f"[CRON {frequency.upper()}:{current_time_unit}] Triggered {len(triggered_jobs)} sheet tasks: {triggered_jobs}")
|
||||
logger.debug(
|
||||
f"[CRON {frequency.upper()}:{current_time_unit}] Triggered {len(triggered_jobs)} sheet tasks: {triggered_jobs}"
|
||||
)
|
||||
|
||||
|
||||
# TODO: on exception should logerror but also prometheus counter
|
||||
DELETE_WINDOW = get_settings().DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS * 24 * 60 * 60
|
||||
DELETE_WINDOW = (
|
||||
get_settings().DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS * 24 * 60 * 60
|
||||
)
|
||||
|
||||
|
||||
@repeat_every(seconds=DELETE_WINDOW, wait_first=180, on_exception=increase_exceptions_counter)
|
||||
@repeat_every(
|
||||
seconds=DELETE_WINDOW,
|
||||
wait_first=180,
|
||||
on_exception=increase_exceptions_counter,
|
||||
)
|
||||
async def notify_about_expired_archives():
|
||||
notify_from = datetime.datetime.now() + datetime.timedelta(days=get_settings().DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS)
|
||||
notify_from = datetime.datetime.now() + datetime.timedelta(
|
||||
days=get_settings().DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS
|
||||
)
|
||||
async with get_db_async() as db:
|
||||
scheduled_deletions = await crud.find_by_store_until(db, notify_from)
|
||||
|
||||
@@ -104,10 +160,15 @@ async def notify_about_expired_archives():
|
||||
user_archives[archive.author_id].append(archive)
|
||||
|
||||
if user_archives:
|
||||
fastmail = FastMail(get_settings().MAIL_CONFIG)
|
||||
fastmail = FastMail(get_settings().mail_config)
|
||||
# notify users
|
||||
for email in user_archives:
|
||||
list_of_archives = "\n".join([f'{a.url}, {a.id}, {a.store_until.isoformat()}<br/>' for a in user_archives[email]])
|
||||
list_of_archives = "\n".join(
|
||||
[
|
||||
f"{a.url}, {a.id}, {a.store_until.isoformat()}<br/>"
|
||||
for a in user_archives[email]
|
||||
]
|
||||
)
|
||||
# TODO: how can users download them in bulk?
|
||||
message = MessageSchema(
|
||||
subject="Auto Archiver: Archives Scheduled for Deletion",
|
||||
@@ -127,16 +188,23 @@ async def notify_about_expired_archives():
|
||||
</body>
|
||||
</html>
|
||||
""",
|
||||
subtype=MessageType.html
|
||||
subtype=MessageType.html,
|
||||
)
|
||||
await fastmail.send_message(message)
|
||||
logger.debug(f"[CRON] Email sent to {email} about {len(user_archives[email])} scheduled archives deletion.")
|
||||
logger.debug(
|
||||
f"[CRON] Email sent to {email} about {len(user_archives[email])} scheduled archives deletion."
|
||||
)
|
||||
|
||||
# now schedule the deletion event
|
||||
asyncio.create_task(delete_expired_archives())
|
||||
|
||||
|
||||
@repeat_every(max_repetitions=1, wait_first=10, seconds=0, on_exception=increase_exceptions_counter)
|
||||
@repeat_every(
|
||||
max_repetitions=1,
|
||||
wait_first=10,
|
||||
seconds=0,
|
||||
on_exception=increase_exceptions_counter,
|
||||
)
|
||||
async def delete_expired_archives():
|
||||
async with get_db_async() as db:
|
||||
count_deleted = await crud.soft_delete_expired_archives(db)
|
||||
@@ -144,19 +212,27 @@ async def delete_expired_archives():
|
||||
logger.debug(f"[CRON] Deleted {count_deleted} archives.")
|
||||
|
||||
|
||||
@repeat_every(seconds=86400, wait_first=150, on_exception=increase_exceptions_counter)
|
||||
@repeat_every(
|
||||
seconds=86400, wait_first=150, on_exception=increase_exceptions_counter
|
||||
)
|
||||
async def delete_stale_sheets():
|
||||
STALE_DAYS = get_settings().DELETE_STALE_SHEETS_DAYS
|
||||
logger.debug(f"[CRON] Deleting stale sheets older than {STALE_DAYS} days.")
|
||||
async with get_db_async() as db:
|
||||
user_sheets = await crud.delete_stale_sheets(db, STALE_DAYS)
|
||||
|
||||
if not user_sheets: return
|
||||
if not user_sheets:
|
||||
return
|
||||
|
||||
fastmail = FastMail(get_settings().MAIL_CONFIG)
|
||||
fastmail = FastMail(get_settings().mail_config)
|
||||
# notify users
|
||||
for email in user_sheets:
|
||||
list_of_sheets = "\n".join([f'<li><a href="https://docs.google.com/spreadsheets/d/{s.id}">{s.name}</a></li>' for s in user_sheets[email]])
|
||||
list_of_sheets = "\n".join(
|
||||
[
|
||||
f'<li><a href="https://docs.google.com/spreadsheets/d/{s.id}">{s.name}</a></li>'
|
||||
for s in user_sheets[email]
|
||||
]
|
||||
)
|
||||
message = MessageSchema(
|
||||
subject="Auto Archiver: Stale Sheets Removed",
|
||||
recipients=[email],
|
||||
@@ -173,14 +249,16 @@ async def delete_stale_sheets():
|
||||
</body>
|
||||
</html>
|
||||
""",
|
||||
subtype=MessageType.html
|
||||
subtype=MessageType.html,
|
||||
)
|
||||
await fastmail.send_message(message)
|
||||
logger.debug(f"[CRON] Email sent to {email} about stale sheets deletion.")
|
||||
logger.debug(
|
||||
f"[CRON] Email sent to {email} about stale sheets deletion."
|
||||
)
|
||||
|
||||
|
||||
# @repeat_at
|
||||
async def generate_users_export_csv():
|
||||
#TODO: implement a cronjob that regularly requested user data to a CSV file
|
||||
# TODO: implement a cronjob that regularly requested user data to a CSV file
|
||||
# see https://colab.research.google.com/drive/1QDbo3QXHPBdiTuANlA1AWVvN-rqxuCPa?authuser=0#scrollTo=4nPXeSdK8RBT
|
||||
pass
|
||||
pass
|
||||
|
||||
@@ -1,34 +1,42 @@
|
||||
import os
|
||||
from fastapi import FastAPI, Depends
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from loguru import logger
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
|
||||
from app.web.middleware import logging_middleware
|
||||
from app.shared.settings import Settings, get_settings
|
||||
from app.shared.task_messaging import get_celery
|
||||
|
||||
from app.web.security import token_api_key_auth
|
||||
from app.web.config import VERSION, API_DESCRIPTION
|
||||
from app.web.config import API_DESCRIPTION, VERSION
|
||||
from app.web.events import lifespan
|
||||
from app.shared.settings import get_settings
|
||||
from app.web.middleware import logging_middleware
|
||||
from app.web.routers.default import router as default_router
|
||||
from app.web.routers.interoperability import router as interoperability_router
|
||||
from app.web.routers.sheet import router as sheet_router
|
||||
from app.web.routers.task import router as task_router
|
||||
from app.web.routers.url import router as url_router
|
||||
from app.web.security import token_api_key_auth
|
||||
|
||||
|
||||
from app.web.endpoints.default import default_router
|
||||
from app.web.endpoints.url import url_router
|
||||
from app.web.endpoints.sheet import sheet_router
|
||||
from app.web.endpoints.task import task_router
|
||||
from app.web.endpoints.interoperability import interoperability_router
|
||||
|
||||
celery = get_celery()
|
||||
|
||||
def app_factory(settings = get_settings()):
|
||||
|
||||
def app_factory(settings: Settings = None):
|
||||
# TODO: Create dev, test, and prod versions of settings that do not have
|
||||
# TODO: to be passed in as a parameter
|
||||
if settings is None:
|
||||
settings = get_settings()
|
||||
|
||||
app = FastAPI(
|
||||
title="Auto-Archiver API",
|
||||
description=API_DESCRIPTION,
|
||||
version=VERSION,
|
||||
contact={"name": "GitHub", "url": "https://github.com/bellingcat/auto-archiver-api"},
|
||||
lifespan=lifespan
|
||||
contact={
|
||||
"name": "GitHub",
|
||||
"url": "https://github.com/bellingcat/auto-archiver-api",
|
||||
},
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
@@ -47,14 +55,30 @@ def app_factory(settings = get_settings()):
|
||||
app.include_router(interoperability_router)
|
||||
|
||||
# prometheus exposed in /metrics with authentication
|
||||
Instrumentator(should_group_status_codes=False, excluded_handlers=["/metrics", "/health", "/openapi.json", "/favicon.ico"]).instrument(app).expose(app, dependencies=[Depends(token_api_key_auth)])
|
||||
Instrumentator(
|
||||
should_group_status_codes=False,
|
||||
excluded_handlers=[
|
||||
"/metrics",
|
||||
"/health",
|
||||
"/openapi.json",
|
||||
"/favicon.ico",
|
||||
],
|
||||
).instrument(app).expose(app, dependencies=[Depends(token_api_key_auth)])
|
||||
|
||||
if settings.SERVE_LOCAL_ARCHIVE:
|
||||
local_dir = settings.SERVE_LOCAL_ARCHIVE
|
||||
if not os.path.isdir(local_dir) and os.path.isdir(local_dir.replace("/app", ".")):
|
||||
if not os.path.isdir(local_dir) and os.path.isdir(
|
||||
local_dir.replace("/app", ".")
|
||||
):
|
||||
local_dir = local_dir.replace("/app", ".")
|
||||
if len(settings.SERVE_LOCAL_ARCHIVE) > 1 and os.path.isdir(local_dir):
|
||||
logger.warning(f"MOUNTing local archive, use this in development only {settings.SERVE_LOCAL_ARCHIVE}")
|
||||
app.mount(settings.SERVE_LOCAL_ARCHIVE, StaticFiles(directory=local_dir), name=settings.SERVE_LOCAL_ARCHIVE)
|
||||
logger.warning(
|
||||
f"MOUNTing local archive, use this in development only {settings.SERVE_LOCAL_ARCHIVE}"
|
||||
)
|
||||
app.mount(
|
||||
settings.SERVE_LOCAL_ARCHIVE,
|
||||
StaticFiles(directory=local_dir),
|
||||
name=settings.SERVE_LOCAL_ARCHIVE,
|
||||
)
|
||||
|
||||
return app
|
||||
return app
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
|
||||
import traceback
|
||||
from loguru import logger
|
||||
|
||||
from fastapi import Request
|
||||
from loguru import logger
|
||||
|
||||
from app.shared.log import log_error
|
||||
from app.web.utils.metrics import EXCEPTION_COUNTER
|
||||
|
||||
@@ -9,23 +10,33 @@ from app.web.utils.metrics import EXCEPTION_COUNTER
|
||||
async def logging_middleware(request: Request, call_next):
|
||||
try:
|
||||
response = await call_next(request)
|
||||
#TODO: use Origin to have summary prometheus metrics on where requests come from
|
||||
# TODO: use Origin to have summary prometheus metrics on where
|
||||
# requests come from
|
||||
# origin = request.headers.get("origin")
|
||||
logger.info(f"{request.client.host}:{request.client.port} {request.method} {request.url._url} - HTTP {response.status_code}")
|
||||
logger.info(
|
||||
f"{request.client.host}:{request.client.port} {request.method} {request.url._url} - HTTP {response.status_code}"
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
location = f"{request.method} {request.url._url}"
|
||||
await increase_exceptions_counter(e, location)
|
||||
logger.info(f"{request.client.host}:{request.client.port} {location} - {e.__class__.__name__} {e}")
|
||||
logger.info(
|
||||
f"{request.client.host}:{request.client.port} {location} - {e.__class__.__name__} {e}"
|
||||
)
|
||||
raise e
|
||||
|
||||
async def increase_exceptions_counter(e: Exception, location:str="cronjob"):
|
||||
|
||||
async def increase_exceptions_counter(
|
||||
e: Exception, location: str = "cronjob"
|
||||
) -> None:
|
||||
if location == "cronjob":
|
||||
try:
|
||||
last_trace = traceback.extract_tb(e.__traceback__)[-1]
|
||||
_file, _line, func_name, _text = last_trace
|
||||
location = func_name
|
||||
except Exception as e:
|
||||
logger.error(f"Unable to get function name from cronjob exception traceback: {e}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unable to get function name from cronjob exception traceback: {e}"
|
||||
)
|
||||
EXCEPTION_COUNTER.labels(type=e.__class__.__name__, location=location).inc()
|
||||
log_error(e)
|
||||
log_error(e)
|
||||
|
||||
0
app/web/routers/__init__.py
Normal file
0
app/web/routers/__init__.py
Normal file
64
app/web/routers/default.py
Normal file
64
app/web/routers/default.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from http import HTTPStatus
|
||||
from typing import Dict
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
|
||||
from app.shared.schemas import ActiveUser, UsageResponse
|
||||
from app.shared.user_groups import GroupInfo
|
||||
from app.web.config import BREAKING_CHANGES, VERSION
|
||||
from app.web.db.user_state import UserState
|
||||
from app.web.security import get_user_state
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def home() -> JSONResponse:
|
||||
return JSONResponse(
|
||||
{"version": VERSION, "breakingChanges": BREAKING_CHANGES}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health() -> JSONResponse:
|
||||
return JSONResponse({"status": "ok"})
|
||||
|
||||
|
||||
@router.get(
|
||||
"/user/active", summary="Check if the user is active and can use the tool."
|
||||
)
|
||||
async def active(
|
||||
user: UserState = Depends(get_user_state),
|
||||
) -> ActiveUser:
|
||||
return ActiveUser(active=user.active)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/user/permissions",
|
||||
summary="Get the user's global 'all' permissions and the permissions for each group they belong to.",
|
||||
)
|
||||
def get_user_permissions(
|
||||
user: UserState = Depends(get_user_state),
|
||||
) -> Dict[str, GroupInfo]:
|
||||
return user.permissions
|
||||
|
||||
|
||||
@router.get(
|
||||
"/user/usage",
|
||||
summary="Get the user's monthly URLs/MBs usage along with the total active sheets, breakdown by group.",
|
||||
)
|
||||
def get_user_usage(
|
||||
user: UserState = Depends(get_user_state),
|
||||
) -> UsageResponse:
|
||||
if not user.active:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN, detail="User is not active."
|
||||
)
|
||||
return user.usage()
|
||||
|
||||
|
||||
@router.get("/favicon.ico", include_in_schema=False)
|
||||
async def favicon() -> FileResponse:
|
||||
return FileResponse("app/web/static/favicon.ico")
|
||||
@@ -1,41 +1,53 @@
|
||||
import json
|
||||
from http import HTTPStatus
|
||||
|
||||
import sqlalchemy
|
||||
from auto_archiver.core import Metadata
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from loguru import logger
|
||||
import sqlalchemy
|
||||
from auto_archiver.core import Metadata
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.shared.aa_utils import get_all_urls
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
from app.shared import business_logic, schemas
|
||||
from app.shared.db import worker_crud
|
||||
from app.shared.db import models, worker_crud
|
||||
from app.shared.db.database import get_db_dependency
|
||||
from app.web.security import token_api_key_auth
|
||||
from app.shared.db import models
|
||||
from app.shared.log import log_error
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
from app.web.security import token_api_key_auth
|
||||
from app.web.utils.misc import get_all_urls
|
||||
|
||||
|
||||
interoperability_router = APIRouter(prefix="/interop", tags=["Interoperability endpoints."])
|
||||
router = APIRouter(prefix="/interop", tags=["Interoperability endpoints."])
|
||||
|
||||
|
||||
# ----- endpoint to submit data archived elsewhere
|
||||
@interoperability_router.post("/submit-archive", status_code=201, summary="Submit a manual archive entry, for data that was archived elsewhere.")
|
||||
@router.post(
|
||||
"/submit-archive",
|
||||
status_code=HTTPStatus.CREATED,
|
||||
summary="Submit a manual archive entry, for data that was archived elsewhere.",
|
||||
)
|
||||
def submit_manual_archive(
|
||||
manual: schemas.SubmitManualArchive,
|
||||
auth=Depends(token_api_key_auth),
|
||||
db: Session = Depends(get_db_dependency)
|
||||
db: Session = Depends(get_db_dependency),
|
||||
):
|
||||
try:
|
||||
result: Metadata = Metadata.from_json(manual.result)
|
||||
except json.JSONDecodeError as e:
|
||||
log_error(e)
|
||||
raise HTTPException(status_code=422, detail="Invalid JSON in result field.")
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
|
||||
detail="Invalid JSON in result field.",
|
||||
) from e
|
||||
manual.author_id = manual.author_id or ALLOW_ANY_EMAIL
|
||||
manual.tags.add("manual")
|
||||
|
||||
store_until = business_logic.get_store_archive_until_or_never(db, manual.group_id)
|
||||
logger.debug(f"[MANUAL ARCHIVE] {manual.author_id} {manual.url} {store_until}")
|
||||
store_until = business_logic.get_store_archive_until_or_never(
|
||||
db, manual.group_id
|
||||
)
|
||||
logger.debug(
|
||||
f"[MANUAL ARCHIVE] {manual.author_id} {manual.url} {store_until}"
|
||||
)
|
||||
|
||||
try:
|
||||
archive = schemas.ArchiveCreate(
|
||||
@@ -51,8 +63,15 @@ def submit_manual_archive(
|
||||
)
|
||||
|
||||
db_archive = worker_crud.store_archived_url(db, archive)
|
||||
logger.debug(f"[MANUAL ARCHIVE STORED] {db_archive.author_id} {db_archive.url}")
|
||||
return JSONResponse({"id": db_archive.id}, status_code=201)
|
||||
logger.debug(
|
||||
f"[MANUAL ARCHIVE STORED] {db_archive.author_id} {db_archive.url}"
|
||||
)
|
||||
return JSONResponse(
|
||||
{"id": db_archive.id}, status_code=HTTPStatus.CREATED
|
||||
)
|
||||
except sqlalchemy.exc.IntegrityError as e:
|
||||
log_error(e)
|
||||
raise HTTPException(status_code=422, detail=f"Cannot insert into DB due to integrity error, likely duplicate urls.")
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
|
||||
detail="Cannot insert into DB due to integrity error, likely duplicate urls.",
|
||||
) from e
|
||||
132
app/web/routers/sheet.py
Normal file
132
app/web/routers/sheet.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy import exc
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.shared.db.database import get_db_dependency
|
||||
from app.shared.schemas import (
|
||||
DeleteResponse,
|
||||
SheetAdd,
|
||||
SheetResponse,
|
||||
SubmitSheet,
|
||||
)
|
||||
from app.shared.task_messaging import get_celery
|
||||
from app.web.db import crud
|
||||
from app.web.db.user_state import UserState
|
||||
from app.web.security import get_user_state
|
||||
|
||||
|
||||
router = APIRouter(prefix="/sheet", tags=["Google Spreadsheet operations"])
|
||||
|
||||
celery = get_celery()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/create",
|
||||
status_code=HTTPStatus.CREATED,
|
||||
summary="Store a new Google Sheet for regular archiving.",
|
||||
)
|
||||
def create_sheet(
|
||||
sheet: SheetAdd,
|
||||
user: UserState = Depends(get_user_state),
|
||||
db: Session = Depends(get_db_dependency),
|
||||
) -> SheetResponse:
|
||||
if not user.in_group(sheet.group_id):
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="User does not have access to this group.",
|
||||
)
|
||||
|
||||
if not user.has_quota_monthly_sheets(sheet.group_id):
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.TOO_MANY_REQUESTS,
|
||||
detail="User has reached their sheet quota for this group.",
|
||||
)
|
||||
|
||||
if not user.is_sheet_frequency_allowed(sheet.group_id, sheet.frequency):
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
|
||||
detail="Invalid frequency selected for this group.",
|
||||
)
|
||||
|
||||
try:
|
||||
return crud.create_sheet(
|
||||
db,
|
||||
sheet.id,
|
||||
sheet.name,
|
||||
user.email,
|
||||
sheet.group_id,
|
||||
sheet.frequency,
|
||||
)
|
||||
except exc.IntegrityError as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
detail="Sheet with this ID is already being archived.",
|
||||
) from e
|
||||
|
||||
|
||||
@router.get(
|
||||
"/mine",
|
||||
status_code=HTTPStatus.OK,
|
||||
summary="Get the authenticated user's Google Sheets.",
|
||||
)
|
||||
def get_user_sheets(
|
||||
user: UserState = Depends(get_user_state),
|
||||
db: Session = Depends(get_db_dependency),
|
||||
) -> list[SheetResponse]:
|
||||
return crud.get_user_sheets(db, user.email)
|
||||
|
||||
|
||||
@router.delete("/{sheet_id}", summary="Delete a Google Sheet by ID.")
|
||||
def delete_sheet(
|
||||
sheet_id: str,
|
||||
user: UserState = Depends(get_user_state),
|
||||
db: Session = Depends(get_db_dependency),
|
||||
) -> DeleteResponse:
|
||||
return DeleteResponse(
|
||||
id=sheet_id, deleted=crud.delete_sheet(db, sheet_id, user.email)
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{sheet_id}/archive",
|
||||
status_code=HTTPStatus.CREATED,
|
||||
summary="Trigger an archiving task for a GSheet you own.",
|
||||
response_description="task_id for the archiving task.",
|
||||
)
|
||||
def archive_user_sheet(
|
||||
sheet_id: str,
|
||||
user: UserState = Depends(get_user_state),
|
||||
db: Session = Depends(get_db_dependency),
|
||||
) -> JSONResponse:
|
||||
sheet = crud.get_user_sheet(db, user.email, sheet_id=sheet_id)
|
||||
if not sheet:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN, detail="No access to this sheet."
|
||||
)
|
||||
|
||||
if not user.in_group(sheet.group_id):
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="User does not have access to this group.",
|
||||
)
|
||||
|
||||
if not user.can_manually_trigger(sheet.group_id):
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.TOO_MANY_REQUESTS,
|
||||
detail="User cannot manually trigger sheet archiving in this group.",
|
||||
)
|
||||
|
||||
group_queue = user.priority_group(sheet.group_id)
|
||||
task = celery.signature(
|
||||
"create_sheet_task",
|
||||
args=[
|
||||
SubmitSheet(
|
||||
sheet_id=sheet_id, author_id=user.email, group_id=sheet.group_id
|
||||
).model_dump_json()
|
||||
],
|
||||
).apply_async(**group_queue)
|
||||
|
||||
return JSONResponse({"id": task.id}, status_code=HTTPStatus.CREATED)
|
||||
52
app/web/routers/task.py
Normal file
52
app/web/routers/task.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from celery.result import AsyncResult
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.shared import schemas
|
||||
from app.shared.constants import STATUS_FAILURE
|
||||
from app.shared.log import log_error
|
||||
from app.shared.task_messaging import get_celery
|
||||
from app.web.security import get_token_or_user_auth
|
||||
from app.web.utils.misc import custom_jsonable_encoder
|
||||
|
||||
|
||||
router = APIRouter(prefix="/task", tags=["Async task operations"])
|
||||
|
||||
celery = get_celery()
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{task_id}",
|
||||
summary="Check the status of an async task by its id, works for URLs and Sheet tasks.",
|
||||
)
|
||||
def get_status(
|
||||
task_id, email=Depends(get_token_or_user_auth)
|
||||
) -> schemas.TaskResult:
|
||||
task = AsyncResult(task_id, app=celery)
|
||||
try:
|
||||
if task.status == STATUS_FAILURE:
|
||||
# *FAILURE* The task raised an exception, or has exceeded the retry limit.
|
||||
# The :attr:`result` attribute then contains the exception raised by
|
||||
# the task.
|
||||
# https://docs.celeryq.dev/en/stable/_modules/celery/result.html#AsyncResult
|
||||
raise task.result
|
||||
|
||||
response = {"id": task_id, "status": task.status, "result": task.result}
|
||||
return JSONResponse(
|
||||
jsonable_encoder(
|
||||
response,
|
||||
exclude_unset=True,
|
||||
custom_encoder={bytes: custom_jsonable_encoder},
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log_error(e)
|
||||
return JSONResponse(
|
||||
{
|
||||
"id": task_id,
|
||||
"status": STATUS_FAILURE,
|
||||
"result": {"error": str(e)},
|
||||
}
|
||||
)
|
||||
125
app/web/routers/url.py
Normal file
125
app/web/routers/url.py
Normal file
@@ -0,0 +1,125 @@
|
||||
from datetime import datetime
|
||||
from http import HTTPStatus
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from loguru import logger
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.shared import schemas
|
||||
from app.shared.db.database import get_db_dependency
|
||||
from app.shared.schemas import DeleteResponse
|
||||
from app.shared.task_messaging import get_celery
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
from app.web.db import crud
|
||||
from app.web.db.user_state import UserState
|
||||
from app.web.security import get_token_or_user_auth, get_user_state
|
||||
from app.web.utils.misc import convert_priority_to_queue_dict
|
||||
|
||||
|
||||
router = APIRouter(prefix="/url", tags=["Single URL operations"])
|
||||
|
||||
celery = get_celery()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/archive",
|
||||
status_code=HTTPStatus.CREATED,
|
||||
summary="Submit a single URL archive request, starts an archiving task.",
|
||||
response_description="task_id for the archiving task, will match the archive id.",
|
||||
)
|
||||
def archive_url(
|
||||
archive: schemas.ArchiveTrigger,
|
||||
email=Depends(get_token_or_user_auth),
|
||||
db: Session = Depends(get_db_dependency),
|
||||
) -> JSONResponse:
|
||||
logger.info(
|
||||
f"new {archive.public=} task for {email=} and {archive.group_id=}: {archive.url}"
|
||||
)
|
||||
|
||||
parsed_url = urlparse(archive.url)
|
||||
if not all([parsed_url.scheme, parsed_url.netloc]):
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST, detail="Invalid URL received."
|
||||
)
|
||||
|
||||
archive_create = schemas.ArchiveCreate(**archive.model_dump())
|
||||
if email != ALLOW_ANY_EMAIL:
|
||||
archive_create.author_id = email
|
||||
user = UserState(db, email)
|
||||
if archive.group_id and not user.in_group(archive.group_id):
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="User does not have access to this group.",
|
||||
)
|
||||
if not user.has_quota_max_monthly_urls(archive.group_id):
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.TOO_MANY_REQUESTS,
|
||||
detail="User has reached their monthly URL quota.",
|
||||
)
|
||||
if not user.has_quota_max_monthly_mbs(archive.group_id):
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.TOO_MANY_REQUESTS,
|
||||
detail="User has reached their monthly MB quota.",
|
||||
)
|
||||
group_queue = user.priority_group(archive_create.group_id)
|
||||
else:
|
||||
archive_create.author_id = archive.author_id or email
|
||||
group_queue = convert_priority_to_queue_dict("high")
|
||||
|
||||
task = celery.signature(
|
||||
"create_archive_task", args=[archive_create.model_dump_json()]
|
||||
).apply_async(**group_queue)
|
||||
task_response = schemas.Task(id=task.id)
|
||||
return JSONResponse(
|
||||
task_response.model_dump(), status_code=HTTPStatus.CREATED
|
||||
)
|
||||
|
||||
|
||||
@router.get("/search", summary="Search for archive entries by URL.")
|
||||
def search_by_url(
|
||||
url: str,
|
||||
skip: int = 0,
|
||||
limit: int = 25,
|
||||
archived_after: datetime = None,
|
||||
archived_before: datetime = None,
|
||||
db: Session = Depends(get_db_dependency),
|
||||
email: str = Depends(get_token_or_user_auth),
|
||||
) -> list[schemas.ArchiveResult]:
|
||||
read_groups, read_public = False, False
|
||||
if email != ALLOW_ANY_EMAIL:
|
||||
user = UserState(db, email)
|
||||
if not user.read and not user.read_public:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="User does not have read access.",
|
||||
)
|
||||
read_groups = user.read
|
||||
read_public = user.read_public
|
||||
return crud.search_archives_by_url(
|
||||
db,
|
||||
url.strip(),
|
||||
email,
|
||||
read_groups,
|
||||
read_public,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
archived_after=archived_after,
|
||||
archived_before=archived_before,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{archive_id}", summary="Delete a single URL archive by id.")
|
||||
def delete_archive(
|
||||
archive_id: str,
|
||||
user: UserState = Depends(get_user_state),
|
||||
db: Session = Depends(get_db_dependency),
|
||||
) -> DeleteResponse:
|
||||
logger.info(
|
||||
f"deleting url archive task {archive_id} request by {user.email}"
|
||||
)
|
||||
return DeleteResponse(
|
||||
id=archive_id,
|
||||
deleted=crud.soft_delete_archive(db, archive_id, user.email),
|
||||
)
|
||||
@@ -1,27 +1,32 @@
|
||||
from loguru import logger
|
||||
import requests, secrets
|
||||
from fastapi import HTTPException, status, Depends
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from sqlalchemy.orm import Session
|
||||
import secrets
|
||||
from http import HTTPStatus
|
||||
|
||||
import firebase_admin
|
||||
from firebase_admin import credentials, auth, exceptions
|
||||
import requests
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from firebase_admin import auth, credentials, exceptions
|
||||
from loguru import logger
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
from app.shared.settings import get_settings
|
||||
from app.shared.db.database import get_db_dependency
|
||||
from app.shared.settings import get_settings
|
||||
from app.web.config import ALLOW_ANY_EMAIL
|
||||
from app.web.db.user_state import UserState
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
bearer_security = HTTPBearer()
|
||||
|
||||
FIREBASE_OAUTH_ENABLED = settings.FIREBASE_SERVICE_ACCOUNT_JSON != ""
|
||||
if FIREBASE_OAUTH_ENABLED:
|
||||
logger.debug("Firebase OAUTH enabled, initializing...")
|
||||
firebase_admin.initialize_app(credentials.Certificate(settings.FIREBASE_SERVICE_ACCOUNT_JSON))
|
||||
firebase_admin.initialize_app(
|
||||
credentials.Certificate(settings.FIREBASE_SERVICE_ACCOUNT_JSON)
|
||||
)
|
||||
|
||||
|
||||
def secure_compare(token, api_key):
|
||||
def secure_compare(token, api_key) -> bool:
|
||||
return secrets.compare_digest(token.encode("utf8"), api_key.encode("utf8"))
|
||||
|
||||
|
||||
@@ -29,9 +34,13 @@ def secure_compare(token, api_key):
|
||||
def api_key_auth(api_key):
|
||||
assert len(api_key) >= 20, "Invalid API key, must be at least 20 chars"
|
||||
|
||||
async def auth(bearer: HTTPAuthorizationCredentials = Depends(bearer_security), auto_error=True):
|
||||
async def auth(
|
||||
bearer: HTTPAuthorizationCredentials = Depends(bearer_security),
|
||||
auto_error=True,
|
||||
):
|
||||
is_correct = secure_compare(bearer.credentials, api_key)
|
||||
if is_correct: return True
|
||||
if is_correct:
|
||||
return True
|
||||
|
||||
if auto_error:
|
||||
raise HTTPException(
|
||||
@@ -43,18 +52,23 @@ def api_key_auth(api_key):
|
||||
return auth
|
||||
|
||||
|
||||
# --------------------- Token Auth for AA itself to query the API, AA setup tool and Prometheus
|
||||
# --- Token Auth for AA itself to query the API, AA setup tool and Prometheus
|
||||
token_api_key_auth = api_key_auth(settings.API_BEARER_TOKEN)
|
||||
|
||||
|
||||
async def get_token_or_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
|
||||
async def get_token_or_user_auth(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(bearer_security),
|
||||
):
|
||||
# tries to use the static API_KEY and defaults to google JWT auth
|
||||
if await token_api_key_auth(credentials, auto_error=False): return ALLOW_ANY_EMAIL
|
||||
if await token_api_key_auth(credentials, auto_error=False):
|
||||
return ALLOW_ANY_EMAIL
|
||||
return await get_user_auth(credentials)
|
||||
|
||||
|
||||
async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
|
||||
# validates the Bearer token in the case that it requires it
|
||||
async def get_user_auth(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(bearer_security),
|
||||
):
|
||||
# Validates the Bearer token in the case that it requires it
|
||||
valid_user, info = authenticate_user(credentials.credentials)
|
||||
if valid_user:
|
||||
return info.lower()
|
||||
@@ -66,39 +80,55 @@ async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bear
|
||||
)
|
||||
|
||||
|
||||
def authenticate_user(access_token):
|
||||
def authenticate_user(access_token) -> (bool, str):
|
||||
if FIREBASE_OAUTH_ENABLED:
|
||||
try:
|
||||
j = auth.verify_id_token(access_token)
|
||||
email = j.get('email', None)
|
||||
logger.debug(f"Successfully verified the ID token for {email}")
|
||||
if email is None:
|
||||
return False, "email not found in token"
|
||||
if email in settings.BLOCKED_EMAILS:
|
||||
return False, f"email '{email}' not allowed"
|
||||
return True, email
|
||||
except exceptions.FirebaseError as e:
|
||||
logger.warning(f"Error verifying ID token: {str(e)[:80]}...")
|
||||
return firebase_login_attempt(access_token)
|
||||
except exceptions.FirebaseError:
|
||||
# used a non-Firebase token, fallback to Google OAuth
|
||||
pass
|
||||
|
||||
# https://cloud.google.com/docs/authentication/token-types#access
|
||||
if type(access_token) != str or len(access_token) < 10: return False, "invalid access_token"
|
||||
r = requests.get("https://oauth2.googleapis.com/tokeninfo", {"access_token": access_token})
|
||||
if r.status_code != 200: return False, "invalid token"
|
||||
if not isinstance(access_token, str) or len(access_token) < 10:
|
||||
return False, "invalid access_token"
|
||||
r = requests.get(
|
||||
"https://oauth2.googleapis.com/tokeninfo",
|
||||
{"access_token": access_token},
|
||||
)
|
||||
if r.status_code != HTTPStatus.OK:
|
||||
return False, "invalid token"
|
||||
try:
|
||||
j = r.json()
|
||||
if j.get("azp") not in settings.CHROME_APP_IDS and j.get("aud") not in settings.CHROME_APP_IDS:
|
||||
return False, f"token does not belong to valid APP_ID"
|
||||
if (
|
||||
j.get("azp") not in settings.CHROME_APP_IDS
|
||||
and j.get("aud") not in settings.CHROME_APP_IDS
|
||||
):
|
||||
return False, "token does not belong to valid APP_ID"
|
||||
if j.get("email") in settings.BLOCKED_EMAILS:
|
||||
return False, f"email '{j.get('email')}' not allowed"
|
||||
if j.get("email_verified") != "true":
|
||||
return False, f"email '{j.get('email')}' not verified"
|
||||
if int(j.get("expires_in", -1)) <= 0:
|
||||
return False, "Token expired"
|
||||
return True, j.get('email').lower()
|
||||
return True, j.get("email").lower()
|
||||
except Exception as e:
|
||||
logger.warning(f"AUTH EXCEPTION occurred: {e}")
|
||||
return False, "exception occurred"
|
||||
|
||||
|
||||
def get_user_state(email: str = Depends(get_user_auth), db: Session = Depends(get_db_dependency)):
|
||||
def firebase_login_attempt(access_token) -> (bool, str):
|
||||
j = auth.verify_id_token(access_token)
|
||||
email = j.get("email", None)
|
||||
logger.debug(f"Successfully verified the ID token for {email}")
|
||||
if email is None:
|
||||
return False, "email not found in token"
|
||||
if email in settings.BLOCKED_EMAILS:
|
||||
return False, f"email '{email}' not allowed"
|
||||
return True, email
|
||||
|
||||
|
||||
def get_user_state(
|
||||
email: str = Depends(get_user_auth),
|
||||
db: Session = Depends(get_db_dependency),
|
||||
) -> UserState:
|
||||
return UserState(db, email)
|
||||
|
||||
@@ -2,52 +2,57 @@ import asyncio
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from prometheus_client import Counter, Gauge
|
||||
|
||||
from app.web.db import crud
|
||||
from app.shared.db.database import get_db
|
||||
from app.shared.log import log_error
|
||||
from app.shared.task_messaging import get_redis
|
||||
from app.web.db import crud
|
||||
|
||||
|
||||
# Custom metrics
|
||||
EXCEPTION_COUNTER = Counter(
|
||||
"exceptions",
|
||||
"Number of times a certain exception has occurred.",
|
||||
labelnames=["type", "location"]
|
||||
labelnames=["type", "location"],
|
||||
)
|
||||
WORKER_EXCEPTION = Counter(
|
||||
"worker_exceptions_total",
|
||||
"Number of times a certain exception has occurred on the worker.",
|
||||
labelnames=["type", "exception", "task", "traceback"]
|
||||
labelnames=["type", "exception", "task", "traceback"],
|
||||
)
|
||||
DISK_UTILIZATION = Gauge(
|
||||
"disk_utilization",
|
||||
"Disk utilization in GB",
|
||||
labelnames=["type"]
|
||||
"disk_utilization", "Disk utilization in GB", labelnames=["type"]
|
||||
)
|
||||
DATABASE_METRICS = Gauge(
|
||||
"database_metrics",
|
||||
"Database metric readings at a certain point in time",
|
||||
labelnames=["query"]
|
||||
labelnames=["query"],
|
||||
)
|
||||
DATABASE_METRICS_COUNTER = Counter(
|
||||
"database_metrics_counter",
|
||||
"Database metrics that increase over time",
|
||||
labelnames=["query", "user"]
|
||||
labelnames=["query", "user"],
|
||||
)
|
||||
|
||||
|
||||
async def redis_subscribe_worker_exceptions(REDIS_EXCEPTIONS_CHANNEL: str):
|
||||
# Subscribe to Redis channel and increment the counter for each exception with info on the exception and task
|
||||
async def redis_subscribe_worker_exceptions(redis_exceptions_channel: str):
|
||||
# Subscribe to Redis channel and increment the counter for each exception
|
||||
# with info on the exception and task
|
||||
Redis = get_redis()
|
||||
PubSubExceptions = Redis.pubsub()
|
||||
PubSubExceptions.subscribe(REDIS_EXCEPTIONS_CHANNEL)
|
||||
PubSubExceptions.subscribe(redis_exceptions_channel)
|
||||
while True:
|
||||
message = PubSubExceptions.get_message()
|
||||
if message and message["type"] == "message":
|
||||
data = json.loads(message["data"].decode("utf-8"))
|
||||
WORKER_EXCEPTION.labels(type=data["type"], exception=data["exception"], task=data["task"], traceback=data["traceback"]).inc()
|
||||
WORKER_EXCEPTION.labels(
|
||||
type=data["type"],
|
||||
exception=data["exception"],
|
||||
task=data["task"],
|
||||
traceback=data["traceback"],
|
||||
).inc()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
@@ -58,12 +63,19 @@ async def measure_regular_metrics(sqlite_db_url: str, repeat_in_seconds: int):
|
||||
try:
|
||||
fs = os.stat(sqlite_db_url.replace("sqlite:///", ""))
|
||||
DISK_UTILIZATION.labels(type="database").set(fs.st_size / (2**30))
|
||||
except Exception as e: log_error(e)
|
||||
except Exception as e:
|
||||
log_error(e)
|
||||
|
||||
with get_db() as db:
|
||||
DATABASE_METRICS.labels(query="count_archives").set(crud.count_archives(db))
|
||||
DATABASE_METRICS.labels(query="count_archive_urls").set(crud.count_archive_urls(db))
|
||||
DATABASE_METRICS.labels(query="count_archives").set(
|
||||
crud.count_archives(db)
|
||||
)
|
||||
DATABASE_METRICS.labels(query="count_archive_urls").set(
|
||||
crud.count_archive_urls(db)
|
||||
)
|
||||
DATABASE_METRICS.labels(query="count_users").set(crud.count_users(db))
|
||||
|
||||
for user in crud.count_by_user_since(db, repeat_in_seconds):
|
||||
DATABASE_METRICS_COUNTER.labels(query="count_by_user", user=user.author_id).inc(user.total)
|
||||
DATABASE_METRICS_COUNTER.labels(
|
||||
query="count_by_user", user=user.author_id
|
||||
).inc(user.total)
|
||||
|
||||
@@ -1,15 +1,62 @@
|
||||
import base64
|
||||
from typing import List
|
||||
|
||||
from auto_archiver.core import Media, Metadata
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from loguru import logger
|
||||
|
||||
from app.shared.db import models
|
||||
|
||||
|
||||
def custom_jsonable_encoder(obj):
|
||||
if isinstance(obj, bytes):
|
||||
return base64.b64encode(obj).decode('utf-8')
|
||||
return base64.b64encode(obj).decode("utf-8")
|
||||
return jsonable_encoder(obj)
|
||||
|
||||
|
||||
def convert_priority_to_queue_dict(priority: str) -> dict:
|
||||
return {
|
||||
"priority": 0 if priority == "high" else 10,
|
||||
"queue": f"{priority}_priority"
|
||||
"queue": f"{priority}_priority",
|
||||
}
|
||||
|
||||
|
||||
def convert_if_media(media):
|
||||
if isinstance(media, Media):
|
||||
return media
|
||||
elif isinstance(media, dict):
|
||||
try:
|
||||
return Media.from_dict(media)
|
||||
except Exception as e:
|
||||
logger.debug(f"error parsing {media} : {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_all_urls(result: Metadata) -> List[models.ArchiveUrl]:
|
||||
db_urls = []
|
||||
for m in result.media:
|
||||
for i, url in enumerate(m.urls):
|
||||
db_urls.append(
|
||||
models.ArchiveUrl(url=url, key=m.get("id", f"media_{i}"))
|
||||
)
|
||||
for k, prop in m.properties.items():
|
||||
if prop_converted := convert_if_media(prop):
|
||||
for i, url in enumerate(prop_converted.urls):
|
||||
db_urls.append(
|
||||
models.ArchiveUrl(
|
||||
url=url, key=prop_converted.get("id", f"{k}_{i}")
|
||||
)
|
||||
)
|
||||
if isinstance(prop, list):
|
||||
for i, prop_media in enumerate(prop):
|
||||
if prop_media := convert_if_media(prop_media):
|
||||
for j, url in enumerate(prop_media.urls):
|
||||
db_urls.append(
|
||||
models.ArchiveUrl(
|
||||
url=url,
|
||||
key=prop_media.get(
|
||||
"id", f"{k}{prop_media.key}_{i}.{j}"
|
||||
),
|
||||
)
|
||||
)
|
||||
return db_urls
|
||||
|
||||
@@ -1,21 +1,22 @@
|
||||
import datetime
|
||||
import json
|
||||
import traceback
|
||||
|
||||
import traceback, datetime
|
||||
from auto_archiver.core.orchestrator import ArchivingOrchestrator
|
||||
from celery.signals import task_failure
|
||||
from loguru import logger
|
||||
from sqlalchemy import exc
|
||||
from auto_archiver.core.orchestrator import ArchivingOrchestrator
|
||||
|
||||
from app.shared.db import models
|
||||
from app.shared import business_logic, constants, schemas
|
||||
from app.shared.db import models, worker_crud
|
||||
from app.shared.db.database import get_db
|
||||
from app.shared import business_logic, schemas
|
||||
from app.shared.task_messaging import get_celery, get_redis
|
||||
from app.shared.settings import get_settings
|
||||
from app.shared.log import log_error
|
||||
from app.shared.aa_utils import get_all_urls
|
||||
from app.shared.db import worker_crud
|
||||
from app.shared.settings import get_settings
|
||||
from app.shared.task_messaging import get_celery, get_redis
|
||||
from app.web.utils.misc import get_all_urls
|
||||
from app.worker.worker_log import setup_celery_logger
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
celery = get_celery("worker")
|
||||
@@ -24,26 +25,36 @@ Redis = get_redis()
|
||||
USER_GROUPS_FILENAME = settings.USER_GROUPS_FILENAME
|
||||
|
||||
setup_celery_logger(celery)
|
||||
AA_LOGGER_ID = None
|
||||
|
||||
# TODO: these are temporary PATCHES for new aa's functionality
|
||||
# logger.add("app/worker/worker_log.log", level="DEBUG")
|
||||
logger.remove = lambda x: print(f"logger.remove({x})")
|
||||
|
||||
# TODO: after release, as it requires updating past entries with sheet_id where tag is used, drop tags
|
||||
@celery.task(name="create_archive_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 1})
|
||||
# TODO: after release, as it requires updating past entries with sheet_id where tag
|
||||
# is used, drop tags
|
||||
@celery.task(
|
||||
name="create_archive_task",
|
||||
bind=True,
|
||||
autoretry_for=(Exception,),
|
||||
retry_backoff=True,
|
||||
retry_kwargs={"max_retries": 1},
|
||||
)
|
||||
def create_archive_task(self, archive_json: str):
|
||||
global AA_LOGGER_ID
|
||||
archive = schemas.ArchiveCreate.model_validate_json(archive_json)
|
||||
|
||||
# call auto-archiver
|
||||
args = get_orchestrator_args(archive.group_id, False, [archive.url])
|
||||
result = None
|
||||
try:
|
||||
orchestrator = ArchivingOrchestrator()
|
||||
orchestrator.logger_id = AA_LOGGER_ID # ensure single logger
|
||||
orchestrator.setup(args)
|
||||
result = next(orchestrator.feed())
|
||||
AA_LOGGER_ID = orchestrator.logger_id
|
||||
for orch_res in orchestrator.feed():
|
||||
result = orch_res
|
||||
except SystemExit as e:
|
||||
log_error(e, f"create_archive_task: SystemExit from AA")
|
||||
log_error(e, "create_archive_task: SystemExit from AA")
|
||||
except Exception as e:
|
||||
log_error(e, f"create_archive_task")
|
||||
log_error(e, "create_archive_task")
|
||||
raise e
|
||||
assert result, f"UNABLE TO archive: {archive.url}"
|
||||
|
||||
@@ -59,13 +70,20 @@ def create_archive_task(self, archive_json: str):
|
||||
|
||||
@celery.task(name="create_sheet_task", bind=True)
|
||||
def create_sheet_task(self, sheet_json: str):
|
||||
global AA_LOGGER_ID
|
||||
sheet = schemas.SubmitSheet.model_validate_json(sheet_json)
|
||||
queue_name = (create_sheet_task.request.delivery_info or {}).get('routing_key', 'unknown')
|
||||
queue_name = (create_sheet_task.request.delivery_info or {}).get(
|
||||
"routing_key", "unknown"
|
||||
)
|
||||
logger.info(f"[queue={queue_name}] SHEET START {sheet=}")
|
||||
|
||||
args = get_orchestrator_args(sheet.group_id, True, ["--gsheet_feeder.sheet_id", sheet.sheet_id])
|
||||
args = get_orchestrator_args(
|
||||
sheet.group_id, True, [constants.SHEET_ID, sheet.sheet_id]
|
||||
)
|
||||
orchestrator = ArchivingOrchestrator()
|
||||
orchestrator.logger_id = AA_LOGGER_ID # ensure single logger
|
||||
orchestrator.setup(args)
|
||||
AA_LOGGER_ID = orchestrator.logger_id
|
||||
|
||||
stats = {"archived": 0, "failed": 0, "errors": []}
|
||||
try:
|
||||
@@ -81,7 +99,7 @@ def create_sheet_task(self, sheet_json: str):
|
||||
result=json.loads(result.to_json()),
|
||||
sheet_id=sheet.sheet_id,
|
||||
urls=get_all_urls(result),
|
||||
store_until=get_store_until(sheet.group_id)
|
||||
store_until=get_store_until(sheet.group_id),
|
||||
)
|
||||
insert_result_into_db(archive)
|
||||
stats["archived"] += 1
|
||||
@@ -94,25 +112,38 @@ def create_sheet_task(self, sheet_json: str):
|
||||
stats["errors"].append(str(e))
|
||||
|
||||
except SystemExit as e:
|
||||
log_error(e, f"create_sheet_task: SystemExit from AA")
|
||||
log_error(e, "create_sheet_task: SystemExit from AA")
|
||||
|
||||
if stats["archived"] > 0:
|
||||
with get_db() as session:
|
||||
worker_crud.update_sheet_last_url_archived_at(session, sheet.sheet_id)
|
||||
worker_crud.update_sheet_last_url_archived_at(
|
||||
session, sheet.sheet_id
|
||||
)
|
||||
|
||||
logger.info(f"SHEET DONE {sheet=}")
|
||||
# TODO: is this used anywhere? maybe drop it
|
||||
return schemas.CelerySheetTask(success=True, sheet_id=sheet.sheet_id, time=datetime.datetime.now().isoformat(), stats=stats).model_dump()
|
||||
return schemas.CelerySheetTask(
|
||||
success=True,
|
||||
sheet_id=sheet.sheet_id,
|
||||
time=datetime.datetime.now().isoformat(),
|
||||
stats=stats,
|
||||
).model_dump()
|
||||
|
||||
|
||||
def get_orchestrator_args(group_id: str, orchestrator_for_sheet: bool, cli_args: list = []) -> list:
|
||||
def get_orchestrator_args(
|
||||
group_id: str, orchestrator_for_sheet: bool, cli_args: list = None
|
||||
) -> list:
|
||||
cli_args.append("--logging.enabled=false")
|
||||
|
||||
aa_configs = []
|
||||
with get_db() as session:
|
||||
group = worker_crud.get_group(session, group_id)
|
||||
if orchestrator_for_sheet:
|
||||
orchestrator_fn = group.orchestrator_sheet
|
||||
else:
|
||||
orchestrator_fn = worker_crud.get_group(session, group_id).orchestrator
|
||||
orchestrator_fn = worker_crud.get_group(
|
||||
session, group_id
|
||||
).orchestrator
|
||||
assert orchestrator_fn, f"no orchestrator found for {group_id}"
|
||||
aa_configs.extend(["--config", orchestrator_fn])
|
||||
aa_configs.extend(cli_args)
|
||||
@@ -122,7 +153,9 @@ def get_orchestrator_args(group_id: str, orchestrator_for_sheet: bool, cli_args:
|
||||
def insert_result_into_db(archive: schemas.ArchiveCreate) -> str:
|
||||
with get_db() as session:
|
||||
db_archive = worker_crud.store_archived_url(session, archive)
|
||||
logger.debug(f"[ARCHIVE STORED] {db_archive.author_id} {db_archive.url}")
|
||||
logger.debug(
|
||||
f"[ARCHIVE STORED] {db_archive.author_id} {db_archive.url}"
|
||||
)
|
||||
return db_archive.id
|
||||
|
||||
|
||||
@@ -131,13 +164,22 @@ def get_store_until(group_id: str) -> datetime.datetime:
|
||||
return business_logic.get_store_archive_until(session, group_id)
|
||||
|
||||
|
||||
def redis_publish_exception(exception, task_name, traceback: str = ""):
|
||||
def redis_publish_exception(exception, task_name, trace_back: str = ""):
|
||||
REDIS_EXCEPTIONS_CHANNEL = settings.REDIS_EXCEPTIONS_CHANNEL
|
||||
try:
|
||||
exception_data = {"task": task_name, "type": exception.__class__.__name__, "exception": exception, "traceback": traceback}
|
||||
Redis.publish(REDIS_EXCEPTIONS_CHANNEL, json.dumps(exception_data, default=str))
|
||||
exception_data = {
|
||||
"task": task_name,
|
||||
"type": exception.__class__.__name__,
|
||||
"exception": exception,
|
||||
"traceback": trace_back,
|
||||
}
|
||||
Redis.publish(
|
||||
REDIS_EXCEPTIONS_CHANNEL, json.dumps(exception_data, default=str)
|
||||
)
|
||||
except Exception as e:
|
||||
log_error(e, f"[CRITICAL] Could not publish to {REDIS_EXCEPTIONS_CHANNEL}")
|
||||
log_error(
|
||||
e, f"[CRITICAL] Could not publish to {REDIS_EXCEPTIONS_CHANNEL}"
|
||||
)
|
||||
|
||||
|
||||
@task_failure.connect(sender=create_sheet_task)
|
||||
@@ -145,6 +187,10 @@ def redis_publish_exception(exception, task_name, traceback: str = ""):
|
||||
def task_failure_notifier(sender, **kwargs):
|
||||
# automatically capture exceptions in the worker tasks
|
||||
logger.warning(f"⚠️ worker task failed: {sender.name}")
|
||||
traceback_msg = "\n".join(traceback.format_list(traceback.extract_tb(kwargs['traceback'])))
|
||||
log_error(kwargs['exception'], traceback_msg, f"task_failure: {sender.name}")
|
||||
redis_publish_exception(kwargs['exception'], sender.name, traceback_msg)
|
||||
traceback_msg = "\n".join(
|
||||
traceback.format_list(traceback.extract_tb(kwargs["traceback"]))
|
||||
)
|
||||
log_error(
|
||||
kwargs["exception"], traceback_msg, f"task_failure: {sender.name}"
|
||||
)
|
||||
redis_publish_exception(kwargs["exception"], sender.name, traceback_msg)
|
||||
|
||||
@@ -1,29 +1,38 @@
|
||||
from loguru import logger
|
||||
from celery import Celery
|
||||
import sys
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from app.shared.task_messaging import get_celery
|
||||
|
||||
|
||||
celery = get_celery("worker")
|
||||
|
||||
def setup_celery_logger(celery):
|
||||
# Remove Celery's default handlers to prevent duplicate logs
|
||||
celery_logger = celery.log.get_default_logger()
|
||||
for handler in celery_logger.handlers[:]:
|
||||
celery_logger.removeHandler(handler)
|
||||
|
||||
# Set up Loguru logging
|
||||
logger.add("logs/celery_logs.log", retention="30 days", level="DEBUG")
|
||||
logger.add("logs/celery_error_logs.log", retention="30 days", level="ERROR")
|
||||
def setup_celery_logger(c):
|
||||
# Remove Celery's default handlers to prevent duplicate logs
|
||||
celery_logger = c.log.get_default_logger()
|
||||
for handler in celery_logger.handlers[:]:
|
||||
celery_logger.removeHandler(handler)
|
||||
|
||||
# Redirect Celery logs to Loguru
|
||||
class InterceptHandler:
|
||||
def write(self, message):
|
||||
if message.strip():
|
||||
logger.info(message.strip())
|
||||
# Required to prevent issues with buffered output
|
||||
def flush(self): pass
|
||||
def isatty(self): return False
|
||||
# Set up Loguru logging
|
||||
logger.add("logs/celery_logs.log", retention="30 days", level="DEBUG")
|
||||
logger.add("logs/celery_error_logs.log", retention="30 days", level="ERROR")
|
||||
|
||||
sys.stdout = InterceptHandler()
|
||||
sys.stderr = InterceptHandler()
|
||||
# Redirect Celery logs to Loguru
|
||||
class InterceptHandler:
|
||||
@staticmethod
|
||||
def write(message):
|
||||
if message.strip():
|
||||
logger.info(message.strip())
|
||||
|
||||
# Required to prevent issues with buffered output
|
||||
@staticmethod
|
||||
def flush():
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def isatty():
|
||||
return False
|
||||
|
||||
sys.stdout = InterceptHandler()
|
||||
sys.stderr = InterceptHandler()
|
||||
|
||||
@@ -12,7 +12,7 @@ services:
|
||||
- ALLOWED_ORIGINS=["http://localhost:8000","http://localhost:8004","http://localhost:8081","chrome-extension://ojcimmjndnlmmlgnjaeojoebaceokpdp"]
|
||||
- USER_GROUPS_FILENAME=/aa-api/app/user-groups.dev.yaml
|
||||
- DATABASE_PATH=sqlite:////aa-api/database/auto-archiver.db
|
||||
|
||||
|
||||
|
||||
worker:
|
||||
# command: watchmedo auto-restart --patterns="*.py" --recursive --ignore-directories -- celery -- --app=app.worker.main.celery worker -Q high_priority,low_priority --concurrency=${CONCURRENCY} --max-tasks-per-child=100
|
||||
|
||||
@@ -5,9 +5,9 @@ volumes:
|
||||
name: "auto-archiver-api"
|
||||
services:
|
||||
web:
|
||||
build:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: web.Dockerfile
|
||||
dockerfile: docker/web/Dockerfile
|
||||
restart: always
|
||||
env_file: .env.prod
|
||||
environment:
|
||||
@@ -29,9 +29,9 @@ services:
|
||||
retries: 3
|
||||
|
||||
worker:
|
||||
build:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: worker.Dockerfile
|
||||
dockerfile: docker/worker/Dockerfile
|
||||
restart: always
|
||||
env_file: .env.prod
|
||||
command: celery --app=app.worker.main.celery worker -Q high_priority,low_priority --concurrency=${CONCURRENCY} --max-tasks-per-child=100 -O fair
|
||||
@@ -68,4 +68,4 @@ services:
|
||||
test: ["CMD", "redis-cli", "-a", "${REDIS_PASSWORD}", "ping"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
retries: 3
|
||||
|
||||
@@ -10,13 +10,12 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN pip install --no-cache-dir poetry
|
||||
COPY pyproject.toml poetry.lock README.md .
|
||||
COPY ../../pyproject.toml ../../poetry.lock ../../README.md ./
|
||||
RUN poetry install --with web --no-interaction --no-ansi --no-cache
|
||||
|
||||
# Copy the application code and configurations
|
||||
COPY alembic.ini ./
|
||||
COPY ./app/ ./app/
|
||||
COPY user-groups.* ./app/
|
||||
COPY ../../app ./app/
|
||||
COPY ../../user-groups.* ./app/
|
||||
|
||||
# Run the FastAPI app with Uvicorn
|
||||
ENTRYPOINT ["poetry", "run"]
|
||||
@@ -20,14 +20,13 @@ RUN apt update -y && \
|
||||
python3 -m venv ./poetry-venv && \
|
||||
./poetry-venv/bin/python -m pip install --upgrade pip && \
|
||||
./poetry-venv/bin/python -m pip install "poetry>=2.0.0,<3.0.0"
|
||||
COPY pyproject.toml poetry.lock ./
|
||||
COPY ../../pyproject.toml ../../poetry.lock ./
|
||||
RUN ./poetry-venv/bin/poetry install --without dev --no-root --no-cache
|
||||
|
||||
# install dependencies
|
||||
|
||||
# copy source code and .env files over
|
||||
COPY alembic.ini ./
|
||||
COPY ./app/ ./app/
|
||||
COPY user-groups.* ./app/
|
||||
COPY ../../app ./app/
|
||||
COPY ../../user-groups.* ./app/
|
||||
|
||||
ENTRYPOINT ["./poetry-venv/bin/poetry", "run"]
|
||||
ENTRYPOINT ["./poetry-venv/bin/poetry", "run"]
|
||||
1669
poetry.lock
generated
1669
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -22,7 +22,6 @@ requires-python = ">=3.10,<3.13"
|
||||
|
||||
dependencies = [
|
||||
"auto-archiver (>=0.13.1)",
|
||||
"oscrypto @ git+https://github.com/wbond/oscrypto.git@d5f3437ed24257895ae1edd9e503cfb352e635a8",
|
||||
"celery (>=5.0)",
|
||||
"redis (==3.5.3)",
|
||||
"loguru (>=0.7.3,<0.8.0)",
|
||||
@@ -31,6 +30,16 @@ dependencies = [
|
||||
"requests (>=2.25.1)",
|
||||
"pyopenssl (>=23.3.0)",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
pythonpath = "."
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = ["app/migrations/*"]
|
||||
|
||||
[tool.ruff.lint.flake8-bugbear]
|
||||
extend-immutable-calls = ["fastapi.Depends", "fastapi.Query"]
|
||||
|
||||
[tool.poetry.group.worker.dependencies]
|
||||
watchdog = ">=6.0.0,<7.0.0"
|
||||
setuptools = "^75.8.0"
|
||||
@@ -53,4 +62,4 @@ pytest = ">=8.3.4,<9.0.0"
|
||||
httpx = ">=0.28.1,<0.29.0"
|
||||
coverage = ">=7.6.11,<8.0.0"
|
||||
pytest-asyncio = ">=0.25.3,<0.26.0"
|
||||
|
||||
pre-commit = "^4.1.0"
|
||||
|
||||
@@ -59,4 +59,3 @@ groups:
|
||||
permissions:
|
||||
read: ["default"]
|
||||
read_public: true
|
||||
|
||||
Reference in New Issue
Block a user