refactor to use pydantic settings and WAL sqlite mode

This commit is contained in:
msramalho
2024-10-18 11:08:51 +01:00
parent 11a8e6f4e0
commit ca7e862855
21 changed files with 246 additions and 135 deletions

7
src/.env.test Normal file
View File

@@ -0,0 +1,7 @@
CHROME_APP_IDS='["test_app_id_1","test_app_id_2"]'
ALLOWED_ORIGINS='["chrome-extension://example1","chrome-extension://example2","http://localhost:8081"]'
BLOCKED_EMAILS='["blocked@example.com"]'
DATABASE_PATH="sqlite:////app/auto-archiver.test.db"
API_BEARER_TOKEN=this_is_the_test_api_token

View File

@@ -19,6 +19,7 @@ alembic = "*"
fastapi-utils = "*"
prometheus-fastapi-instrumentator = "*"
auto-archiver = "*"
pydantic-settings = "*"
[dev-packages]
watchdog = "*"

19
src/Pipfile.lock generated
View File

@@ -1,7 +1,7 @@
{
"_meta": {
"hash": {
"sha256": "359638472cb3c3914fac7040bc702463c0ed1ae2e5cee00abddd59d9b34e923e"
"sha256": "c34b5745f3a6f67222d3f26e6c7f2d13615a3301d0ca4d1f2b0ec58474b1d43a"
},
"pipfile-spec": 6,
"requires": {
@@ -1470,11 +1470,11 @@
},
"marshmallow": {
"hashes": [
"sha256:4972f529104a220bb8637d595aa4c9762afbe7f7a77d82dc58c1615d70c5823e",
"sha256:71a2dce49ef901c3f97ed296ae5051135fd3febd2bf43afe0ae9a82143a494d9"
"sha256:82f20a2397834fe6d9611b241f2f7e7b680ed89c49f84728a1ad937be6b4bdf4",
"sha256:98d8827a9f10c03d44ead298d2e99c6aea8197df18ccfad360dae7f89a50da2e"
],
"markers": "python_version >= '3.8'",
"version": "==3.22.0"
"markers": "python_version >= '3.9'",
"version": "==3.23.0"
},
"mccabe": {
"hashes": [
@@ -2182,6 +2182,15 @@
"markers": "python_version >= '3.8'",
"version": "==2.23.4"
},
"pydantic-settings": {
"hashes": [
"sha256:44a1804abffac9e6a30372bb45f6cafab945ef5af25e66b1c634c01dd39e0188",
"sha256:4a819166f119b74d7f8c765196b165f95cc7487ce58ea27dec8a5a26be0970e0"
],
"index": "pypi",
"markers": "python_version >= '3.8'",
"version": "==2.6.0"
},
"pyflakes": {
"hashes": [
"sha256:1c61603ff154621fb2a9172037d84dca3500def8c8b630657d1701f026f8af3f",

View File

@@ -1,5 +1,3 @@
import os
VERSION = "0.7.0"
API_DESCRIPTION = """
#### API for the Auto-Archiver project, a tool to archive web pages and Google Sheets.
@@ -9,16 +7,4 @@ API_DESCRIPTION = """
- You can use this API to archive single URLs or entire Google Sheets.
- Once you submit a URL or Sheet for archiving, the API will return a task_id that you can use to check the status of the archiving process. It works asynchronously.
"""
ALLOWED_ORIGINS = os.environ.get("ALLOWED_ORIGINS", "chrome-extension://ondkcheoicfckabcnkdgbepofpjmjcmb,chrome-extension://ojcimmjndnlmmlgnjaeojoebaceokpdp").split(",")
BREAKING_CHANGES = {"minVersion": "0.3.1", "message": "The latest update has breaking changes, please update the extension to the most recent version."}
SERVE_LOCAL_ARCHIVE = os.environ.get("SERVE_LOCAL_ARCHIVE", "")
SQLALCHEMY_DATABASE_URL = os.environ.get("DATABASE_PATH")
REPEAT_COUNT_METRICS_SECONDS = 15
CHROME_APP_IDS = set([app_id.strip() for app_id in os.environ.get("CHROME_APP_IDS", "").split(",")])
BLOCKED_EMAILS = set([e.strip().lower() for e in os.environ.get("BLOCKED_EMAILS", "").split(",")])
BREAKING_CHANGES = {"minVersion": "0.3.1", "message": "The latest update has breaking changes, please update the extension to the most recent version."}

View File

@@ -2,15 +2,16 @@ import asyncio
import logging
import alembic.config
from fastapi import FastAPI
from sqlalchemy.orm import Session
from contextlib import asynccontextmanager
from fastapi_utils.tasks import repeat_every
from loguru import logger
from db import crud, models
from db.database import get_db, engine
from db.database import get_db, make_engine
from shared.settings import Settings
from utils.metrics import measure_regular_metrics, redis_subscribe_worker_exceptions
from core.config import REPEAT_COUNT_METRICS_SECONDS
settings = Settings()
@asynccontextmanager
@@ -18,13 +19,14 @@ async def lifespan(app: FastAPI):
# see https://fastapi.tiangolo.com/advanced/events/#lifespan
# STARTUP
engine = make_engine(settings.DATABASE_PATH)
models.Base.metadata.create_all(bind=engine)
alembic.config.main(argv=['--raiseerr', 'upgrade', 'head'])
# disabling uvicorn logger since we use loguru in logging_middleware
logging.getLogger("uvicorn.access").disabled = True
asyncio.create_task(redis_subscribe_worker_exceptions())
asyncio.create_task(refresh_user_groups())
asyncio.create_task(measure_regular_metrics())
asyncio.create_task(repeat_measure_regular_metrics())
yield # separates startup from shutdown instructions
@@ -36,9 +38,10 @@ async def lifespan(app: FastAPI):
@repeat_every(seconds=60 * 60) # 1 hour
async def refresh_user_groups():
db: Session = next(get_db())
crud.upsert_user_groups(db)
with get_db() as db:
crud.upsert_user_groups(db)
@repeat_every(seconds=REPEAT_COUNT_METRICS_SECONDS)
@repeat_every(seconds=settings.REPEAT_COUNT_METRICS_SECONDS)
async def repeat_measure_regular_metrics():
measure_regular_metrics()
measure_regular_metrics(settings.DATABASE_PATH, settings.REPEAT_COUNT_METRICS_SECONDS)

View File

@@ -5,12 +5,13 @@ from loguru import logger
from datetime import datetime, timedelta
from security import ALLOW_ANY_EMAIL
from shared.settings import Settings
from . import models, schemas
import yaml, os
import yaml
DOMAIN_GROUPS = {}
DOMAIN_GROUPS_LOADED = False
MAX_LIMIT = 100
DATABASE_QUERY_LIMIT = Settings().DATABASE_QUERY_LIMIT
# --------------- TASK = Archive
@@ -39,12 +40,12 @@ def search_archives_by_url(db: Session, url: str, email: str, skip: int = 0, lim
query = query.filter(models.Archive.created_at >= archived_after)
if archived_before:
query = query.filter(models.Archive.created_at <= archived_before)
return query.order_by(models.Archive.created_at.desc()).offset(skip).limit(min(limit, MAX_LIMIT)).all()
return query.order_by(models.Archive.created_at.desc()).offset(skip).limit(min(limit, DATABASE_QUERY_LIMIT)).all()
def search_archives_by_email(db: Session, email: str, skip: int = 0, limit: int = 100):
email = email.lower()
return base_query(db).filter(models.Archive.author.has(email=email)).offset(skip).limit(min(limit, MAX_LIMIT)).all()
return base_query(db).filter(models.Archive.author.has(email=email)).offset(skip).limit(min(limit, DATABASE_QUERY_LIMIT)).all()
def create_task(db: Session, task: schemas.ArchiveCreate, tags: list[models.Tag], urls: list[models.ArchiveUrl]):
@@ -76,7 +77,7 @@ def count_by_user_since(db:Session, seconds_delta: int = 15):
return db.query(models.Archive.author_id,func.count().label('total'))\
.filter(models.Archive.created_at >= time_threshold)\
.group_by(models.Archive.author_id)\
.order_by(func.count().desc()).limit(5 * MAX_LIMIT).all()
.order_by(func.count().desc()).limit(5 * DATABASE_QUERY_LIMIT).all()
def base_query(db: Session):
# allow only some fields to be returned, for example author should remain hidden
@@ -98,7 +99,7 @@ def create_tag(db: Session, tag: str):
def search_tags(db: Session, tag: str, skip: int = 0, limit: int = 100):
return db.query(models.Tag).filter(models.Tag.url.like(f'%{tag}%')).offset(skip).limit(min(limit, MAX_LIMIT)).all()
return db.query(models.Tag).filter(models.Tag.url.like(f'%{tag}%')).offset(skip).limit(min(limit, DATABASE_QUERY_LIMIT)).all()
def is_user_in_group(db: Session, group_name: str, email: str) -> models.Group:
@@ -148,7 +149,7 @@ def upsert_user_groups(db: Session):
along with new participation of users in groups
"""
logger.debug("Updating user-groups configuration.")
filename = os.environ.get("USER_GROUPS_FILENAME", "user-groups.yaml")
filename = Settings().USER_GROUPS_FILENAME
# read yaml safely
with open(filename) as inf:

View File

@@ -1,17 +1,36 @@
from sqlalchemy import create_engine
from sqlalchemy import Engine, create_engine, event
from sqlalchemy.orm import sessionmaker, declarative_base
from core.config import SQLALCHEMY_DATABASE_URL
from shared.settings import Settings
from contextlib import contextmanager
engine = create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
settings = Settings()
Base = declarative_base()
def make_engine(database_url: str):
engine = create_engine(database_url, connect_args={"check_same_thread": False})
@event.listens_for(engine, "connect")
def set_sqlite_pragma(conn, _) -> None:
cursor = conn.cursor()
cursor.execute("PRAGMA journal_mode=WAL")
cursor.close()
return engine
def make_session_local(engine: Engine):
session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
return session_local
@contextmanager
def get_db():
session = SessionLocal()
session = make_session_local(make_engine(settings.DATABASE_PATH))()
try: yield session
finally: session.close()
def get_db_dependency():
# to use with Depends and ensure proper session closing
with get_db() as db:
yield db

View File

@@ -1,8 +1,10 @@
from sqlalchemy import Column, String, JSON, DateTime, Boolean, Table, ForeignKey
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
from sqlalchemy.orm import relationship, declarative_base
import uuid
from .database import Base
Base = declarative_base()
def generate_uuid():
return str(uuid.uuid4())
@@ -59,7 +61,6 @@ class Tag(Base):
archives = relationship("Archive", back_populates="tags", secondary=association_table_archive_tags)
class User(Base):
__tablename__ = "users"

View File

@@ -2,6 +2,13 @@ from pydantic import BaseModel
from datetime import datetime
class Tag(BaseModel):
id: str
created_at: datetime
model_config = { "from_attributes": True }
__hash__ = object.__hash__
class ArchiveCreate(BaseModel):
id: str | None = None
url: str
@@ -9,7 +16,7 @@ class ArchiveCreate(BaseModel):
public: bool = True
author_id: str | None = None
group_id: str | None = None
tags: set = set()
tags: set[Tag] | None = set()
rearchive: bool = True
# urls: list = []
@@ -28,7 +35,7 @@ class SubmitSheet(BaseModel):
public: bool = False
author_id: str | None = None
group_id: str | None = None
tags: set | None = set()
tags: set[Tag] | None = set()
columns: dict | None = {} # TODO: implement
class SubmitManual(BaseModel):
@@ -36,7 +43,7 @@ class SubmitManual(BaseModel):
public: bool = False
author_id: str | None = None
group_id: str | None = None
tags: set | None = set()
tags: set[Tag] | None = set()
class Task(BaseModel):
id: str

View File

@@ -6,20 +6,20 @@ from sqlalchemy.orm import Session
from core.config import VERSION, BREAKING_CHANGES
from db import crud
from db.database import get_db
from db.database import get_db_dependency, get_db
from security import get_user_auth, bearer_security
default_router = APIRouter()
@default_router.get("/")
async def home(request: Request):
# TODO: maybe split into 2 routes: one non authenticated and one authenticated for the groups info only
status = {"version": VERSION, "breakingChanges": BREAKING_CHANGES}
try:
email = await get_user_auth(await bearer_security(request))
db: Session = next(get_db())
status["groups"] = crud.get_user_groups(db, email)
with get_db() as db:
status["groups"] = crud.get_user_groups(db, email)
except HTTPException: pass # not authenticated is fine
except Exception as e: logger.error(e)
return JSONResponse(status)
@@ -29,8 +29,9 @@ async def home(request: Request):
async def health():
return JSONResponse({"status": "ok"})
@default_router.get("/groups", response_model=list[str])
def get_user_groups(db: Session = Depends(get_db), email=Depends(get_user_auth)):
def get_user_groups(db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
return crud.get_user_groups(db, email)

View File

@@ -8,7 +8,7 @@ from security import get_user_auth, get_token_or_user_auth
from sqlalchemy.orm import Session
from db import crud, schemas
from db.database import get_db
from db.database import get_db_dependency
from worker import create_archive_task
@@ -32,23 +32,23 @@ def archive_url(archive: schemas.ArchiveCreate, email=Depends(get_token_or_user_
def search_by_url(
url: str, skip: int = 0, limit: int = 25,
archived_after: datetime = None, archived_before: datetime = None,
db: Session = Depends(get_db),
db: Session = Depends(get_db_dependency),
email=Depends(get_token_or_user_auth)):
return crud.search_archives_by_url(db, url.strip(), email, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before)
@url_router.get("/latest", response_model=list[schemas.Archive], summary="Fetch latest URL archives for the authenticated user.")
def latest(skip: int = 0, limit: int = 25, db: Session = Depends(get_db), email=Depends(get_user_auth)):
def latest(skip: int = 0, limit: int = 25, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
return crud.search_archives_by_email(db, email, skip=skip, limit=limit)
@url_router.get("/{id}", response_model=schemas.Archive, summary="Fetch a single URL archive by the associated id.")
def lookup(id, db: Session = Depends(get_db), email=Depends(get_token_or_user_auth)):
def lookup(id, db: Session = Depends(get_db_dependency), email=Depends(get_token_or_user_auth)):
return crud.get_archive(db, id, email)
@url_router.delete("/{id}", response_model=schemas.TaskDelete, summary="Delete a single URL archive by id.")
def delete_task(id, db: Session = Depends(get_db), email=Depends(get_user_auth)):
def delete_task(id, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
logger.info(f"deleting url archive task {id} request by {email}")
#TODO: use response model?
return JSONResponse({

View File

@@ -16,14 +16,16 @@ from worker import create_archive_task, create_sheet_task, celery, insert_result
from db import crud, models, schemas
from security import get_user_auth, token_api_key_auth, get_token_or_user_auth
from core.config import ALLOWED_ORIGINS, VERSION, SERVE_LOCAL_ARCHIVE, API_DESCRIPTION
from db.database import get_db
from core.config import VERSION, API_DESCRIPTION
from db.database import get_db_dependency
from core.events import lifespan
from shared.settings import Settings
from auto_archiver import Metadata
from endpoints import default_router, url_router, sheet_router, task_router, interoperability_router
settings = Settings()
app = FastAPI(
title="Auto-Archiver API",
@@ -35,7 +37,7 @@ app = FastAPI(
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_origins=settings.ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
@@ -48,17 +50,15 @@ app.include_router(task_router)
app.include_router(interoperability_router)
# prometheus exposed in /metrics with authentication
Instrumentator(should_group_status_codes=False, excluded_handlers=["/metrics"]).instrument(app).expose(app, dependencies=[Depends(token_api_key_auth)])
Instrumentator(should_group_status_codes=False, excluded_handlers=["/metrics", "/health"]).instrument(app).expose(app, dependencies=[Depends(token_api_key_auth)])
def setup_local_archive_serve():
# if env SERVE_LOCAL_ARCHIVE is set it serves files from that dir, useful for development and using local_archive
SERVE_LOCAL_ARCHIVE = os.environ.get("SERVE_LOCAL_ARCHIVE", "")
local_dir = SERVE_LOCAL_ARCHIVE
local_dir = settings.SERVE_LOCAL_ARCHIVE
if not os.path.isdir(local_dir) and os.path.isdir(local_dir.replace("/app", ".")):
local_dir = local_dir.replace("/app", ".")
if len(SERVE_LOCAL_ARCHIVE) > 1 and os.path.isdir(local_dir):
logger.warning(f"MOUNTing local archive {SERVE_LOCAL_ARCHIVE}")
app.mount(SERVE_LOCAL_ARCHIVE, StaticFiles(directory=local_dir), name=SERVE_LOCAL_ARCHIVE)
if len(settings.SERVE_LOCAL_ARCHIVE) > 1 and os.path.isdir(local_dir):
logger.warning(f"MOUNTing local archive {settings.SERVE_LOCAL_ARCHIVE}")
app.mount(settings.SERVE_LOCAL_ARCHIVE, StaticFiles(directory=local_dir), name=settings.SERVE_LOCAL_ARCHIVE)
setup_local_archive_serve()
@@ -68,12 +68,12 @@ app.middleware("http")(logging_middleware)
@app.get("/tasks/search-url", response_model=list[schemas.Archive], deprecated=True) # DEPRECATED
def search_by_url(url: str, skip: int = 0, limit: int = 100, archived_after: datetime = None, archived_before: datetime = None, db: Session = Depends(get_db), email=Depends(get_token_or_user_auth)):
def search_by_url(url: str, skip: int = 0, limit: int = 100, archived_after: datetime = None, archived_before: datetime = None, db: Session = Depends(get_db_dependency), email=Depends(get_token_or_user_auth)):
return crud.search_archives_by_url(db, url.strip(), email, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before)
@app.get("/tasks/sync", response_model=list[schemas.Archive], deprecated=True) # DEPRECATED
def search(skip: int = 0, limit: int = 100, db: Session = Depends(get_db), email=Depends(get_user_auth)):
def search(skip: int = 0, limit: int = 100, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
return crud.search_archives_by_email(db, email, skip=skip, limit=limit)
@@ -90,7 +90,7 @@ def archive_tasks(archive: schemas.ArchiveCreate, email=Depends(get_token_or_use
@app.get("/archive/{task_id}", deprecated=True) # DEPRECATED
def lookup(task_id, db: Session = Depends(get_db), email=Depends(get_token_or_user_auth)):
def lookup(task_id, db: Session = Depends(get_db_dependency), email=Depends(get_token_or_user_auth)):
return crud.get_archive(db, task_id, email)
@@ -123,7 +123,7 @@ def get_status(task_id, email=Depends(get_token_or_user_auth)):
@app.delete("/tasks/{task_id}", deprecated=True) # DEPRECATED
def delete_task(task_id, db: Session = Depends(get_db), email=Depends(get_user_auth)):
def delete_task(task_id, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
logger.info(f"deleting task {task_id} request by {email}")
return JSONResponse({
"id": task_id,

View File

@@ -5,10 +5,12 @@ from sqlalchemy import pool
from alembic import context
from shared.settings import Settings
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
config.set_main_option('sqlalchemy.url', os.environ.get("DATABASE_PATH"))
config.set_main_option('sqlalchemy.url', Settings().DATABASE_PATH)
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:

View File

@@ -2,21 +2,18 @@ from loguru import logger
import requests, os, secrets
from fastapi import HTTPException, status, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from core.config import CHROME_APP_IDS, BLOCKED_EMAILS
# Configuration checks
assert len(CHROME_APP_IDS) > 0, "CHROME_APP_IDS env variable not properly set, it's a csv"
for app_id in CHROME_APP_IDS:
assert len(app_id) > 10, f"CHROME_APP_IDS got invalid id: {app_id} env variable not set"
# Auth logic
bearer_security = HTTPBearer()
from shared.settings import Settings
ALLOW_ANY_EMAIL = "*"
settings = Settings()
bearer_security = HTTPBearer()
def secure_compare(token, api_key):
return secrets.compare_digest(token.encode("utf8"), api_key.encode("utf8"))
# Factory method to create an authentication dependency for a specific key
def api_key_auth(api_key):
@@ -35,9 +32,10 @@ def api_key_auth(api_key):
return auth
# --------------------- Token Auth for AA itself to query the API, AA setup tool and Prometheus
API_BEARER_TOKEN = os.environ.get("API_BEARER_TOKEN", "") # min length is 20 chars
token_api_key_auth = api_key_auth(API_BEARER_TOKEN)
token_api_key_auth = api_key_auth(settings.API_BEARER_TOKEN)
async def get_token_or_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
# tries to use the static API_KEY and defaults to google JWT auth
@@ -45,6 +43,7 @@ async def get_token_or_user_auth(credentials: HTTPAuthorizationCredentials = Dep
if token_api_key_auth(access_token, auto_error=False): return ALLOW_ANY_EMAIL
return await get_user_auth(credentials)
async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
# validates the Bearer token in the case that it requires it
valid_user, info = authenticate_user(credentials.credentials)
@@ -56,6 +55,7 @@ async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bear
headers={"WWW-Authenticate": "Bearer"},
)
def authenticate_user(access_token):
# https://cloud.google.com/docs/authentication/token-types#access
if type(access_token) != str or len(access_token) < 10: return False, "invalid access_token"
@@ -63,9 +63,9 @@ def authenticate_user(access_token):
if r.status_code != 200: return False, "error occurred"
try:
j = r.json()
if j.get("azp") not in CHROME_APP_IDS and j.get("aud") not in CHROME_APP_IDS:
if j.get("azp") not in settings.CHROME_APP_IDS and j.get("aud") not in settings.CHROME_APP_IDS:
return False, f"token does not belong to valid APP_ID"
if j.get("email") in BLOCKED_EMAILS:
if j.get("email") in settings.BLOCKED_EMAILS:
return False, f"email '{j.get('email')}' not allowed"
if j.get("email_verified") != "true":
return False, f"email '{j.get('email')}' not verified"
@@ -75,4 +75,3 @@ def authenticate_user(access_token):
except Exception as e:
logger.warning(f"EXCEPTION occurred: {e}")
return False, f"EXCEPTION occurred"

31
src/shared/settings.py Normal file
View File

@@ -0,0 +1,31 @@
from pydantic_settings import BaseSettings
from pydantic import ConfigDict
from typing import Annotated, Set
from annotated_types import Len
class Settings(BaseSettings):
model_config = ConfigDict(extra='ignore', str_strip_whitespace=True)
# general
SERVE_LOCAL_ARCHIVE: str = ""
USER_GROUPS_FILENAME: str = "user-groups.yaml"
# database
DATABASE_PATH: str
DATABASE_QUERY_LIMIT: int = 100
# redis
CELERY_BROKER_URL: str = "redis://localhost:6379"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379"
REDIS_EXCEPTIONS_CHANNEL: str = "exceptions-channel"
# observability
REPEAT_COUNT_METRICS_SECONDS: int = 15
# security
API_BEARER_TOKEN: Annotated[str, Len(min_length=20)]
ALLOWED_ORIGINS: Annotated[set[str], Len(min_length=1)]
CHROME_APP_IDS: Annotated[set[Annotated[str, Len(min_length=10)]], Len(min_length=1)]
BLOCKED_EMAILS: Annotated[Set[str], Len(min_length=0)] = set()

View File

@@ -1,6 +1,7 @@
import pytest
import os
import pytest
from unittest.mock import patch
from shared.settings import Settings
@pytest.fixture(autouse=True)
def mock_logger_add():
@@ -8,6 +9,52 @@ def mock_logger_add():
with patch('loguru.logger.add') as mock_add:
yield mock_add # This makes the mock available to tests
os.environ["CHROME_APP_IDS"] = 'test_app_id_1,test_app_id_2'
os.environ["DATABASE_PATH"] = "sqlite:////app/auto-archiver.test.db"
os.environ["BLOCKED_EMAILS"] = "blocked@example.com"
# @pytest.fixture(autouse=True)
# def settings():
# return Settings(_env_file=".env.test")
@pytest.fixture(autouse=True)
def settings():
with patch('shared.settings.Settings', return_value=Settings(_env_file=".env.test")) as mock_settings:
yield mock_settings
@pytest.fixture()
def test_db(settings):
from db.database import make_engine, make_session_local
from db import models
engine = make_engine(settings.DATABASE_PATH)
if not os.path.exists(settings.DATABASE_PATH):
open(settings.DATABASE_PATH, 'w').close()
models.Base.metadata.create_all(engine)
connection = engine.connect()
yield connection
connection.close()
models.Base.metadata.drop_all(bind=engine)
os.remove(settings.DATABASE_PATH)
# @pytest.fixture()
# def db_session(test_db):
# session_local = make_session_local(test_db)
# with session_local() as session:
# yield session
# # create test data and insert it into the database
# def create_test_data():
# from db.database import SessionLocal
# from db.models import Task
# db = SessionLocal()
# task = Task(id="test-task-id", status="PENDING")
# db.add(task)
# db.commit()
# db.refresh(task)
# db.close()
# return task.id

View File

@@ -1,8 +1,7 @@
from unittest.mock import AsyncMock, patch
from unittest.mock import patch
from fastapi.testclient import TestClient
def setup_client():
from main import app
from security import get_token_or_user_auth
@@ -10,6 +9,7 @@ def setup_client():
app.dependency_overrides[get_token_or_user_auth] = mock_get_token_or_user_auth
return TestClient(app), app
@patch("endpoints.task.AsyncResult")
def test_get_status_success(mock_async_result):
client, app = setup_client()

View File

@@ -1,22 +1,23 @@
import os
from unittest.mock import patch
from fastapi.testclient import TestClient
def test_serve_local_archive_logic():
os.environ["SERVE_LOCAL_ARCHIVE"] = "/app/local_archive_test"
with patch("main.settings.SERVE_LOCAL_ARCHIVE", "/app/local_archive_test"):
# create a test file
os.makedirs("local_archive_test", exist_ok=True)
with open("local_archive_test/temp.txt", "w") as f:
f.write("test")
# create a test file
os.makedirs("local_archive_test", exist_ok=True)
with open("local_archive_test/temp.txt", "w") as f:
f.write("test")
from main import app, setup_local_archive_serve
setup_local_archive_serve()
client = TestClient(app)
from main import app, setup_local_archive_serve
setup_local_archive_serve()
client = TestClient(app)
r = client.get("/app/local_archive_test/temp.txt")
assert r.status_code == 200
assert r.text == "test"
r = client.get("/app/local_archive_test/temp.txt")
assert r.status_code == 200
assert r.text == "test"
os.remove("local_archive_test/temp.txt")
os.rmdir("local_archive_test")
os.remove("local_archive_test/temp.txt")
os.rmdir("local_archive_test")

View File

@@ -6,10 +6,8 @@ from loguru import logger
from prometheus_client import Counter, Gauge
from sqlalchemy.orm import Session
from core.config import REPEAT_COUNT_METRICS_SECONDS
from db import crud
from db.database import get_db
from core.config import SQLALCHEMY_DATABASE_URL
from worker import REDIS_EXCEPTIONS_CHANNEL, Rdis
@@ -47,20 +45,20 @@ async def redis_subscribe_worker_exceptions():
WORKER_EXCEPTION.labels(exception=data["exception"], task=data["task"]).inc()
await asyncio.sleep(1)
async def measure_regular_metrics():
async def measure_regular_metrics(sqlite_db_url:str, repeat_in_seconds:int):
_total, used, free = shutil.disk_usage("/")
DISK_UTILIZATION.labels(type="used").set(used / (2**30))
DISK_UTILIZATION.labels(type="free").set(free / (2**30))
try:
fs = os.stat(SQLALCHEMY_DATABASE_URL.replace("sqlite:///", ""))
fs = os.stat(sqlite_db_url.replace("sqlite:///", ""))
DISK_UTILIZATION.labels(type="database").set(fs.st_size / (2**30))
except Exception as e: logger.error(e)
session: Session = next(get_db())
count_archives = crud.count_archives(session)
count_archive_urls = crud.count_archive_urls(session)
DATABASE_METRICS.labels(query="count_archives", user="-").set(count_archives)
DATABASE_METRICS.labels(query="count_archive_urls", user="-").set(count_archive_urls)
with get_db as db:
count_archives = crud.count_archives(db)
count_archive_urls = crud.count_archive_urls(db)
DATABASE_METRICS.labels(query="count_archives", user="-").set(count_archives)
DATABASE_METRICS.labels(query="count_archive_urls", user="-").set(count_archive_urls)
for user in crud.count_by_user_since(session, REPEAT_COUNT_METRICS_SECONDS):
DATABASE_METRICS.labels(query="count_by_user", user=user.author_id).set(user.total)
for user in crud.count_by_user_since(db, repeat_in_seconds):
DATABASE_METRICS.labels(query="count_by_user", user=user.author_id).set(user.total)

View File

@@ -1,5 +1,5 @@
import os, traceback, yaml, datetime, sys
import traceback, yaml, datetime
from typing import List, Set
from celery import Celery
@@ -9,29 +9,25 @@ from auto_archiver.core import Media
from loguru import logger
from db import crud, schemas, models
from db.database import SessionLocal
from contextlib import contextmanager
from db.database import get_db
from shared.settings import Settings
import json
import redis
from sqlalchemy import exc
settings = Settings()
celery = Celery(__name__)
celery.conf.broker_url = os.environ.get("CELERY_BROKER_URL", "redis://localhost:6379")
celery.conf.result_backend = os.environ.get("CELERY_RESULT_BACKEND", "redis://localhost:6379")
USER_GROUPS_FILENAME = os.environ.get("USER_GROUPS_FILENAME", "user-groups.yaml")
REDIS_EXCEPTIONS_CHANNEL = "exceptions-channel"
celery.conf.broker_url = settings.CELERY_BROKER_URL
celery.conf.result_backend = settings.CELERY_RESULT_BACKEND
USER_GROUPS_FILENAME = settings.USER_GROUPS_FILENAME
REDIS_EXCEPTIONS_CHANNEL = settings.REDIS_EXCEPTIONS_CHANNEL
Rdis = redis.Redis.from_url(celery.conf.broker_url)
@contextmanager
def get_db():
session = SessionLocal()
try: yield session
finally: session.close()
@celery.task(name="create_archive_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 3})
def create_archive_task(self, archive_json: str):
archive = schemas.ArchiveCreate.parse_raw(archive_json)
archive = schemas.ArchiveCreate.model_validate_json(archive_json)
logger.info(f"Archiving {archive.url=} {archive.tags=} {archive.public=} {archive.group_id=} {archive.author_id=}")
invalid = is_group_invalid_for_user(archive.public, archive.group_id, archive.author_id)
if invalid:
@@ -63,7 +59,7 @@ def create_archive_task(self, archive_json: str):
@celery.task(name="create_sheet_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 0})
def create_sheet_task(self, sheet_json: str):
sheet = schemas.SubmitSheet.parse_raw(sheet_json)
sheet = schemas.SubmitSheet.model_validate_json(sheet_json)
sheet.tags.add("gsheet")
logger.info(f"SHEET START {sheet=}")