mirror of
https://github.com/bellingcat/auto-archiver-api.git
synced 2026-06-12 21:48:35 +03:00
major refactor of structure for worker V web: docker/app/secrets/envs/...
This commit is contained in:
18
app/example.user-groups.yaml
Normal file
18
app/example.user-groups.yaml
Normal file
@@ -0,0 +1,18 @@
|
||||
# email-level group access
|
||||
users:
|
||||
email1@example.com:
|
||||
- group1
|
||||
- group2
|
||||
email2@example.com:
|
||||
- group2
|
||||
email3@example-no-group.com:
|
||||
|
||||
# domain-level group access (taken from the emails)
|
||||
domains:
|
||||
example.com:
|
||||
- group3
|
||||
|
||||
orchestrators:
|
||||
group1: secrets/orchestration-group1.yaml
|
||||
group2: secrets/orchestration-group2.yaml
|
||||
default: secrets/orchestration-default.yaml
|
||||
0
app/logs/.gitkeep
Normal file
0
app/logs/.gitkeep
Normal file
79
app/migrations/env.py
Normal file
79
app/migrations/env.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from logging.config import fileConfig
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
|
||||
from alembic import context
|
||||
|
||||
from app.shared.settings import get_settings
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
config.set_main_option('sqlalchemy.url', get_settings().DATABASE_PATH)
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name, disable_existing_loggers=False) # disable_existing_loggers prevents loguru disabling
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
# from myapp import mymodel
|
||||
# target_metadata = mymodel.Base.metadata
|
||||
target_metadata = None
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode.
|
||||
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection, target_metadata=target_metadata
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
24
app/migrations/script.py.mako
Normal file
24
app/migrations/script.py.mako
Normal file
@@ -0,0 +1,24 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = ${repr(up_revision)}
|
||||
down_revision = ${repr(down_revision)}
|
||||
branch_labels = ${repr(branch_labels)}
|
||||
depends_on = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
0
app/migrations/versions/.gitkeep
Normal file
0
app/migrations/versions/.gitkeep
Normal file
@@ -0,0 +1,34 @@
|
||||
"""create archives.store_until column
|
||||
|
||||
Revision ID: 02b2f6d17ed0
|
||||
Revises: 1636724ec4b1
|
||||
Create Date: 2025-02-08 15:22:20.392522
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '02b2f6d17ed0'
|
||||
down_revision = '1636724ec4b1'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
STORE_UNTIL_COL = "store_until"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('archives')]
|
||||
|
||||
if STORE_UNTIL_COL not in columns:
|
||||
op.add_column('archives', sa.Column(STORE_UNTIL_COL, sa.DateTime(), nullable=True, default=None))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('archives')]
|
||||
if STORE_UNTIL_COL in columns:
|
||||
op.drop_column('archives', STORE_UNTIL_COL)
|
||||
@@ -0,0 +1,32 @@
|
||||
"""rename sheets last_archived col
|
||||
|
||||
Revision ID: 1636724ec4b1
|
||||
Revises: a23aaf3ae930
|
||||
Create Date: 2025-02-05 19:19:01.984396
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '1636724ec4b1'
|
||||
down_revision = 'a23aaf3ae930'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('sheets')]
|
||||
if 'last_archived_at' in columns:
|
||||
op.alter_column('sheets', 'last_archived_at', new_column_name='last_url_archived_at')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('sheets')]
|
||||
if 'last_url_archived_at' in columns:
|
||||
op.alter_column('sheets', 'last_url_archived_at', new_column_name='last_archived_at')
|
||||
@@ -0,0 +1,42 @@
|
||||
"""add sheet_id to archive table
|
||||
|
||||
Revision ID: 89121d2c96d8
|
||||
Revises: fa012ec405b8
|
||||
Create Date: 2024-11-04 11:12:30.237299
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '89121d2c96d8'
|
||||
down_revision = 'fa012ec405b8'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('archives')]
|
||||
|
||||
if 'sheet_id' not in columns:
|
||||
with op.batch_alter_table('archives') as batch_op:
|
||||
batch_op.add_column(sa.Column('sheet_id', sa.String(), nullable=True, default=None))
|
||||
batch_op.create_foreign_key('fk_sheet_id', 'sheets', ['sheet_id'], ['id'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
foreign_keys = [fk['name'] for fk in inspector.get_foreign_keys('archives')]
|
||||
columns = [col['name'] for col in inspector.get_columns('archives')]
|
||||
|
||||
with op.batch_alter_table('archives') as batch_op:
|
||||
if 'fk_sheet_id' in foreign_keys:
|
||||
batch_op.drop_constraint('fk_sheet_id', type_='foreignkey')
|
||||
|
||||
if 'sheet_id' in columns:
|
||||
batch_op.drop_column('sheet_id')
|
||||
@@ -0,0 +1,27 @@
|
||||
"""modify archive url to have uuid id instead of url unique constraint
|
||||
|
||||
Revision ID: 9369a264945b
|
||||
Revises:
|
||||
Create Date: 2023-12-20 17:24:59.320691
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '9369a264945b'
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# since the primary key constraint is not named, we have to recreate it first
|
||||
with op.batch_alter_table("archive_urls") as batch_op:
|
||||
batch_op.create_primary_key("pk_url", ["url"])
|
||||
batch_op.drop_constraint("pk_url", type_='primary')
|
||||
batch_op.create_primary_key("pk_url_archive_id", ["url", "archive_id"])
|
||||
|
||||
def downgrade() -> None:
|
||||
with op.batch_alter_table("archive_urls") as batch_op:
|
||||
batch_op.drop_constraint("pk_url_archive_id", type_='primary')
|
||||
batch_op.create_primary_key("url", ["url"])
|
||||
@@ -0,0 +1,28 @@
|
||||
"""vacuum database (if there's enough space)
|
||||
|
||||
Revision ID: 93a611e4c066
|
||||
Revises: 9369a264945b
|
||||
Create Date: 2023-12-20 18:33:27.132566
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '93a611e4c066'
|
||||
down_revision = '9369a264945b'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
try:
|
||||
with op.get_context().autocommit_block():
|
||||
op.execute("VACUUM")
|
||||
except Exception as e:
|
||||
print("Unable to run vacuum, maybe there's not enough disk space. it should be 2x the size of the database")
|
||||
print(e)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
34
app/migrations/versions/a23aaf3ae930_drop_active_column.py
Normal file
34
app/migrations/versions/a23aaf3ae930_drop_active_column.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""drop active column
|
||||
|
||||
Revision ID: a23aaf3ae930
|
||||
Revises: 89121d2c96d8
|
||||
Create Date: 2025-02-04 12:19:20.753570
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'a23aaf3ae930'
|
||||
down_revision = '89121d2c96d8'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('users')]
|
||||
|
||||
if 'is_active' in columns:
|
||||
op.drop_column('users', 'is_active')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('users')]
|
||||
|
||||
if 'is_active' not in columns:
|
||||
op.add_column('users', sa.Column('is_active', sa.Boolean(), nullable=False, server_default=sa.false()))
|
||||
@@ -0,0 +1,45 @@
|
||||
"""add columns to groups table
|
||||
|
||||
Revision ID: fa012ec405b8
|
||||
Revises: 93a611e4c066
|
||||
Create Date: 2024-10-31 09:36:50.360710
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'fa012ec405b8'
|
||||
down_revision = '93a611e4c066'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('groups')]
|
||||
|
||||
if 'description' not in columns:
|
||||
op.add_column('groups', sa.Column('description', sa.String(), nullable=True))
|
||||
if 'orchestrator' not in columns:
|
||||
op.add_column('groups', sa.Column('orchestrator', sa.String(), nullable=True))
|
||||
if 'orchestrator_sheet' not in columns:
|
||||
op.add_column('groups', sa.Column('orchestrator_sheet', sa.String(), nullable=True))
|
||||
if 'permissions' not in columns:
|
||||
op.add_column('groups', sa.Column('permissions', sa.JSON(), nullable=True))
|
||||
if 'domains' not in columns:
|
||||
op.add_column('groups', sa.Column('domains', sa.JSON(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('groups')]
|
||||
|
||||
column_names = ['description', 'orchestrator', 'orchestrator_sheet', 'permissions', 'domains']
|
||||
for column_name in column_names:
|
||||
if column_name in columns:
|
||||
op.drop_column('groups', column_name)
|
||||
0
app/shared/__init__.py
Normal file
0
app/shared/__init__.py
Normal file
33
app/shared/aa_utils.py
Normal file
33
app/shared/aa_utils.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# TODO: code in this file should eventually be moved to the auto-archiver code base
|
||||
|
||||
from typing import List
|
||||
from loguru import logger
|
||||
from auto_archiver import Metadata
|
||||
from auto_archiver.core import Media
|
||||
|
||||
from app.shared.db import models
|
||||
|
||||
def get_all_urls(result: Metadata) -> List[models.ArchiveUrl]:
|
||||
db_urls = []
|
||||
for m in result.media:
|
||||
for i, url in enumerate(m.urls): db_urls.append(models.ArchiveUrl(url=url, key=m.get("id", f"media_{i}")))
|
||||
for k, prop in m.properties.items():
|
||||
if prop_converted := convert_if_media(prop):
|
||||
for i, url in enumerate(prop_converted.urls): db_urls.append(models.ArchiveUrl(url=url, key=prop_converted.get("id", f"{k}_{i}")))
|
||||
if isinstance(prop, list):
|
||||
for i, prop_media in enumerate(prop):
|
||||
if prop_media := convert_if_media(prop_media):
|
||||
for j, url in enumerate(prop_media.urls):
|
||||
db_urls.append(models.ArchiveUrl(url=url, key=prop_media.get("id", f"{k}{prop_media.key}_{i}.{j}")))
|
||||
return db_urls
|
||||
|
||||
|
||||
|
||||
def convert_if_media(media):
|
||||
if isinstance(media, Media): return media
|
||||
elif isinstance(media, dict):
|
||||
try: return Media.from_dict(media)
|
||||
except Exception as e:
|
||||
logger.debug(f"error parsing {media} : {e}")
|
||||
return False
|
||||
|
||||
15
app/shared/business_logic.py
Normal file
15
app/shared/business_logic.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# TODO: temporary file for this code, maybe other code belongs here, maybe not. do decide
|
||||
|
||||
|
||||
import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.shared.db import crud
|
||||
|
||||
|
||||
def get_store_archive_until(db: Session, group_id: str) -> datetime.datetime:
|
||||
group = crud.get_group(db, group_id)
|
||||
max_lifespan = group.permissions.get("max_archive_lifespan_months", -1)
|
||||
if max_lifespan == -1: return None
|
||||
|
||||
return datetime.datetime.now() + datetime.timedelta(days=30 * max_lifespan)
|
||||
13
app/shared/config.py
Normal file
13
app/shared/config.py
Normal file
@@ -0,0 +1,13 @@
|
||||
VERSION = "0.8.0"
|
||||
API_DESCRIPTION = """
|
||||
#### API for the Auto-Archiver project, a tool to archive web pages and Google Sheets.
|
||||
|
||||
**Usage notes:**
|
||||
- The API requires a Bearer token for most operations, which you can obtain by logging in with your Google account.
|
||||
- You can use this API to archive single URLs or entire Google Sheets.
|
||||
- Once you submit a URL or Sheet for archiving, the API will return a task_id that you can use to check the status of the archiving process. It works asynchronously.
|
||||
"""
|
||||
BREAKING_CHANGES = {"minVersion": "0.3.1", "message": "The latest update has breaking changes, please update the extension to the most recent version."}
|
||||
|
||||
# changing this will corrupt the database logic
|
||||
ALLOW_ANY_EMAIL = "*"
|
||||
0
app/shared/db/__init__.py
Normal file
0
app/shared/db/__init__.py
Normal file
314
app/shared/db/crud.py
Normal file
314
app/shared/db/crud.py
Normal file
@@ -0,0 +1,314 @@
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache
|
||||
from sqlalchemy.orm import Session, load_only
|
||||
from sqlalchemy import Column, or_, func, select
|
||||
from loguru import logger
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.shared.config import ALLOW_ANY_EMAIL
|
||||
from app.shared.db.database import get_db
|
||||
from app.shared.db import models
|
||||
from app.shared import schemas
|
||||
from app.shared.settings import get_settings
|
||||
from app.shared.user_groups import UserGroups
|
||||
from app.shared.utils.misc import fnv1a_hash_mod
|
||||
|
||||
DATABASE_QUERY_LIMIT = get_settings().DATABASE_QUERY_LIMIT
|
||||
|
||||
|
||||
def get_limit(user_limit: int):
|
||||
return max(1, min(user_limit, DATABASE_QUERY_LIMIT))
|
||||
|
||||
# --------------- TASK = Archive
|
||||
|
||||
def base_query(db: Session):
|
||||
# NOTE: load_only is for optimization and not obfuscation, use .with_entities() if needed
|
||||
return db.query(models.Archive)\
|
||||
.filter(models.Archive.deleted == False)\
|
||||
.options(load_only(models.Archive.id, models.Archive.created_at, models.Archive.url, models.Archive.result, models.Archive.store_until))
|
||||
|
||||
def get_archive(db: Session, id: str, email: str):
|
||||
query = base_query(db).filter(models.Archive.id == id)
|
||||
if email != ALLOW_ANY_EMAIL:
|
||||
groups = get_user_groups(email)
|
||||
query = query.filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups)))
|
||||
return query.first()
|
||||
|
||||
def search_archives_by_url(db: Session, url: str, email: str, skip: int = 0, limit: int = 100, archived_after: datetime = None, archived_before: datetime = None, absolute_search: bool = False)-> list[models.Archive]:
|
||||
# searches for partial URLs, if email is * no ownership filtering happens
|
||||
query = base_query(db)
|
||||
if email != ALLOW_ANY_EMAIL:
|
||||
groups = get_user_groups(email)
|
||||
query = query.filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups)))
|
||||
if absolute_search:
|
||||
query = query.filter(models.Archive.url == url)
|
||||
else:
|
||||
query = query.filter(models.Archive.url.like(f'%{url}%'))
|
||||
if archived_after:
|
||||
query = query.filter(models.Archive.created_at > archived_after)
|
||||
if archived_before:
|
||||
query = query.filter(models.Archive.created_at < archived_before)
|
||||
return query.order_by(models.Archive.created_at.desc()).offset(skip).limit(get_limit(limit)).all()
|
||||
|
||||
|
||||
def search_archives_by_email(db: Session, email: str, skip: int = 0, limit: int = 100):
|
||||
return base_query(db).filter(models.Archive.author_id == email).order_by(models.Archive.created_at.desc()).offset(skip).limit(get_limit(limit)).all()
|
||||
|
||||
#TODO: rename task to archive
|
||||
def create_task(db: Session, task: schemas.ArchiveCreate, tags: list[models.Tag], urls: list[models.ArchiveUrl]) -> models.Archive:
|
||||
db_task = models.Archive(id=task.id, url=task.url, result=task.result, public=task.public, author_id=task.author_id, group_id=task.group_id, sheet_id=task.sheet_id, store_until=task.store_until)
|
||||
db_task.tags = tags
|
||||
db_task.urls = urls
|
||||
db.add(db_task)
|
||||
db.commit()
|
||||
db.refresh(db_task)
|
||||
return db_task
|
||||
|
||||
|
||||
def soft_delete_task(db: Session, task_id: str, email: str) -> bool:
|
||||
# TODO: implement hard-delete with cronjob that deletes from S3
|
||||
db_task = db.query(models.Archive).filter(models.Archive.id == task_id, models.Archive.author_id == email, models.Archive.deleted == False).first()
|
||||
if db_task:
|
||||
db_task.deleted = True
|
||||
db.commit()
|
||||
return db_task is not None
|
||||
|
||||
|
||||
def count_archives(db: Session):
|
||||
return db.query(func.count(models.Archive.id)).scalar()
|
||||
|
||||
|
||||
def count_archive_urls(db: Session):
|
||||
return db.query(func.count(models.ArchiveUrl.url)).scalar()
|
||||
|
||||
|
||||
def count_users(db: Session):
|
||||
return db.query(func.count(models.User.email)).scalar()
|
||||
|
||||
|
||||
def count_by_user_since(db: Session, seconds_delta: int = 15):
|
||||
time_threshold = datetime.now() - timedelta(seconds=seconds_delta)
|
||||
return db.query(models.Archive.author_id, func.count().label('total'))\
|
||||
.filter(models.Archive.created_at >= time_threshold)\
|
||||
.group_by(models.Archive.author_id)\
|
||||
.order_by(func.count().desc())\
|
||||
.limit(500).all()
|
||||
|
||||
async def find_by_store_until(db: AsyncSession, store_until_is_before:datetime) -> dict:
|
||||
res = await db.execute(
|
||||
select(models.Archive)
|
||||
.filter(models.Archive.deleted ==False, models.Archive.store_until < store_until_is_before)
|
||||
)
|
||||
return res.scalars()
|
||||
|
||||
async def soft_delete_expired_archives(db: AsyncSession) -> dict:
|
||||
to_delete = await find_by_store_until(db, datetime.now())
|
||||
counter = 0
|
||||
for archive in to_delete:
|
||||
archive.deleted = True
|
||||
counter += 1
|
||||
await db.commit()
|
||||
return counter
|
||||
# --------------- TAG
|
||||
|
||||
|
||||
def create_tag(db: Session, tag: str) -> models.Tag:
|
||||
db_tag = db.query(models.Tag).filter(models.Tag.id == tag).first()
|
||||
if not db_tag:
|
||||
db_tag = models.Tag(id=tag)
|
||||
db.add(db_tag)
|
||||
db.commit()
|
||||
db.refresh(db_tag)
|
||||
return db_tag
|
||||
|
||||
|
||||
def is_user_in_group(db: Session, email: str, group_name: str) -> models.Group:
|
||||
if email == ALLOW_ANY_EMAIL: return True
|
||||
return len(group_name) and len(email) and group_name in get_user_groups(email)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_user_groups(email: str) -> list[str]:
|
||||
"""
|
||||
given an email retrieves the user groups from the DB and then the email-domain groups from a global variable, the email does not need to belong to an existing user.
|
||||
"""
|
||||
if not email or not len(email) or "@" not in email: return []
|
||||
|
||||
with get_db() as db:
|
||||
# get user groups
|
||||
user_groups = db.query(models.association_table_user_groups).filter_by(user_id=email).with_entities(Column("group_id")).all()
|
||||
user_level_groups_names = [g[0] for g in user_groups]
|
||||
|
||||
# get domain groups
|
||||
domain = email.split('@')[1]
|
||||
domain_level_groups = db.query(models.Group.id).filter(models.Group.domains.contains(domain)).with_entities(Column("id")).all()
|
||||
domain_level_groups_names = [g[0] for g in domain_level_groups]
|
||||
|
||||
return list(set(user_level_groups_names + domain_level_groups_names))
|
||||
|
||||
|
||||
# --------------- INIT User-Groups
|
||||
|
||||
def get_group(db: Session, group_name: str) -> models.Group:
|
||||
return db.query(models.Group).filter(models.Group.id == group_name).first()
|
||||
|
||||
|
||||
def create_or_get_user(db: Session, author_id: str) -> models.User:
|
||||
if type(author_id) == str: author_id = author_id.lower()
|
||||
db_user = db.query(models.User).filter(models.User.email == author_id).first()
|
||||
if not db_user:
|
||||
db_user = models.User(email=author_id)
|
||||
db.add(db_user)
|
||||
db.commit()
|
||||
db.refresh(db_user)
|
||||
return db_user
|
||||
|
||||
|
||||
def upsert_group(db: Session, group_name: str, description: str, orchestrator: str, orchestrator_sheet: str, permissions: dict, domains: list) -> models.Group:
|
||||
db_group = db.query(models.Group).filter(models.Group.id == group_name).first()
|
||||
if db_group is None:
|
||||
db_group = models.Group(id=group_name, description=description, orchestrator=orchestrator, orchestrator_sheet=orchestrator_sheet, permissions=permissions, domains=domains)
|
||||
db.add(db_group)
|
||||
else:
|
||||
db_group.description = description
|
||||
db_group.orchestrator = orchestrator
|
||||
db_group.orchestrator_sheet = orchestrator_sheet
|
||||
db_group.permissions = permissions
|
||||
db_group.domains = domains
|
||||
db.commit()
|
||||
db.refresh(db_group)
|
||||
return db_group
|
||||
|
||||
|
||||
def upsert_user(db: Session, email: str):
|
||||
email = email.lower()
|
||||
db_user = db.query(models.User).filter(models.User.email == email).first()
|
||||
if db_user is None:
|
||||
db_user = models.User(email=email)
|
||||
db.add(db_user)
|
||||
db.commit()
|
||||
return db_user
|
||||
|
||||
|
||||
def upsert_user_groups(db: Session):
|
||||
def display_email_pii(email: str):
|
||||
return f"'{email[0:3]}...@{email.split('@')[1]}'"
|
||||
"""
|
||||
reads the user_groups yaml file and inserts any new users, groups,
|
||||
along with new participation of users in groups
|
||||
"""
|
||||
logger.debug("Updating user-groups configuration.")
|
||||
filename = get_settings().USER_GROUPS_FILENAME
|
||||
|
||||
ug = UserGroups(filename)
|
||||
|
||||
# delete all user-groups relationships
|
||||
db.query(models.association_table_user_groups).delete()
|
||||
|
||||
# create a map of group_id -> domains and another of domain -> groups
|
||||
group_domains = defaultdict(set)
|
||||
domain_groups = defaultdict(list)
|
||||
for domain, explicit_groups in ug.domains.items():
|
||||
domain_groups[domain] = list(set(explicit_groups))
|
||||
for group in explicit_groups:
|
||||
group_domains[group].add(domain)
|
||||
import json
|
||||
# upsert groups and save a map of groupid -> dbobject
|
||||
for group_id, g in ug.groups.items():
|
||||
upsert_group(db, group_id, g.description, g.orchestrator, g.orchestrator_sheet, json.loads(g.permissions.model_dump_json()), list(group_domains.get(group_id, [])))
|
||||
db_groups: dict[str, models.Group] = {g.id: g for g in db.query(models.Group).all()}
|
||||
|
||||
# integrity checks
|
||||
for group_in_domains in group_domains:
|
||||
if group_in_domains not in db_groups:
|
||||
logger.warning(f"[CONFIG] Group '{group_in_domains}' does not exist in the database: domains setting will not work.")
|
||||
|
||||
# reinsert users in their EXPLICITLY DEFINED groups
|
||||
# domain groups are check live, as there may be new users that are not explicitly registered but belong to a domain
|
||||
for email, explicit_groups in ug.users.items():
|
||||
explicit_groups = explicit_groups or []
|
||||
logger.info(f"EXPLICIT {display_email_pii(email)} => {explicit_groups}")
|
||||
|
||||
db_user = upsert_user(db, email)
|
||||
|
||||
# connect users to groups
|
||||
for group_id in explicit_groups:
|
||||
if group_id not in db_groups:
|
||||
logger.warning(f"[CONFIG] Group {group_id} does not exist in config file, skipping for email={display_email_pii(email)}.")
|
||||
continue
|
||||
db_groups[group_id].users.append(db_user)
|
||||
|
||||
db.commit()
|
||||
count_user_groups = db.query(models.association_table_user_groups).count()
|
||||
count_groups = db.query(func.count(models.Group.id)).scalar()
|
||||
|
||||
logger.success(f"[CONFIG] DONE: [users={count_users(db)}, groups={count_groups}, explicit user groups={count_user_groups}].")
|
||||
|
||||
|
||||
# --------------- SHEET
|
||||
def create_sheet(db: Session, sheet_id: str, name: str, email: str, group_id: str, frequency: str):
|
||||
db_sheet = models.Sheet(id=sheet_id, name=name, author_id=email, group_id=group_id, frequency=frequency)
|
||||
db.add(db_sheet)
|
||||
db.commit()
|
||||
db.refresh(db_sheet)
|
||||
return db_sheet
|
||||
|
||||
|
||||
def get_user_sheet(db: Session, email: str, sheet_id: str) -> models.Sheet:
|
||||
return db.query(models.Sheet).filter(models.Sheet.author_id == email, models.Sheet.id == sheet_id).first()
|
||||
|
||||
|
||||
def get_user_sheets(db: Session, email: str) -> list[models.Sheet]:
|
||||
return db.query(models.Sheet).filter(models.Sheet.author_id == email).order_by(models.Sheet.last_url_archived_at.desc()).all()
|
||||
|
||||
|
||||
async def get_sheets_by_id_hash(db: AsyncSession, frequency: str, modulo: str, id_hash: str) -> list[models.Sheet]:
|
||||
result = await db.execute(
|
||||
select(models.Sheet).filter(models.Sheet.frequency == frequency)
|
||||
)
|
||||
filtered = []
|
||||
for sheet in result.scalars():
|
||||
if fnv1a_hash_mod(sheet.id, modulo) == id_hash:
|
||||
filtered.append(sheet)
|
||||
return filtered
|
||||
|
||||
async def delete_stale_sheets(db: AsyncSession, inactivity_days: int) -> dict:
|
||||
time_threshold = datetime.now() - timedelta(days=inactivity_days)
|
||||
result = await db.execute(
|
||||
select(models.Sheet).filter(models.Sheet.last_url_archived_at < time_threshold)
|
||||
)
|
||||
deleted = defaultdict(list)
|
||||
for sheet in result.scalars():
|
||||
await db.delete(sheet)
|
||||
deleted[sheet.author_id].append(sheet)
|
||||
await db.commit()
|
||||
return dict(deleted)
|
||||
|
||||
def update_sheet_last_url_archived_at(db: Session, sheet_id: str):
|
||||
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id).first()
|
||||
if db_sheet:
|
||||
db_sheet.last_url_archived_at = datetime.now()
|
||||
db.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def delete_sheet(db: Session, sheet_id: str, email: str) -> bool:
|
||||
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id, models.Sheet.author_id == email).first()
|
||||
if db_sheet:
|
||||
db.delete(db_sheet)
|
||||
db.commit()
|
||||
return db_sheet is not None
|
||||
|
||||
|
||||
#--- Celery worker tasks
|
||||
|
||||
|
||||
def store_archived_url(db: Session, archive: schemas.ArchiveCreate) -> models.Archive:
|
||||
# create and load user, tags, if needed
|
||||
create_or_get_user(db, archive.author_id)
|
||||
db_tags = [create_tag(db, tag) for tag in archive.tags]
|
||||
# insert everything
|
||||
db_task = create_task(db, task=archive, tags=db_tags, urls=archive.urls)
|
||||
return db_task
|
||||
67
app/shared/db/database.py
Normal file
67
app/shared/db/database.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from functools import lru_cache
|
||||
from sqlalchemy import Engine, create_engine, event, text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, AsyncEngine, async_sessionmaker
|
||||
|
||||
from app.shared.settings import get_settings
|
||||
|
||||
|
||||
@lru_cache
|
||||
def make_engine(database_url: str):
|
||||
engine = create_engine(database_url, connect_args={"check_same_thread": False})
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def set_sqlite_pragma(conn, _) -> None:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.close()
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
def make_session_local(engine: Engine):
|
||||
session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
return session_local
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db():
|
||||
session = make_session_local(make_engine(get_settings().DATABASE_PATH))()
|
||||
try: yield session
|
||||
finally: session.close()
|
||||
|
||||
|
||||
def get_db_dependency():
|
||||
# to use with Depends and ensure proper session closing
|
||||
with get_db() as db:
|
||||
yield db
|
||||
|
||||
def wal_checkpoint():
|
||||
# WAL checkpointing, make sure the .sqlite file receives the latest changes
|
||||
# to be called at startup as it halts writes
|
||||
with get_db() as db:
|
||||
db.execute(text("PRAGMA wal_checkpoint(TRUNCATE)"))
|
||||
|
||||
|
||||
# ASYNC connections
|
||||
async def make_async_engine(database_url: str) -> AsyncEngine:
|
||||
engine = create_async_engine(database_url, connect_args={"check_same_thread": False})
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(lambda sync_conn: sync_conn.execute(text("PRAGMA journal_mode=WAL;")))
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
async def make_async_session_local(engine: AsyncEngine) -> AsyncSession:
|
||||
return async_sessionmaker(engine, expire_on_commit=False, autoflush=False, autocommit=False)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db_async():
|
||||
engine = await make_async_engine(get_settings().ASYNC_DATABASE_PATH)
|
||||
async_session = await make_async_session_local(engine)
|
||||
async with async_session() as session:
|
||||
try: yield session
|
||||
finally: await engine.dispose()
|
||||
113
app/shared/db/models.py
Normal file
113
app/shared/db/models.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from sqlalchemy import Column, String, JSON, DateTime, Boolean, Table, ForeignKey
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship, declarative_base
|
||||
import uuid
|
||||
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def generate_uuid():
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
# many to many association tables
|
||||
association_table_archive_tags = Table(
|
||||
"mtm_archives_tags",
|
||||
Base.metadata,
|
||||
Column("archive_id", ForeignKey("archives.id")),
|
||||
Column("tag_id", ForeignKey("tags.id")),
|
||||
)
|
||||
association_table_user_groups = Table(
|
||||
"mtm_users_groups",
|
||||
Base.metadata,
|
||||
Column("user_id", ForeignKey("users.email")),
|
||||
Column("group_id", ForeignKey("groups.id")),
|
||||
)
|
||||
|
||||
|
||||
# data model tables
|
||||
class Archive(Base):
|
||||
__tablename__ = "archives"
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
url = Column(String, index=True)
|
||||
result = Column(JSON, default=None)
|
||||
public = Column(Boolean, default=True) # if public=false, access by group and author
|
||||
deleted = Column(Boolean, default=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
store_until = Column(DateTime(timezone=True), default=None)
|
||||
|
||||
group_id = Column(String, ForeignKey("groups.id"), default=None)
|
||||
author_id = Column(String, ForeignKey("users.email"))
|
||||
sheet_id = Column(String, ForeignKey("sheets.id"), default=None)
|
||||
|
||||
tags = relationship("Tag", back_populates="archives", secondary=association_table_archive_tags)
|
||||
group = relationship("Group", back_populates="archives")
|
||||
author = relationship("User", back_populates="archives")
|
||||
urls = relationship("ArchiveUrl", back_populates="archive")
|
||||
sheet = relationship("Sheet", back_populates="archives")
|
||||
|
||||
|
||||
class ArchiveUrl(Base):
|
||||
__tablename__ = "archive_urls"
|
||||
|
||||
url = Column(String, primary_key=True, index=True)
|
||||
archive_id = Column(String, ForeignKey("archives.id"), primary_key=True)
|
||||
key = Column(String, default=None)
|
||||
|
||||
archive = relationship("Archive", back_populates="urls")
|
||||
|
||||
|
||||
class Tag(Base):
|
||||
__tablename__ = "tags"
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
archives = relationship("Archive", back_populates="tags", secondary=association_table_archive_tags)
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
email = Column(String, primary_key=True, index=True)
|
||||
|
||||
archives = relationship("Archive", back_populates="author")
|
||||
sheets = relationship("Sheet", back_populates="author")
|
||||
groups = relationship("Group", back_populates="users", secondary=association_table_user_groups)
|
||||
|
||||
|
||||
class Group(Base):
|
||||
__tablename__ = "groups"
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
description = Column(String, default=None)
|
||||
orchestrator = Column(String, default=None)
|
||||
orchestrator_sheet = Column(String, default=None)
|
||||
permissions = Column(JSON, default={})
|
||||
domains = Column(JSON, default=[])
|
||||
|
||||
archives = relationship("Archive", back_populates="group")
|
||||
sheets = relationship("Sheet", back_populates="group")
|
||||
users = relationship("User", back_populates="groups", secondary=association_table_user_groups)
|
||||
|
||||
|
||||
class Sheet(Base):
|
||||
__tablename__ = "sheets"
|
||||
|
||||
id = Column(String, primary_key=True, index=True, doc="Google Sheet ID")
|
||||
name = Column(String, default=None)
|
||||
author_id = Column(String, ForeignKey("users.email"))
|
||||
group_id = Column(String, ForeignKey("groups.id"), doc="Group ID, user must be in a group to create a sheet.")
|
||||
frequency = Column(String, default="daily", doc="Frequency of archiving: hourly, daily, weekly.")
|
||||
# TODO: stats is not needed, is it?
|
||||
stats = Column(JSON, default={}, doc="Sheet statistics like total links, total rows, ...")
|
||||
last_url_archived_at = Column(DateTime(timezone=True), server_default=func.now(), doc="Last time a new link was archived.")
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
group = relationship("Group", back_populates="sheets")
|
||||
author = relationship("User", back_populates="sheets")
|
||||
archives = relationship("Archive", back_populates="sheet")
|
||||
328
app/shared/db/user_state.py
Normal file
328
app/shared/db/user_state.py
Normal file
@@ -0,0 +1,328 @@
|
||||
|
||||
from typing import Dict, Set
|
||||
import sqlalchemy
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
from datetime import datetime
|
||||
|
||||
from app.shared.db import crud, models
|
||||
from app.shared.user_groups import GroupInfo, GroupPermissions
|
||||
from app.shared.schemas import Usage, UsageResponse
|
||||
|
||||
class UserState:
|
||||
"""
|
||||
Manage a user's state and permissions
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session, email: str):
|
||||
self.db = db
|
||||
self.email = email.lower()
|
||||
|
||||
@property
|
||||
def permissions(self) -> Dict[str, GroupInfo]:
|
||||
"""
|
||||
Returns a dict of all group permissions and a special {"all": read/archive_url/archive_sheet} key
|
||||
"""
|
||||
if not hasattr(self, '_permissions'):
|
||||
self._permissions = {}
|
||||
self._permissions["all"] = GroupInfo(
|
||||
read=self.read,
|
||||
read_public=self.read_public,
|
||||
archive_url=self.archive_url,
|
||||
archive_sheet=self.archive_sheet,
|
||||
# below are relevant only for /url endpoints
|
||||
max_archive_lifespan_months=self.max_archive_lifespan_months,
|
||||
max_monthly_urls=self.max_monthly_urls,
|
||||
max_monthly_mbs=self.max_monthly_mbs,
|
||||
priority=self.priority
|
||||
)
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
self._permissions[group.id] = GroupInfo(**group.permissions, description=group.description)
|
||||
return self._permissions
|
||||
|
||||
@property
|
||||
def user_groups_names(self):
|
||||
if not hasattr(self, '_user_groups_names'):
|
||||
self._user_groups_names = crud.get_user_groups(self.email) + ["default"]
|
||||
return self._user_groups_names
|
||||
|
||||
@property
|
||||
def user_groups(self):
|
||||
if not hasattr(self, '_user_groups'):
|
||||
self._user_groups = self.db.query(models.Group).filter(
|
||||
models.Group.id.in_(self.user_groups_names)
|
||||
).all()
|
||||
return self._user_groups
|
||||
|
||||
@property
|
||||
def read(self) -> Set[str] | bool:
|
||||
"""
|
||||
Read can be a list of group names or True, if all can be read.
|
||||
"""
|
||||
if not hasattr(self, '_read'):
|
||||
self._read = set()
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
group_read_permissions = group.permissions.get("read", [])
|
||||
if "all" in group_read_permissions:
|
||||
self._read = True
|
||||
return self._read
|
||||
else:
|
||||
self._read.update(group.permissions.get("read", []))
|
||||
return self._read
|
||||
|
||||
@property
|
||||
def read_public(self) -> bool:
|
||||
"""
|
||||
Read public permission
|
||||
"""
|
||||
if not hasattr(self, '_read_public'):
|
||||
self._read_public = False
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
if group.permissions.get("read_public", False):
|
||||
self._read_public = True
|
||||
return self._read_public
|
||||
return self._read_public
|
||||
|
||||
@property
|
||||
def archive_url(self) -> bool:
|
||||
"""
|
||||
Archive URL permission
|
||||
"""
|
||||
if not hasattr(self, '_archive_url'):
|
||||
self._archive_url = False
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
if group.permissions.get("archive_url", False):
|
||||
self._archive_url = True
|
||||
return self._archive_url
|
||||
return self._archive_url
|
||||
|
||||
@property
|
||||
def archive_sheet(self) -> bool:
|
||||
"""
|
||||
Archive sheet permission
|
||||
"""
|
||||
if not hasattr(self, '_archive_sheet'):
|
||||
self._archive_sheet = False
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
if group.permissions.get("archive_sheet", False):
|
||||
self._archive_sheet = True
|
||||
return self._archive_sheet
|
||||
return self._archive_sheet
|
||||
|
||||
@property
|
||||
def sheet_frequency(self):
|
||||
if not hasattr(self, '_sheet_frequency'):
|
||||
self._sheet_frequency = set()
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
self._sheet_frequency.update(group.permissions.get("sheet_frequency", None))
|
||||
return self._sheet_frequency
|
||||
|
||||
@property
|
||||
def max_archive_lifespan_months(self) -> int:
|
||||
if not hasattr(self, '_max_archive_lifespan_months'):
|
||||
self._max_archive_lifespan_months = self._helper_for_grouping_max_numerical_permissions("max_archive_lifespan_months")
|
||||
return self._max_archive_lifespan_months
|
||||
|
||||
@property
|
||||
def max_monthly_urls(self) -> int:
|
||||
if not hasattr(self, '_max_monthly_urls'):
|
||||
self._max_monthly_urls = self._helper_for_grouping_max_numerical_permissions("max_monthly_urls")
|
||||
return self._max_monthly_urls
|
||||
|
||||
@property
|
||||
def max_monthly_mbs(self) -> int:
|
||||
if not hasattr(self, '_max_monthly_mbs'):
|
||||
self._max_monthly_mbs = self._helper_for_grouping_max_numerical_permissions("max_monthly_mbs")
|
||||
return self._max_monthly_mbs
|
||||
|
||||
@property
|
||||
def priority(self) -> str:
|
||||
if not hasattr(self, '_priority'):
|
||||
self._priority = "low"
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
if group.permissions.get("priority", "low") == "high":
|
||||
self._priority = "high"
|
||||
return self._priority
|
||||
|
||||
@property
|
||||
def active(self) -> bool:
|
||||
"""
|
||||
A user is active if they can read/archive anything
|
||||
"""
|
||||
if not hasattr(self, '_active'):
|
||||
self._active = bool(self.read or self.read_public or self.archive_url or self.archive_sheet)
|
||||
return self._active
|
||||
|
||||
def _helper_for_grouping_max_numerical_permissions(self, permission_name: str) -> int:
|
||||
"""
|
||||
Iterates one of the numerical permissions where -1 means no restrictions and returns either -1 or the maximum value, defaults according to GroupPermissions
|
||||
"""
|
||||
default = GroupPermissions.model_fields[permission_name].default
|
||||
max_value = default
|
||||
for group in self.user_groups:
|
||||
if not group.permissions: continue
|
||||
group_value = group.permissions.get(permission_name, default)
|
||||
if group_value == -1:
|
||||
max_value = -1
|
||||
return max_value
|
||||
max_value = max(max_value, group_value)
|
||||
return max_value
|
||||
|
||||
def in_group(self, group_id: str) -> bool:
|
||||
return group_id in self.user_groups_names
|
||||
|
||||
def usage(self) -> Dict:
|
||||
"""
|
||||
returns the monthly quotas for the URLs/MBs and the totals for Sheets
|
||||
"""
|
||||
current_month = datetime.now().month
|
||||
current_year = datetime.now().year
|
||||
|
||||
# find and sum all user sheets over this month
|
||||
user_sheets = self.db.query(
|
||||
models.Sheet.group_id,
|
||||
func.count(models.Sheet.id).label('sheet_count')
|
||||
).filter(models.Sheet.author_id == self.email).group_by(models.Sheet.group_id).all()
|
||||
|
||||
sheets_by_group = {sheet.group_id: sheet.sheet_count for sheet in user_sheets}
|
||||
|
||||
# find and sum all user urls over this month
|
||||
urls_by_group = self.db.query(
|
||||
models.Archive.group_id,
|
||||
func.count(models.Archive.id).label('url_count'),
|
||||
func.coalesce(func.sum(
|
||||
func.coalesce(
|
||||
func.cast(
|
||||
func.json_extract(models.Archive.result, '$.metadata.total_bytes'),
|
||||
sqlalchemy.Integer
|
||||
), 0
|
||||
)
|
||||
), 0).label('total_bytes')
|
||||
).filter(
|
||||
models.Archive.author_id == self.email,
|
||||
func.extract('month', models.Archive.created_at) == current_month,
|
||||
func.extract('year', models.Archive.created_at) == current_year
|
||||
).group_by(models.Archive.group_id).all()
|
||||
|
||||
# merge the two queries
|
||||
usage_by_group: Dict[str, Usage] = {
|
||||
(url.group_id or ""):
|
||||
Usage(monthly_urls=url.url_count, monthly_mbs=int(url.total_bytes / 1024 / 1024))
|
||||
for url in urls_by_group
|
||||
}
|
||||
for group_id, sheet_count in sheets_by_group.items():
|
||||
group_id = group_id or ""
|
||||
if group_id in usage_by_group:
|
||||
usage_by_group[group_id].total_sheets = sheet_count
|
||||
else:
|
||||
usage_by_group[group_id] = Usage(total_sheets=sheet_count)
|
||||
|
||||
# calculate totals
|
||||
total_sheets = sum([sheet.sheet_count for sheet in user_sheets])
|
||||
total_bytes = sum([url.total_bytes for url in urls_by_group])
|
||||
total_urls = sum([url.url_count for url in urls_by_group])
|
||||
|
||||
return UsageResponse(
|
||||
monthly_urls=total_urls,
|
||||
monthly_mbs=int(total_bytes / 1024 / 1024),
|
||||
total_sheets=total_sheets,
|
||||
groups=usage_by_group
|
||||
)
|
||||
|
||||
def has_quota_monthly_sheets(self, group_id: str) -> bool:
|
||||
"""
|
||||
checks if a user has reached their sheet quota for a given group
|
||||
"""
|
||||
if group_id not in self.permissions:
|
||||
return False
|
||||
|
||||
user_sheets = self.db.query(models.Sheet).filter(models.Sheet.author_id == self.email, models.Sheet.group_id == group_id).count()
|
||||
|
||||
sheet_quota = self.permissions[group_id].max_sheets
|
||||
if sheet_quota == -1:
|
||||
return True
|
||||
return user_sheets < sheet_quota
|
||||
|
||||
def has_quota_max_monthly_urls(self, group_id: str) -> bool:
|
||||
"""
|
||||
checks if a user has reached their monthly url quota for a group, if global then group should be empty string
|
||||
"""
|
||||
quota = 0
|
||||
if not group_id:
|
||||
quota = self.max_monthly_urls
|
||||
else:
|
||||
if group_id not in self.permissions: return False
|
||||
quota = self.permissions[group_id].max_monthly_urls
|
||||
|
||||
if quota == -1:
|
||||
return True
|
||||
|
||||
current_month = datetime.now().month
|
||||
current_year = datetime.now().year
|
||||
user_urls = self.db.query(models.Archive).filter(
|
||||
models.Archive.author_id == self.email,
|
||||
func.extract('month', models.Archive.created_at) == current_month,
|
||||
func.extract('year', models.Archive.created_at) == current_year
|
||||
).count()
|
||||
|
||||
return user_urls < quota
|
||||
|
||||
def has_quota_max_monthly_mbs(self, group_id: str) -> bool:
|
||||
"""
|
||||
checks if a user has reached their monthly MBs quota for a group, if global then group should be empty string
|
||||
"""
|
||||
quota = 0
|
||||
if not group_id:
|
||||
quota = self.max_monthly_mbs
|
||||
else:
|
||||
if group_id not in self.permissions: return False
|
||||
quota = self.permissions[group_id].max_monthly_mbs
|
||||
|
||||
if quota == -1:
|
||||
return True
|
||||
|
||||
current_month = datetime.now().month
|
||||
current_year = datetime.now().year
|
||||
|
||||
# find and sum all user bytes over this month
|
||||
user_bytes = self.db.query(models.Archive).filter(
|
||||
models.Archive.author_id == self.email,
|
||||
func.extract('month', models.Archive.created_at) == current_month,
|
||||
func.extract('year', models.Archive.created_at) == current_year
|
||||
).with_entities(func.coalesce(func.sum(
|
||||
func.coalesce(
|
||||
func.cast(
|
||||
func.json_extract(models.Archive.result, '$.metadata.total_bytes'),
|
||||
sqlalchemy.Integer
|
||||
), 0
|
||||
)
|
||||
), 0).label('total')).scalar()
|
||||
|
||||
# convert bytes to mb
|
||||
user_mbs = int(user_bytes / 1024 / 1024)
|
||||
return user_mbs < quota
|
||||
|
||||
def can_manually_trigger(self, group_id: str) -> bool:
|
||||
"""
|
||||
checks if a user is allowed to manually trigger a sheet
|
||||
"""
|
||||
if group_id not in self.permissions:
|
||||
return False
|
||||
|
||||
return self.permissions[group_id].manually_trigger_sheet
|
||||
|
||||
def is_sheet_frequency_allowed(self, group_id: str, frequency: str) -> bool:
|
||||
"""
|
||||
checks if a user is allowed to create a sheet with this frequency for this group
|
||||
"""
|
||||
if group_id not in self.permissions:
|
||||
return False
|
||||
|
||||
return frequency in self.permissions[group_id].sheet_frequency
|
||||
13
app/shared/log.py
Normal file
13
app/shared/log.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import traceback
|
||||
from loguru import logger
|
||||
|
||||
|
||||
# logging configurations
|
||||
logger.add("logs/api_logs.log", retention="30 days", rotation="3 days")
|
||||
logger.add("logs/error_logs.log", retention="30 days", level="ERROR")
|
||||
|
||||
|
||||
def log_error(e: Exception, traceback_str: str = None, extra:str = ""):
|
||||
if not traceback_str: traceback_str = traceback.format_exc()
|
||||
if extra: extra = f"{extra}\n"
|
||||
logger.error(f"{extra}{e.__class__.__name__}: {e}\n{traceback_str}")
|
||||
117
app/shared/schemas.py
Normal file
117
app/shared/schemas.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from typing import Annotated
|
||||
from annotated_types import Len
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class SubmitSheet(BaseModel):
|
||||
sheet_id: str | None
|
||||
author_id: str | None = None
|
||||
group_id: str = "default"
|
||||
tags: set[str] | None = set()
|
||||
columns: dict | None = {} # TODO: implement/remove
|
||||
|
||||
|
||||
class SubmitManual(BaseModel): # deprecated
|
||||
result: str # should be a Metadata.to_json()
|
||||
public: bool = False
|
||||
author_id: str | None = None
|
||||
group_id: str | None = None
|
||||
tags: set[str] | None = set()
|
||||
|
||||
# API REQUESTS BELOW
|
||||
# TODO: replace existing schemas with these
|
||||
|
||||
|
||||
class ArchiveUrl(BaseModel):
|
||||
url: str
|
||||
public: bool = False
|
||||
author_id: str | None
|
||||
group_id: str | None
|
||||
tags: set[str] | None = set()
|
||||
|
||||
# API RESPONSES BELOW
|
||||
|
||||
|
||||
class ArchiveResult(BaseModel):
|
||||
id: str
|
||||
url: str
|
||||
result: dict
|
||||
created_at: datetime
|
||||
store_until: datetime | None
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
class TaskResult(Task):
|
||||
status: str
|
||||
result: str
|
||||
|
||||
|
||||
class TaskDelete(Task):
|
||||
deleted: bool
|
||||
|
||||
|
||||
class ActiveUser(BaseModel):
|
||||
active: bool
|
||||
|
||||
|
||||
class SheetAdd(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
group_id: str
|
||||
frequency: str
|
||||
|
||||
|
||||
class SheetResponse(SheetAdd):
|
||||
author_id: str
|
||||
created_at: datetime
|
||||
last_url_archived_at: datetime | None
|
||||
|
||||
|
||||
class ArchiveTrigger(BaseModel):
|
||||
author_id: str | None = None
|
||||
url: Annotated[str, Len(min_length=5)]
|
||||
public: bool = False
|
||||
group_id: Annotated[str, Len(min_length=1)] = "default"
|
||||
tags: set[str] | None = None
|
||||
|
||||
|
||||
class ArchiveCreate(ArchiveTrigger):
|
||||
id: str | None = None
|
||||
result: dict | None = None
|
||||
sheet_id: str | None = None
|
||||
urls: list | None = None
|
||||
store_until: datetime | None = None
|
||||
|
||||
|
||||
class Archive(ArchiveCreate):
|
||||
created_at: datetime
|
||||
updated_at: datetime | None
|
||||
deleted: bool
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
monthly_urls: int = 0
|
||||
monthly_mbs: int = 0
|
||||
total_sheets: int = 0
|
||||
|
||||
|
||||
class UsageResponse(Usage):
|
||||
groups: dict[str, Usage]
|
||||
|
||||
|
||||
class CelerySheetTask(BaseModel):
|
||||
success: bool
|
||||
sheet_id: str
|
||||
time: datetime
|
||||
stats: dict
|
||||
|
||||
|
||||
class SubmitManualArchive(ArchiveTrigger):
|
||||
url: None = None
|
||||
result: str # should be a Metadata.to_json()
|
||||
73
app/shared/settings.py
Normal file
73
app/shared/settings.py
Normal file
@@ -0,0 +1,73 @@
|
||||
|
||||
from functools import lru_cache
|
||||
from fastapi_mail import ConnectionConfig
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import ConfigDict
|
||||
from typing import Annotated, Set
|
||||
from annotated_types import Len
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = ConfigDict(extra='ignore', str_strip_whitespace=True)
|
||||
|
||||
# general
|
||||
SERVE_LOCAL_ARCHIVE: str = ""
|
||||
USER_GROUPS_FILENAME: str = "user-groups.yaml"
|
||||
SHEET_ORCHESTRATION_YAML : str = "secrets/orchestration-sheet.yaml"
|
||||
|
||||
# cronjobs
|
||||
#TODO: disable by default?
|
||||
CRON_ARCHIVE_SHEETS: bool = False
|
||||
CRON_DELETE_STALE_SHEETS: bool = True
|
||||
DELETE_STALE_SHEETS_DAYS: int = 14
|
||||
CRON_DELETE_SCHEDULED_ARCHIVES: bool = True
|
||||
DELETE_SCHEDULED_ARCHIVES_NOTIFY_DAYS: int = 14
|
||||
|
||||
# database
|
||||
DATABASE_PATH: str
|
||||
DATABASE_QUERY_LIMIT: int = 100
|
||||
@property
|
||||
def ASYNC_DATABASE_PATH(self) -> str:
|
||||
return self.DATABASE_PATH.replace("sqlite://", "sqlite+aiosqlite://")
|
||||
|
||||
# redis
|
||||
REDIS_PASSWORD: str = ""
|
||||
CELERY_BROKER_URL: str = "redis://localhost:6379"
|
||||
REDIS_EXCEPTIONS_CHANNEL: str = "exceptions-channel"
|
||||
|
||||
# observability
|
||||
REPEAT_COUNT_METRICS_SECONDS: int = 30
|
||||
|
||||
# security
|
||||
API_BEARER_TOKEN: Annotated[str, Len(min_length=20)]
|
||||
ALLOWED_ORIGINS: Annotated[Set[str], Len(min_length=1)]
|
||||
CHROME_APP_IDS: Annotated[Set[Annotated[str, Len(min_length=10)]], Len(min_length=1)]
|
||||
#TODO: deprecate blocklist?
|
||||
BLOCKED_EMAILS: Annotated[Set[str], Len(min_length=0)] = set()
|
||||
|
||||
# email configuration, if needed
|
||||
MAIL_FROM: str = "noreply@bellingcat.com"
|
||||
MAIL_FROM_NAME: str = "Bellingcat's Auto Archiver"
|
||||
MAIL_USERNAME: str = ""
|
||||
MAIL_PASSWORD: str = ""
|
||||
MAIL_SERVER: str = ""
|
||||
MAIL_PORT: int = 587
|
||||
MAIL_STARTTLS: bool = False
|
||||
MAIL_SSL_TLS: bool = True
|
||||
@property
|
||||
def MAIL_CONFIG(self) -> str:
|
||||
return ConnectionConfig(
|
||||
MAIL_FROM=self.MAIL_FROM,
|
||||
MAIL_FROM_NAME=self.MAIL_FROM_NAME,
|
||||
MAIL_USERNAME=self.MAIL_USERNAME,
|
||||
MAIL_PASSWORD=self.MAIL_PASSWORD,
|
||||
MAIL_SERVER=self.MAIL_SERVER,
|
||||
MAIL_PORT=self.MAIL_PORT,
|
||||
MAIL_STARTTLS=self.MAIL_STARTTLS,
|
||||
MAIL_SSL_TLS=self.MAIL_SSL_TLS,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings():
|
||||
return Settings()
|
||||
18
app/shared/task_messaging.py
Normal file
18
app/shared/task_messaging.py
Normal file
@@ -0,0 +1,18 @@
|
||||
|
||||
from functools import lru_cache
|
||||
from celery import Celery
|
||||
import redis
|
||||
|
||||
from app.shared.settings import get_settings
|
||||
|
||||
@lru_cache
|
||||
def get_celery(name:str="") -> Celery:
|
||||
return Celery(
|
||||
name,
|
||||
broker_url=get_settings().CELERY_BROKER_URL,
|
||||
result_backend=get_settings().CELERY_BROKER_URL,
|
||||
)
|
||||
|
||||
|
||||
def get_redis() -> redis.Redis:
|
||||
return redis.Redis.from_url(get_settings().CELERY_BROKER_URL)
|
||||
140
app/shared/user_groups.py
Normal file
140
app/shared/user_groups.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import os
|
||||
import yaml
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, field_validator, Field, model_validator
|
||||
from typing import Dict, List, Set
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class UserGroups:
|
||||
def __init__(self, filename):
|
||||
user_groups = self.read_yaml(filename)
|
||||
self.validate_and_load(user_groups)
|
||||
|
||||
def read_yaml(self, user_groups_filename):
|
||||
# read yaml safely
|
||||
with open(user_groups_filename) as inf:
|
||||
try:
|
||||
return yaml.safe_load(inf)
|
||||
except yaml.YAMLError as e:
|
||||
logger.error(f"could not open user groups filename {user_groups_filename}: {e}")
|
||||
raise e
|
||||
|
||||
def validate_and_load(self, user_groups):
|
||||
try:
|
||||
configs = UserGroupModel(**user_groups)
|
||||
self.users = configs.users
|
||||
self.domains = configs.domains
|
||||
self.groups = configs.groups
|
||||
except Exception as e:
|
||||
logger.error(f"Validation error: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
class GroupPermissions(BaseModel):
|
||||
read: Set[str] | bool = Field(default_factory=list)
|
||||
read_public: bool = False
|
||||
archive_url: bool = False
|
||||
archive_sheet: bool = False
|
||||
manually_trigger_sheet: bool = False
|
||||
sheet_frequency: Set[str] = Field(default_factory=list)
|
||||
max_sheets: int = 0
|
||||
max_archive_lifespan_months: int = 12
|
||||
max_monthly_urls: int = 0
|
||||
max_monthly_mbs: int = 0
|
||||
priority: str = "low"
|
||||
|
||||
@field_validator('max_sheets', 'max_archive_lifespan_months', 'max_monthly_urls', 'max_monthly_mbs', mode='before')
|
||||
def validate_max_values(cls, v):
|
||||
if v < -1:
|
||||
raise ValueError("max_* values should be positive integers or -1 (for no limit).")
|
||||
return v
|
||||
|
||||
@field_validator('sheet_frequency', mode='before')
|
||||
def validate_sheet_frequency(cls, v):
|
||||
if not v: return []
|
||||
allowed = ["daily", "hourly"]
|
||||
for k in v:
|
||||
if k not in allowed:
|
||||
raise ValueError(f"Invalid sheet frequency: '{k}', expected one of {allowed}")
|
||||
return v
|
||||
|
||||
@field_validator('priority', mode='before')
|
||||
def validate_priority(cls, v):
|
||||
v = v.lower()
|
||||
if v not in ["low", "high"]:
|
||||
raise ValueError("priority must be either 'low' or 'high'.")
|
||||
return v
|
||||
|
||||
|
||||
class GroupModel(BaseModel):
|
||||
description: str
|
||||
orchestrator: str
|
||||
orchestrator_sheet: str
|
||||
permissions: GroupPermissions
|
||||
|
||||
@field_validator('orchestrator', 'orchestrator_sheet', mode='before')
|
||||
def validate_priority(cls, v):
|
||||
if not os.path.exists(v):
|
||||
raise ValueError(f"Orchestrator file not found with this path: {v}")
|
||||
return v
|
||||
|
||||
|
||||
class UserGroupModel(BaseModel):
|
||||
users: Dict[str, List[str]] = Field(default_factory=dict)
|
||||
domains: Dict[str, List[str]] = Field(default_factory=dict)
|
||||
groups: Dict[str, GroupModel] = Field(default_factory=dict)
|
||||
|
||||
@field_validator('users', mode='before')
|
||||
@classmethod
|
||||
def validate_emails(cls, v):
|
||||
for email in v.keys():
|
||||
if '@' not in email:
|
||||
raise ValueError(f"Invalid user, it should be an address: {email}")
|
||||
if not v[email]:
|
||||
raise ValueError(f"User {email} has no explicitly listed groups, only include them here if they should be in a group.")
|
||||
# all users belong to the default group
|
||||
return {k.lower().strip(): list(set(["default"] + [g.lower().strip() for g in v])) for k, v in v.items()}
|
||||
|
||||
@field_validator('domains', mode='before')
|
||||
@classmethod
|
||||
def validate_domains(cls, v):
|
||||
for domain, members in v.items():
|
||||
if '.' not in domain:
|
||||
raise ValueError(f"Invalid domain, it should contain a dot: {domain}")
|
||||
if not members:
|
||||
raise ValueError(f"Domain {domain} should have at least one member.")
|
||||
return {k.lower().strip(): list(set([g.lower().strip() for g in v])) for k, v in v.items()}
|
||||
|
||||
@field_validator('groups', mode='before')
|
||||
@classmethod
|
||||
def validate_groups(cls, v):
|
||||
if "default" not in v.keys():
|
||||
raise ValueError("Please include a 'default' group.")
|
||||
if "all" in v.keys():
|
||||
raise ValueError("'all' is a reserved group name.")
|
||||
for group in v.keys():
|
||||
if not group == group.lower():
|
||||
raise ValueError(f"Group names should be lowercase: {group}")
|
||||
return v
|
||||
|
||||
@model_validator(mode='after')
|
||||
def check_groups_consistency(self) -> Self:
|
||||
groups_in_domains = set([g for domain in self.domains for g in self.domains[domain]])
|
||||
groups_in_users = set([g for user in self.users for g in self.users[user]])
|
||||
configured_groups = set(self.groups.keys())
|
||||
|
||||
# groups mentioned in domains and users should be defined, but this is not a ValueError since historical DB data may require it
|
||||
if groups_in_domains - configured_groups:
|
||||
logger.warning(f"These groups are associated to DOMAINS but not defined in the GROUPS section, the domains settings may not work as expected: {groups_in_domains - configured_groups}")
|
||||
if groups_in_users - configured_groups:
|
||||
logger.warning(f"These groups are associated to USERS but not defined in the GROUPS section, the users settings may not work as expected: {groups_in_users - configured_groups}")
|
||||
|
||||
return self
|
||||
|
||||
# for the API return values
|
||||
|
||||
|
||||
class GroupInfo(GroupPermissions):
|
||||
description: str = ""
|
||||
service_account_emails: list[str] = []
|
||||
10
app/shared/utils/misc.py
Normal file
10
app/shared/utils/misc.py
Normal file
@@ -0,0 +1,10 @@
|
||||
|
||||
def fnv1a_hash_mod(s: str, modulo:int) -> int:
|
||||
# receives a string and returns a number in [0:modulo-1], ensures an even distribution over the modulo range
|
||||
hash = 0x811c9dc5 # FNV offset basis
|
||||
fnv_prime = 0x01000193 # FNV prime
|
||||
for char in s:
|
||||
hash ^= ord(char)
|
||||
hash *= fnv_prime
|
||||
hash &= 0xFFFFFFFF # Keep it 32-bit
|
||||
return (hash if hash < 0x80000000 else hash - 0x100000000) % modulo
|
||||
114
app/tests/conftest.py
Normal file
114
app/tests/conftest.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import os
|
||||
from fastapi.testclient import TestClient
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from app.shared.config import ALLOW_ANY_EMAIL
|
||||
from db.user_state import UserState
|
||||
from shared.settings import Settings
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_logger_add():
|
||||
"""Fixture to mock loguru.logger.add for all tests."""
|
||||
with patch('loguru.logger.add') as mock_add:
|
||||
yield mock_add # This makes the mock available to tests
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def get_settings():
|
||||
return Settings(_env_file=".env.test")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_settings():
|
||||
with patch('shared.settings.Settings', return_value=Settings(_env_file=".env.test")) as mock_settings:
|
||||
yield mock_settings
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_db(get_settings: Settings):
|
||||
from db.database import make_engine
|
||||
from db import models
|
||||
from db.crud import get_user_groups
|
||||
|
||||
get_user_groups.cache_clear()
|
||||
make_engine.cache_clear()
|
||||
engine = make_engine(get_settings.DATABASE_PATH)
|
||||
|
||||
fs = get_settings.DATABASE_PATH.replace("sqlite:///", "")
|
||||
if not os.path.exists(fs):
|
||||
open(fs, 'w').close()
|
||||
|
||||
models.Base.metadata.create_all(engine)
|
||||
|
||||
connection = engine.connect()
|
||||
yield connection
|
||||
connection.close()
|
||||
|
||||
models.Base.metadata.drop_all(bind=engine)
|
||||
for suffix in ["", "-wal", "-shm"]:
|
||||
new_fs = fs + suffix
|
||||
if os.path.exists(new_fs):
|
||||
os.remove(new_fs)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def db_session(test_db):
|
||||
from db.database import make_session_local
|
||||
session_local = make_session_local(test_db)
|
||||
with session_local() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def app(db_session):
|
||||
from web.main import app_factory
|
||||
from db import crud
|
||||
app = app_factory()
|
||||
crud.upsert_user_groups(db_session)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client(app):
|
||||
client = TestClient(app)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def app_with_auth(app, db_session):
|
||||
from web.security import get_token_or_user_auth, get_user_auth, get_user_state
|
||||
app.dependency_overrides[get_token_or_user_auth] = lambda: "rick@example.com"
|
||||
app.dependency_overrides[get_user_auth] = lambda: "morty@example.com"
|
||||
app.dependency_overrides[get_user_state] = lambda: UserState(db_session, "MORTY@example.com")
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client_with_auth(app_with_auth):
|
||||
client = TestClient(app_with_auth)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def app_with_token(app):
|
||||
from web.security import token_api_key_auth, get_token_or_user_auth
|
||||
app.dependency_overrides[token_api_key_auth] = lambda: ALLOW_ANY_EMAIL
|
||||
app.dependency_overrides[get_token_or_user_auth] = lambda: ALLOW_ANY_EMAIL
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client_with_token(app_with_token):
|
||||
client = TestClient(app_with_token)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_no_auth():
|
||||
# reusable code to ensure a method/endpoint combination is unauthorized
|
||||
def no_auth(http_method, endpoint):
|
||||
response = http_method(endpoint)
|
||||
assert response.status_code == 403
|
||||
assert response.json() == {"detail": "Not authenticated"}
|
||||
return no_auth
|
||||
468
app/tests/db/test_crud.py
Normal file
468
app/tests/db/test_crud.py
Normal file
@@ -0,0 +1,468 @@
|
||||
from datetime import datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from db import models
|
||||
from shared.settings import Settings
|
||||
|
||||
authors = ["rick@example.com", "morty@example.com", "jerry@example.com"]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_data(db_session):
|
||||
|
||||
# creates 3 users
|
||||
for email in authors:
|
||||
db_session.add(models.User(email=email))
|
||||
db_session.commit()
|
||||
assert db_session.query(models.User).count() == 3
|
||||
|
||||
# creates 100 archives for 3 users over 2 months with repeating URLs
|
||||
for i in range(100):
|
||||
author = authors[i % 3]
|
||||
archive = models.Archive(
|
||||
id=f"archive-id-456-{i}",
|
||||
url=f"https://example-{i%3}.com",
|
||||
result={},
|
||||
public=author == "jerry@example.com",
|
||||
author_id=author,
|
||||
group_id="spaceship" if author == "morty@example.com" and i % 2 == 0 else None,
|
||||
created_at=datetime(2021, (i % 2) + 1, (i % 25) + 1)
|
||||
)
|
||||
if i % 5 == 0:
|
||||
archive.tags.append(models.Tag(id=f"tag-{i}"))
|
||||
if i % 10 == 0:
|
||||
archive.tags.append(models.Tag(id=f"tag-second-{i}"))
|
||||
if i % 4 == 0:
|
||||
archive.tags.append(models.Tag(id=f"tag-third-{i}"))
|
||||
for j in range(10):
|
||||
archive.urls.append(models.ArchiveUrl(url=f"https://example-{i}.com/{j}", key=f"media_{j}"))
|
||||
db_session.add(archive)
|
||||
|
||||
# creates a sheet for each user
|
||||
for i, email in enumerate(authors):
|
||||
db_session.add(models.Sheet(id=f"sheet-{i}", name=f"sheet-{i}", author_id=email, group_id=None, frequency="daily"))
|
||||
if email == "rick@example.com":
|
||||
db_session.add(models.Sheet(id=f"sheet-{i}-2", name=f"sheet-{i}-2", author_id=email, group_id="spaceship", frequency="hourly"))
|
||||
|
||||
db_session.commit()
|
||||
|
||||
assert db_session.query(models.Archive).count() == 100
|
||||
assert db_session.query(models.Tag).count() == 20 + 10 + 25
|
||||
assert db_session.query(models.ArchiveUrl).count() == 1000
|
||||
assert db_session.query(models.ArchiveUrl).filter(models.ArchiveUrl.archive_id == "archive-id-456-0").count() == 10
|
||||
|
||||
# setup groups
|
||||
assert db_session.query(models.Group).count() == 0
|
||||
from db import crud
|
||||
crud.upsert_user_groups(db_session)
|
||||
assert db_session.query(models.Group).count() == 4
|
||||
assert db_session.query(models.User).count() == 3
|
||||
|
||||
|
||||
def test_get_archive(test_data, db_session):
|
||||
from db import crud
|
||||
from app.shared.config import ALLOW_ANY_EMAIL
|
||||
|
||||
print(db_session.query(models.Group).all())
|
||||
|
||||
# each author's archives work
|
||||
assert (a0 := crud.get_archive(db_session, "archive-id-456-0", authors[0])) is not None
|
||||
assert a0.id == "archive-id-456-0"
|
||||
assert a0.url == "https://example-0.com"
|
||||
assert a0.author_id == authors[0]
|
||||
assert a0.public == False
|
||||
|
||||
assert crud.get_archive(db_session, "archive-id-456-1", authors[1]) is not None
|
||||
assert crud.get_archive(db_session, "archive-id-456-2", authors[2]) is not None
|
||||
|
||||
# ALLOW_ANY_EMAIL
|
||||
assert crud.get_archive(db_session, "archive-id-456-0", ALLOW_ANY_EMAIL) is not None
|
||||
assert crud.get_archive(db_session, "archive-id-456-1", ALLOW_ANY_EMAIL) is not None
|
||||
|
||||
# not found
|
||||
assert crud.get_archive(db_session, "archive-missing", authors[0]) is None
|
||||
|
||||
# public
|
||||
assert (a_public := crud.get_archive(db_session, "archive-id-456-2", authors[0])) is not None
|
||||
assert a_public.public == True
|
||||
|
||||
# not public - rick's
|
||||
assert crud.get_archive(db_session, "archive-id-456-0", authors[1]) is None
|
||||
|
||||
|
||||
def test_search_archives_by_url(test_data, db_session):
|
||||
from db import crud
|
||||
from app.shared.config import ALLOW_ANY_EMAIL
|
||||
|
||||
# rick's archives are private
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "rick@example.com")) == 34
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", ALLOW_ANY_EMAIL)) == 34
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-0.com", "morty@example.com")) == 0
|
||||
|
||||
# morty's archives are public but half are in spaceship group
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-1.com", "rick@example.com")) == 16
|
||||
|
||||
# jerry's archives are public
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-2.com", "jerry@example.com")) == 33
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-2.com", "rick@example.com")) == 33
|
||||
|
||||
# fuzzy search
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL)) == 100
|
||||
assert len(crud.search_archives_by_url(db_session, "https://EXAMPLE", ALLOW_ANY_EMAIL)) == 100
|
||||
assert len(crud.search_archives_by_url(db_session, "2.com", ALLOW_ANY_EMAIL)) == 33
|
||||
|
||||
# absolute search
|
||||
assert len(crud.search_archives_by_url(db_session, "example-2.com", ALLOW_ANY_EMAIL, absolute_search=True)) == 0
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example-2.com", ALLOW_ANY_EMAIL, absolute_search=True)) == 33
|
||||
|
||||
# archived_after
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, archived_after=datetime(2010, 1, 1))) == 100
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, archived_after=datetime(2021, 1, 15))) == 70
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, archived_after=datetime(2031, 1, 1))) == 0
|
||||
|
||||
# archived before
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, archived_before=datetime(2010, 1, 1))) == 0
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, archived_before=datetime(2021, 1, 15))) == 28
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, archived_before=datetime(2031, 1, 1))) == 100
|
||||
|
||||
# archived before and after
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, archived_after=datetime(2001, 1, 1), archived_before=datetime(2031, 1, 11))) == 100
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, archived_after=datetime(2021, 1, 14), archived_before=datetime(2021, 1, 16))) == 2
|
||||
|
||||
# limit
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, limit=10)) == 10
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, limit=-1)) == 1
|
||||
|
||||
# skip
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, skip=10)) == 90
|
||||
|
||||
|
||||
def test_search_archives_by_email(test_data, db_session):
|
||||
from app.shared.config import ALLOW_ANY_EMAIL
|
||||
from db import crud
|
||||
|
||||
# lower/upper case
|
||||
assert len(crud.search_archives_by_email(db_session, "rick@example.com")) == 34
|
||||
|
||||
# ALLOW_ANY_EMAIL is not a user
|
||||
assert len(crud.search_archives_by_email(db_session, ALLOW_ANY_EMAIL)) == 0
|
||||
|
||||
# most recent first
|
||||
a1 = crud.search_archives_by_email(db_session, "rick@example.com", limit=1)
|
||||
assert len(a1) == 1
|
||||
assert a1[0].created_at == datetime(2021, 2, 25)
|
||||
|
||||
# earliest is the last
|
||||
a2 = crud.search_archives_by_email(db_session, "rick@example.com", skip=33)
|
||||
assert len(a2) == 1
|
||||
assert a2[0].created_at == datetime(2021, 1, 1)
|
||||
|
||||
|
||||
@patch("db.crud.DATABASE_QUERY_LIMIT", new=25)
|
||||
def test_max_query_limit(test_data, db_session):
|
||||
from db import crud
|
||||
from app.shared.config import ALLOW_ANY_EMAIL
|
||||
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL)) == 25
|
||||
assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, limit=1000)) == 25
|
||||
|
||||
assert len(crud.search_archives_by_email(db_session, "rick@example.com")) == 25
|
||||
assert len(crud.search_archives_by_email(db_session, "rick@example.com", limit=1000)) == 25
|
||||
|
||||
|
||||
def test_create_task(db_session):
|
||||
from db import crud
|
||||
from db import schemas
|
||||
|
||||
task = schemas.ArchiveCreate(
|
||||
id="archive-id-456-101",
|
||||
url="https://example-0.com",
|
||||
result={},
|
||||
public=False,
|
||||
author_id="rick@example.com",
|
||||
group_id="spaceship",
|
||||
tags=[],
|
||||
urls=[]
|
||||
)
|
||||
|
||||
# with tags and urls
|
||||
nt = crud.create_task(db_session, task, [models.Tag(id="tag-101")], [models.ArchiveUrl(url="https://example-0.com/0", key="media_0")])
|
||||
|
||||
assert nt is not None
|
||||
assert nt.id == "archive-id-456-101"
|
||||
assert nt.url == "https://example-0.com"
|
||||
assert nt.author_id == "rick@example.com"
|
||||
assert nt.public == False
|
||||
assert nt.group_id == "spaceship"
|
||||
assert len(nt.tags) == 1
|
||||
assert nt.tags[0].id == "tag-101"
|
||||
assert len(nt.urls) == 1
|
||||
assert nt.urls[0].url == "https://example-0.com/0"
|
||||
assert nt.urls[0].key == "media_0"
|
||||
assert nt.created_at is not None
|
||||
|
||||
# without tags and urls
|
||||
task.id = "archive-id-456-102"
|
||||
nt = crud.create_task(db_session, task, [], [])
|
||||
assert nt is not None
|
||||
assert nt.id == "archive-id-456-102"
|
||||
assert nt.url == "https://example-0.com"
|
||||
assert nt.author_id == "rick@example.com"
|
||||
assert nt.public == False
|
||||
assert nt.group_id == "spaceship"
|
||||
assert len(nt.tags) == 0
|
||||
assert len(nt.urls) == 0
|
||||
assert nt.created_at is not None
|
||||
|
||||
|
||||
def test_soft_delete(test_data, db_session):
|
||||
from db import crud
|
||||
|
||||
# none deleted yet
|
||||
assert crud.get_archive(db_session, "archive-id-456-0", "rick@example.com") is not None
|
||||
assert db_session.query(models.Archive).filter(models.Archive.deleted == True).count() == 0
|
||||
|
||||
# delete
|
||||
assert crud.soft_delete_task(db_session, "archive-id-456-0", "rick@example.com") == True
|
||||
|
||||
# ensure soft delete
|
||||
assert db_session.query(models.Archive).filter(models.Archive.deleted == True).count() == 1
|
||||
assert crud.get_archive(db_session, "archive-id-456-0", "rick@example.com") is None
|
||||
|
||||
# already deleted
|
||||
assert crud.soft_delete_task(db_session, "archive-id-456-0", "rick@example.com") == False
|
||||
|
||||
|
||||
def test_count_archives(test_data, db_session):
|
||||
from db import crud
|
||||
|
||||
assert crud.count_archives(db_session) == 100
|
||||
db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").delete()
|
||||
db_session.commit()
|
||||
assert crud.count_archives(db_session) == 99
|
||||
|
||||
|
||||
def test_count_archive_urls(test_data, db_session):
|
||||
from db import crud
|
||||
|
||||
assert crud.count_archive_urls(db_session) == 1000
|
||||
db_session.query(models.ArchiveUrl).filter(models.ArchiveUrl.url == "https://example-0.com/0").delete()
|
||||
db_session.commit()
|
||||
assert crud.count_archive_urls(db_session) == 999
|
||||
|
||||
db_session.query(models.Archive).filter(models.Archive.id == "archive-id-456-0").delete()
|
||||
db_session.commit()
|
||||
# no Cascade is enabled
|
||||
assert crud.count_archives(db_session) == 99
|
||||
assert crud.count_archive_urls(db_session) == 999
|
||||
|
||||
|
||||
def test_count_users(test_data, db_session):
|
||||
from db import crud
|
||||
|
||||
assert crud.count_users(db_session) == 3
|
||||
db_session.query(models.User).filter(models.User.email == "rick@example.com").delete()
|
||||
db_session.commit()
|
||||
assert crud.count_users(db_session) == 2
|
||||
|
||||
|
||||
def test_count_by_users_since(test_data, db_session):
|
||||
from db import crud
|
||||
|
||||
# 100y window
|
||||
assert len(cu := crud.count_by_user_since(db_session, 60 * 60 * 24 * 31 * 12 * 100)) == 3
|
||||
assert cu[0].total == 34
|
||||
assert cu[1].total == 33
|
||||
assert cu[2].total == 33
|
||||
|
||||
|
||||
def test_create_tag(db_session):
|
||||
from db import crud
|
||||
|
||||
assert db_session.query(models.Tag).count() == 0
|
||||
|
||||
# create first
|
||||
create_tag = crud.create_tag(db_session, "tag-101")
|
||||
assert create_tag is not None
|
||||
assert create_tag.id == "tag-101"
|
||||
assert db_session.query(models.Tag).count() == 1
|
||||
assert db_session.query(models.Tag).filter(models.Tag.id == "tag-101").first() == create_tag
|
||||
|
||||
# same id does not add new db entry
|
||||
existing_tag = crud.create_tag(db_session, "tag-101")
|
||||
assert existing_tag == create_tag
|
||||
assert db_session.query(models.Tag).count() == 1
|
||||
|
||||
# create second
|
||||
second_tag = crud.create_tag(db_session, "tag-102")
|
||||
assert second_tag is not None
|
||||
assert second_tag.id == "tag-102"
|
||||
assert db_session.query(models.Tag).count() == 2
|
||||
|
||||
|
||||
def test_is_user_in_group(test_data, db_session):
|
||||
from db import crud
|
||||
from app.shared.config import ALLOW_ANY_EMAIL
|
||||
|
||||
# see user-groups.test.yaml
|
||||
test_pairs = [
|
||||
(ALLOW_ANY_EMAIL, "spaceship", True),
|
||||
(ALLOW_ANY_EMAIL, "non-existant!@#!%!", True),
|
||||
|
||||
("rick@example.com", "spaceship", True),
|
||||
("rick@example.com", "SPACESHIP", False),
|
||||
("rick@example.com", "interdimensional", True),
|
||||
("rick@example.com", "animated-characters", True),
|
||||
("rick@example.com", "the-jerrys-club", False),
|
||||
|
||||
("morty@example.com", "spaceship", True),
|
||||
("morty@example.com", "interdimensional", False),
|
||||
("morty@example.com", "the-jerrys-club", False),
|
||||
|
||||
("jerry@example.com", "spaceship", False),
|
||||
("jerry@example.com", "interdimensional", False),
|
||||
("jerry@example.com", "the-jerrys-club", False), # group not in 'groups'
|
||||
|
||||
("rick@example.com", "animated-characters", True),
|
||||
("morty@example.com", "animated-characters", True),
|
||||
("jerry@example.com", "animated-characters", True),
|
||||
("anyone@example.com", "animated-characters", True),
|
||||
("anyone@birdy.com", "animated-characters", True),
|
||||
|
||||
("summer@herself.com", "animated-characters", False),
|
||||
|
||||
("rick@example.com", "", False),
|
||||
("", "spaceship", False),
|
||||
("bademailexample.com", "spaceship", False),
|
||||
]
|
||||
for email, group, expected in test_pairs:
|
||||
print(f"{email} in {group} == {expected}")
|
||||
assert crud.is_user_in_group(db_session, email, group) == expected
|
||||
|
||||
|
||||
def test_get_group(test_data, db_session):
|
||||
from db import crud
|
||||
|
||||
assert crud.get_group(db_session, "spaceship") is not None
|
||||
assert crud.get_group(db_session, "interdimensional") is not None
|
||||
assert crud.get_group(db_session, "animated-characters") is not None
|
||||
assert crud.get_group(db_session, "non-existent!@#!%!") is None
|
||||
|
||||
|
||||
def test_create_or_get_user(test_data, db_session):
|
||||
from db import crud
|
||||
|
||||
assert db_session.query(models.User).count() == 3
|
||||
|
||||
# already exists
|
||||
assert (u1 := crud.create_or_get_user(db_session, "rick@example.com")) is not None
|
||||
assert u1.email == "rick@example.com"
|
||||
|
||||
# new user
|
||||
assert (u2 := crud.create_or_get_user(db_session, "beth@example.com")) is not None
|
||||
assert u2.email == "beth@example.com"
|
||||
|
||||
assert db_session.query(models.User).count() == 4
|
||||
|
||||
|
||||
def test_upsert_group(test_data, db_session):
|
||||
from db import crud
|
||||
|
||||
assert db_session.query(models.Group).count() == 4
|
||||
|
||||
repeatable_params = ["desc 1", "orch.yaml", "sheet.yaml", {"read": ["all"]}, ["example.com"]]
|
||||
|
||||
assert (g1 := crud.upsert_group(db_session, "spaceship", *repeatable_params)) is not None
|
||||
assert g1.id == "spaceship"
|
||||
assert g1.description == "desc 1"
|
||||
assert g1.orchestrator == "orch.yaml"
|
||||
assert g1.orchestrator_sheet == "sheet.yaml"
|
||||
assert g1.permissions == {"read": ["all"]}
|
||||
assert g1.domains == ["example.com"]
|
||||
assert len(g1.users) == 2
|
||||
assert [u.email for u in g1.users] == ["rick@example.com", "morty@example.com"]
|
||||
|
||||
assert (g2 := crud.upsert_group(db_session, "interdimensional", *repeatable_params)) is not None
|
||||
assert g2.id == "interdimensional"
|
||||
assert len(g2.users) == 1
|
||||
assert [u.email for u in g2.users] == ["rick@example.com"]
|
||||
|
||||
assert (g3 := crud.upsert_group(db_session, "this-is-a-new-group", *repeatable_params)) is not None
|
||||
assert g3.id == "this-is-a-new-group"
|
||||
assert len(g3.users) == 0
|
||||
|
||||
assert db_session.query(models.Group).count() == 5
|
||||
|
||||
|
||||
def test_upsert_user_groups(db_session):
|
||||
from db import crud
|
||||
|
||||
@patch('db.crud.get_settings', new=lambda: bad_setings)
|
||||
def test_missing_yaml(db_session):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
crud.upsert_user_groups(db_session)
|
||||
|
||||
@patch('db.crud.get_settings', new=lambda: bad_setings)
|
||||
def test_broken_yaml(db_session):
|
||||
with pytest.raises(yaml.YAMLError):
|
||||
crud.upsert_user_groups(db_session)
|
||||
|
||||
bad_setings = Settings(_env_file=".env.test")
|
||||
|
||||
bad_setings.USER_GROUPS_FILENAME = "tests/user-groups.test.missing.yaml"
|
||||
test_missing_yaml(db_session)
|
||||
|
||||
bad_setings.USER_GROUPS_FILENAME = "tests/user-groups.test.broken.yaml"
|
||||
test_broken_yaml(db_session)
|
||||
|
||||
|
||||
def test_create_sheet(db_session):
|
||||
from db import crud
|
||||
|
||||
assert db_session.query(models.Sheet).count() == 0
|
||||
|
||||
s = crud.create_sheet(db_session, "sheet-id-123", "sheet name", "email@example.com", "group-id", "hourly")
|
||||
assert s is not None
|
||||
assert s.id == "sheet-id-123"
|
||||
assert s.name == "sheet name"
|
||||
assert s.author_id == "email@example.com"
|
||||
assert s.group_id == "group-id"
|
||||
assert s.frequency == "hourly"
|
||||
|
||||
assert db_session.query(models.Sheet).count() == 1
|
||||
|
||||
# duplicate id
|
||||
import sqlalchemy
|
||||
with pytest.raises(sqlalchemy.exc.IntegrityError):
|
||||
crud.create_sheet(db_session, "sheet-id-123", "I thought this was another sheet", "email", "group-id", "hourly")
|
||||
|
||||
|
||||
def test_get_user_sheet(test_data, db_session):
|
||||
from db import crud
|
||||
|
||||
assert crud.get_user_sheet(db_session, "", "sheet-0") is None
|
||||
assert crud.get_user_sheet(db_session, "morty@example.com", "sheet-0") is None
|
||||
|
||||
assert crud.get_user_sheet(db_session, "rick@example.com", "sheet-0") is not None
|
||||
assert crud.get_user_sheet(db_session, "rick@example.com", "sheet-0-2") is not None
|
||||
assert crud.get_user_sheet(db_session, "morty@example.com", "sheet-1") is not None
|
||||
|
||||
|
||||
def test_get_user_sheets(test_data, db_session):
|
||||
from db import crud
|
||||
|
||||
assert len(crud.get_user_sheets(db_session, "")) == 0
|
||||
rick_sheets = crud.get_user_sheets(db_session, "rick@example.com")
|
||||
assert len(rick_sheets) == 2
|
||||
assert [s.id for s in rick_sheets] == ["sheet-0", "sheet-0-2"]
|
||||
assert len(crud.get_user_sheets(db_session, "morty@example.com")) == 1
|
||||
|
||||
def test_delete_sheet(test_data, db_session):
|
||||
from db import crud
|
||||
|
||||
assert crud.delete_sheet(db_session, "sheet-0", "") == False
|
||||
assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") == True
|
||||
assert crud.delete_sheet(db_session, "sheet-0", "rick@example.com") == False
|
||||
|
||||
6
app/tests/db/test_models.py
Normal file
6
app/tests/db/test_models.py
Normal file
@@ -0,0 +1,6 @@
|
||||
def test_generate_uuid():
|
||||
from db.models import generate_uuid
|
||||
|
||||
assert generate_uuid() != generate_uuid()
|
||||
assert len(generate_uuid()) == 36
|
||||
assert generate_uuid().count("-") == 4
|
||||
128
app/tests/endpoints/test_default.py
Normal file
128
app/tests/endpoints/test_default.py
Normal file
@@ -0,0 +1,128 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from fastapi.testclient import TestClient
|
||||
import pytest
|
||||
from app.shared.config import VERSION
|
||||
from tests.db.test_crud import test_data
|
||||
|
||||
|
||||
def test_endpoint_home(client_with_auth):
|
||||
r = client_with_auth.get("/")
|
||||
assert r.status_code == 200
|
||||
j = r.json()
|
||||
assert "version" in j and j["version"] == VERSION
|
||||
assert "breakingChanges" in j
|
||||
assert "groups" not in j
|
||||
|
||||
|
||||
@patch("endpoints.default.bearer_security", new_callable=AsyncMock)
|
||||
@patch("endpoints.default.get_user_auth", new_callable=AsyncMock, return_value="test@example.com")
|
||||
@patch("endpoints.default.crud.get_user_groups", return_value=["group1", "group2"])
|
||||
def test_endpoint_home_with_groups(m1, m2, m3, client_with_auth):
|
||||
r = client_with_auth.get("/")
|
||||
assert r.status_code == 200
|
||||
j = r.json()
|
||||
assert "version" in j and j["version"] == VERSION
|
||||
assert "breakingChanges" in j
|
||||
assert "groups" in j
|
||||
assert j["groups"] == ["group1", "group2"]
|
||||
|
||||
|
||||
@patch("endpoints.default.bearer_security", new_callable=AsyncMock)
|
||||
@patch("endpoints.default.get_user_auth", new_callable=AsyncMock, return_value="test@example.com")
|
||||
@patch("endpoints.default.crud.get_user_groups", side_effect=Exception('mocked error'))
|
||||
def test_endpoint_home_with_groups_exception(m1, m2, m3, client_with_auth): # mocks call that triggers an internal error
|
||||
r = client_with_auth.get("/")
|
||||
assert r.status_code == 200
|
||||
j = r.json()
|
||||
assert "version" in j and j["version"] == VERSION
|
||||
assert "breakingChanges" in j
|
||||
assert "groups" not in j
|
||||
|
||||
|
||||
def test_endpoint_health(client_with_auth):
|
||||
r = client_with_auth.get("/health")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"status": "ok"}
|
||||
|
||||
|
||||
def test_endpoint_active_no_auth(client, test_no_auth):
|
||||
test_no_auth(client.get, "/user/active")
|
||||
|
||||
|
||||
def test_endpoint_active(app):
|
||||
m_user_state = MagicMock()
|
||||
|
||||
from web.security import get_user_state
|
||||
app.dependency_overrides[get_user_state] = lambda: m_user_state
|
||||
|
||||
# inactive user
|
||||
m_user_state.active = False
|
||||
client = TestClient(app)
|
||||
r = client.get("/user/active")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"active": False}
|
||||
|
||||
# active user
|
||||
m_user_state.active = True
|
||||
client = TestClient(app)
|
||||
r = client.get("/user/active")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"active": True}
|
||||
|
||||
|
||||
|
||||
def test_no_serve_local_archive_by_default(client_with_auth):
|
||||
r = client_with_auth.get("/app/local_archive_test/temp.txt")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_favicon(client_with_auth):
|
||||
r = client_with_auth.get("/favicon.ico")
|
||||
assert r.status_code == 200
|
||||
assert r.headers["content-type"] == "image/vnd.microsoft.icon"
|
||||
|
||||
|
||||
def test_endpoint_test_prometheus_no_auth(client, test_no_auth):
|
||||
test_no_auth(client.get, "/metrics")
|
||||
|
||||
|
||||
def test_endpoint_test_prometheus_no_user_auth(client_with_auth, test_no_auth):
|
||||
test_no_auth(client_with_auth.get, "/metrics")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prometheus_metrics(test_data, client_with_token, get_settings):
|
||||
# before metrics calculation
|
||||
r = client_with_token.get("/metrics")
|
||||
assert r.status_code == 200
|
||||
assert r.headers["content-type"] == "text/plain; version=0.0.4; charset=utf-8"
|
||||
assert "disk_utilization" in r.text
|
||||
assert "database_metrics" in r.text
|
||||
assert "exceptions" in r.text
|
||||
assert "worker_exceptions_total" in r.text
|
||||
assert 'disk_utilization{type="used"}' not in r.text
|
||||
|
||||
# after metrics calculation
|
||||
from web.utils.metrics import measure_regular_metrics
|
||||
await measure_regular_metrics(get_settings.DATABASE_PATH, 60 * 60 * 24 * 31 * 12 * 100)
|
||||
r2 = client_with_token.get("/metrics")
|
||||
assert 'disk_utilization{type="used"}' in r2.text
|
||||
assert 'disk_utilization{type="free"}' in r2.text
|
||||
assert 'disk_utilization{type="database"}' in r2.text
|
||||
assert 'database_metrics{query="count_archives"} 100.0' in r2.text
|
||||
assert 'database_metrics{query="count_archive_urls"} 1000.0' in r2.text
|
||||
assert 'database_metrics{query="count_users"} 3.0' in r2.text
|
||||
assert 'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0' in r2.text
|
||||
assert 'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0' in r2.text
|
||||
assert 'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0' in r2.text
|
||||
|
||||
# 30s window, should not change the gauges nor the total in the counters
|
||||
from web.utils.metrics import measure_regular_metrics
|
||||
await measure_regular_metrics(get_settings.DATABASE_PATH, 30)
|
||||
r3 = client_with_token.get("/metrics")
|
||||
assert 'database_metrics{query="count_archives"} 100.0' in r3.text
|
||||
assert 'database_metrics{query="count_archive_urls"} 1000.0' in r3.text
|
||||
assert 'database_metrics{query="count_users"} 3.0' in r3.text
|
||||
assert 'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0' in r3.text
|
||||
assert 'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0' in r3.text
|
||||
assert 'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0' in r3.text
|
||||
40
app/tests/endpoints/test_interoperability.py
Normal file
40
app/tests/endpoints/test_interoperability.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from datetime import datetime
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.shared.config import ALLOW_ANY_EMAIL
|
||||
from db import crud
|
||||
|
||||
|
||||
def test_submit_manual_archive_unauthenticated(client, test_no_auth):
|
||||
test_no_auth(client.post, "/interop/submit-archive")
|
||||
|
||||
|
||||
def test_submit_manual_archive_not_user_auth(client_with_auth, test_no_auth):
|
||||
test_no_auth(client_with_auth.post, "/interop/submit-archive")
|
||||
|
||||
|
||||
@patch("endpoints.interoperability.get_store_until", return_value=datetime.now())
|
||||
def test_submit_manual_archive(m1, client_with_token, db_session):
|
||||
# normal workflow
|
||||
aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.s3.com"]}]})
|
||||
r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": True, "author_id": "jerry@gmail.com", "group_id": "spaceship", "tags": ["test"]})
|
||||
assert r.status_code == 201
|
||||
assert "id" in r.json()
|
||||
|
||||
inserted = crud.get_archive(db_session, r.json()["id"], ALLOW_ANY_EMAIL)
|
||||
assert inserted.url == "http://example.com"
|
||||
assert inserted.group_id == "spaceship"
|
||||
assert inserted.author_id == "jerry@gmail.com"
|
||||
assert sorted([t.id for t in inserted.tags]) == sorted(["test", "manual"])
|
||||
assert inserted.public
|
||||
assert type(inserted.result) == dict
|
||||
assert [u.url for u in inserted.urls] == ["http://example.s3.com"]
|
||||
assert type(inserted.store_until) == datetime
|
||||
|
||||
|
||||
# cannot have the same URL twice
|
||||
aa_metadata = json.dumps({"status": "test: success", "metadata": {"url": "http://example.com"}, "media": [{"filename": "fn1", "urls": ["http://example.com", "http://example.com"]}]})
|
||||
r = client_with_token.post("/interop/submit-archive", json={"result": aa_metadata, "public": False, "author_id": "jerry@gmail.com", "tags": ["test"]})
|
||||
assert r.status_code == 422
|
||||
assert r.json() == {"detail": "Cannot insert into DB due to integrity error, likely duplicate urls."}
|
||||
193
app/tests/endpoints/test_sheet.py
Normal file
193
app/tests/endpoints/test_sheet.py
Normal file
@@ -0,0 +1,193 @@
|
||||
from datetime import datetime
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.shared.schemas import TaskResult
|
||||
|
||||
|
||||
def test_endpoints_no_auth(client, test_no_auth):
|
||||
test_no_auth(client.post, "/sheet/create")
|
||||
test_no_auth(client.get, "/sheet/mine")
|
||||
test_no_auth(client.delete, "/sheet/123-sheet-id")
|
||||
test_no_auth(client.post, "/sheet/123-sheet-id/archive")
|
||||
|
||||
|
||||
def test_create_sheet_endpoint(app_with_auth, db_session):
|
||||
client_with_auth = TestClient(app_with_auth)
|
||||
good_data = {
|
||||
"id": "123-sheet-id",
|
||||
"name": "Test Sheet",
|
||||
"group_id": "spaceship",
|
||||
"frequency": "daily"
|
||||
}
|
||||
|
||||
# with good data
|
||||
response = client_with_auth.post("/sheet/create", json=good_data)
|
||||
assert response.status_code == 201
|
||||
j = response.json()
|
||||
assert datetime.fromisoformat(j.pop("created_at"))
|
||||
assert datetime.fromisoformat(j.pop("last_url_archived_at"))
|
||||
assert j.pop("author_id") == 'morty@example.com'
|
||||
assert j == good_data
|
||||
|
||||
# already exists
|
||||
response = client_with_auth.post("/sheet/create", json=good_data)
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {"detail": "Sheet with this ID is already being archived."}
|
||||
|
||||
# bad group
|
||||
bad_data = good_data.copy()
|
||||
bad_data["group_id"] = "not a group"
|
||||
response = client_with_auth.post("/sheet/create", json=bad_data)
|
||||
assert response.status_code == 403
|
||||
assert response.json() == {"detail": "User does not have access to this group."}
|
||||
|
||||
# switch to jerry who's got less quota/permissions
|
||||
from web.security import get_user_state
|
||||
from db.user_state import UserState
|
||||
app_with_auth.dependency_overrides[get_user_state] = lambda: UserState(db_session, "jerry@example.com")
|
||||
client_jerry = TestClient(app_with_auth)
|
||||
|
||||
# frequency not allowed
|
||||
jerry_data = good_data.copy()
|
||||
jerry_data["group_id"] = "animated-characters"
|
||||
jerry_data["frequency"] = "hourly"
|
||||
jerry_data["id"] = "jerry-sheet-id"
|
||||
response = client_jerry.post("/sheet/create", json=jerry_data)
|
||||
assert response.status_code == 422
|
||||
assert response.json() == {"detail": "Invalid frequency selected for this group."}
|
||||
|
||||
jerry_data["frequency"] = "daily"
|
||||
# success for the first sheet, bad quota on second
|
||||
response = client_jerry.post("/sheet/create", json=jerry_data)
|
||||
assert response.status_code == 201
|
||||
|
||||
response = client_jerry.post("/sheet/create", json=jerry_data)
|
||||
assert response.status_code == 429
|
||||
assert response.json() == {"detail": "User has reached their sheet quota for this group."}
|
||||
|
||||
|
||||
def test_get_user_sheets_endpoint(client_with_auth, db_session):
|
||||
# no data
|
||||
response = client_with_auth.get("/sheet/mine")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
# with data
|
||||
from db import models
|
||||
db_session.add(
|
||||
models.Sheet(id="123", name="Test Sheet 1", author_id="morty@example.com", group_id="spaceship", frequency="hourly")
|
||||
)
|
||||
db_session.commit()
|
||||
db_session.add_all([
|
||||
models.Sheet(id="456", name="Test Sheet 2", author_id="morty@example.com", group_id="interdimensional", frequency="daily"),
|
||||
models.Sheet(id="789", name="Test Sheet 3", author_id="rick@example.com", group_id="interdimensional", frequency="hourly"),
|
||||
])
|
||||
db_session.commit()
|
||||
|
||||
response = client_with_auth.get("/sheet/mine")
|
||||
assert response.status_code == 200
|
||||
r = response.json()
|
||||
assert isinstance(r, list)
|
||||
assert len(r) == 2
|
||||
assert datetime.fromisoformat(r[0].pop("created_at"))
|
||||
assert datetime.fromisoformat(r[0].pop("last_url_archived_at"))
|
||||
assert datetime.fromisoformat(r[1].pop("created_at"))
|
||||
assert datetime.fromisoformat(r[1].pop("last_url_archived_at"))
|
||||
assert r[0] == {
|
||||
'id': '123',
|
||||
'author_id': 'morty@example.com',
|
||||
'frequency': 'hourly',
|
||||
'group_id': 'spaceship',
|
||||
'name': 'Test Sheet 1',
|
||||
}
|
||||
assert r[1] == {
|
||||
'id': '456',
|
||||
'author_id': 'morty@example.com',
|
||||
'frequency': 'daily',
|
||||
'group_id': 'interdimensional',
|
||||
'name': 'Test Sheet 2',
|
||||
}
|
||||
|
||||
|
||||
def test_delete_sheet_endpoint(client_with_auth, db_session):
|
||||
# missing sheet
|
||||
response = client_with_auth.delete("/sheet/123-sheet-id")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"id": "123-sheet-id",
|
||||
"deleted": False
|
||||
}
|
||||
|
||||
# add sheets for deletion
|
||||
from db import models
|
||||
db_session.add_all([
|
||||
models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="interdimensional", frequency="daily"),
|
||||
models.Sheet(id="456-sheet-id", name="Test Sheet 2", author_id="rick@example.com", group_id="spaceship", frequency="hourly"),
|
||||
])
|
||||
db_session.commit()
|
||||
|
||||
# morty can delete his
|
||||
response = client_with_auth.delete("/sheet/123-sheet-id")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"id": "123-sheet-id", "deleted": True}
|
||||
# but only once
|
||||
response = client_with_auth.delete("/sheet/123-sheet-id")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"id": "123-sheet-id", "deleted": False}
|
||||
# and not rick's
|
||||
response = client_with_auth.delete("/sheet/456-sheet-id")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"id": "456-sheet-id", "deleted": False}
|
||||
|
||||
|
||||
class TestArchiveUserSheetEndpoint:
|
||||
@patch("endpoints.sheet.celery", return_value=MagicMock())
|
||||
def test_normal_flow(self, m_celery, client_with_auth, db_session):
|
||||
from db import models
|
||||
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="spaceship", frequency="hourly"))
|
||||
db_session.commit()
|
||||
|
||||
m_signature = MagicMock()
|
||||
m_signature.delay.return_value = TaskResult(id="123-taskid", status="PENDING", result="")
|
||||
m_celery.signature.return_value = m_signature
|
||||
|
||||
r = client_with_auth.post("/sheet/123-sheet-id/archive")
|
||||
assert r.status_code == 201
|
||||
assert r.json() == {"id": "123-taskid"}
|
||||
m_celery.signature.assert_called_once()
|
||||
m_signature.delay.assert_called_once()
|
||||
|
||||
def test_token_auth(self, client_with_token, test_no_auth):
|
||||
test_no_auth(client_with_token.post, "/sheet/123-sheet-id/archive")
|
||||
|
||||
def test_missing_data(self, client_with_auth):
|
||||
r = client_with_auth.post("/sheet/123-sheet-id/archive")
|
||||
assert r.status_code == 403
|
||||
assert r.json() == {"detail": "No access to this sheet."}
|
||||
|
||||
def test_no_access(self, client_with_auth, db_session):
|
||||
from db import models
|
||||
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="rick@example.com", group_id="spaceship", frequency="hourly"))
|
||||
db_session.commit()
|
||||
r = client_with_auth.post("/sheet/123-sheet-id/archive")
|
||||
assert r.status_code == 403
|
||||
assert r.json() == {"detail": "No access to this sheet."}
|
||||
|
||||
def test_user_not_in_group(self, client_with_auth, db_session):
|
||||
from db import models
|
||||
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="interdimensional", frequency="hourly"))
|
||||
db_session.commit()
|
||||
r = client_with_auth.post("/sheet/123-sheet-id/archive")
|
||||
assert r.status_code == 403
|
||||
assert r.json() == {"detail": "User does not have access to this group."}
|
||||
|
||||
def test_user_cannot_manually_trigger(self, client_with_auth, db_session):
|
||||
from db import models
|
||||
db_session.add(models.Sheet(id="123-sheet-id", name="Test Sheet 1", author_id="morty@example.com", group_id="default", frequency="hourly"))
|
||||
db_session.commit()
|
||||
r = client_with_auth.post("/sheet/123-sheet-id/archive")
|
||||
assert r.status_code == 429
|
||||
assert r.json() == {"detail": "User cannot manually trigger sheet archiving in this group."}
|
||||
51
app/tests/endpoints/test_task.py
Normal file
51
app/tests/endpoints/test_task.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
def test_endpoint_task_status_no_auth(client, test_no_auth):
|
||||
test_no_auth(client.get, "/task/test-task-id")
|
||||
|
||||
|
||||
@patch("endpoints.task.AsyncResult")
|
||||
def test_get_status_success(mock_async_result, client_with_auth):
|
||||
mock_async_result.return_value.status = "SUCCESS"
|
||||
mock_async_result.return_value.result = {"data": "some result"}
|
||||
|
||||
response = client_with_auth.get("/task/test-task-id")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"id": "test-task-id",
|
||||
"status": "SUCCESS",
|
||||
"result": {"data": "some result"}
|
||||
}
|
||||
|
||||
|
||||
@patch("endpoints.task.AsyncResult")
|
||||
def test_get_status_failure(mock_async_result, client_with_auth):
|
||||
|
||||
mock_async_result.return_value.status = "FAILURE"
|
||||
mock_async_result.return_value.result = Exception("Some error")
|
||||
|
||||
response = client_with_auth.get("/task/test-task-id")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"id": "test-task-id",
|
||||
"status": "FAILURE",
|
||||
"result": {"error": "Some error"}
|
||||
}
|
||||
|
||||
|
||||
@patch("endpoints.task.AsyncResult")
|
||||
def test_get_status_pending(mock_async_result, client_with_auth):
|
||||
mock_async_result.return_value.status = "PENDING"
|
||||
mock_async_result.return_value.result = None
|
||||
|
||||
response = client_with_auth.get("/task/test-task-id")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"id": "test-task-id",
|
||||
"status": "PENDING",
|
||||
"result": None
|
||||
}
|
||||
192
app/tests/endpoints/test_url.py
Normal file
192
app/tests/endpoints/test_url.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from app.shared.schemas import ArchiveCreate, TaskResult
|
||||
|
||||
|
||||
def test_archive_url_unauthenticated(client, test_no_auth):
|
||||
test_no_auth(client.post, "/url/archive")
|
||||
|
||||
|
||||
@patch("endpoints.url.UserState")
|
||||
@patch("endpoints.url.celery", return_value=MagicMock())
|
||||
def test_archive_url(m_celery, m2, client_with_auth):
|
||||
m_signature = MagicMock()
|
||||
m_signature.delay.return_value = TaskResult(id="123-456-789", status="PENDING", result="")
|
||||
m_celery.signature.return_value = m_signature
|
||||
|
||||
m_user_state = MagicMock()
|
||||
m2.return_value = m_user_state
|
||||
|
||||
# url is too short
|
||||
response = client_with_auth.post("/url/archive", json={"url": "bad"})
|
||||
assert response.status_code == 422
|
||||
assert response.json()["detail"][0]["msg"] == 'String should have at least 5 characters'
|
||||
m_celery.signature.assert_not_called()
|
||||
|
||||
# url is invalid
|
||||
response = client_with_auth.post("/url/archive", json={"url": "example.com"})
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "Invalid URL received."
|
||||
|
||||
# valid request
|
||||
m_user_state.has_quota_max_monthly_urls.return_value = True
|
||||
m_user_state.has_quota_max_monthly_mbs.return_value = True
|
||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {'id': '123-456-789'}
|
||||
m_celery.signature.assert_called_once()
|
||||
m_signature.delay.assert_called_once()
|
||||
called_val = m_celery.signature.call_args
|
||||
assert called_val[0][0] == "create_archive_task"
|
||||
assert json.loads(called_val[1]['args'][0]) == {"id": None, "url": "https://example.com", "result": None, "public": False, "author_id": "rick@example.com", "group_id": "default", "tags": None, "sheet_id": None, "store_until": None, "urls": None}
|
||||
m_user_state.has_quota_max_monthly_urls.assert_called_once()
|
||||
m_user_state.has_quota_max_monthly_mbs.assert_called_once()
|
||||
m_user_state.in_group.assert_called_once_with("default")
|
||||
|
||||
# user is not in group
|
||||
m_user_state.in_group.return_value = False
|
||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "new-group"})
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "User does not have access to this group."
|
||||
m_user_state.in_group.assert_called_with("new-group")
|
||||
|
||||
# user is in group
|
||||
m_user_state.in_group.return_value = True
|
||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spaceship"})
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {'id': '123-456-789'}
|
||||
assert m_celery.signature.call_count == 2
|
||||
assert m_signature.delay.call_count == 2
|
||||
called_val = m_celery.signature.call_args
|
||||
assert json.loads(called_val[1]['args'][0])["group_id"] == "spaceship"
|
||||
m_user_state.in_group.assert_called_with("spaceship")
|
||||
|
||||
# user is over monthly URL quota
|
||||
m_user_state.has_quota_max_monthly_urls.return_value = False
|
||||
m_user_state.has_quota_max_monthly_mbs.return_value = True
|
||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spaceship"})
|
||||
assert response.status_code == 429
|
||||
assert response.json()["detail"] == "User has reached their monthly URL quota."
|
||||
m_user_state.has_quota_max_monthly_urls.assert_called_with("spaceship")
|
||||
|
||||
# user is over monthly MB quota
|
||||
m_user_state.has_quota_max_monthly_urls.return_value = True
|
||||
m_user_state.has_quota_max_monthly_mbs.return_value = False
|
||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com", "group_id": "spacesuit"})
|
||||
assert response.status_code == 429
|
||||
assert response.json()["detail"] == "User has reached their monthly MB quota."
|
||||
m_user_state.has_quota_max_monthly_mbs.assert_called_with("spacesuit")
|
||||
assert m_celery.signature.call_count == 2
|
||||
assert m_signature.delay.call_count == 2
|
||||
|
||||
|
||||
@patch("endpoints.url.UserState")
|
||||
def test_archive_url_quotas(m1, client_with_auth):
|
||||
m_user_state = MagicMock()
|
||||
m1.return_value = m_user_state
|
||||
|
||||
# misses on monthly URLs quota
|
||||
m_user_state.has_quota_max_monthly_urls.return_value = False
|
||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
|
||||
assert response.status_code == 429
|
||||
assert response.json()["detail"] == "User has reached their monthly URL quota."
|
||||
m_user_state.has_quota_max_monthly_urls.assert_called_once()
|
||||
|
||||
# misses on monthly MBs quota
|
||||
m_user_state.has_quota_max_monthly_urls.return_value = True
|
||||
m_user_state.has_quota_max_monthly_mbs.return_value = False
|
||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
|
||||
assert response.status_code == 429
|
||||
assert response.json()["detail"] == "User has reached their monthly MB quota."
|
||||
m_user_state.has_quota_max_monthly_mbs.assert_called_once()
|
||||
|
||||
|
||||
@patch("endpoints.url.celery", return_value=MagicMock())
|
||||
def test_archive_url_with_api_token(m_celery, client_with_token):
|
||||
m_signature = MagicMock()
|
||||
m_signature.delay.return_value = TaskResult(id="123-456-789", status="PENDING", result="")
|
||||
m_celery.signature.return_value = m_signature
|
||||
response = client_with_token.post("/url/archive", json={"url": "https://example.com"})
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {'id': '123-456-789'}
|
||||
m_celery.signature.assert_called_once()
|
||||
m_signature.delay.assert_called_once()
|
||||
called_val = m_celery.signature.call_args
|
||||
assert called_val[0][0] == "create_archive_task"
|
||||
|
||||
|
||||
def test_search_by_url_unauthenticated(client, test_no_auth):
|
||||
test_no_auth(client.get, "/url/search")
|
||||
|
||||
|
||||
def test_search_by_url(client_with_auth, client_with_token, db_session):
|
||||
# tests the search endpoint, including through some db data for the endpoint params
|
||||
response = client_with_auth.get("/url/search")
|
||||
assert response.status_code == 422
|
||||
assert response.json()["detail"][0]["msg"] == "Field required"
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
from db import crud, schemas
|
||||
for i in range(11):
|
||||
crud.create_task(db_session, ArchiveCreate(id=f"url-456-{i}", url="https://example.com" if i < 10 else "https://something-else.com", result={}, public=True, author_id="rick@example.com"), [], [])
|
||||
# NB: this insertion is too fast for the ordering to be correct as they are within the same second
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com")
|
||||
assert response.status_code == 200
|
||||
assert len(j := response.json()) == 10
|
||||
assert "url-456-0" in [i["id"] for i in j]
|
||||
assert "url-456-9" in [i["id"] for i in j]
|
||||
assert "url-456-10" not in [i["id"] for i in j]
|
||||
assert j[0].keys() == schemas.ArchiveResult.model_fields.keys()
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com&limit=5")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 5
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com&skip=5&limit=2")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 2
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com&archived_before=2010-01-01")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 0
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com&archived_after=2010-01-01")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 10
|
||||
|
||||
# API token will also work
|
||||
response = client_with_token.get("/url/search?url=https://example.com&archived_after=2010-01-01")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 10
|
||||
|
||||
|
||||
@patch("endpoints.url.UserState")
|
||||
def test_search_no_read_access(mock_user_state, client_with_auth):
|
||||
mock_user_state.return_value.read = False
|
||||
mock_user_state.return_value.read_public = False
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com")
|
||||
assert response.status_code == 403
|
||||
assert response.json() == {"detail": "User does not have read access."}
|
||||
|
||||
|
||||
def test_delete_task_unauthenticated(client, test_no_auth):
|
||||
test_no_auth(client.delete, "/url/123-456-789")
|
||||
|
||||
|
||||
def test_delete_task(client_with_auth, db_session):
|
||||
response = client_with_auth.delete("/url/delete-123-456-789")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"id": "delete-123-456-789", "deleted": False}
|
||||
|
||||
from db import crud
|
||||
crud.create_task(db_session, ArchiveCreate(id="delete-123-456-789", url="https://example.com", result={}, public=True, author_id="morty@example.com"), [], [])
|
||||
|
||||
response = client_with_auth.delete("/url/delete-123-456-789")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"id": "delete-123-456-789", "deleted": True}
|
||||
24
app/tests/orchestration.test.yaml
Normal file
24
app/tests/orchestration.test.yaml
Normal file
@@ -0,0 +1,24 @@
|
||||
steps:
|
||||
feeder: cli_feeder
|
||||
archivers: # order matters
|
||||
- youtubedl_archiver
|
||||
enrichers:
|
||||
- hash_enricher
|
||||
|
||||
formatter: html_formatter # defaults to mute_formatter
|
||||
storages:
|
||||
- local_storage
|
||||
databases:
|
||||
- console_db
|
||||
|
||||
configurations:
|
||||
cli_feeder:
|
||||
urls:
|
||||
- "url1"
|
||||
hash_enricher:
|
||||
algorithm: "SHA-256"
|
||||
local_storage:
|
||||
save_to: "./local_archive"
|
||||
save_absolute: true
|
||||
filename_generator: static
|
||||
path_generator: flat
|
||||
6
app/tests/user-groups.test.broken.yaml
Normal file
6
app/tests/user-groups.test.broken.yaml
Normal file
@@ -0,0 +1,6 @@
|
||||
broken: True
|
||||
This is just an invalid yaml for testing
|
||||
|
||||
still broken: True
|
||||
- one
|
||||
- two
|
||||
87
app/tests/user-groups.test.yaml
Normal file
87
app/tests/user-groups.test.yaml
Normal file
@@ -0,0 +1,87 @@
|
||||
# NOTE: all emails should be lower-cased
|
||||
users:
|
||||
rick@example.com:
|
||||
- spaceship
|
||||
- interdimensional
|
||||
morty@example.com:
|
||||
- spaceship
|
||||
jerry@example.com:
|
||||
- the-jerrys-club
|
||||
# summer@herself.com:
|
||||
# badyemail.com:
|
||||
|
||||
domains:
|
||||
example.com:
|
||||
- animated-characters
|
||||
birdy.com:
|
||||
- animated-characters
|
||||
- this-does-not-exist
|
||||
|
||||
|
||||
orchestrators:
|
||||
spaceship: tests/orchestration.test.yaml
|
||||
interdimensional: tests/orchestration.test.yaml
|
||||
default: tests/orchestration.test.yaml
|
||||
|
||||
default_orchestrator: tests/orchestration.test.yaml
|
||||
|
||||
groups:
|
||||
spaceship:
|
||||
description: "The spaceship crew"
|
||||
orchestrator: tests/orchestration.test.yaml
|
||||
orchestrator_sheet: tests/orchestration.test.yaml
|
||||
permissions:
|
||||
read: ["all"]
|
||||
archive_url: true
|
||||
archive_sheet: true
|
||||
manually_trigger_sheet: true
|
||||
sheet_frequency: ["hourly", "daily"]
|
||||
max_sheets: -1
|
||||
max_archive_lifespan_months: -1
|
||||
max_monthly_urls: -1
|
||||
max_monthly_mbs: -1
|
||||
priority: "high"
|
||||
interdimensional:
|
||||
description: "Interdimensional travelers"
|
||||
orchestrator: tests/orchestration.test.yaml
|
||||
orchestrator_sheet: tests/orchestration.test.yaml
|
||||
permissions:
|
||||
read: ["interdimensional", "animated-characters"]
|
||||
archive_url: true
|
||||
archive_sheet: true
|
||||
manually_trigger_sheet: true
|
||||
sheet_frequency: ["hourly", "daily"]
|
||||
max_sheets: 5
|
||||
max_archive_lifespan_months: 12
|
||||
max_monthly_urls: 1000
|
||||
max_monthly_mbs: 1000
|
||||
priority: "high"
|
||||
animated-characters:
|
||||
description: "Animated characters"
|
||||
orchestrator: tests/orchestration.test.yaml
|
||||
orchestrator_sheet: tests/orchestration.test.yaml
|
||||
permissions:
|
||||
read: ["animated-characters"]
|
||||
archive_url: true
|
||||
archive_sheet: true
|
||||
sheet_frequency: ["daily"]
|
||||
max_sheets: 1
|
||||
max_archive_lifespan_months: 12
|
||||
max_monthly_urls: 2
|
||||
max_monthly_mbs: 10
|
||||
priority: "low"
|
||||
default:
|
||||
description: "Public access"
|
||||
orchestrator: tests/orchestration.test.yaml
|
||||
orchestrator_sheet: tests/orchestration.test.yaml
|
||||
permissions:
|
||||
# read: []
|
||||
archive_url: true
|
||||
# manually_trigger_sheet: false
|
||||
# archive_sheet: false
|
||||
# sheet_frequency: []
|
||||
# max_sheets: 0
|
||||
# max_archive_lifespan_months: 12
|
||||
max_monthly_urls: 1
|
||||
# max_monthly_mbs: 50
|
||||
priority: "low"
|
||||
49
app/tests/web/test_main.py
Normal file
49
app/tests/web/test_main.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
|
||||
def test_lifespan(app):
|
||||
with TestClient(app) as client:
|
||||
r = client.get("/health")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"status": "ok"}
|
||||
|
||||
def test_alembic(db_session):
|
||||
import alembic.config
|
||||
alembic.config.main(argv=['--raiseerr', 'upgrade', 'head'])
|
||||
alembic.config.main(argv=['--raiseerr', 'downgrade', 'base'])
|
||||
|
||||
@patch("endpoints.default.crud.soft_delete_task", side_effect=Exception('mocked error'))
|
||||
def test_logging_middleware(m1, client_with_auth):
|
||||
from web.utils.metrics import EXCEPTION_COUNTER
|
||||
assert len(EXCEPTION_COUNTER.collect()[0].samples) == 0
|
||||
with pytest.raises(Exception, match="mocked error"):
|
||||
client_with_auth.delete("/url/123")
|
||||
# creates one empty and one from above
|
||||
assert len(EXCEPTION_COUNTER.collect()[0].samples) == 2
|
||||
|
||||
|
||||
def test_serve_local_archive_logic(get_settings):
|
||||
# create a test file first
|
||||
os.makedirs("local_archive_test", exist_ok=True)
|
||||
with open("local_archive_test/temp.txt", "w") as f:
|
||||
f.write("test")
|
||||
|
||||
try:
|
||||
# modify the settings
|
||||
get_settings.SERVE_LOCAL_ARCHIVE = "/app/local_archive_test"
|
||||
from web.main import app_factory
|
||||
app = app_factory(get_settings)
|
||||
|
||||
# test
|
||||
client = TestClient(app)
|
||||
r = client.get("/app/local_archive_test/temp.txt")
|
||||
assert r.status_code == 200
|
||||
assert r.text == "test"
|
||||
finally:
|
||||
# cleanup
|
||||
shutil.rmtree("local_archive_test")
|
||||
108
app/tests/web/test_security.py
Normal file
108
app/tests/web/test_security.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
import pytest
|
||||
|
||||
from app.shared.config import ALLOW_ANY_EMAIL
|
||||
|
||||
|
||||
def test_secure_compare():
|
||||
from web.security import secure_compare
|
||||
|
||||
assert secure_compare("test", "test")
|
||||
assert not secure_compare("test", "test2")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_token_or_user_auth_with_api():
|
||||
from web.security import get_token_or_user_auth
|
||||
mock_api = HTTPAuthorizationCredentials(scheme="lorem", credentials="this_is_the_test_api_token")
|
||||
assert await get_token_or_user_auth(mock_api) == ALLOW_ANY_EMAIL
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_token_or_user_auth_with_user():
|
||||
from web.security import get_token_or_user_auth
|
||||
bad_user = HTTPAuthorizationCredentials(scheme="ipsum", credentials="invalid")
|
||||
e: pytest.ExceptionInfo = None
|
||||
with pytest.raises(HTTPException) as e:
|
||||
await get_token_or_user_auth(bad_user)
|
||||
assert e.value.status_code == 401
|
||||
assert e.value.detail == "invalid access_token"
|
||||
|
||||
|
||||
@patch("web.security.authenticate_user", return_value=(True, "summer@example.com"))
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_auth(m1):
|
||||
from web.security import get_user_auth
|
||||
good_user = HTTPAuthorizationCredentials(scheme="ipsum", credentials="valid-and-good")
|
||||
assert await get_user_auth(good_user) == "summer@example.com"
|
||||
|
||||
|
||||
@patch("web.security.secure_compare", return_value=False)
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_api_key_auth_exception(m1):
|
||||
from web.security import token_api_key_auth
|
||||
|
||||
e: pytest.ExceptionInfo = None
|
||||
with pytest.raises(HTTPException) as e:
|
||||
await token_api_key_auth(HTTPAuthorizationCredentials(scheme="ipsum", credentials="does-not-matter"), auto_error=True)
|
||||
assert e.value.status_code == 401
|
||||
assert e.value.detail == "Wrong auth credentials"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_user():
|
||||
from web.security import authenticate_user
|
||||
|
||||
assert authenticate_user("test") == (False, "invalid access_token")
|
||||
assert authenticate_user(123) == (False, "invalid access_token")
|
||||
|
||||
with patch("web.security.requests.get") as mock_get:
|
||||
# bad response from oauth2
|
||||
mock_get.return_value.status_code = 403
|
||||
assert authenticate_user("this-will-call-requests") == (False, "invalid token")
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
# 200 but invalid json
|
||||
mock_get.return_value.status_code = 200
|
||||
assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID")
|
||||
assert mock_get.call_count == 2
|
||||
|
||||
# 200 but invalid azp and aud
|
||||
mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "not_an_app"}
|
||||
assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID")
|
||||
|
||||
mock_get.return_value.json.return_value = {"email": "summer@example.com", "aud": "not_an_app"}
|
||||
assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID")
|
||||
|
||||
mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "not_an_app", "aud": "not_an_app"}
|
||||
assert authenticate_user("this-will-call-requests") == (False, "token does not belong to valid APP_ID")
|
||||
|
||||
# blocked email
|
||||
mock_get.return_value.json.return_value = {"email": "blocked@example.com", "azp": "test_app_id_1", "aud": "not_an_app"}
|
||||
assert authenticate_user("this-will-call-requests") == (False, "email 'blocked@example.com' not allowed")
|
||||
|
||||
# not verified
|
||||
mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "not_an_app", "aud": "test_app_id_1"}
|
||||
assert authenticate_user("this-will-call-requests") == (False, "email 'summer@example.com' not verified")
|
||||
|
||||
# token expired
|
||||
mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "test_app_id_2", "email_verified": "true"}
|
||||
assert authenticate_user("this-will-call-requests") == (False, "Token expired")
|
||||
|
||||
# 200 and valid azp and aup and verified
|
||||
mock_get.return_value.json.return_value = {"email": "summer@example.com", "azp": "test_app_id_2", "email_verified": "true", "expires_in": 100}
|
||||
assert authenticate_user("this-will-call-requests") == (True, "summer@example.com")
|
||||
assert mock_get.call_count == 9
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_user_exception():
|
||||
from web.security import authenticate_user
|
||||
|
||||
with patch("web.security.requests.get") as mock_get:
|
||||
mock_get.return_value.status_code = 200
|
||||
mock_get.return_value.json.side_effect = Exception("mocked error")
|
||||
assert authenticate_user("this-will-call-requests") == (False, "exception occurred")
|
||||
138
app/tests/worker/test_worker_main.py
Normal file
138
app/tests/worker/test_worker_main.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from datetime import datetime
|
||||
from unittest import mock
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from db import models, schemas
|
||||
from auto_archiver import Metadata
|
||||
from auto_archiver.core import Media
|
||||
|
||||
|
||||
|
||||
class Test_create_archive_task():
|
||||
URL = "https://example-live.com"
|
||||
archive = schemas.ArchiveCreate(url=URL, tags=["tag-celery"], public=True, author_id="rick@example.com", group_id="interstellar")
|
||||
|
||||
@patch("worker.main.insert_result_into_db")
|
||||
@patch("worker.main.get_store_until", return_value=datetime.now())
|
||||
@patch("worker.main.load_orchestrator")
|
||||
@patch("celery.app.task.Task.request")
|
||||
def test_success(self, m_req, m_load, m_store, m_insert, db_session):
|
||||
from worker.main import create_archive_task
|
||||
|
||||
m_req.id = "this-just-in"
|
||||
mock_orchestrator = self.mock_orchestrator_choice(m_load)
|
||||
|
||||
task = create_archive_task(self.archive.model_dump_json())
|
||||
|
||||
m_load.assert_called_once_with("interstellar")
|
||||
m_store.assert_called_once_with("interstellar")
|
||||
m_insert.assert_called_once()
|
||||
mock_orchestrator.feed_item.assert_called_once()
|
||||
|
||||
assert task["status"] == "success"
|
||||
assert task["metadata"]["url"] == self.URL
|
||||
assert len(task["media"]) == 0
|
||||
|
||||
def test_raise_invalid(self):
|
||||
from worker.main import create_archive_task
|
||||
with pytest.raises(Exception):
|
||||
create_archive_task(self.archive.model_dump_json())
|
||||
|
||||
@patch("worker.main.insert_result_into_db", side_effect=Exception)
|
||||
@patch("worker.main.load_orchestrator")
|
||||
def test_raise_db_error(self, m_load, m_insert):
|
||||
from worker.main import create_archive_task
|
||||
mock_orchestrator = self.mock_orchestrator_choice(m_load)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
create_archive_task(self.archive.model_dump_json())
|
||||
mock_orchestrator.feed_item.assert_called_once()
|
||||
|
||||
|
||||
@patch("worker.main.insert_result_into_db", return_value=None)
|
||||
@patch("worker.main.load_orchestrator")
|
||||
def test_raise_empty_result(self, m_load, m_insert):
|
||||
from worker.main import create_archive_task
|
||||
mock_orchestrator = self.mock_orchestrator_choice(m_load)
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
create_archive_task(self.archive.model_dump_json())
|
||||
assert "UNABLE TO archive" in str(e)
|
||||
mock_orchestrator.feed_item.assert_called_once()
|
||||
|
||||
def mock_orchestrator_choice(self, m_load):
|
||||
mock_orchestrator = mock.MagicMock()
|
||||
mock_orchestrator.configure_mock(feed_item=mock.MagicMock(return_value=Metadata().set_url(self.URL).success()))
|
||||
m_load.return_value = mock_orchestrator
|
||||
return mock_orchestrator
|
||||
|
||||
|
||||
class Test_create_sheet_task():
|
||||
URL = "https://example-live.com"
|
||||
sheet = schemas.SubmitSheet(sheet_id="123", author_id="rick@example.com", group_id="interstellar", tags=["spaceship"])
|
||||
|
||||
@patch("worker.main.models.generate_uuid", return_value="constant-uuid")
|
||||
@patch("worker.main.get_store_until", return_value=datetime.now())
|
||||
@patch("worker.main.load_orchestrator")
|
||||
def test_success(self, m_load, m_store, m_uuid, db_session):
|
||||
from worker.main import create_sheet_task
|
||||
|
||||
assert db_session.query(models.Archive).filter(models.Archive.url == self.URL).count() == 0
|
||||
|
||||
mock_metadata = Metadata().set_url(self.URL).success()
|
||||
mock_metadata.add_media(Media("fn1.txt", urls=["outcome1.com"]))
|
||||
m_orch = MagicMock()
|
||||
m_orch.feed.return_value = iter([False, mock_metadata, mock_metadata])
|
||||
m_load.return_value = m_orch
|
||||
|
||||
res = create_sheet_task(self.sheet.model_dump_json())
|
||||
|
||||
m_load.assert_called_once_with("interstellar", True, {'configurations': {'gsheet_feeder': {'sheet_id': '123'}}})
|
||||
m_orch.feed.assert_called_once()
|
||||
m_store.assert_called_with("interstellar")
|
||||
m_store.call_count == 2
|
||||
m_uuid.call_count == 2
|
||||
assert type(res) == dict
|
||||
assert res["stats"]["archived"] == 1
|
||||
assert res["stats"]["failed"] == 1
|
||||
assert len(res["stats"]["errors"]) == 1
|
||||
assert res["sheet_id"] == "123"
|
||||
assert res["success"]
|
||||
assert type(res["time"]) == datetime
|
||||
|
||||
# query created archive entry
|
||||
inserted = db_session.query(models.Archive).filter(models.Archive.url == self.URL).one()
|
||||
assert inserted is not None
|
||||
assert inserted.url == self.URL
|
||||
assert len(inserted.tags) == 1
|
||||
assert inserted.tags[0].id == "spaceship"
|
||||
assert inserted.group_id == "interstellar"
|
||||
assert inserted.author_id == "rick@example.com"
|
||||
assert inserted.public == False
|
||||
|
||||
|
||||
def test_get_all_urls(db_session):
|
||||
from worker.main import get_all_urls
|
||||
from auto_archiver import Metadata
|
||||
|
||||
meta = Metadata().set_url("https://example.com")
|
||||
m1 = meta.add_media(Media("fn1.txt", urls=["outcome1.com"]))
|
||||
m2 = meta.add_media(Media("fn2.txt", urls=["outcome2.com"]))
|
||||
m3 = meta.add_media(Media("fn3.txt", urls=["outcome3.com"]))
|
||||
m1.set("screenshot", Media("screenshot.png", urls=["screenshot.com"]))
|
||||
m2.set("thumbnails", [Media("thumb1.png", urls=["thumb1.com"]), Media("thumb2.png", urls=["thumb2.com"])])
|
||||
m3.set("ssl_data", Media("ssl_data.txt", urls=["ssl_data.com"]).to_dict())
|
||||
m3.set("bad_data", {"bad": "dict is ignored"})
|
||||
|
||||
urls = [u.url for u in get_all_urls(meta)]
|
||||
assert len(urls) == 7
|
||||
assert "outcome1.com" in urls
|
||||
assert "outcome2.com" in urls
|
||||
assert "outcome3.com" in urls
|
||||
assert "screenshot.com" in urls
|
||||
assert "thumb1.com" in urls
|
||||
assert "thumb2.com" in urls
|
||||
assert "ssl_data.com" in urls
|
||||
3
app/web/__init__.py
Normal file
3
app/web/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from app.web.main import app_factory
|
||||
|
||||
app = app_factory
|
||||
0
app/web/endpoints/__init__.py
Normal file
0
app/web/endpoints/__init__.py
Normal file
59
app/web/endpoints/default.py
Normal file
59
app/web/endpoints/default.py
Normal file
@@ -0,0 +1,59 @@
|
||||
|
||||
from typing import Dict
|
||||
from fastapi import APIRouter, Depends, Request, HTTPException
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
|
||||
from app.shared.config import VERSION, BREAKING_CHANGES
|
||||
from app.shared.log import log_error
|
||||
from app.shared.db import crud
|
||||
from app.shared.schemas import ActiveUser, UsageResponse
|
||||
from app.shared.db.user_state import UserState
|
||||
from app.web.security import get_user_auth, bearer_security, get_user_state
|
||||
from app.shared.user_groups import GroupInfo
|
||||
|
||||
default_router = APIRouter()
|
||||
|
||||
|
||||
@default_router.get("/")
|
||||
async def home(request: Request):
|
||||
# TODO: maybe split into 2 routes: one non authenticated and one authenticated for the groups info only, necessary only for the extension
|
||||
status = {"version": VERSION, "breakingChanges": BREAKING_CHANGES}
|
||||
try:
|
||||
email = await get_user_auth(await bearer_security(request))
|
||||
status["groups"] = crud.get_user_groups(email)
|
||||
except HTTPException: pass # not authenticated is fine
|
||||
except Exception as e: log_error(e)
|
||||
return JSONResponse(status)
|
||||
|
||||
|
||||
@default_router.get("/health")
|
||||
async def health():
|
||||
return JSONResponse({"status": "ok"})
|
||||
|
||||
|
||||
@default_router.get("/user/active", summary="Check if the user is active and can use the tool.")
|
||||
async def active(
|
||||
user: UserState = Depends(get_user_state),
|
||||
) -> ActiveUser:
|
||||
return {"active": user.active}
|
||||
|
||||
|
||||
@default_router.get("/user/permissions", summary="Get the user's global 'all' permissions and the permissions for each group they belong to.")
|
||||
def get_user_permissions(
|
||||
user: UserState = Depends(get_user_state),
|
||||
) -> Dict[str, GroupInfo]:
|
||||
return user.permissions
|
||||
|
||||
@default_router.get("/user/usage", summary="Get the user's monthly URLs/MBs usage along with the total active sheets, breakdown by group.")
|
||||
def get_user_usage(
|
||||
user: UserState = Depends(get_user_state),
|
||||
) -> UsageResponse:
|
||||
if not user.active:
|
||||
raise HTTPException(status_code=403, detail="User is not active.")
|
||||
return user.usage()
|
||||
|
||||
|
||||
|
||||
@default_router.get('/favicon.ico', include_in_schema=False)
|
||||
async def favicon() -> FileResponse:
|
||||
return FileResponse("web/static/favicon.ico")
|
||||
51
app/web/endpoints/interoperability.py
Normal file
51
app/web/endpoints/interoperability.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import json
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from loguru import logger
|
||||
import sqlalchemy
|
||||
from auto_archiver import Metadata
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.shared.aa_utils import get_all_urls
|
||||
from app.shared.config import ALLOW_ANY_EMAIL
|
||||
from app.shared import business_logic, schemas
|
||||
from app.shared.db import crud
|
||||
from app.shared.db.database import get_db_dependency
|
||||
from app.web.security import token_api_key_auth
|
||||
from app.shared.db import models
|
||||
from app.shared.log import log_error
|
||||
|
||||
|
||||
interoperability_router = APIRouter(prefix="/interop", tags=["Interoperability endpoints."])
|
||||
|
||||
|
||||
# ----- endpoint to submit data archived elsewhere
|
||||
@interoperability_router.post("/submit-archive", status_code=201, summary="Submit a manual archive entry, for data that was archived elsewhere.")
|
||||
def submit_manual_archive(
|
||||
manual: schemas.SubmitManualArchive,
|
||||
auth=Depends(token_api_key_auth),
|
||||
db: Session = Depends(get_db_dependency)
|
||||
):
|
||||
result: Metadata = Metadata.from_json(manual.result)
|
||||
manual.author_id = manual.author_id or ALLOW_ANY_EMAIL
|
||||
manual.tags.add("manual")
|
||||
|
||||
try:
|
||||
archive = schemas.ArchiveCreate(
|
||||
author_id=manual.author_id,
|
||||
url=result.get_url(),
|
||||
public=manual.public,
|
||||
group_id=manual.group_id,
|
||||
tags=manual.tags,
|
||||
id=models.generate_uuid(),
|
||||
result=json.loads(result.to_json()),
|
||||
urls=get_all_urls(result),
|
||||
store_until=business_logic.get_store_archive_until(db, manual.group_id),
|
||||
)
|
||||
|
||||
db_archive = crud.store_archived_url(db, archive)
|
||||
logger.debug(f"[MANUAL ARCHIVE STORED] {db_archive.author_id} {db_archive.url}")
|
||||
return JSONResponse({"id": db_archive.id}, status_code=201)
|
||||
except sqlalchemy.exc.IntegrityError as e:
|
||||
log_error(e)
|
||||
raise HTTPException(status_code=422, detail=f"Cannot insert into DB due to integrity error, likely duplicate urls.")
|
||||
80
app/web/endpoints/sheet.py
Normal file
80
app/web/endpoints/sheet.py
Normal file
@@ -0,0 +1,80 @@
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from sqlalchemy import exc
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.shared.db.user_state import UserState
|
||||
from app.shared import schemas
|
||||
from app.shared.task_messaging import get_celery
|
||||
from app.web.security import get_user_state
|
||||
from app.shared.db import crud
|
||||
from app.shared.db.database import get_db_dependency
|
||||
|
||||
sheet_router = APIRouter(prefix="/sheet", tags=["Google Spreadsheet operations"])
|
||||
|
||||
celery = get_celery()
|
||||
|
||||
@sheet_router.post("/create", status_code=201, summary="Store a new Google Sheet for regular archiving.")
|
||||
def create_sheet(
|
||||
sheet: schemas.SheetAdd,
|
||||
user: UserState = Depends(get_user_state),
|
||||
db: Session = Depends(get_db_dependency),
|
||||
) -> schemas.SheetResponse:
|
||||
|
||||
if not user.in_group(sheet.group_id):
|
||||
raise HTTPException(status_code=403, detail="User does not have access to this group.")
|
||||
|
||||
if not user.has_quota_monthly_sheets(sheet.group_id):
|
||||
raise HTTPException(status_code=429, detail="User has reached their sheet quota for this group.")
|
||||
|
||||
if not user.is_sheet_frequency_allowed(sheet.group_id, sheet.frequency):
|
||||
raise HTTPException(status_code=422, detail="Invalid frequency selected for this group.")
|
||||
|
||||
try:
|
||||
return crud.create_sheet(db, sheet.id, sheet.name, user.email, sheet.group_id, sheet.frequency)
|
||||
except exc.IntegrityError as e:
|
||||
raise HTTPException(status_code=400, detail="Sheet with this ID is already being archived.") from e
|
||||
|
||||
|
||||
@sheet_router.get("/mine", status_code=200, summary="Get the authenticated user's Google Sheets.")
|
||||
def get_user_sheets(
|
||||
user: UserState = Depends(get_user_state),
|
||||
db: Session = Depends(get_db_dependency)
|
||||
) -> list[schemas.SheetResponse]:
|
||||
return crud.get_user_sheets(db, user.email)
|
||||
|
||||
|
||||
@sheet_router.delete("/{id}", summary="Delete a Google Sheet by ID.")
|
||||
def delete_sheet(
|
||||
id: str,
|
||||
user: UserState = Depends(get_user_state),
|
||||
db: Session = Depends(get_db_dependency),
|
||||
) -> schemas.TaskDelete:
|
||||
return JSONResponse({
|
||||
"id": id,
|
||||
"deleted": crud.delete_sheet(db, id, user.email)
|
||||
})
|
||||
|
||||
|
||||
@sheet_router.post("/{id}/archive", status_code=201, summary="Trigger an archiving task for a GSheet you own.", response_description="task_id for the archiving task.")
|
||||
def archive_user_sheet(
|
||||
id: str,
|
||||
user: UserState = Depends(get_user_state),
|
||||
db: Session = Depends(get_db_dependency),
|
||||
) -> schemas.Task:
|
||||
|
||||
sheet = crud.get_user_sheet(db, user.email, sheet_id=id)
|
||||
if not sheet:
|
||||
raise HTTPException(status_code=403, detail="No access to this sheet.")
|
||||
|
||||
if not user.in_group(sheet.group_id):
|
||||
raise HTTPException(status_code=403, detail="User does not have access to this group.")
|
||||
|
||||
if not user.can_manually_trigger(sheet.group_id):
|
||||
raise HTTPException(status_code=429, detail="User cannot manually trigger sheet archiving in this group.")
|
||||
|
||||
task = celery.signature("create_sheet_task", args=[schemas.SubmitSheet(sheet_id=id, author_id=user.email, group_id=sheet.group_id).model_dump_json()]).delay()
|
||||
|
||||
return JSONResponse({"id": task.id}, status_code=201)
|
||||
40
app/web/endpoints/task.py
Normal file
40
app/web/endpoints/task.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from celery.result import AsyncResult
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.shared.task_messaging import get_celery
|
||||
from app.web.security import get_token_or_user_auth
|
||||
from app.shared import schemas
|
||||
from app.shared.log import log_error
|
||||
from app.web.utils.misc import custom_jsonable_encoder
|
||||
|
||||
|
||||
task_router = APIRouter(prefix="/task", tags=["Async task operations"])
|
||||
|
||||
celery = get_celery()
|
||||
|
||||
@task_router.get("/{task_id}", summary="Check the status of an async task by its id, works for URLs and Sheet tasks.")
|
||||
def get_status(task_id, email=Depends(get_token_or_user_auth)) -> schemas.TaskResult:
|
||||
task = AsyncResult(task_id, app=celery)
|
||||
try:
|
||||
if task.status == "FAILURE":
|
||||
# *FAILURE* The task raised an exception, or has exceeded the retry limit.
|
||||
# The :attr:`result` attribute then contains the exception raised by the task.
|
||||
# https://docs.celeryq.dev/en/stable/_modules/celery/result.html#AsyncResult
|
||||
raise task.result
|
||||
|
||||
response = {
|
||||
"id": task_id,
|
||||
"status": task.status,
|
||||
"result": task.result
|
||||
}
|
||||
return JSONResponse(jsonable_encoder(response, exclude_unset=True, custom_encoder={bytes: custom_jsonable_encoder}))
|
||||
|
||||
except Exception as e:
|
||||
log_error(e)
|
||||
return JSONResponse({
|
||||
"id": task_id,
|
||||
"status": "FAILURE",
|
||||
"result": {"error": str(e)}
|
||||
})
|
||||
77
app/web/endpoints/url.py
Normal file
77
app/web/endpoints/url.py
Normal file
@@ -0,0 +1,77 @@
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.shared.config import ALLOW_ANY_EMAIL
|
||||
from app.shared import schemas
|
||||
from app.shared.task_messaging import get_celery
|
||||
from app.web.security import get_token_or_user_auth, get_user_state
|
||||
from app.shared.db import crud
|
||||
from app.shared.db.user_state import UserState
|
||||
from app.shared.db.database import get_db_dependency
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
url_router = APIRouter(prefix="/url", tags=["Single URL operations"])
|
||||
|
||||
celery = get_celery()
|
||||
|
||||
@url_router.post("/archive", status_code=201, summary="Submit a single URL archive request, starts an archiving task.", response_description="task_id for the archiving task, will match the archive id.")
|
||||
def archive_url(
|
||||
archive: schemas.ArchiveTrigger,
|
||||
email=Depends(get_token_or_user_auth),
|
||||
db: Session = Depends(get_db_dependency)
|
||||
) -> schemas.Task:
|
||||
archive.author_id = email
|
||||
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {archive.url}")
|
||||
|
||||
parsed_url = urlparse(archive.url)
|
||||
if not all([parsed_url.scheme, parsed_url.netloc]):
|
||||
raise HTTPException(status_code=400, detail="Invalid URL received.")
|
||||
|
||||
if email != ALLOW_ANY_EMAIL:
|
||||
user = UserState(db, email)
|
||||
if archive.group_id and not user.in_group(archive.group_id):
|
||||
raise HTTPException(status_code=403, detail="User does not have access to this group.")
|
||||
if not user.has_quota_max_monthly_urls(archive.group_id):
|
||||
raise HTTPException(status_code=429, detail="User has reached their monthly URL quota.")
|
||||
if not user.has_quota_max_monthly_mbs(archive.group_id):
|
||||
raise HTTPException(status_code=429, detail="User has reached their monthly MB quota.")
|
||||
|
||||
archive_create = schemas.ArchiveCreate(**archive.model_dump())
|
||||
|
||||
task = celery.signature("create_archive_task", args=[archive_create.model_dump_json()]).delay()
|
||||
task_response = schemas.Task(id=task.id)
|
||||
return JSONResponse(task_response.model_dump(), status_code=201)
|
||||
|
||||
|
||||
@url_router.get("/search", summary="Search for archive entries by URL.")
|
||||
def search_by_url(
|
||||
url: str, skip: int = 0, limit: int = 25,
|
||||
archived_after: datetime = None, archived_before: datetime = None,
|
||||
db: Session = Depends(get_db_dependency),
|
||||
email: str = Depends(get_token_or_user_auth)
|
||||
) -> list[schemas.ArchiveResult]:
|
||||
|
||||
if email != ALLOW_ANY_EMAIL:
|
||||
user = UserState(db, email)
|
||||
if not user.read and not user.read_public:
|
||||
raise HTTPException(status_code=403, detail="User does not have read access.")
|
||||
|
||||
return crud.search_archives_by_url(db, url.strip(), email, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before)
|
||||
|
||||
|
||||
@url_router.delete("/{id}", summary="Delete a single URL archive by id.")
|
||||
def delete_task(
|
||||
id:str,
|
||||
user: UserState = Depends(get_user_state),
|
||||
db: Session = Depends(get_db_dependency)
|
||||
) -> schemas.TaskDelete:
|
||||
logger.info(f"deleting url archive task {id} request by {user.email}")
|
||||
return JSONResponse({
|
||||
"id": id,
|
||||
"deleted": crud.soft_delete_task(db, id, user.email)
|
||||
})
|
||||
179
app/web/events.py
Normal file
179
app/web/events.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
import datetime
|
||||
import logging
|
||||
import alembic.config
|
||||
from fastapi import FastAPI
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi_utils.tasks import repeat_every
|
||||
from loguru import logger
|
||||
from fastapi_mail import FastMail, MessageSchema, MessageType
|
||||
|
||||
from app.shared.db import crud, models
|
||||
from app.shared.db.database import get_db, get_db_async, make_engine, wal_checkpoint
|
||||
from app.shared import schemas
|
||||
from app.shared.settings import get_settings
|
||||
from app.shared.task_messaging import get_celery
|
||||
from app.web.utils.metrics import measure_regular_metrics, redis_subscribe_worker_exceptions
|
||||
|
||||
celery = get_celery()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# see https://fastapi.tiangolo.com/advanced/events/#lifespan
|
||||
|
||||
# STARTUP
|
||||
logger.debug("HERE 00")
|
||||
engine = make_engine(get_settings().DATABASE_PATH)
|
||||
models.Base.metadata.create_all(bind=engine)
|
||||
logger.debug("HERE 01")
|
||||
alembic.config.main(prog="alembic", argv=['--raiseerr', 'upgrade', 'head'])
|
||||
logger.debug("HERE 02")
|
||||
logging.getLogger("uvicorn.access").disabled = True # loguru
|
||||
asyncio.create_task(redis_subscribe_worker_exceptions(get_settings().REDIS_EXCEPTIONS_CHANNEL))
|
||||
asyncio.create_task(repeat_measure_regular_metrics())
|
||||
with get_db() as db:
|
||||
crud.upsert_user_groups(db)
|
||||
|
||||
# setup archive cronjobs
|
||||
if get_settings().CRON_ARCHIVE_SHEETS:
|
||||
asyncio.create_task(archive_hourly_sheets_cronjob())
|
||||
asyncio.create_task(archive_daily_sheets_cronjob())
|
||||
else:
|
||||
logger.warning("[CRON] Sheet archive cronjobs are disabled.")
|
||||
|
||||
if get_settings().CRON_DELETE_STALE_SHEETS:
|
||||
asyncio.create_task(delete_stale_sheets())
|
||||
else:
|
||||
logger.warning("[CRON] Delete stale sheets cronjob is disabled.")
|
||||
|
||||
if get_settings().CRON_DELETE_SCHEDULED_ARCHIVES:
|
||||
asyncio.create_task(notify_about_expired_archives())
|
||||
else:
|
||||
logger.warning("[CRON] Delete scheduled archives cronjob is disabled.")
|
||||
|
||||
wal_checkpoint()
|
||||
|
||||
yield # separates startup from shutdown instructions
|
||||
|
||||
# SHUTDOWN
|
||||
logger.info("shutting down")
|
||||
|
||||
|
||||
# CRON JOBS
|
||||
@repeat_every(seconds=get_settings().REPEAT_COUNT_METRICS_SECONDS, on_exception=logger.error)
|
||||
async def repeat_measure_regular_metrics():
|
||||
await measure_regular_metrics(get_settings().DATABASE_PATH, get_settings().REPEAT_COUNT_METRICS_SECONDS)
|
||||
|
||||
|
||||
@repeat_every(seconds=60, wait_first=120, on_exception=logger.error)
|
||||
async def archive_hourly_sheets_cronjob():
|
||||
await archive_sheets_cronjob("hourly", 60, datetime.datetime.now().minute)
|
||||
|
||||
|
||||
@repeat_every(seconds=3600, wait_first=120, on_exception=logger.error)
|
||||
async def archive_daily_sheets_cronjob():
|
||||
await archive_sheets_cronjob("daily", 24, datetime.datetime.now().hour)
|
||||
|
||||
|
||||
async def archive_sheets_cronjob(frequency: str, interval: int, current_time_unit: int):
|
||||
triggered_jobs = []
|
||||
|
||||
async with get_db_async() as db:
|
||||
sheets = await crud.get_sheets_by_id_hash(db, frequency, interval, current_time_unit)
|
||||
for s in sheets:
|
||||
task = celery.signature("create_sheet_task", args=[schemas.SubmitSheet(sheet_id=s.id, author_id=s.author_id, group_id=s.group_id).model_dump_json()]).apply_async()
|
||||
|
||||
triggered_jobs.append({"sheet_id": s.id, "task_id": task.id})
|
||||
logger.info(f"[CRON {frequency.upper()}:{current_time_unit}] Triggered {len(triggered_jobs)} sheet tasks: {triggered_jobs}")
|
||||
|
||||
|
||||
# TODO: on exception should logerror but also prometheus counter
|
||||
DELETE_WINDOW = get_settings().DELETE_SCHEDULED_ARCHIVES_NOTIFY_DAYS * 24 * 60 * 60
|
||||
|
||||
|
||||
@repeat_every(seconds=DELETE_WINDOW, wait_first=180, on_exception=logger.error)
|
||||
async def notify_about_expired_archives():
|
||||
notify_from = datetime.datetime.now() + datetime.timedelta(days=get_settings().DELETE_SCHEDULED_ARCHIVES_NOTIFY_DAYS)
|
||||
async with get_db_async() as db:
|
||||
scheduled_deletions = await crud.find_by_store_until(db, notify_from)
|
||||
|
||||
user_archives = defaultdict(list)
|
||||
for archive in scheduled_deletions:
|
||||
user_archives[archive.author_id].append(archive)
|
||||
|
||||
if user_archives:
|
||||
fastmail = FastMail(get_settings().MAIL_CONFIG)
|
||||
# notify users
|
||||
for email in user_archives:
|
||||
list_of_archives = "\n".join([f'{a.url},{a.id}<br/>' for a in user_archives[email]])
|
||||
# TODO: how can users download them in bulk?
|
||||
message = MessageSchema(
|
||||
subject="Auto Archiver: Archives Scheduled for Deletion",
|
||||
recipients=[email],
|
||||
body=f"""
|
||||
<html>
|
||||
<body>
|
||||
<p>Hi {email},</p>
|
||||
<p>Some of your archives will be deleted in the next {get_settings().DELETE_SCHEDULED_ARCHIVES_NOTIFY_DAYS} days, as they are reaching their expiration date according to our retention policy for their groups.</p>
|
||||
<p>If you want to preserve any, make sure to download them now.</p>
|
||||
<p>Here is a CSV list of URLs:</p>
|
||||
<code>
|
||||
url,archive_id<br/>
|
||||
{list_of_archives}
|
||||
</code>
|
||||
<p>Best,<br>The Auto Archiver team</p>
|
||||
</body>
|
||||
</html>
|
||||
""",
|
||||
subtype=MessageType.html
|
||||
)
|
||||
await fastmail.send_message(message)
|
||||
logger.info(f"[CRON] Email sent to {email} about {len(user_archives[email])} scheduled archives deletion.")
|
||||
|
||||
# now schedule the deletion event
|
||||
asyncio.create_task(delete_expired_archives())
|
||||
|
||||
|
||||
@repeat_every(max_repetitions=1, wait_first=DELETE_WINDOW - (60 * 60), seconds=0, on_exception=logger.error)
|
||||
async def delete_expired_archives():
|
||||
async with get_db_async() as db:
|
||||
count_deleted = await crud.soft_delete_expired_archives(db)
|
||||
if count_deleted:
|
||||
logger.info(f"[CRON] Deleted {count_deleted} archives.")
|
||||
|
||||
|
||||
@repeat_every(seconds=86400, wait_first=150, on_exception=logger.error)
|
||||
async def delete_stale_sheets():
|
||||
STALE_DAYS = get_settings().DELETE_STALE_SHEETS_DAYS
|
||||
logger.info(f"[CRON] Deleting stale sheets older than {STALE_DAYS} days.")
|
||||
async with get_db_async() as db:
|
||||
user_sheets = await crud.delete_stale_sheets(db, STALE_DAYS)
|
||||
|
||||
if not user_sheets: return
|
||||
|
||||
fastmail = FastMail(get_settings().MAIL_CONFIG)
|
||||
# notify users
|
||||
for email in user_sheets:
|
||||
list_of_sheets = "\n".join([f'<li><a href="https://docs.google.com/spreadsheets/d/{s.id}">{s.name}</a></li>' for s in user_sheets[email]])
|
||||
message = MessageSchema(
|
||||
subject="Auto Archiver: Stale Sheets Removed",
|
||||
recipients=[email],
|
||||
body=f"""
|
||||
<html>
|
||||
<body>
|
||||
<p>Hi {email},</p>
|
||||
<p>Your stale sheets have been removed from our system as no new URL was archived in the past {STALE_DAYS} days:</p>
|
||||
<ul>
|
||||
{list_of_sheets}
|
||||
</ul>
|
||||
<p>You can always re-add them at https://auto-archiver.bellingcat.com/.</p>
|
||||
<p>Best,<br>The Auto Archiver team</p>
|
||||
</body>
|
||||
</html>
|
||||
""",
|
||||
subtype=MessageType.html
|
||||
)
|
||||
await fastmail.send_message(message)
|
||||
logger.info(f"[CRON] Email sent to {email} about stale sheets deletion.")
|
||||
174
app/web/main.py
Normal file
174
app/web/main.py
Normal file
@@ -0,0 +1,174 @@
|
||||
import os
|
||||
from celery.result import AsyncResult
|
||||
from fastapi import FastAPI, Depends, HTTPException
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from loguru import logger
|
||||
|
||||
from app.shared.log import log_error
|
||||
from app.web.middleware import logging_middleware
|
||||
from app.shared import schemas
|
||||
from app.shared.task_messaging import get_celery
|
||||
|
||||
from app.shared.db import crud
|
||||
from app.web.security import get_user_auth, token_api_key_auth, get_token_or_user_auth
|
||||
from app.shared.config import VERSION, API_DESCRIPTION
|
||||
from app.shared.db.database import get_db_dependency
|
||||
from app.web.events import lifespan
|
||||
from app.shared.settings import get_settings
|
||||
|
||||
|
||||
from app.web.endpoints.default import default_router
|
||||
from app.web.endpoints.url import url_router
|
||||
from app.web.endpoints.sheet import sheet_router
|
||||
from app.web.endpoints.task import task_router
|
||||
from app.web.endpoints.interoperability import interoperability_router
|
||||
|
||||
celery = get_celery()
|
||||
|
||||
def app_factory(settings = get_settings()):
|
||||
app = FastAPI(
|
||||
title="Auto-Archiver API",
|
||||
description=API_DESCRIPTION,
|
||||
version=VERSION,
|
||||
contact={"name": "GitHub", "url": "https://github.com/bellingcat/auto-archiver-api"},
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
app.middleware("http")(logging_middleware)
|
||||
|
||||
app.include_router(default_router)
|
||||
app.include_router(url_router)
|
||||
app.include_router(sheet_router)
|
||||
app.include_router(task_router)
|
||||
app.include_router(interoperability_router)
|
||||
|
||||
# prometheus exposed in /metrics with authentication
|
||||
Instrumentator(should_group_status_codes=False, excluded_handlers=["/metrics", "/health", "/openapi.json", "/favicon.ico"]).instrument(app).expose(app, dependencies=[Depends(token_api_key_auth)])
|
||||
|
||||
# TODO: recheck this for security, currently only needed for when local_storage is used
|
||||
local_dir = settings.SERVE_LOCAL_ARCHIVE
|
||||
if not os.path.isdir(local_dir) and os.path.isdir(local_dir.replace("/app", ".")):
|
||||
local_dir = local_dir.replace("/app", ".")
|
||||
if len(settings.SERVE_LOCAL_ARCHIVE) > 1 and os.path.isdir(local_dir):
|
||||
logger.warning(f"MOUNTing local archive {settings.SERVE_LOCAL_ARCHIVE}")
|
||||
app.mount(settings.SERVE_LOCAL_ARCHIVE, StaticFiles(directory=local_dir), name=settings.SERVE_LOCAL_ARCHIVE)
|
||||
|
||||
|
||||
|
||||
# -----Submit URL and manipulate tasks. Bearer protected below
|
||||
|
||||
|
||||
@app.get("/tasks/search-url", response_model=list[schemas.Archive], deprecated=True) # DEPRECATED
|
||||
def search_by_url(url: str, skip: int = 0, limit: int = 100, archived_after: datetime = None, archived_before: datetime = None, db: Session = Depends(get_db_dependency), email=Depends(get_token_or_user_auth)):
|
||||
return crud.search_archives_by_url(db, url.strip(), email, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before)
|
||||
|
||||
|
||||
@app.get("/tasks/sync", response_model=list[schemas.Archive], deprecated=True) # DEPRECATED
|
||||
def search(skip: int = 0, limit: int = 100, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
|
||||
return crud.search_archives_by_email(db, email, skip=skip, limit=limit)
|
||||
|
||||
|
||||
@app.post("/tasks", status_code=201, deprecated=True) # DEPRECATED
|
||||
def archive_tasks(archive: schemas.ArchiveCreate, email=Depends(get_token_or_user_auth)):
|
||||
archive.author_id = email
|
||||
url = archive.url
|
||||
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {url}")
|
||||
if type(url) != str or len(url) <= 5:
|
||||
raise HTTPException(status_code=422, detail=f"Invalid URL received: {url}")
|
||||
logger.info("creating task")
|
||||
|
||||
task = celery.signature("create_archive_task", args=[archive.model_dump_json()]).delay()
|
||||
return JSONResponse({"id": task.id})
|
||||
|
||||
|
||||
@app.get("/archive/{task_id}", deprecated=True) # DEPRECATED
|
||||
def lookup(task_id, db: Session = Depends(get_db_dependency), email=Depends(get_token_or_user_auth)):
|
||||
return crud.get_archive(db, task_id, email)
|
||||
|
||||
|
||||
@app.get("/tasks/{task_id}", deprecated=True) # DEPRECATED
|
||||
def get_status(task_id, email=Depends(get_token_or_user_auth)):
|
||||
logger.info(f"status check for user {email} task {task_id}")
|
||||
task = AsyncResult(task_id, app=celery)
|
||||
try:
|
||||
if task.status == "FAILURE":
|
||||
# *FAILURE* The task raised an exception, or has exceeded the retry limit.
|
||||
# The :attr:`result` attribute then contains the exception raised by the task.
|
||||
# https://docs.celeryq.dev/en/stable/_modules/celery/result.html#AsyncResult
|
||||
raise task.result
|
||||
|
||||
response = {
|
||||
"id": task_id,
|
||||
"status": task.status,
|
||||
"result": task.result
|
||||
}
|
||||
return JSONResponse(jsonable_encoder(response, exclude_unset=True))
|
||||
|
||||
except Exception as e:
|
||||
log_error(e)
|
||||
return JSONResponse({
|
||||
"id": task_id,
|
||||
"status": "FAILURE",
|
||||
"result": {"error": str(e)}
|
||||
})
|
||||
|
||||
|
||||
@app.delete("/tasks/{task_id}", deprecated=True) # DEPRECATED
|
||||
def delete_task(task_id, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
|
||||
logger.info(f"deleting task {task_id} request by {email}")
|
||||
return JSONResponse({
|
||||
"id": task_id,
|
||||
"deleted": crud.soft_delete_task(db, task_id, email)
|
||||
})
|
||||
|
||||
# ----- Google Sheets Logic
|
||||
|
||||
|
||||
@app.post("/sheet", status_code=201, deprecated=True) # DEPRECATED
|
||||
def archive_sheet(sheet: schemas.SubmitSheet, email=Depends(get_user_auth), db: Session = Depends(get_db_dependency)):
|
||||
logger.info(f"SHEET TASK for {sheet=}")
|
||||
sheet.author_id = email
|
||||
if not crud.is_user_in_group(db, email, sheet.group_id):
|
||||
raise HTTPException(status_code=403, detail="User does not have access to this group.")
|
||||
task = celery.signature("create_sheet_task", args=[sheet.model_dump_json()]).delay()
|
||||
return JSONResponse({"id": task.id})
|
||||
|
||||
|
||||
@app.post("/sheet_service", status_code=201, deprecated=True) # DEPRECATED
|
||||
def archive_sheet_service(sheet: schemas.SubmitSheet, auth=Depends(token_api_key_auth)):
|
||||
logger.info(f"SHEET TASK for {sheet=}")
|
||||
sheet.author_id = sheet.author_id or "api-endpoint"
|
||||
|
||||
task = celery.signature("create_sheet_task", args=[sheet.model_dump_json()]).delay()
|
||||
return JSONResponse({"id": task.id})
|
||||
|
||||
# ----- endpoint to submit data archived elsewhere
|
||||
|
||||
|
||||
@app.post("/submit-archive", status_code=201, deprecated=True) # DEPRECATED
|
||||
def submit_manual_archive(manual: schemas.SubmitManual, auth=Depends(token_api_key_auth)):
|
||||
raise HTTPException(status_code=410, detail="This endpoint is deprecated. Use /interop/submit-archive instead.")
|
||||
# result = Metadata.from_json(manual.result)
|
||||
# logger.info(f"MANUAL SUBMIT {result.get_url()} {manual.author_id}")
|
||||
# manual.tags.add("manual")
|
||||
# try:
|
||||
# # archive_id = insert_result_into_db(result, manual.tags, manual.public, manual.group_id, manual.author_id, models.generate_uuid())
|
||||
# except sqlalchemy.exc.IntegrityError as e:
|
||||
# log_error(e)
|
||||
# raise HTTPException(status_code=422, detail=f"Cannot insert into DB due to integrity error")
|
||||
# return JSONResponse({"id": archive_id})
|
||||
|
||||
return app
|
||||
17
app/web/middleware.py
Normal file
17
app/web/middleware.py
Normal file
@@ -0,0 +1,17 @@
|
||||
|
||||
from loguru import logger
|
||||
from fastapi import Request
|
||||
from app.shared.log import log_error
|
||||
|
||||
|
||||
async def logging_middleware(request: Request, call_next):
|
||||
try:
|
||||
response = await call_next(request)
|
||||
logger.info(f"{request.client.host}:{request.client.port} {request.method} {request.url._url} - HTTP {response.status_code}")
|
||||
return response
|
||||
except Exception as e:
|
||||
from web.utils.metrics import EXCEPTION_COUNTER
|
||||
EXCEPTION_COUNTER.labels(type=e.__class__.__name__).inc()
|
||||
logger.info(f"{request.client.host}:{request.client.port} {request.method} {request.url._url} - {e.__class__.__name__} {e}")
|
||||
log_error(e)
|
||||
raise e
|
||||
83
app/web/security.py
Normal file
83
app/web/security.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from loguru import logger
|
||||
import requests, secrets
|
||||
from fastapi import HTTPException, status, Depends
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
|
||||
from app.shared.config import ALLOW_ANY_EMAIL
|
||||
from app.shared.settings import get_settings
|
||||
from app.shared.db.database import get_db
|
||||
from app.shared.db.user_state import UserState
|
||||
|
||||
settings = get_settings()
|
||||
bearer_security = HTTPBearer()
|
||||
|
||||
|
||||
def secure_compare(token, api_key):
|
||||
return secrets.compare_digest(token.encode("utf8"), api_key.encode("utf8"))
|
||||
|
||||
|
||||
# Factory method to create an authentication dependency for a specific key
|
||||
def api_key_auth(api_key):
|
||||
assert len(api_key) >= 20, "Invalid API key, must be at least 20 chars"
|
||||
|
||||
async def auth(bearer: HTTPAuthorizationCredentials = Depends(bearer_security), auto_error=True):
|
||||
is_correct = secure_compare(bearer.credentials, api_key)
|
||||
if is_correct: return True
|
||||
|
||||
if auto_error:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Wrong auth credentials",
|
||||
)
|
||||
return False
|
||||
|
||||
return auth
|
||||
|
||||
|
||||
# --------------------- Token Auth for AA itself to query the API, AA setup tool and Prometheus
|
||||
token_api_key_auth = api_key_auth(settings.API_BEARER_TOKEN)
|
||||
|
||||
|
||||
async def get_token_or_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
|
||||
# tries to use the static API_KEY and defaults to google JWT auth
|
||||
if await token_api_key_auth(credentials, auto_error=False): return ALLOW_ANY_EMAIL
|
||||
return await get_user_auth(credentials)
|
||||
|
||||
|
||||
async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
|
||||
# validates the Bearer token in the case that it requires it
|
||||
valid_user, info = authenticate_user(credentials.credentials)
|
||||
if valid_user:
|
||||
return info.lower()
|
||||
logger.debug(f"TOKEN FAILURE: {valid_user=} {info=}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=info,
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
def authenticate_user(access_token):
|
||||
# https://cloud.google.com/docs/authentication/token-types#access
|
||||
if type(access_token) != str or len(access_token) < 10: return False, "invalid access_token"
|
||||
r = requests.get("https://oauth2.googleapis.com/tokeninfo", {"access_token": access_token})
|
||||
if r.status_code != 200: return False, "invalid token"
|
||||
try:
|
||||
j = r.json()
|
||||
if j.get("azp") not in settings.CHROME_APP_IDS and j.get("aud") not in settings.CHROME_APP_IDS:
|
||||
return False, f"token does not belong to valid APP_ID"
|
||||
if j.get("email") in settings.BLOCKED_EMAILS:
|
||||
return False, f"email '{j.get('email')}' not allowed"
|
||||
if j.get("email_verified") != "true":
|
||||
return False, f"email '{j.get('email')}' not verified"
|
||||
if int(j.get("expires_in", -1)) <= 0:
|
||||
return False, "Token expired"
|
||||
return True, j.get('email').lower()
|
||||
except Exception as e:
|
||||
logger.warning(f"AUTH EXCEPTION occurred: {e}")
|
||||
return False, "exception occurred"
|
||||
|
||||
|
||||
def get_user_state(email=Depends(get_user_auth)):
|
||||
with get_db() as db:
|
||||
return UserState(db, email)
|
||||
0
app/web/static/.gitkeep
Normal file
0
app/web/static/.gitkeep
Normal file
BIN
app/web/static/favicon.ico
Normal file
BIN
app/web/static/favicon.ico
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 93 KiB |
0
app/web/utils/__init__.py
Normal file
0
app/web/utils/__init__.py
Normal file
69
app/web/utils/metrics.py
Normal file
69
app/web/utils/metrics.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from prometheus_client import Counter, Gauge
|
||||
|
||||
from app.shared.db import crud
|
||||
from app.shared.db.database import get_db
|
||||
from app.shared.log import log_error
|
||||
from app.shared.task_messaging import get_redis
|
||||
|
||||
|
||||
# Custom metrics
|
||||
EXCEPTION_COUNTER = Counter(
|
||||
"exceptions",
|
||||
"Number of times a certain exception has occurred.",
|
||||
labelnames=["type"]
|
||||
)
|
||||
WORKER_EXCEPTION = Counter(
|
||||
"worker_exceptions_total",
|
||||
"Number of times a certain exception has occurred on the worker.",
|
||||
labelnames=["type", "exception", "task", "traceback"]
|
||||
)
|
||||
DISK_UTILIZATION = Gauge(
|
||||
"disk_utilization",
|
||||
"Disk utilization in GB",
|
||||
labelnames=["type"]
|
||||
)
|
||||
DATABASE_METRICS = Gauge(
|
||||
"database_metrics",
|
||||
"Database metric readings at a certain point in time",
|
||||
labelnames=["query"]
|
||||
)
|
||||
DATABASE_METRICS_COUNTER = Counter(
|
||||
"database_metrics_counter",
|
||||
"Database metrics that increase over time",
|
||||
labelnames=["query", "user"]
|
||||
)
|
||||
|
||||
|
||||
async def redis_subscribe_worker_exceptions(REDIS_EXCEPTIONS_CHANNEL: str):
|
||||
# Subscribe to Redis channel and increment the counter for each exception with info on the exception and task
|
||||
Redis = get_redis()
|
||||
PubSubExceptions = Redis.pubsub()
|
||||
PubSubExceptions.subscribe(REDIS_EXCEPTIONS_CHANNEL)
|
||||
while True:
|
||||
message = PubSubExceptions.get_message()
|
||||
if message and message["type"] == "message":
|
||||
data = json.loads(message["data"].decode("utf-8"))
|
||||
WORKER_EXCEPTION.labels(type=data["type"], exception=data["exception"], task=data["task"], traceback=data["traceback"]).inc()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
async def measure_regular_metrics(sqlite_db_url: str, repeat_in_seconds: int):
|
||||
_total, used, free = shutil.disk_usage("/")
|
||||
DISK_UTILIZATION.labels(type="used").set(used / (2**30))
|
||||
DISK_UTILIZATION.labels(type="free").set(free / (2**30))
|
||||
try:
|
||||
fs = os.stat(sqlite_db_url.replace("sqlite:///", ""))
|
||||
DISK_UTILIZATION.labels(type="database").set(fs.st_size / (2**30))
|
||||
except Exception as e: log_error(e)
|
||||
|
||||
with get_db() as db:
|
||||
DATABASE_METRICS.labels(query="count_archives").set(crud.count_archives(db))
|
||||
DATABASE_METRICS.labels(query="count_archive_urls").set(crud.count_archive_urls(db))
|
||||
DATABASE_METRICS.labels(query="count_users").set(crud.count_users(db))
|
||||
|
||||
for user in crud.count_by_user_since(db, repeat_in_seconds):
|
||||
DATABASE_METRICS_COUNTER.labels(query="count_by_user", user=user.author_id).inc(user.total)
|
||||
7
app/web/utils/misc.py
Normal file
7
app/web/utils/misc.py
Normal file
@@ -0,0 +1,7 @@
|
||||
import base64
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
|
||||
def custom_jsonable_encoder(obj):
|
||||
if isinstance(obj, bytes):
|
||||
return base64.b64encode(obj).decode('utf-8')
|
||||
return jsonable_encoder(obj)
|
||||
0
app/worker/__init__.py
Normal file
0
app/worker/__init__.py
Normal file
130
app/worker/main.py
Normal file
130
app/worker/main.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import json
|
||||
|
||||
import traceback, datetime
|
||||
from celery.signals import task_failure
|
||||
from loguru import logger
|
||||
from sqlalchemy import exc
|
||||
|
||||
from auto_archiver import Config, ArchivingOrchestrator, Metadata
|
||||
|
||||
from app.shared.db import crud, models
|
||||
from app.shared.db.database import get_db
|
||||
from app.shared import business_logic, schemas
|
||||
from app.shared.task_messaging import get_celery, get_redis
|
||||
from app.shared.settings import get_settings
|
||||
from app.shared.log import log_error
|
||||
from app.shared.aa_utils import get_all_urls
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
celery = get_celery("worker")
|
||||
Redis = get_redis()
|
||||
|
||||
USER_GROUPS_FILENAME = settings.USER_GROUPS_FILENAME
|
||||
|
||||
# TODO: after release, as it requires updating past entries with sheet_id where tag is used, drop tags
|
||||
@celery.task(name="create_archive_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 0})
|
||||
def create_archive_task(self, archive_json: str):
|
||||
logger.info(archive_json)
|
||||
archive = schemas.ArchiveCreate.model_validate_json(archive_json)
|
||||
|
||||
# call auto-archiver
|
||||
orchestrator = load_orchestrator(archive.group_id)
|
||||
result = orchestrator.feed_item(Metadata().set_url(archive.url))
|
||||
assert result, f"UNABLE TO archive: {archive.url}"
|
||||
|
||||
# prepare and insert in DB
|
||||
store_until = get_store_until(archive.group_id)
|
||||
archive.store_until = store_until
|
||||
archive.id = self.request.id
|
||||
archive.urls = get_all_urls(result)
|
||||
archive.result = json.loads(result.to_json())
|
||||
insert_result_into_db(archive)
|
||||
|
||||
return archive.result
|
||||
|
||||
|
||||
@celery.task(name="create_sheet_task", bind=True)
|
||||
def create_sheet_task(self, sheet_json: str):
|
||||
sheet = schemas.SubmitSheet.model_validate_json(sheet_json)
|
||||
logger.info(f"SHEET START {sheet=}")
|
||||
|
||||
orchestrator = load_orchestrator(sheet.group_id, True, {"configurations": {"gsheet_feeder": {"sheet_id": sheet.sheet_id}}})
|
||||
|
||||
stats = {"archived": 0, "failed": 0, "errors": []}
|
||||
for result in orchestrator.feed():
|
||||
try:
|
||||
assert result, f"UNABLE TO archive: {result.get_url()}"
|
||||
archive = schemas.ArchiveCreate(
|
||||
author_id=sheet.author_id,
|
||||
url=result.get_url(),
|
||||
group_id=sheet.group_id,
|
||||
tags=sheet.tags,
|
||||
id=models.generate_uuid(),
|
||||
result=json.loads(result.to_json()),
|
||||
sheet_id=sheet.sheet_id,
|
||||
urls=get_all_urls(result),
|
||||
store_until = get_store_until(sheet.group_id)
|
||||
)
|
||||
insert_result_into_db(archive)
|
||||
stats["archived"] += 1
|
||||
except exc.IntegrityError as e:
|
||||
logger.warning(f"cached result detected: {e}")
|
||||
except Exception as e:
|
||||
log_error(e, extra=f"{self.name}: {sheet_json}")
|
||||
redis_publish_exception(e, self.name, traceback.format_exc())
|
||||
stats["failed"] += 1
|
||||
stats["errors"].append(str(e))
|
||||
|
||||
if stats["archived"] > 0:
|
||||
with get_db() as session:
|
||||
crud.update_sheet_last_url_archived_at(session, sheet.sheet_id)
|
||||
|
||||
logger.info(f"SHEET DONE {sheet=}")
|
||||
# TODO: is this used anywhere? maybe drop it
|
||||
return schemas.CelerySheetTask(success=True, sheet_id=sheet.sheet_id, time=datetime.datetime.now().isoformat(), stats=stats).model_dump()
|
||||
|
||||
|
||||
def load_orchestrator(group_id: str, orchestrator_for_sheet: bool = False, overwrite_configs: dict = {}) -> ArchivingOrchestrator:
|
||||
with get_db() as session:
|
||||
group = crud.get_group(session, group_id)
|
||||
if orchestrator_for_sheet:
|
||||
orchestrator_fn = group.orchestrator_sheet
|
||||
else:
|
||||
orchestrator_fn = crud.get_group(session, group_id).orchestrator
|
||||
assert orchestrator_fn, f"no orchestrator found for {group_id}"
|
||||
|
||||
|
||||
config = Config()
|
||||
config.parse(use_cli=False, yaml_config_filename=orchestrator_fn, overwrite_configs=overwrite_configs)
|
||||
return ArchivingOrchestrator(config)
|
||||
|
||||
|
||||
def insert_result_into_db(archive: schemas.ArchiveCreate) -> str:
|
||||
with get_db() as session:
|
||||
db_task = crud.store_archived_url(session, archive)
|
||||
logger.debug(f"[ARCHIVE STORED] {db_task.author_id} {db_task.url}")
|
||||
return db_task.id
|
||||
|
||||
def get_store_until(group_id: str) -> datetime.datetime:
|
||||
with get_db() as session:
|
||||
return business_logic.get_store_archive_until(session, group_id)
|
||||
|
||||
def redis_publish_exception(exception, task_name, traceback: str = ""):
|
||||
REDIS_EXCEPTIONS_CHANNEL = settings.REDIS_EXCEPTIONS_CHANNEL
|
||||
try:
|
||||
exception_data = {"task": task_name, "type": exception.__class__.__name__, "exception": exception, "traceback": traceback}
|
||||
Redis.publish(REDIS_EXCEPTIONS_CHANNEL, json.dumps(exception_data, default=str))
|
||||
except Exception as e:
|
||||
log_error(e, f"[CRITICAL] Could not publish to {REDIS_EXCEPTIONS_CHANNEL}")
|
||||
|
||||
|
||||
@task_failure.connect(sender=create_sheet_task)
|
||||
@task_failure.connect(sender=create_archive_task)
|
||||
def task_failure_notifier(sender, **kwargs):
|
||||
# automatically capture exceptions in the worker tasks
|
||||
logger.warning(f"⚠️ worker task failed: {sender.name}")
|
||||
traceback_msg = "\n".join(traceback.format_list(traceback.extract_tb(kwargs['traceback'])))
|
||||
log_error(kwargs['exception'], traceback_msg, f"task_failure: {sender.name}")
|
||||
redis_publish_exception(kwargs['exception'], sender.name, traceback_msg)
|
||||
Reference in New Issue
Block a user