refactor to use pydantic settings and WAL sqlite mode

This commit is contained in:
msramalho
2024-10-18 11:08:51 +01:00
parent 11a8e6f4e0
commit ca7e862855
21 changed files with 246 additions and 135 deletions

4
.gitignore vendored
View File

@@ -19,4 +19,6 @@ src/crawls
.coverage .coverage
.pytest_cache/* .pytest_cache/*
htmlcov htmlcov
local_archive local_archive
*db-wal
*db-shm

7
src/.env.test Normal file
View File

@@ -0,0 +1,7 @@
CHROME_APP_IDS='["test_app_id_1","test_app_id_2"]'
ALLOWED_ORIGINS='["chrome-extension://example1","chrome-extension://example2","http://localhost:8081"]'
BLOCKED_EMAILS='["blocked@example.com"]'
DATABASE_PATH="sqlite:////app/auto-archiver.test.db"
API_BEARER_TOKEN=this_is_the_test_api_token

View File

@@ -19,6 +19,7 @@ alembic = "*"
fastapi-utils = "*" fastapi-utils = "*"
prometheus-fastapi-instrumentator = "*" prometheus-fastapi-instrumentator = "*"
auto-archiver = "*" auto-archiver = "*"
pydantic-settings = "*"
[dev-packages] [dev-packages]
watchdog = "*" watchdog = "*"

19
src/Pipfile.lock generated
View File

@@ -1,7 +1,7 @@
{ {
"_meta": { "_meta": {
"hash": { "hash": {
"sha256": "359638472cb3c3914fac7040bc702463c0ed1ae2e5cee00abddd59d9b34e923e" "sha256": "c34b5745f3a6f67222d3f26e6c7f2d13615a3301d0ca4d1f2b0ec58474b1d43a"
}, },
"pipfile-spec": 6, "pipfile-spec": 6,
"requires": { "requires": {
@@ -1470,11 +1470,11 @@
}, },
"marshmallow": { "marshmallow": {
"hashes": [ "hashes": [
"sha256:4972f529104a220bb8637d595aa4c9762afbe7f7a77d82dc58c1615d70c5823e", "sha256:82f20a2397834fe6d9611b241f2f7e7b680ed89c49f84728a1ad937be6b4bdf4",
"sha256:71a2dce49ef901c3f97ed296ae5051135fd3febd2bf43afe0ae9a82143a494d9" "sha256:98d8827a9f10c03d44ead298d2e99c6aea8197df18ccfad360dae7f89a50da2e"
], ],
"markers": "python_version >= '3.8'", "markers": "python_version >= '3.9'",
"version": "==3.22.0" "version": "==3.23.0"
}, },
"mccabe": { "mccabe": {
"hashes": [ "hashes": [
@@ -2182,6 +2182,15 @@
"markers": "python_version >= '3.8'", "markers": "python_version >= '3.8'",
"version": "==2.23.4" "version": "==2.23.4"
}, },
"pydantic-settings": {
"hashes": [
"sha256:44a1804abffac9e6a30372bb45f6cafab945ef5af25e66b1c634c01dd39e0188",
"sha256:4a819166f119b74d7f8c765196b165f95cc7487ce58ea27dec8a5a26be0970e0"
],
"index": "pypi",
"markers": "python_version >= '3.8'",
"version": "==2.6.0"
},
"pyflakes": { "pyflakes": {
"hashes": [ "hashes": [
"sha256:1c61603ff154621fb2a9172037d84dca3500def8c8b630657d1701f026f8af3f", "sha256:1c61603ff154621fb2a9172037d84dca3500def8c8b630657d1701f026f8af3f",

View File

@@ -1,5 +1,3 @@
import os
VERSION = "0.7.0" VERSION = "0.7.0"
API_DESCRIPTION = """ API_DESCRIPTION = """
#### API for the Auto-Archiver project, a tool to archive web pages and Google Sheets. #### API for the Auto-Archiver project, a tool to archive web pages and Google Sheets.
@@ -9,16 +7,4 @@ API_DESCRIPTION = """
- You can use this API to archive single URLs or entire Google Sheets. - You can use this API to archive single URLs or entire Google Sheets.
- Once you submit a URL or Sheet for archiving, the API will return a task_id that you can use to check the status of the archiving process. It works asynchronously. - Once you submit a URL or Sheet for archiving, the API will return a task_id that you can use to check the status of the archiving process. It works asynchronously.
""" """
BREAKING_CHANGES = {"minVersion": "0.3.1", "message": "The latest update has breaking changes, please update the extension to the most recent version."}
ALLOWED_ORIGINS = os.environ.get("ALLOWED_ORIGINS", "chrome-extension://ondkcheoicfckabcnkdgbepofpjmjcmb,chrome-extension://ojcimmjndnlmmlgnjaeojoebaceokpdp").split(",")
BREAKING_CHANGES = {"minVersion": "0.3.1", "message": "The latest update has breaking changes, please update the extension to the most recent version."}
SERVE_LOCAL_ARCHIVE = os.environ.get("SERVE_LOCAL_ARCHIVE", "")
SQLALCHEMY_DATABASE_URL = os.environ.get("DATABASE_PATH")
REPEAT_COUNT_METRICS_SECONDS = 15
CHROME_APP_IDS = set([app_id.strip() for app_id in os.environ.get("CHROME_APP_IDS", "").split(",")])
BLOCKED_EMAILS = set([e.strip().lower() for e in os.environ.get("BLOCKED_EMAILS", "").split(",")])

View File

@@ -2,15 +2,16 @@ import asyncio
import logging import logging
import alembic.config import alembic.config
from fastapi import FastAPI from fastapi import FastAPI
from sqlalchemy.orm import Session
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi_utils.tasks import repeat_every from fastapi_utils.tasks import repeat_every
from loguru import logger from loguru import logger
from db import crud, models from db import crud, models
from db.database import get_db, engine from db.database import get_db, make_engine
from shared.settings import Settings
from utils.metrics import measure_regular_metrics, redis_subscribe_worker_exceptions from utils.metrics import measure_regular_metrics, redis_subscribe_worker_exceptions
from core.config import REPEAT_COUNT_METRICS_SECONDS
settings = Settings()
@asynccontextmanager @asynccontextmanager
@@ -18,13 +19,14 @@ async def lifespan(app: FastAPI):
# see https://fastapi.tiangolo.com/advanced/events/#lifespan # see https://fastapi.tiangolo.com/advanced/events/#lifespan
# STARTUP # STARTUP
engine = make_engine(settings.DATABASE_PATH)
models.Base.metadata.create_all(bind=engine) models.Base.metadata.create_all(bind=engine)
alembic.config.main(argv=['--raiseerr', 'upgrade', 'head']) alembic.config.main(argv=['--raiseerr', 'upgrade', 'head'])
# disabling uvicorn logger since we use loguru in logging_middleware # disabling uvicorn logger since we use loguru in logging_middleware
logging.getLogger("uvicorn.access").disabled = True logging.getLogger("uvicorn.access").disabled = True
asyncio.create_task(redis_subscribe_worker_exceptions()) asyncio.create_task(redis_subscribe_worker_exceptions())
asyncio.create_task(refresh_user_groups()) asyncio.create_task(refresh_user_groups())
asyncio.create_task(measure_regular_metrics()) asyncio.create_task(repeat_measure_regular_metrics())
yield # separates startup from shutdown instructions yield # separates startup from shutdown instructions
@@ -36,9 +38,10 @@ async def lifespan(app: FastAPI):
@repeat_every(seconds=60 * 60) # 1 hour @repeat_every(seconds=60 * 60) # 1 hour
async def refresh_user_groups(): async def refresh_user_groups():
db: Session = next(get_db()) with get_db() as db:
crud.upsert_user_groups(db) crud.upsert_user_groups(db)
@repeat_every(seconds=REPEAT_COUNT_METRICS_SECONDS)
@repeat_every(seconds=settings.REPEAT_COUNT_METRICS_SECONDS)
async def repeat_measure_regular_metrics(): async def repeat_measure_regular_metrics():
measure_regular_metrics() measure_regular_metrics(settings.DATABASE_PATH, settings.REPEAT_COUNT_METRICS_SECONDS)

View File

@@ -5,12 +5,13 @@ from loguru import logger
from datetime import datetime, timedelta from datetime import datetime, timedelta
from security import ALLOW_ANY_EMAIL from security import ALLOW_ANY_EMAIL
from shared.settings import Settings
from . import models, schemas from . import models, schemas
import yaml, os import yaml
DOMAIN_GROUPS = {} DOMAIN_GROUPS = {}
DOMAIN_GROUPS_LOADED = False DOMAIN_GROUPS_LOADED = False
MAX_LIMIT = 100 DATABASE_QUERY_LIMIT = Settings().DATABASE_QUERY_LIMIT
# --------------- TASK = Archive # --------------- TASK = Archive
@@ -39,12 +40,12 @@ def search_archives_by_url(db: Session, url: str, email: str, skip: int = 0, lim
query = query.filter(models.Archive.created_at >= archived_after) query = query.filter(models.Archive.created_at >= archived_after)
if archived_before: if archived_before:
query = query.filter(models.Archive.created_at <= archived_before) query = query.filter(models.Archive.created_at <= archived_before)
return query.order_by(models.Archive.created_at.desc()).offset(skip).limit(min(limit, MAX_LIMIT)).all() return query.order_by(models.Archive.created_at.desc()).offset(skip).limit(min(limit, DATABASE_QUERY_LIMIT)).all()
def search_archives_by_email(db: Session, email: str, skip: int = 0, limit: int = 100): def search_archives_by_email(db: Session, email: str, skip: int = 0, limit: int = 100):
email = email.lower() email = email.lower()
return base_query(db).filter(models.Archive.author.has(email=email)).offset(skip).limit(min(limit, MAX_LIMIT)).all() return base_query(db).filter(models.Archive.author.has(email=email)).offset(skip).limit(min(limit, DATABASE_QUERY_LIMIT)).all()
def create_task(db: Session, task: schemas.ArchiveCreate, tags: list[models.Tag], urls: list[models.ArchiveUrl]): def create_task(db: Session, task: schemas.ArchiveCreate, tags: list[models.Tag], urls: list[models.ArchiveUrl]):
@@ -76,7 +77,7 @@ def count_by_user_since(db:Session, seconds_delta: int = 15):
return db.query(models.Archive.author_id,func.count().label('total'))\ return db.query(models.Archive.author_id,func.count().label('total'))\
.filter(models.Archive.created_at >= time_threshold)\ .filter(models.Archive.created_at >= time_threshold)\
.group_by(models.Archive.author_id)\ .group_by(models.Archive.author_id)\
.order_by(func.count().desc()).limit(5 * MAX_LIMIT).all() .order_by(func.count().desc()).limit(5 * DATABASE_QUERY_LIMIT).all()
def base_query(db: Session): def base_query(db: Session):
# allow only some fields to be returned, for example author should remain hidden # allow only some fields to be returned, for example author should remain hidden
@@ -98,7 +99,7 @@ def create_tag(db: Session, tag: str):
def search_tags(db: Session, tag: str, skip: int = 0, limit: int = 100): def search_tags(db: Session, tag: str, skip: int = 0, limit: int = 100):
return db.query(models.Tag).filter(models.Tag.url.like(f'%{tag}%')).offset(skip).limit(min(limit, MAX_LIMIT)).all() return db.query(models.Tag).filter(models.Tag.url.like(f'%{tag}%')).offset(skip).limit(min(limit, DATABASE_QUERY_LIMIT)).all()
def is_user_in_group(db: Session, group_name: str, email: str) -> models.Group: def is_user_in_group(db: Session, group_name: str, email: str) -> models.Group:
@@ -148,7 +149,7 @@ def upsert_user_groups(db: Session):
along with new participation of users in groups along with new participation of users in groups
""" """
logger.debug("Updating user-groups configuration.") logger.debug("Updating user-groups configuration.")
filename = os.environ.get("USER_GROUPS_FILENAME", "user-groups.yaml") filename = Settings().USER_GROUPS_FILENAME
# read yaml safely # read yaml safely
with open(filename) as inf: with open(filename) as inf:

View File

@@ -1,17 +1,36 @@
from sqlalchemy import create_engine from sqlalchemy import Engine, create_engine, event
from sqlalchemy.orm import sessionmaker, declarative_base from sqlalchemy.orm import sessionmaker, declarative_base
from core.config import SQLALCHEMY_DATABASE_URL from shared.settings import Settings
from contextlib import contextmanager
engine = create_engine( settings = Settings()
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base() def make_engine(database_url: str):
engine = create_engine(database_url, connect_args={"check_same_thread": False})
@event.listens_for(engine, "connect")
def set_sqlite_pragma(conn, _) -> None:
cursor = conn.cursor()
cursor.execute("PRAGMA journal_mode=WAL")
cursor.close()
return engine
def make_session_local(engine: Engine):
session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
return session_local
@contextmanager
def get_db(): def get_db():
session = SessionLocal() session = make_session_local(make_engine(settings.DATABASE_PATH))()
try: yield session try: yield session
finally: session.close() finally: session.close()
def get_db_dependency():
# to use with Depends and ensure proper session closing
with get_db() as db:
yield db

View File

@@ -1,8 +1,10 @@
from sqlalchemy import Column, String, JSON, DateTime, Boolean, Table, ForeignKey from sqlalchemy import Column, String, JSON, DateTime, Boolean, Table, ForeignKey
from sqlalchemy.sql import func from sqlalchemy.sql import func
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship, declarative_base
import uuid import uuid
from .database import Base
Base = declarative_base()
def generate_uuid(): def generate_uuid():
return str(uuid.uuid4()) return str(uuid.uuid4())
@@ -59,7 +61,6 @@ class Tag(Base):
archives = relationship("Archive", back_populates="tags", secondary=association_table_archive_tags) archives = relationship("Archive", back_populates="tags", secondary=association_table_archive_tags)
class User(Base): class User(Base):
__tablename__ = "users" __tablename__ = "users"

View File

@@ -2,6 +2,13 @@ from pydantic import BaseModel
from datetime import datetime from datetime import datetime
class Tag(BaseModel):
id: str
created_at: datetime
model_config = { "from_attributes": True }
__hash__ = object.__hash__
class ArchiveCreate(BaseModel): class ArchiveCreate(BaseModel):
id: str | None = None id: str | None = None
url: str url: str
@@ -9,7 +16,7 @@ class ArchiveCreate(BaseModel):
public: bool = True public: bool = True
author_id: str | None = None author_id: str | None = None
group_id: str | None = None group_id: str | None = None
tags: set = set() tags: set[Tag] | None = set()
rearchive: bool = True rearchive: bool = True
# urls: list = [] # urls: list = []
@@ -28,7 +35,7 @@ class SubmitSheet(BaseModel):
public: bool = False public: bool = False
author_id: str | None = None author_id: str | None = None
group_id: str | None = None group_id: str | None = None
tags: set | None = set() tags: set[Tag] | None = set()
columns: dict | None = {} # TODO: implement columns: dict | None = {} # TODO: implement
class SubmitManual(BaseModel): class SubmitManual(BaseModel):
@@ -36,7 +43,7 @@ class SubmitManual(BaseModel):
public: bool = False public: bool = False
author_id: str | None = None author_id: str | None = None
group_id: str | None = None group_id: str | None = None
tags: set | None = set() tags: set[Tag] | None = set()
class Task(BaseModel): class Task(BaseModel):
id: str id: str

View File

@@ -6,20 +6,20 @@ from sqlalchemy.orm import Session
from core.config import VERSION, BREAKING_CHANGES from core.config import VERSION, BREAKING_CHANGES
from db import crud from db import crud
from db.database import get_db from db.database import get_db_dependency, get_db
from security import get_user_auth, bearer_security from security import get_user_auth, bearer_security
default_router = APIRouter() default_router = APIRouter()
@default_router.get("/") @default_router.get("/")
async def home(request: Request): async def home(request: Request):
# TODO: maybe split into 2 routes: one non authenticated and one authenticated for the groups info only # TODO: maybe split into 2 routes: one non authenticated and one authenticated for the groups info only
status = {"version": VERSION, "breakingChanges": BREAKING_CHANGES} status = {"version": VERSION, "breakingChanges": BREAKING_CHANGES}
try: try:
email = await get_user_auth(await bearer_security(request)) email = await get_user_auth(await bearer_security(request))
db: Session = next(get_db()) with get_db() as db:
status["groups"] = crud.get_user_groups(db, email) status["groups"] = crud.get_user_groups(db, email)
except HTTPException: pass # not authenticated is fine except HTTPException: pass # not authenticated is fine
except Exception as e: logger.error(e) except Exception as e: logger.error(e)
return JSONResponse(status) return JSONResponse(status)
@@ -29,8 +29,9 @@ async def home(request: Request):
async def health(): async def health():
return JSONResponse({"status": "ok"}) return JSONResponse({"status": "ok"})
@default_router.get("/groups", response_model=list[str]) @default_router.get("/groups", response_model=list[str])
def get_user_groups(db: Session = Depends(get_db), email=Depends(get_user_auth)): def get_user_groups(db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
return crud.get_user_groups(db, email) return crud.get_user_groups(db, email)

View File

@@ -8,7 +8,7 @@ from security import get_user_auth, get_token_or_user_auth
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from db import crud, schemas from db import crud, schemas
from db.database import get_db from db.database import get_db_dependency
from worker import create_archive_task from worker import create_archive_task
@@ -32,23 +32,23 @@ def archive_url(archive: schemas.ArchiveCreate, email=Depends(get_token_or_user_
def search_by_url( def search_by_url(
url: str, skip: int = 0, limit: int = 25, url: str, skip: int = 0, limit: int = 25,
archived_after: datetime = None, archived_before: datetime = None, archived_after: datetime = None, archived_before: datetime = None,
db: Session = Depends(get_db), db: Session = Depends(get_db_dependency),
email=Depends(get_token_or_user_auth)): email=Depends(get_token_or_user_auth)):
return crud.search_archives_by_url(db, url.strip(), email, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before) return crud.search_archives_by_url(db, url.strip(), email, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before)
@url_router.get("/latest", response_model=list[schemas.Archive], summary="Fetch latest URL archives for the authenticated user.") @url_router.get("/latest", response_model=list[schemas.Archive], summary="Fetch latest URL archives for the authenticated user.")
def latest(skip: int = 0, limit: int = 25, db: Session = Depends(get_db), email=Depends(get_user_auth)): def latest(skip: int = 0, limit: int = 25, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
return crud.search_archives_by_email(db, email, skip=skip, limit=limit) return crud.search_archives_by_email(db, email, skip=skip, limit=limit)
@url_router.get("/{id}", response_model=schemas.Archive, summary="Fetch a single URL archive by the associated id.") @url_router.get("/{id}", response_model=schemas.Archive, summary="Fetch a single URL archive by the associated id.")
def lookup(id, db: Session = Depends(get_db), email=Depends(get_token_or_user_auth)): def lookup(id, db: Session = Depends(get_db_dependency), email=Depends(get_token_or_user_auth)):
return crud.get_archive(db, id, email) return crud.get_archive(db, id, email)
@url_router.delete("/{id}", response_model=schemas.TaskDelete, summary="Delete a single URL archive by id.") @url_router.delete("/{id}", response_model=schemas.TaskDelete, summary="Delete a single URL archive by id.")
def delete_task(id, db: Session = Depends(get_db), email=Depends(get_user_auth)): def delete_task(id, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
logger.info(f"deleting url archive task {id} request by {email}") logger.info(f"deleting url archive task {id} request by {email}")
#TODO: use response model? #TODO: use response model?
return JSONResponse({ return JSONResponse({

View File

@@ -16,14 +16,16 @@ from worker import create_archive_task, create_sheet_task, celery, insert_result
from db import crud, models, schemas from db import crud, models, schemas
from security import get_user_auth, token_api_key_auth, get_token_or_user_auth from security import get_user_auth, token_api_key_auth, get_token_or_user_auth
from core.config import ALLOWED_ORIGINS, VERSION, SERVE_LOCAL_ARCHIVE, API_DESCRIPTION from core.config import VERSION, API_DESCRIPTION
from db.database import get_db from db.database import get_db_dependency
from core.events import lifespan from core.events import lifespan
from shared.settings import Settings
from auto_archiver import Metadata from auto_archiver import Metadata
from endpoints import default_router, url_router, sheet_router, task_router, interoperability_router from endpoints import default_router, url_router, sheet_router, task_router, interoperability_router
settings = Settings()
app = FastAPI( app = FastAPI(
title="Auto-Archiver API", title="Auto-Archiver API",
@@ -35,7 +37,7 @@ app = FastAPI(
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=ALLOWED_ORIGINS, allow_origins=settings.ALLOWED_ORIGINS,
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
@@ -48,17 +50,15 @@ app.include_router(task_router)
app.include_router(interoperability_router) app.include_router(interoperability_router)
# prometheus exposed in /metrics with authentication # prometheus exposed in /metrics with authentication
Instrumentator(should_group_status_codes=False, excluded_handlers=["/metrics"]).instrument(app).expose(app, dependencies=[Depends(token_api_key_auth)]) Instrumentator(should_group_status_codes=False, excluded_handlers=["/metrics", "/health"]).instrument(app).expose(app, dependencies=[Depends(token_api_key_auth)])
def setup_local_archive_serve(): def setup_local_archive_serve():
# if env SERVE_LOCAL_ARCHIVE is set it serves files from that dir, useful for development and using local_archive local_dir = settings.SERVE_LOCAL_ARCHIVE
SERVE_LOCAL_ARCHIVE = os.environ.get("SERVE_LOCAL_ARCHIVE", "")
local_dir = SERVE_LOCAL_ARCHIVE
if not os.path.isdir(local_dir) and os.path.isdir(local_dir.replace("/app", ".")): if not os.path.isdir(local_dir) and os.path.isdir(local_dir.replace("/app", ".")):
local_dir = local_dir.replace("/app", ".") local_dir = local_dir.replace("/app", ".")
if len(SERVE_LOCAL_ARCHIVE) > 1 and os.path.isdir(local_dir): if len(settings.SERVE_LOCAL_ARCHIVE) > 1 and os.path.isdir(local_dir):
logger.warning(f"MOUNTing local archive {SERVE_LOCAL_ARCHIVE}") logger.warning(f"MOUNTing local archive {settings.SERVE_LOCAL_ARCHIVE}")
app.mount(SERVE_LOCAL_ARCHIVE, StaticFiles(directory=local_dir), name=SERVE_LOCAL_ARCHIVE) app.mount(settings.SERVE_LOCAL_ARCHIVE, StaticFiles(directory=local_dir), name=settings.SERVE_LOCAL_ARCHIVE)
setup_local_archive_serve() setup_local_archive_serve()
@@ -68,12 +68,12 @@ app.middleware("http")(logging_middleware)
@app.get("/tasks/search-url", response_model=list[schemas.Archive], deprecated=True) # DEPRECATED @app.get("/tasks/search-url", response_model=list[schemas.Archive], deprecated=True) # DEPRECATED
def search_by_url(url: str, skip: int = 0, limit: int = 100, archived_after: datetime = None, archived_before: datetime = None, db: Session = Depends(get_db), email=Depends(get_token_or_user_auth)): def search_by_url(url: str, skip: int = 0, limit: int = 100, archived_after: datetime = None, archived_before: datetime = None, db: Session = Depends(get_db_dependency), email=Depends(get_token_or_user_auth)):
return crud.search_archives_by_url(db, url.strip(), email, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before) return crud.search_archives_by_url(db, url.strip(), email, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before)
@app.get("/tasks/sync", response_model=list[schemas.Archive], deprecated=True) # DEPRECATED @app.get("/tasks/sync", response_model=list[schemas.Archive], deprecated=True) # DEPRECATED
def search(skip: int = 0, limit: int = 100, db: Session = Depends(get_db), email=Depends(get_user_auth)): def search(skip: int = 0, limit: int = 100, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
return crud.search_archives_by_email(db, email, skip=skip, limit=limit) return crud.search_archives_by_email(db, email, skip=skip, limit=limit)
@@ -90,7 +90,7 @@ def archive_tasks(archive: schemas.ArchiveCreate, email=Depends(get_token_or_use
@app.get("/archive/{task_id}", deprecated=True) # DEPRECATED @app.get("/archive/{task_id}", deprecated=True) # DEPRECATED
def lookup(task_id, db: Session = Depends(get_db), email=Depends(get_token_or_user_auth)): def lookup(task_id, db: Session = Depends(get_db_dependency), email=Depends(get_token_or_user_auth)):
return crud.get_archive(db, task_id, email) return crud.get_archive(db, task_id, email)
@@ -123,7 +123,7 @@ def get_status(task_id, email=Depends(get_token_or_user_auth)):
@app.delete("/tasks/{task_id}", deprecated=True) # DEPRECATED @app.delete("/tasks/{task_id}", deprecated=True) # DEPRECATED
def delete_task(task_id, db: Session = Depends(get_db), email=Depends(get_user_auth)): def delete_task(task_id, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
logger.info(f"deleting task {task_id} request by {email}") logger.info(f"deleting task {task_id} request by {email}")
return JSONResponse({ return JSONResponse({
"id": task_id, "id": task_id,

View File

@@ -5,10 +5,12 @@ from sqlalchemy import pool
from alembic import context from alembic import context
from shared.settings import Settings
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides
# access to the values within the .ini file in use. # access to the values within the .ini file in use.
config = context.config config = context.config
config.set_main_option('sqlalchemy.url', os.environ.get("DATABASE_PATH")) config.set_main_option('sqlalchemy.url', Settings().DATABASE_PATH)
# Interpret the config file for Python logging. # Interpret the config file for Python logging.
# This line sets up loggers basically. # This line sets up loggers basically.
if config.config_file_name is not None: if config.config_file_name is not None:

View File

@@ -2,21 +2,18 @@ from loguru import logger
import requests, os, secrets import requests, os, secrets
from fastapi import HTTPException, status, Depends from fastapi import HTTPException, status, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from core.config import CHROME_APP_IDS, BLOCKED_EMAILS from shared.settings import Settings
# Configuration checks
assert len(CHROME_APP_IDS) > 0, "CHROME_APP_IDS env variable not properly set, it's a csv"
for app_id in CHROME_APP_IDS:
assert len(app_id) > 10, f"CHROME_APP_IDS got invalid id: {app_id} env variable not set"
# Auth logic
bearer_security = HTTPBearer()
ALLOW_ANY_EMAIL = "*" ALLOW_ANY_EMAIL = "*"
settings = Settings()
bearer_security = HTTPBearer()
def secure_compare(token, api_key): def secure_compare(token, api_key):
return secrets.compare_digest(token.encode("utf8"), api_key.encode("utf8")) return secrets.compare_digest(token.encode("utf8"), api_key.encode("utf8"))
# Factory method to create an authentication dependency for a specific key # Factory method to create an authentication dependency for a specific key
def api_key_auth(api_key): def api_key_auth(api_key):
@@ -35,9 +32,10 @@ def api_key_auth(api_key):
return auth return auth
# --------------------- Token Auth for AA itself to query the API, AA setup tool and Prometheus # --------------------- Token Auth for AA itself to query the API, AA setup tool and Prometheus
API_BEARER_TOKEN = os.environ.get("API_BEARER_TOKEN", "") # min length is 20 chars token_api_key_auth = api_key_auth(settings.API_BEARER_TOKEN)
token_api_key_auth = api_key_auth(API_BEARER_TOKEN)
async def get_token_or_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)): async def get_token_or_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
# tries to use the static API_KEY and defaults to google JWT auth # tries to use the static API_KEY and defaults to google JWT auth
@@ -45,6 +43,7 @@ async def get_token_or_user_auth(credentials: HTTPAuthorizationCredentials = Dep
if token_api_key_auth(access_token, auto_error=False): return ALLOW_ANY_EMAIL if token_api_key_auth(access_token, auto_error=False): return ALLOW_ANY_EMAIL
return await get_user_auth(credentials) return await get_user_auth(credentials)
async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)): async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
# validates the Bearer token in the case that it requires it # validates the Bearer token in the case that it requires it
valid_user, info = authenticate_user(credentials.credentials) valid_user, info = authenticate_user(credentials.credentials)
@@ -56,6 +55,7 @@ async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bear
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
def authenticate_user(access_token): def authenticate_user(access_token):
# https://cloud.google.com/docs/authentication/token-types#access # https://cloud.google.com/docs/authentication/token-types#access
if type(access_token) != str or len(access_token) < 10: return False, "invalid access_token" if type(access_token) != str or len(access_token) < 10: return False, "invalid access_token"
@@ -63,9 +63,9 @@ def authenticate_user(access_token):
if r.status_code != 200: return False, "error occurred" if r.status_code != 200: return False, "error occurred"
try: try:
j = r.json() j = r.json()
if j.get("azp") not in CHROME_APP_IDS and j.get("aud") not in CHROME_APP_IDS: if j.get("azp") not in settings.CHROME_APP_IDS and j.get("aud") not in settings.CHROME_APP_IDS:
return False, f"token does not belong to valid APP_ID" return False, f"token does not belong to valid APP_ID"
if j.get("email") in BLOCKED_EMAILS: if j.get("email") in settings.BLOCKED_EMAILS:
return False, f"email '{j.get('email')}' not allowed" return False, f"email '{j.get('email')}' not allowed"
if j.get("email_verified") != "true": if j.get("email_verified") != "true":
return False, f"email '{j.get('email')}' not verified" return False, f"email '{j.get('email')}' not verified"
@@ -75,4 +75,3 @@ def authenticate_user(access_token):
except Exception as e: except Exception as e:
logger.warning(f"EXCEPTION occurred: {e}") logger.warning(f"EXCEPTION occurred: {e}")
return False, f"EXCEPTION occurred" return False, f"EXCEPTION occurred"

31
src/shared/settings.py Normal file
View File

@@ -0,0 +1,31 @@
from pydantic_settings import BaseSettings
from pydantic import ConfigDict
from typing import Annotated, Set
from annotated_types import Len
class Settings(BaseSettings):
model_config = ConfigDict(extra='ignore', str_strip_whitespace=True)
# general
SERVE_LOCAL_ARCHIVE: str = ""
USER_GROUPS_FILENAME: str = "user-groups.yaml"
# database
DATABASE_PATH: str
DATABASE_QUERY_LIMIT: int = 100
# redis
CELERY_BROKER_URL: str = "redis://localhost:6379"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379"
REDIS_EXCEPTIONS_CHANNEL: str = "exceptions-channel"
# observability
REPEAT_COUNT_METRICS_SECONDS: int = 15
# security
API_BEARER_TOKEN: Annotated[str, Len(min_length=20)]
ALLOWED_ORIGINS: Annotated[set[str], Len(min_length=1)]
CHROME_APP_IDS: Annotated[set[Annotated[str, Len(min_length=10)]], Len(min_length=1)]
BLOCKED_EMAILS: Annotated[Set[str], Len(min_length=0)] = set()

View File

@@ -1,6 +1,7 @@
import pytest
import os import os
import pytest
from unittest.mock import patch from unittest.mock import patch
from shared.settings import Settings
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_logger_add(): def mock_logger_add():
@@ -8,6 +9,52 @@ def mock_logger_add():
with patch('loguru.logger.add') as mock_add: with patch('loguru.logger.add') as mock_add:
yield mock_add # This makes the mock available to tests yield mock_add # This makes the mock available to tests
os.environ["CHROME_APP_IDS"] = 'test_app_id_1,test_app_id_2' # @pytest.fixture(autouse=True)
os.environ["DATABASE_PATH"] = "sqlite:////app/auto-archiver.test.db" # def settings():
os.environ["BLOCKED_EMAILS"] = "blocked@example.com" # return Settings(_env_file=".env.test")
@pytest.fixture(autouse=True)
def settings():
with patch('shared.settings.Settings', return_value=Settings(_env_file=".env.test")) as mock_settings:
yield mock_settings
@pytest.fixture()
def test_db(settings):
from db.database import make_engine, make_session_local
from db import models
engine = make_engine(settings.DATABASE_PATH)
if not os.path.exists(settings.DATABASE_PATH):
open(settings.DATABASE_PATH, 'w').close()
models.Base.metadata.create_all(engine)
connection = engine.connect()
yield connection
connection.close()
models.Base.metadata.drop_all(bind=engine)
os.remove(settings.DATABASE_PATH)
# @pytest.fixture()
# def db_session(test_db):
# session_local = make_session_local(test_db)
# with session_local() as session:
# yield session
# # create test data and insert it into the database
# def create_test_data():
# from db.database import SessionLocal
# from db.models import Task
# db = SessionLocal()
# task = Task(id="test-task-id", status="PENDING")
# db.add(task)
# db.commit()
# db.refresh(task)
# db.close()
# return task.id

View File

@@ -1,8 +1,7 @@
from unittest.mock import AsyncMock, patch from unittest.mock import patch
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
def setup_client(): def setup_client():
from main import app from main import app
from security import get_token_or_user_auth from security import get_token_or_user_auth
@@ -10,6 +9,7 @@ def setup_client():
app.dependency_overrides[get_token_or_user_auth] = mock_get_token_or_user_auth app.dependency_overrides[get_token_or_user_auth] = mock_get_token_or_user_auth
return TestClient(app), app return TestClient(app), app
@patch("endpoints.task.AsyncResult") @patch("endpoints.task.AsyncResult")
def test_get_status_success(mock_async_result): def test_get_status_success(mock_async_result):
client, app = setup_client() client, app = setup_client()

View File

@@ -1,22 +1,23 @@
import os import os
from unittest.mock import patch
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
def test_serve_local_archive_logic(): def test_serve_local_archive_logic():
os.environ["SERVE_LOCAL_ARCHIVE"] = "/app/local_archive_test" with patch("main.settings.SERVE_LOCAL_ARCHIVE", "/app/local_archive_test"):
# create a test file # create a test file
os.makedirs("local_archive_test", exist_ok=True) os.makedirs("local_archive_test", exist_ok=True)
with open("local_archive_test/temp.txt", "w") as f: with open("local_archive_test/temp.txt", "w") as f:
f.write("test") f.write("test")
from main import app, setup_local_archive_serve from main import app, setup_local_archive_serve
setup_local_archive_serve() setup_local_archive_serve()
client = TestClient(app) client = TestClient(app)
r = client.get("/app/local_archive_test/temp.txt") r = client.get("/app/local_archive_test/temp.txt")
assert r.status_code == 200 assert r.status_code == 200
assert r.text == "test" assert r.text == "test"
os.remove("local_archive_test/temp.txt") os.remove("local_archive_test/temp.txt")
os.rmdir("local_archive_test") os.rmdir("local_archive_test")

View File

@@ -6,10 +6,8 @@ from loguru import logger
from prometheus_client import Counter, Gauge from prometheus_client import Counter, Gauge
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from core.config import REPEAT_COUNT_METRICS_SECONDS
from db import crud from db import crud
from db.database import get_db from db.database import get_db
from core.config import SQLALCHEMY_DATABASE_URL
from worker import REDIS_EXCEPTIONS_CHANNEL, Rdis from worker import REDIS_EXCEPTIONS_CHANNEL, Rdis
@@ -47,20 +45,20 @@ async def redis_subscribe_worker_exceptions():
WORKER_EXCEPTION.labels(exception=data["exception"], task=data["task"]).inc() WORKER_EXCEPTION.labels(exception=data["exception"], task=data["task"]).inc()
await asyncio.sleep(1) await asyncio.sleep(1)
async def measure_regular_metrics(): async def measure_regular_metrics(sqlite_db_url:str, repeat_in_seconds:int):
_total, used, free = shutil.disk_usage("/") _total, used, free = shutil.disk_usage("/")
DISK_UTILIZATION.labels(type="used").set(used / (2**30)) DISK_UTILIZATION.labels(type="used").set(used / (2**30))
DISK_UTILIZATION.labels(type="free").set(free / (2**30)) DISK_UTILIZATION.labels(type="free").set(free / (2**30))
try: try:
fs = os.stat(SQLALCHEMY_DATABASE_URL.replace("sqlite:///", "")) fs = os.stat(sqlite_db_url.replace("sqlite:///", ""))
DISK_UTILIZATION.labels(type="database").set(fs.st_size / (2**30)) DISK_UTILIZATION.labels(type="database").set(fs.st_size / (2**30))
except Exception as e: logger.error(e) except Exception as e: logger.error(e)
session: Session = next(get_db()) with get_db as db:
count_archives = crud.count_archives(session) count_archives = crud.count_archives(db)
count_archive_urls = crud.count_archive_urls(session) count_archive_urls = crud.count_archive_urls(db)
DATABASE_METRICS.labels(query="count_archives", user="-").set(count_archives) DATABASE_METRICS.labels(query="count_archives", user="-").set(count_archives)
DATABASE_METRICS.labels(query="count_archive_urls", user="-").set(count_archive_urls) DATABASE_METRICS.labels(query="count_archive_urls", user="-").set(count_archive_urls)
for user in crud.count_by_user_since(session, REPEAT_COUNT_METRICS_SECONDS): for user in crud.count_by_user_since(db, repeat_in_seconds):
DATABASE_METRICS.labels(query="count_by_user", user=user.author_id).set(user.total) DATABASE_METRICS.labels(query="count_by_user", user=user.author_id).set(user.total)

View File

@@ -1,5 +1,5 @@
import os, traceback, yaml, datetime, sys import traceback, yaml, datetime
from typing import List, Set from typing import List, Set
from celery import Celery from celery import Celery
@@ -9,29 +9,25 @@ from auto_archiver.core import Media
from loguru import logger from loguru import logger
from db import crud, schemas, models from db import crud, schemas, models
from db.database import SessionLocal from db.database import get_db
from contextlib import contextmanager from shared.settings import Settings
import json import json
import redis import redis
from sqlalchemy import exc from sqlalchemy import exc
settings = Settings()
celery = Celery(__name__) celery = Celery(__name__)
celery.conf.broker_url = os.environ.get("CELERY_BROKER_URL", "redis://localhost:6379") celery.conf.broker_url = settings.CELERY_BROKER_URL
celery.conf.result_backend = os.environ.get("CELERY_RESULT_BACKEND", "redis://localhost:6379") celery.conf.result_backend = settings.CELERY_RESULT_BACKEND
USER_GROUPS_FILENAME = os.environ.get("USER_GROUPS_FILENAME", "user-groups.yaml") USER_GROUPS_FILENAME = settings.USER_GROUPS_FILENAME
REDIS_EXCEPTIONS_CHANNEL = "exceptions-channel" REDIS_EXCEPTIONS_CHANNEL = settings.REDIS_EXCEPTIONS_CHANNEL
Rdis = redis.Redis.from_url(celery.conf.broker_url) Rdis = redis.Redis.from_url(celery.conf.broker_url)
@contextmanager
def get_db():
session = SessionLocal()
try: yield session
finally: session.close()
@celery.task(name="create_archive_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 3}) @celery.task(name="create_archive_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 3})
def create_archive_task(self, archive_json: str): def create_archive_task(self, archive_json: str):
archive = schemas.ArchiveCreate.parse_raw(archive_json) archive = schemas.ArchiveCreate.model_validate_json(archive_json)
logger.info(f"Archiving {archive.url=} {archive.tags=} {archive.public=} {archive.group_id=} {archive.author_id=}") logger.info(f"Archiving {archive.url=} {archive.tags=} {archive.public=} {archive.group_id=} {archive.author_id=}")
invalid = is_group_invalid_for_user(archive.public, archive.group_id, archive.author_id) invalid = is_group_invalid_for_user(archive.public, archive.group_id, archive.author_id)
if invalid: if invalid:
@@ -63,7 +59,7 @@ def create_archive_task(self, archive_json: str):
@celery.task(name="create_sheet_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 0}) @celery.task(name="create_sheet_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 0})
def create_sheet_task(self, sheet_json: str): def create_sheet_task(self, sheet_json: str):
sheet = schemas.SubmitSheet.parse_raw(sheet_json) sheet = schemas.SubmitSheet.model_validate_json(sheet_json)
sheet.tags.add("gsheet") sheet.tags.add("gsheet")
logger.info(f"SHEET START {sheet=}") logger.info(f"SHEET START {sheet=}")