From ca7e862855e72dc5fb2551b31486f5476ed2a070 Mon Sep 17 00:00:00 2001 From: msramalho <19508417+msramalho@users.noreply.github.com> Date: Fri, 18 Oct 2024 11:08:51 +0100 Subject: [PATCH] refactor to use pydantic settings and WAL sqlite mode --- .gitignore | 4 ++- src/.env.test | 7 ++++ src/Pipfile | 1 + src/Pipfile.lock | 19 ++++++++--- src/core/config.py | 16 +--------- src/core/events.py | 19 ++++++----- src/db/crud.py | 15 +++++---- src/db/database.py | 35 +++++++++++++++----- src/db/models.py | 7 ++-- src/db/schemas.py | 13 ++++++-- src/endpoints/default.py | 11 ++++--- src/endpoints/url.py | 10 +++--- src/main.py | 28 ++++++++-------- src/migrations/env.py | 4 ++- src/security.py | 27 ++++++++-------- src/shared/settings.py | 31 ++++++++++++++++++ src/tests/conftest.py | 55 +++++++++++++++++++++++++++++--- src/tests/endpoints/test_task.py | 4 +-- src/tests/test_main.py | 27 ++++++++-------- src/utils/metrics.py | 20 ++++++------ src/worker.py | 28 +++++++--------- 21 files changed, 246 insertions(+), 135 deletions(-) create mode 100644 src/.env.test create mode 100644 src/shared/settings.py diff --git a/.gitignore b/.gitignore index e9ae646..01ac3f7 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,6 @@ src/crawls .coverage .pytest_cache/* htmlcov -local_archive \ No newline at end of file +local_archive +*db-wal +*db-shm \ No newline at end of file diff --git a/src/.env.test b/src/.env.test new file mode 100644 index 0000000..fca6b15 --- /dev/null +++ b/src/.env.test @@ -0,0 +1,7 @@ +CHROME_APP_IDS='["test_app_id_1","test_app_id_2"]' +ALLOWED_ORIGINS='["chrome-extension://example1","chrome-extension://example2","http://localhost:8081"]' +BLOCKED_EMAILS='["blocked@example.com"]' + + +DATABASE_PATH="sqlite:////app/auto-archiver.test.db" +API_BEARER_TOKEN=this_is_the_test_api_token \ No newline at end of file diff --git a/src/Pipfile b/src/Pipfile index dec6834..4423b6f 100644 --- a/src/Pipfile +++ b/src/Pipfile @@ -19,6 +19,7 @@ alembic = "*" fastapi-utils = "*" prometheus-fastapi-instrumentator = "*" auto-archiver = "*" +pydantic-settings = "*" [dev-packages] watchdog = "*" diff --git a/src/Pipfile.lock b/src/Pipfile.lock index 068420f..6e6c0f2 100644 --- a/src/Pipfile.lock +++ b/src/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "359638472cb3c3914fac7040bc702463c0ed1ae2e5cee00abddd59d9b34e923e" + "sha256": "c34b5745f3a6f67222d3f26e6c7f2d13615a3301d0ca4d1f2b0ec58474b1d43a" }, "pipfile-spec": 6, "requires": { @@ -1470,11 +1470,11 @@ }, "marshmallow": { "hashes": [ - "sha256:4972f529104a220bb8637d595aa4c9762afbe7f7a77d82dc58c1615d70c5823e", - "sha256:71a2dce49ef901c3f97ed296ae5051135fd3febd2bf43afe0ae9a82143a494d9" + "sha256:82f20a2397834fe6d9611b241f2f7e7b680ed89c49f84728a1ad937be6b4bdf4", + "sha256:98d8827a9f10c03d44ead298d2e99c6aea8197df18ccfad360dae7f89a50da2e" ], - "markers": "python_version >= '3.8'", - "version": "==3.22.0" + "markers": "python_version >= '3.9'", + "version": "==3.23.0" }, "mccabe": { "hashes": [ @@ -2182,6 +2182,15 @@ "markers": "python_version >= '3.8'", "version": "==2.23.4" }, + "pydantic-settings": { + "hashes": [ + "sha256:44a1804abffac9e6a30372bb45f6cafab945ef5af25e66b1c634c01dd39e0188", + "sha256:4a819166f119b74d7f8c765196b165f95cc7487ce58ea27dec8a5a26be0970e0" + ], + "index": "pypi", + "markers": "python_version >= '3.8'", + "version": "==2.6.0" + }, "pyflakes": { "hashes": [ "sha256:1c61603ff154621fb2a9172037d84dca3500def8c8b630657d1701f026f8af3f", diff --git a/src/core/config.py b/src/core/config.py index 9671eeb..0d33cf6 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -1,5 +1,3 @@ -import os - VERSION = "0.7.0" API_DESCRIPTION = """ #### API for the Auto-Archiver project, a tool to archive web pages and Google Sheets. @@ -9,16 +7,4 @@ API_DESCRIPTION = """ - You can use this API to archive single URLs or entire Google Sheets. - Once you submit a URL or Sheet for archiving, the API will return a task_id that you can use to check the status of the archiving process. It works asynchronously. """ - -ALLOWED_ORIGINS = os.environ.get("ALLOWED_ORIGINS", "chrome-extension://ondkcheoicfckabcnkdgbepofpjmjcmb,chrome-extension://ojcimmjndnlmmlgnjaeojoebaceokpdp").split(",") - -BREAKING_CHANGES = {"minVersion": "0.3.1", "message": "The latest update has breaking changes, please update the extension to the most recent version."} - -SERVE_LOCAL_ARCHIVE = os.environ.get("SERVE_LOCAL_ARCHIVE", "") - -SQLALCHEMY_DATABASE_URL = os.environ.get("DATABASE_PATH") - -REPEAT_COUNT_METRICS_SECONDS = 15 - -CHROME_APP_IDS = set([app_id.strip() for app_id in os.environ.get("CHROME_APP_IDS", "").split(",")]) -BLOCKED_EMAILS = set([e.strip().lower() for e in os.environ.get("BLOCKED_EMAILS", "").split(",")]) \ No newline at end of file +BREAKING_CHANGES = {"minVersion": "0.3.1", "message": "The latest update has breaking changes, please update the extension to the most recent version."} \ No newline at end of file diff --git a/src/core/events.py b/src/core/events.py index a5e483c..abad52a 100644 --- a/src/core/events.py +++ b/src/core/events.py @@ -2,15 +2,16 @@ import asyncio import logging import alembic.config from fastapi import FastAPI -from sqlalchemy.orm import Session from contextlib import asynccontextmanager from fastapi_utils.tasks import repeat_every from loguru import logger from db import crud, models -from db.database import get_db, engine +from db.database import get_db, make_engine +from shared.settings import Settings from utils.metrics import measure_regular_metrics, redis_subscribe_worker_exceptions -from core.config import REPEAT_COUNT_METRICS_SECONDS + +settings = Settings() @asynccontextmanager @@ -18,13 +19,14 @@ async def lifespan(app: FastAPI): # see https://fastapi.tiangolo.com/advanced/events/#lifespan # STARTUP + engine = make_engine(settings.DATABASE_PATH) models.Base.metadata.create_all(bind=engine) alembic.config.main(argv=['--raiseerr', 'upgrade', 'head']) # disabling uvicorn logger since we use loguru in logging_middleware logging.getLogger("uvicorn.access").disabled = True asyncio.create_task(redis_subscribe_worker_exceptions()) asyncio.create_task(refresh_user_groups()) - asyncio.create_task(measure_regular_metrics()) + asyncio.create_task(repeat_measure_regular_metrics()) yield # separates startup from shutdown instructions @@ -36,9 +38,10 @@ async def lifespan(app: FastAPI): @repeat_every(seconds=60 * 60) # 1 hour async def refresh_user_groups(): - db: Session = next(get_db()) - crud.upsert_user_groups(db) + with get_db() as db: + crud.upsert_user_groups(db) -@repeat_every(seconds=REPEAT_COUNT_METRICS_SECONDS) + +@repeat_every(seconds=settings.REPEAT_COUNT_METRICS_SECONDS) async def repeat_measure_regular_metrics(): - measure_regular_metrics() + measure_regular_metrics(settings.DATABASE_PATH, settings.REPEAT_COUNT_METRICS_SECONDS) diff --git a/src/db/crud.py b/src/db/crud.py index 765ebd3..a62113f 100644 --- a/src/db/crud.py +++ b/src/db/crud.py @@ -5,12 +5,13 @@ from loguru import logger from datetime import datetime, timedelta from security import ALLOW_ANY_EMAIL +from shared.settings import Settings from . import models, schemas -import yaml, os +import yaml DOMAIN_GROUPS = {} DOMAIN_GROUPS_LOADED = False -MAX_LIMIT = 100 +DATABASE_QUERY_LIMIT = Settings().DATABASE_QUERY_LIMIT # --------------- TASK = Archive @@ -39,12 +40,12 @@ def search_archives_by_url(db: Session, url: str, email: str, skip: int = 0, lim query = query.filter(models.Archive.created_at >= archived_after) if archived_before: query = query.filter(models.Archive.created_at <= archived_before) - return query.order_by(models.Archive.created_at.desc()).offset(skip).limit(min(limit, MAX_LIMIT)).all() + return query.order_by(models.Archive.created_at.desc()).offset(skip).limit(min(limit, DATABASE_QUERY_LIMIT)).all() def search_archives_by_email(db: Session, email: str, skip: int = 0, limit: int = 100): email = email.lower() - return base_query(db).filter(models.Archive.author.has(email=email)).offset(skip).limit(min(limit, MAX_LIMIT)).all() + return base_query(db).filter(models.Archive.author.has(email=email)).offset(skip).limit(min(limit, DATABASE_QUERY_LIMIT)).all() def create_task(db: Session, task: schemas.ArchiveCreate, tags: list[models.Tag], urls: list[models.ArchiveUrl]): @@ -76,7 +77,7 @@ def count_by_user_since(db:Session, seconds_delta: int = 15): return db.query(models.Archive.author_id,func.count().label('total'))\ .filter(models.Archive.created_at >= time_threshold)\ .group_by(models.Archive.author_id)\ - .order_by(func.count().desc()).limit(5 * MAX_LIMIT).all() + .order_by(func.count().desc()).limit(5 * DATABASE_QUERY_LIMIT).all() def base_query(db: Session): # allow only some fields to be returned, for example author should remain hidden @@ -98,7 +99,7 @@ def create_tag(db: Session, tag: str): def search_tags(db: Session, tag: str, skip: int = 0, limit: int = 100): - return db.query(models.Tag).filter(models.Tag.url.like(f'%{tag}%')).offset(skip).limit(min(limit, MAX_LIMIT)).all() + return db.query(models.Tag).filter(models.Tag.url.like(f'%{tag}%')).offset(skip).limit(min(limit, DATABASE_QUERY_LIMIT)).all() def is_user_in_group(db: Session, group_name: str, email: str) -> models.Group: @@ -148,7 +149,7 @@ def upsert_user_groups(db: Session): along with new participation of users in groups """ logger.debug("Updating user-groups configuration.") - filename = os.environ.get("USER_GROUPS_FILENAME", "user-groups.yaml") + filename = Settings().USER_GROUPS_FILENAME # read yaml safely with open(filename) as inf: diff --git a/src/db/database.py b/src/db/database.py index c212d20..8b72f70 100644 --- a/src/db/database.py +++ b/src/db/database.py @@ -1,17 +1,36 @@ -from sqlalchemy import create_engine +from sqlalchemy import Engine, create_engine, event from sqlalchemy.orm import sessionmaker, declarative_base -from core.config import SQLALCHEMY_DATABASE_URL +from shared.settings import Settings +from contextlib import contextmanager -engine = create_engine( - SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} -) -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) +settings = Settings() -Base = declarative_base() +def make_engine(database_url: str): + engine = create_engine(database_url, connect_args={"check_same_thread": False}) + + @event.listens_for(engine, "connect") + def set_sqlite_pragma(conn, _) -> None: + cursor = conn.cursor() + cursor.execute("PRAGMA journal_mode=WAL") + cursor.close() + + return engine +def make_session_local(engine: Engine): + session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine) + return session_local + + +@contextmanager def get_db(): - session = SessionLocal() + session = make_session_local(make_engine(settings.DATABASE_PATH))() try: yield session finally: session.close() + + +def get_db_dependency(): + # to use with Depends and ensure proper session closing + with get_db() as db: + yield db \ No newline at end of file diff --git a/src/db/models.py b/src/db/models.py index 604bbcc..29f0a05 100644 --- a/src/db/models.py +++ b/src/db/models.py @@ -1,8 +1,10 @@ from sqlalchemy import Column, String, JSON, DateTime, Boolean, Table, ForeignKey from sqlalchemy.sql import func -from sqlalchemy.orm import relationship +from sqlalchemy.orm import relationship, declarative_base import uuid -from .database import Base + + +Base = declarative_base() def generate_uuid(): return str(uuid.uuid4()) @@ -59,7 +61,6 @@ class Tag(Base): archives = relationship("Archive", back_populates="tags", secondary=association_table_archive_tags) - class User(Base): __tablename__ = "users" diff --git a/src/db/schemas.py b/src/db/schemas.py index fff03f9..ef06fe7 100644 --- a/src/db/schemas.py +++ b/src/db/schemas.py @@ -2,6 +2,13 @@ from pydantic import BaseModel from datetime import datetime +class Tag(BaseModel): + id: str + created_at: datetime + + model_config = { "from_attributes": True } + __hash__ = object.__hash__ + class ArchiveCreate(BaseModel): id: str | None = None url: str @@ -9,7 +16,7 @@ class ArchiveCreate(BaseModel): public: bool = True author_id: str | None = None group_id: str | None = None - tags: set = set() + tags: set[Tag] | None = set() rearchive: bool = True # urls: list = [] @@ -28,7 +35,7 @@ class SubmitSheet(BaseModel): public: bool = False author_id: str | None = None group_id: str | None = None - tags: set | None = set() + tags: set[Tag] | None = set() columns: dict | None = {} # TODO: implement class SubmitManual(BaseModel): @@ -36,7 +43,7 @@ class SubmitManual(BaseModel): public: bool = False author_id: str | None = None group_id: str | None = None - tags: set | None = set() + tags: set[Tag] | None = set() class Task(BaseModel): id: str diff --git a/src/endpoints/default.py b/src/endpoints/default.py index 091202e..7a6a9ea 100644 --- a/src/endpoints/default.py +++ b/src/endpoints/default.py @@ -6,20 +6,20 @@ from sqlalchemy.orm import Session from core.config import VERSION, BREAKING_CHANGES from db import crud -from db.database import get_db +from db.database import get_db_dependency, get_db from security import get_user_auth, bearer_security - default_router = APIRouter() + @default_router.get("/") async def home(request: Request): # TODO: maybe split into 2 routes: one non authenticated and one authenticated for the groups info only status = {"version": VERSION, "breakingChanges": BREAKING_CHANGES} try: email = await get_user_auth(await bearer_security(request)) - db: Session = next(get_db()) - status["groups"] = crud.get_user_groups(db, email) + with get_db() as db: + status["groups"] = crud.get_user_groups(db, email) except HTTPException: pass # not authenticated is fine except Exception as e: logger.error(e) return JSONResponse(status) @@ -29,8 +29,9 @@ async def home(request: Request): async def health(): return JSONResponse({"status": "ok"}) + @default_router.get("/groups", response_model=list[str]) -def get_user_groups(db: Session = Depends(get_db), email=Depends(get_user_auth)): +def get_user_groups(db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)): return crud.get_user_groups(db, email) diff --git a/src/endpoints/url.py b/src/endpoints/url.py index 199be59..dc1ca68 100644 --- a/src/endpoints/url.py +++ b/src/endpoints/url.py @@ -8,7 +8,7 @@ from security import get_user_auth, get_token_or_user_auth from sqlalchemy.orm import Session from db import crud, schemas -from db.database import get_db +from db.database import get_db_dependency from worker import create_archive_task @@ -32,23 +32,23 @@ def archive_url(archive: schemas.ArchiveCreate, email=Depends(get_token_or_user_ def search_by_url( url: str, skip: int = 0, limit: int = 25, archived_after: datetime = None, archived_before: datetime = None, - db: Session = Depends(get_db), + db: Session = Depends(get_db_dependency), email=Depends(get_token_or_user_auth)): return crud.search_archives_by_url(db, url.strip(), email, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before) @url_router.get("/latest", response_model=list[schemas.Archive], summary="Fetch latest URL archives for the authenticated user.") -def latest(skip: int = 0, limit: int = 25, db: Session = Depends(get_db), email=Depends(get_user_auth)): +def latest(skip: int = 0, limit: int = 25, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)): return crud.search_archives_by_email(db, email, skip=skip, limit=limit) @url_router.get("/{id}", response_model=schemas.Archive, summary="Fetch a single URL archive by the associated id.") -def lookup(id, db: Session = Depends(get_db), email=Depends(get_token_or_user_auth)): +def lookup(id, db: Session = Depends(get_db_dependency), email=Depends(get_token_or_user_auth)): return crud.get_archive(db, id, email) @url_router.delete("/{id}", response_model=schemas.TaskDelete, summary="Delete a single URL archive by id.") -def delete_task(id, db: Session = Depends(get_db), email=Depends(get_user_auth)): +def delete_task(id, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)): logger.info(f"deleting url archive task {id} request by {email}") #TODO: use response model? return JSONResponse({ diff --git a/src/main.py b/src/main.py index 369a206..a4c2adf 100644 --- a/src/main.py +++ b/src/main.py @@ -16,14 +16,16 @@ from worker import create_archive_task, create_sheet_task, celery, insert_result from db import crud, models, schemas from security import get_user_auth, token_api_key_auth, get_token_or_user_auth -from core.config import ALLOWED_ORIGINS, VERSION, SERVE_LOCAL_ARCHIVE, API_DESCRIPTION -from db.database import get_db +from core.config import VERSION, API_DESCRIPTION +from db.database import get_db_dependency from core.events import lifespan +from shared.settings import Settings from auto_archiver import Metadata from endpoints import default_router, url_router, sheet_router, task_router, interoperability_router +settings = Settings() app = FastAPI( title="Auto-Archiver API", @@ -35,7 +37,7 @@ app = FastAPI( app.add_middleware( CORSMiddleware, - allow_origins=ALLOWED_ORIGINS, + allow_origins=settings.ALLOWED_ORIGINS, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], @@ -48,17 +50,15 @@ app.include_router(task_router) app.include_router(interoperability_router) # prometheus exposed in /metrics with authentication -Instrumentator(should_group_status_codes=False, excluded_handlers=["/metrics"]).instrument(app).expose(app, dependencies=[Depends(token_api_key_auth)]) +Instrumentator(should_group_status_codes=False, excluded_handlers=["/metrics", "/health"]).instrument(app).expose(app, dependencies=[Depends(token_api_key_auth)]) def setup_local_archive_serve(): - # if env SERVE_LOCAL_ARCHIVE is set it serves files from that dir, useful for development and using local_archive - SERVE_LOCAL_ARCHIVE = os.environ.get("SERVE_LOCAL_ARCHIVE", "") - local_dir = SERVE_LOCAL_ARCHIVE + local_dir = settings.SERVE_LOCAL_ARCHIVE if not os.path.isdir(local_dir) and os.path.isdir(local_dir.replace("/app", ".")): local_dir = local_dir.replace("/app", ".") - if len(SERVE_LOCAL_ARCHIVE) > 1 and os.path.isdir(local_dir): - logger.warning(f"MOUNTing local archive {SERVE_LOCAL_ARCHIVE}") - app.mount(SERVE_LOCAL_ARCHIVE, StaticFiles(directory=local_dir), name=SERVE_LOCAL_ARCHIVE) + if len(settings.SERVE_LOCAL_ARCHIVE) > 1 and os.path.isdir(local_dir): + logger.warning(f"MOUNTing local archive {settings.SERVE_LOCAL_ARCHIVE}") + app.mount(settings.SERVE_LOCAL_ARCHIVE, StaticFiles(directory=local_dir), name=settings.SERVE_LOCAL_ARCHIVE) setup_local_archive_serve() @@ -68,12 +68,12 @@ app.middleware("http")(logging_middleware) @app.get("/tasks/search-url", response_model=list[schemas.Archive], deprecated=True) # DEPRECATED -def search_by_url(url: str, skip: int = 0, limit: int = 100, archived_after: datetime = None, archived_before: datetime = None, db: Session = Depends(get_db), email=Depends(get_token_or_user_auth)): +def search_by_url(url: str, skip: int = 0, limit: int = 100, archived_after: datetime = None, archived_before: datetime = None, db: Session = Depends(get_db_dependency), email=Depends(get_token_or_user_auth)): return crud.search_archives_by_url(db, url.strip(), email, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before) @app.get("/tasks/sync", response_model=list[schemas.Archive], deprecated=True) # DEPRECATED -def search(skip: int = 0, limit: int = 100, db: Session = Depends(get_db), email=Depends(get_user_auth)): +def search(skip: int = 0, limit: int = 100, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)): return crud.search_archives_by_email(db, email, skip=skip, limit=limit) @@ -90,7 +90,7 @@ def archive_tasks(archive: schemas.ArchiveCreate, email=Depends(get_token_or_use @app.get("/archive/{task_id}", deprecated=True) # DEPRECATED -def lookup(task_id, db: Session = Depends(get_db), email=Depends(get_token_or_user_auth)): +def lookup(task_id, db: Session = Depends(get_db_dependency), email=Depends(get_token_or_user_auth)): return crud.get_archive(db, task_id, email) @@ -123,7 +123,7 @@ def get_status(task_id, email=Depends(get_token_or_user_auth)): @app.delete("/tasks/{task_id}", deprecated=True) # DEPRECATED -def delete_task(task_id, db: Session = Depends(get_db), email=Depends(get_user_auth)): +def delete_task(task_id, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)): logger.info(f"deleting task {task_id} request by {email}") return JSONResponse({ "id": task_id, diff --git a/src/migrations/env.py b/src/migrations/env.py index 9b95d75..cabd992 100644 --- a/src/migrations/env.py +++ b/src/migrations/env.py @@ -5,10 +5,12 @@ from sqlalchemy import pool from alembic import context +from shared.settings import Settings + # this is the Alembic Config object, which provides # access to the values within the .ini file in use. config = context.config -config.set_main_option('sqlalchemy.url', os.environ.get("DATABASE_PATH")) +config.set_main_option('sqlalchemy.url', Settings().DATABASE_PATH) # Interpret the config file for Python logging. # This line sets up loggers basically. if config.config_file_name is not None: diff --git a/src/security.py b/src/security.py index 860a396..03ef4d5 100644 --- a/src/security.py +++ b/src/security.py @@ -2,21 +2,18 @@ from loguru import logger import requests, os, secrets from fastapi import HTTPException, status, Depends from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials -from core.config import CHROME_APP_IDS, BLOCKED_EMAILS - -# Configuration checks -assert len(CHROME_APP_IDS) > 0, "CHROME_APP_IDS env variable not properly set, it's a csv" -for app_id in CHROME_APP_IDS: - assert len(app_id) > 10, f"CHROME_APP_IDS got invalid id: {app_id} env variable not set" - -# Auth logic -bearer_security = HTTPBearer() +from shared.settings import Settings ALLOW_ANY_EMAIL = "*" +settings = Settings() +bearer_security = HTTPBearer() + + def secure_compare(token, api_key): return secrets.compare_digest(token.encode("utf8"), api_key.encode("utf8")) + # Factory method to create an authentication dependency for a specific key def api_key_auth(api_key): @@ -35,9 +32,10 @@ def api_key_auth(api_key): return auth + # --------------------- Token Auth for AA itself to query the API, AA setup tool and Prometheus -API_BEARER_TOKEN = os.environ.get("API_BEARER_TOKEN", "") # min length is 20 chars -token_api_key_auth = api_key_auth(API_BEARER_TOKEN) +token_api_key_auth = api_key_auth(settings.API_BEARER_TOKEN) + async def get_token_or_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)): # tries to use the static API_KEY and defaults to google JWT auth @@ -45,6 +43,7 @@ async def get_token_or_user_auth(credentials: HTTPAuthorizationCredentials = Dep if token_api_key_auth(access_token, auto_error=False): return ALLOW_ANY_EMAIL return await get_user_auth(credentials) + async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)): # validates the Bearer token in the case that it requires it valid_user, info = authenticate_user(credentials.credentials) @@ -56,6 +55,7 @@ async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bear headers={"WWW-Authenticate": "Bearer"}, ) + def authenticate_user(access_token): # https://cloud.google.com/docs/authentication/token-types#access if type(access_token) != str or len(access_token) < 10: return False, "invalid access_token" @@ -63,9 +63,9 @@ def authenticate_user(access_token): if r.status_code != 200: return False, "error occurred" try: j = r.json() - if j.get("azp") not in CHROME_APP_IDS and j.get("aud") not in CHROME_APP_IDS: + if j.get("azp") not in settings.CHROME_APP_IDS and j.get("aud") not in settings.CHROME_APP_IDS: return False, f"token does not belong to valid APP_ID" - if j.get("email") in BLOCKED_EMAILS: + if j.get("email") in settings.BLOCKED_EMAILS: return False, f"email '{j.get('email')}' not allowed" if j.get("email_verified") != "true": return False, f"email '{j.get('email')}' not verified" @@ -75,4 +75,3 @@ def authenticate_user(access_token): except Exception as e: logger.warning(f"EXCEPTION occurred: {e}") return False, f"EXCEPTION occurred" - diff --git a/src/shared/settings.py b/src/shared/settings.py new file mode 100644 index 0000000..dc216d3 --- /dev/null +++ b/src/shared/settings.py @@ -0,0 +1,31 @@ + +from pydantic_settings import BaseSettings +from pydantic import ConfigDict +from typing import Annotated, Set +from annotated_types import Len + + +class Settings(BaseSettings): + model_config = ConfigDict(extra='ignore', str_strip_whitespace=True) + + # general + SERVE_LOCAL_ARCHIVE: str = "" + USER_GROUPS_FILENAME: str = "user-groups.yaml" + + # database + DATABASE_PATH: str + DATABASE_QUERY_LIMIT: int = 100 + # redis + CELERY_BROKER_URL: str = "redis://localhost:6379" + CELERY_RESULT_BACKEND: str = "redis://localhost:6379" + REDIS_EXCEPTIONS_CHANNEL: str = "exceptions-channel" + + # observability + REPEAT_COUNT_METRICS_SECONDS: int = 15 + + # security + API_BEARER_TOKEN: Annotated[str, Len(min_length=20)] + ALLOWED_ORIGINS: Annotated[set[str], Len(min_length=1)] + CHROME_APP_IDS: Annotated[set[Annotated[str, Len(min_length=10)]], Len(min_length=1)] + BLOCKED_EMAILS: Annotated[Set[str], Len(min_length=0)] = set() + diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 6d8a561..62c57bd 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -1,6 +1,7 @@ -import pytest import os +import pytest from unittest.mock import patch +from shared.settings import Settings @pytest.fixture(autouse=True) def mock_logger_add(): @@ -8,6 +9,52 @@ def mock_logger_add(): with patch('loguru.logger.add') as mock_add: yield mock_add # This makes the mock available to tests -os.environ["CHROME_APP_IDS"] = 'test_app_id_1,test_app_id_2' -os.environ["DATABASE_PATH"] = "sqlite:////app/auto-archiver.test.db" -os.environ["BLOCKED_EMAILS"] = "blocked@example.com" \ No newline at end of file +# @pytest.fixture(autouse=True) +# def settings(): +# return Settings(_env_file=".env.test") + + +@pytest.fixture(autouse=True) +def settings(): + with patch('shared.settings.Settings', return_value=Settings(_env_file=".env.test")) as mock_settings: + yield mock_settings + + +@pytest.fixture() +def test_db(settings): + from db.database import make_engine, make_session_local + from db import models + + engine = make_engine(settings.DATABASE_PATH) + + if not os.path.exists(settings.DATABASE_PATH): + open(settings.DATABASE_PATH, 'w').close() + + models.Base.metadata.create_all(engine) + + connection = engine.connect() + yield connection + connection.close() + + models.Base.metadata.drop_all(bind=engine) + os.remove(settings.DATABASE_PATH) + +# @pytest.fixture() +# def db_session(test_db): +# session_local = make_session_local(test_db) +# with session_local() as session: +# yield session + +# # create test data and insert it into the database +# def create_test_data(): +# from db.database import SessionLocal +# from db.models import Task + +# db = SessionLocal() +# task = Task(id="test-task-id", status="PENDING") +# db.add(task) +# db.commit() +# db.refresh(task) +# db.close() + +# return task.id \ No newline at end of file diff --git a/src/tests/endpoints/test_task.py b/src/tests/endpoints/test_task.py index 87a5489..7510a34 100644 --- a/src/tests/endpoints/test_task.py +++ b/src/tests/endpoints/test_task.py @@ -1,8 +1,7 @@ -from unittest.mock import AsyncMock, patch +from unittest.mock import patch from fastapi.testclient import TestClient - def setup_client(): from main import app from security import get_token_or_user_auth @@ -10,6 +9,7 @@ def setup_client(): app.dependency_overrides[get_token_or_user_auth] = mock_get_token_or_user_auth return TestClient(app), app + @patch("endpoints.task.AsyncResult") def test_get_status_success(mock_async_result): client, app = setup_client() diff --git a/src/tests/test_main.py b/src/tests/test_main.py index 4c26073..f19f6ae 100644 --- a/src/tests/test_main.py +++ b/src/tests/test_main.py @@ -1,22 +1,23 @@ import os +from unittest.mock import patch from fastapi.testclient import TestClient def test_serve_local_archive_logic(): - os.environ["SERVE_LOCAL_ARCHIVE"] = "/app/local_archive_test" + with patch("main.settings.SERVE_LOCAL_ARCHIVE", "/app/local_archive_test"): - # create a test file - os.makedirs("local_archive_test", exist_ok=True) - with open("local_archive_test/temp.txt", "w") as f: - f.write("test") + # create a test file + os.makedirs("local_archive_test", exist_ok=True) + with open("local_archive_test/temp.txt", "w") as f: + f.write("test") - from main import app, setup_local_archive_serve - setup_local_archive_serve() - client = TestClient(app) + from main import app, setup_local_archive_serve + setup_local_archive_serve() + client = TestClient(app) - r = client.get("/app/local_archive_test/temp.txt") - assert r.status_code == 200 - assert r.text == "test" + r = client.get("/app/local_archive_test/temp.txt") + assert r.status_code == 200 + assert r.text == "test" - os.remove("local_archive_test/temp.txt") - os.rmdir("local_archive_test") + os.remove("local_archive_test/temp.txt") + os.rmdir("local_archive_test") diff --git a/src/utils/metrics.py b/src/utils/metrics.py index 70ff157..f706154 100644 --- a/src/utils/metrics.py +++ b/src/utils/metrics.py @@ -6,10 +6,8 @@ from loguru import logger from prometheus_client import Counter, Gauge from sqlalchemy.orm import Session -from core.config import REPEAT_COUNT_METRICS_SECONDS from db import crud from db.database import get_db -from core.config import SQLALCHEMY_DATABASE_URL from worker import REDIS_EXCEPTIONS_CHANNEL, Rdis @@ -47,20 +45,20 @@ async def redis_subscribe_worker_exceptions(): WORKER_EXCEPTION.labels(exception=data["exception"], task=data["task"]).inc() await asyncio.sleep(1) -async def measure_regular_metrics(): +async def measure_regular_metrics(sqlite_db_url:str, repeat_in_seconds:int): _total, used, free = shutil.disk_usage("/") DISK_UTILIZATION.labels(type="used").set(used / (2**30)) DISK_UTILIZATION.labels(type="free").set(free / (2**30)) try: - fs = os.stat(SQLALCHEMY_DATABASE_URL.replace("sqlite:///", "")) + fs = os.stat(sqlite_db_url.replace("sqlite:///", "")) DISK_UTILIZATION.labels(type="database").set(fs.st_size / (2**30)) except Exception as e: logger.error(e) - session: Session = next(get_db()) - count_archives = crud.count_archives(session) - count_archive_urls = crud.count_archive_urls(session) - DATABASE_METRICS.labels(query="count_archives", user="-").set(count_archives) - DATABASE_METRICS.labels(query="count_archive_urls", user="-").set(count_archive_urls) + with get_db as db: + count_archives = crud.count_archives(db) + count_archive_urls = crud.count_archive_urls(db) + DATABASE_METRICS.labels(query="count_archives", user="-").set(count_archives) + DATABASE_METRICS.labels(query="count_archive_urls", user="-").set(count_archive_urls) - for user in crud.count_by_user_since(session, REPEAT_COUNT_METRICS_SECONDS): - DATABASE_METRICS.labels(query="count_by_user", user=user.author_id).set(user.total) \ No newline at end of file + for user in crud.count_by_user_since(db, repeat_in_seconds): + DATABASE_METRICS.labels(query="count_by_user", user=user.author_id).set(user.total) \ No newline at end of file diff --git a/src/worker.py b/src/worker.py index a2220d7..513ebf1 100644 --- a/src/worker.py +++ b/src/worker.py @@ -1,5 +1,5 @@ -import os, traceback, yaml, datetime, sys +import traceback, yaml, datetime from typing import List, Set from celery import Celery @@ -9,29 +9,25 @@ from auto_archiver.core import Media from loguru import logger from db import crud, schemas, models -from db.database import SessionLocal -from contextlib import contextmanager +from db.database import get_db +from shared.settings import Settings import json import redis from sqlalchemy import exc +settings = Settings() + celery = Celery(__name__) -celery.conf.broker_url = os.environ.get("CELERY_BROKER_URL", "redis://localhost:6379") -celery.conf.result_backend = os.environ.get("CELERY_RESULT_BACKEND", "redis://localhost:6379") -USER_GROUPS_FILENAME = os.environ.get("USER_GROUPS_FILENAME", "user-groups.yaml") -REDIS_EXCEPTIONS_CHANNEL = "exceptions-channel" +celery.conf.broker_url = settings.CELERY_BROKER_URL +celery.conf.result_backend = settings.CELERY_RESULT_BACKEND +USER_GROUPS_FILENAME = settings.USER_GROUPS_FILENAME +REDIS_EXCEPTIONS_CHANNEL = settings.REDIS_EXCEPTIONS_CHANNEL + Rdis = redis.Redis.from_url(celery.conf.broker_url) -@contextmanager -def get_db(): - session = SessionLocal() - try: yield session - finally: session.close() - - @celery.task(name="create_archive_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 3}) def create_archive_task(self, archive_json: str): - archive = schemas.ArchiveCreate.parse_raw(archive_json) + archive = schemas.ArchiveCreate.model_validate_json(archive_json) logger.info(f"Archiving {archive.url=} {archive.tags=} {archive.public=} {archive.group_id=} {archive.author_id=}") invalid = is_group_invalid_for_user(archive.public, archive.group_id, archive.author_id) if invalid: @@ -63,7 +59,7 @@ def create_archive_task(self, archive_json: str): @celery.task(name="create_sheet_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 0}) def create_sheet_task(self, sheet_json: str): - sheet = schemas.SubmitSheet.parse_raw(sheet_json) + sheet = schemas.SubmitSheet.model_validate_json(sheet_json) sheet.tags.add("gsheet") logger.info(f"SHEET START {sheet=}")