mirror of
https://github.com/bellingcat/auto-archiver-api.git
synced 2026-06-11 13:08:34 +03:00
refactor to use pydantic settings and WAL sqlite mode
This commit is contained in:
7
src/.env.test
Normal file
7
src/.env.test
Normal file
@@ -0,0 +1,7 @@
|
||||
CHROME_APP_IDS='["test_app_id_1","test_app_id_2"]'
|
||||
ALLOWED_ORIGINS='["chrome-extension://example1","chrome-extension://example2","http://localhost:8081"]'
|
||||
BLOCKED_EMAILS='["blocked@example.com"]'
|
||||
|
||||
|
||||
DATABASE_PATH="sqlite:////app/auto-archiver.test.db"
|
||||
API_BEARER_TOKEN=this_is_the_test_api_token
|
||||
@@ -19,6 +19,7 @@ alembic = "*"
|
||||
fastapi-utils = "*"
|
||||
prometheus-fastapi-instrumentator = "*"
|
||||
auto-archiver = "*"
|
||||
pydantic-settings = "*"
|
||||
|
||||
[dev-packages]
|
||||
watchdog = "*"
|
||||
|
||||
19
src/Pipfile.lock
generated
19
src/Pipfile.lock
generated
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"_meta": {
|
||||
"hash": {
|
||||
"sha256": "359638472cb3c3914fac7040bc702463c0ed1ae2e5cee00abddd59d9b34e923e"
|
||||
"sha256": "c34b5745f3a6f67222d3f26e6c7f2d13615a3301d0ca4d1f2b0ec58474b1d43a"
|
||||
},
|
||||
"pipfile-spec": 6,
|
||||
"requires": {
|
||||
@@ -1470,11 +1470,11 @@
|
||||
},
|
||||
"marshmallow": {
|
||||
"hashes": [
|
||||
"sha256:4972f529104a220bb8637d595aa4c9762afbe7f7a77d82dc58c1615d70c5823e",
|
||||
"sha256:71a2dce49ef901c3f97ed296ae5051135fd3febd2bf43afe0ae9a82143a494d9"
|
||||
"sha256:82f20a2397834fe6d9611b241f2f7e7b680ed89c49f84728a1ad937be6b4bdf4",
|
||||
"sha256:98d8827a9f10c03d44ead298d2e99c6aea8197df18ccfad360dae7f89a50da2e"
|
||||
],
|
||||
"markers": "python_version >= '3.8'",
|
||||
"version": "==3.22.0"
|
||||
"markers": "python_version >= '3.9'",
|
||||
"version": "==3.23.0"
|
||||
},
|
||||
"mccabe": {
|
||||
"hashes": [
|
||||
@@ -2182,6 +2182,15 @@
|
||||
"markers": "python_version >= '3.8'",
|
||||
"version": "==2.23.4"
|
||||
},
|
||||
"pydantic-settings": {
|
||||
"hashes": [
|
||||
"sha256:44a1804abffac9e6a30372bb45f6cafab945ef5af25e66b1c634c01dd39e0188",
|
||||
"sha256:4a819166f119b74d7f8c765196b165f95cc7487ce58ea27dec8a5a26be0970e0"
|
||||
],
|
||||
"index": "pypi",
|
||||
"markers": "python_version >= '3.8'",
|
||||
"version": "==2.6.0"
|
||||
},
|
||||
"pyflakes": {
|
||||
"hashes": [
|
||||
"sha256:1c61603ff154621fb2a9172037d84dca3500def8c8b630657d1701f026f8af3f",
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import os
|
||||
|
||||
VERSION = "0.7.0"
|
||||
API_DESCRIPTION = """
|
||||
#### API for the Auto-Archiver project, a tool to archive web pages and Google Sheets.
|
||||
@@ -9,16 +7,4 @@ API_DESCRIPTION = """
|
||||
- You can use this API to archive single URLs or entire Google Sheets.
|
||||
- Once you submit a URL or Sheet for archiving, the API will return a task_id that you can use to check the status of the archiving process. It works asynchronously.
|
||||
"""
|
||||
|
||||
ALLOWED_ORIGINS = os.environ.get("ALLOWED_ORIGINS", "chrome-extension://ondkcheoicfckabcnkdgbepofpjmjcmb,chrome-extension://ojcimmjndnlmmlgnjaeojoebaceokpdp").split(",")
|
||||
|
||||
BREAKING_CHANGES = {"minVersion": "0.3.1", "message": "The latest update has breaking changes, please update the extension to the most recent version."}
|
||||
|
||||
SERVE_LOCAL_ARCHIVE = os.environ.get("SERVE_LOCAL_ARCHIVE", "")
|
||||
|
||||
SQLALCHEMY_DATABASE_URL = os.environ.get("DATABASE_PATH")
|
||||
|
||||
REPEAT_COUNT_METRICS_SECONDS = 15
|
||||
|
||||
CHROME_APP_IDS = set([app_id.strip() for app_id in os.environ.get("CHROME_APP_IDS", "").split(",")])
|
||||
BLOCKED_EMAILS = set([e.strip().lower() for e in os.environ.get("BLOCKED_EMAILS", "").split(",")])
|
||||
BREAKING_CHANGES = {"minVersion": "0.3.1", "message": "The latest update has breaking changes, please update the extension to the most recent version."}
|
||||
@@ -2,15 +2,16 @@ import asyncio
|
||||
import logging
|
||||
import alembic.config
|
||||
from fastapi import FastAPI
|
||||
from sqlalchemy.orm import Session
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi_utils.tasks import repeat_every
|
||||
from loguru import logger
|
||||
|
||||
from db import crud, models
|
||||
from db.database import get_db, engine
|
||||
from db.database import get_db, make_engine
|
||||
from shared.settings import Settings
|
||||
from utils.metrics import measure_regular_metrics, redis_subscribe_worker_exceptions
|
||||
from core.config import REPEAT_COUNT_METRICS_SECONDS
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -18,13 +19,14 @@ async def lifespan(app: FastAPI):
|
||||
# see https://fastapi.tiangolo.com/advanced/events/#lifespan
|
||||
|
||||
# STARTUP
|
||||
engine = make_engine(settings.DATABASE_PATH)
|
||||
models.Base.metadata.create_all(bind=engine)
|
||||
alembic.config.main(argv=['--raiseerr', 'upgrade', 'head'])
|
||||
# disabling uvicorn logger since we use loguru in logging_middleware
|
||||
logging.getLogger("uvicorn.access").disabled = True
|
||||
asyncio.create_task(redis_subscribe_worker_exceptions())
|
||||
asyncio.create_task(refresh_user_groups())
|
||||
asyncio.create_task(measure_regular_metrics())
|
||||
asyncio.create_task(repeat_measure_regular_metrics())
|
||||
|
||||
yield # separates startup from shutdown instructions
|
||||
|
||||
@@ -36,9 +38,10 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
@repeat_every(seconds=60 * 60) # 1 hour
|
||||
async def refresh_user_groups():
|
||||
db: Session = next(get_db())
|
||||
crud.upsert_user_groups(db)
|
||||
with get_db() as db:
|
||||
crud.upsert_user_groups(db)
|
||||
|
||||
@repeat_every(seconds=REPEAT_COUNT_METRICS_SECONDS)
|
||||
|
||||
@repeat_every(seconds=settings.REPEAT_COUNT_METRICS_SECONDS)
|
||||
async def repeat_measure_regular_metrics():
|
||||
measure_regular_metrics()
|
||||
measure_regular_metrics(settings.DATABASE_PATH, settings.REPEAT_COUNT_METRICS_SECONDS)
|
||||
|
||||
@@ -5,12 +5,13 @@ from loguru import logger
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from security import ALLOW_ANY_EMAIL
|
||||
from shared.settings import Settings
|
||||
from . import models, schemas
|
||||
import yaml, os
|
||||
import yaml
|
||||
|
||||
DOMAIN_GROUPS = {}
|
||||
DOMAIN_GROUPS_LOADED = False
|
||||
MAX_LIMIT = 100
|
||||
DATABASE_QUERY_LIMIT = Settings().DATABASE_QUERY_LIMIT
|
||||
|
||||
# --------------- TASK = Archive
|
||||
|
||||
@@ -39,12 +40,12 @@ def search_archives_by_url(db: Session, url: str, email: str, skip: int = 0, lim
|
||||
query = query.filter(models.Archive.created_at >= archived_after)
|
||||
if archived_before:
|
||||
query = query.filter(models.Archive.created_at <= archived_before)
|
||||
return query.order_by(models.Archive.created_at.desc()).offset(skip).limit(min(limit, MAX_LIMIT)).all()
|
||||
return query.order_by(models.Archive.created_at.desc()).offset(skip).limit(min(limit, DATABASE_QUERY_LIMIT)).all()
|
||||
|
||||
|
||||
def search_archives_by_email(db: Session, email: str, skip: int = 0, limit: int = 100):
|
||||
email = email.lower()
|
||||
return base_query(db).filter(models.Archive.author.has(email=email)).offset(skip).limit(min(limit, MAX_LIMIT)).all()
|
||||
return base_query(db).filter(models.Archive.author.has(email=email)).offset(skip).limit(min(limit, DATABASE_QUERY_LIMIT)).all()
|
||||
|
||||
|
||||
def create_task(db: Session, task: schemas.ArchiveCreate, tags: list[models.Tag], urls: list[models.ArchiveUrl]):
|
||||
@@ -76,7 +77,7 @@ def count_by_user_since(db:Session, seconds_delta: int = 15):
|
||||
return db.query(models.Archive.author_id,func.count().label('total'))\
|
||||
.filter(models.Archive.created_at >= time_threshold)\
|
||||
.group_by(models.Archive.author_id)\
|
||||
.order_by(func.count().desc()).limit(5 * MAX_LIMIT).all()
|
||||
.order_by(func.count().desc()).limit(5 * DATABASE_QUERY_LIMIT).all()
|
||||
|
||||
def base_query(db: Session):
|
||||
# allow only some fields to be returned, for example author should remain hidden
|
||||
@@ -98,7 +99,7 @@ def create_tag(db: Session, tag: str):
|
||||
|
||||
|
||||
def search_tags(db: Session, tag: str, skip: int = 0, limit: int = 100):
|
||||
return db.query(models.Tag).filter(models.Tag.url.like(f'%{tag}%')).offset(skip).limit(min(limit, MAX_LIMIT)).all()
|
||||
return db.query(models.Tag).filter(models.Tag.url.like(f'%{tag}%')).offset(skip).limit(min(limit, DATABASE_QUERY_LIMIT)).all()
|
||||
|
||||
|
||||
def is_user_in_group(db: Session, group_name: str, email: str) -> models.Group:
|
||||
@@ -148,7 +149,7 @@ def upsert_user_groups(db: Session):
|
||||
along with new participation of users in groups
|
||||
"""
|
||||
logger.debug("Updating user-groups configuration.")
|
||||
filename = os.environ.get("USER_GROUPS_FILENAME", "user-groups.yaml")
|
||||
filename = Settings().USER_GROUPS_FILENAME
|
||||
|
||||
# read yaml safely
|
||||
with open(filename) as inf:
|
||||
|
||||
@@ -1,17 +1,36 @@
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import Engine, create_engine, event
|
||||
from sqlalchemy.orm import sessionmaker, declarative_base
|
||||
from core.config import SQLALCHEMY_DATABASE_URL
|
||||
from shared.settings import Settings
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
||||
)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
settings = Settings()
|
||||
|
||||
Base = declarative_base()
|
||||
def make_engine(database_url: str):
|
||||
engine = create_engine(database_url, connect_args={"check_same_thread": False})
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def set_sqlite_pragma(conn, _) -> None:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.close()
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
def make_session_local(engine: Engine):
|
||||
session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
return session_local
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db():
|
||||
session = SessionLocal()
|
||||
session = make_session_local(make_engine(settings.DATABASE_PATH))()
|
||||
try: yield session
|
||||
finally: session.close()
|
||||
|
||||
|
||||
def get_db_dependency():
|
||||
# to use with Depends and ensure proper session closing
|
||||
with get_db() as db:
|
||||
yield db
|
||||
@@ -1,8 +1,10 @@
|
||||
from sqlalchemy import Column, String, JSON, DateTime, Boolean, Table, ForeignKey
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.orm import relationship, declarative_base
|
||||
import uuid
|
||||
from .database import Base
|
||||
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
def generate_uuid():
|
||||
return str(uuid.uuid4())
|
||||
@@ -59,7 +61,6 @@ class Tag(Base):
|
||||
|
||||
archives = relationship("Archive", back_populates="tags", secondary=association_table_archive_tags)
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
|
||||
@@ -2,6 +2,13 @@ from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class Tag(BaseModel):
|
||||
id: str
|
||||
created_at: datetime
|
||||
|
||||
model_config = { "from_attributes": True }
|
||||
__hash__ = object.__hash__
|
||||
|
||||
class ArchiveCreate(BaseModel):
|
||||
id: str | None = None
|
||||
url: str
|
||||
@@ -9,7 +16,7 @@ class ArchiveCreate(BaseModel):
|
||||
public: bool = True
|
||||
author_id: str | None = None
|
||||
group_id: str | None = None
|
||||
tags: set = set()
|
||||
tags: set[Tag] | None = set()
|
||||
rearchive: bool = True
|
||||
# urls: list = []
|
||||
|
||||
@@ -28,7 +35,7 @@ class SubmitSheet(BaseModel):
|
||||
public: bool = False
|
||||
author_id: str | None = None
|
||||
group_id: str | None = None
|
||||
tags: set | None = set()
|
||||
tags: set[Tag] | None = set()
|
||||
columns: dict | None = {} # TODO: implement
|
||||
|
||||
class SubmitManual(BaseModel):
|
||||
@@ -36,7 +43,7 @@ class SubmitManual(BaseModel):
|
||||
public: bool = False
|
||||
author_id: str | None = None
|
||||
group_id: str | None = None
|
||||
tags: set | None = set()
|
||||
tags: set[Tag] | None = set()
|
||||
|
||||
class Task(BaseModel):
|
||||
id: str
|
||||
|
||||
@@ -6,20 +6,20 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from core.config import VERSION, BREAKING_CHANGES
|
||||
from db import crud
|
||||
from db.database import get_db
|
||||
from db.database import get_db_dependency, get_db
|
||||
from security import get_user_auth, bearer_security
|
||||
|
||||
|
||||
default_router = APIRouter()
|
||||
|
||||
|
||||
@default_router.get("/")
|
||||
async def home(request: Request):
|
||||
# TODO: maybe split into 2 routes: one non authenticated and one authenticated for the groups info only
|
||||
status = {"version": VERSION, "breakingChanges": BREAKING_CHANGES}
|
||||
try:
|
||||
email = await get_user_auth(await bearer_security(request))
|
||||
db: Session = next(get_db())
|
||||
status["groups"] = crud.get_user_groups(db, email)
|
||||
with get_db() as db:
|
||||
status["groups"] = crud.get_user_groups(db, email)
|
||||
except HTTPException: pass # not authenticated is fine
|
||||
except Exception as e: logger.error(e)
|
||||
return JSONResponse(status)
|
||||
@@ -29,8 +29,9 @@ async def home(request: Request):
|
||||
async def health():
|
||||
return JSONResponse({"status": "ok"})
|
||||
|
||||
|
||||
@default_router.get("/groups", response_model=list[str])
|
||||
def get_user_groups(db: Session = Depends(get_db), email=Depends(get_user_auth)):
|
||||
def get_user_groups(db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
|
||||
return crud.get_user_groups(db, email)
|
||||
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from security import get_user_auth, get_token_or_user_auth
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from db import crud, schemas
|
||||
from db.database import get_db
|
||||
from db.database import get_db_dependency
|
||||
|
||||
from worker import create_archive_task
|
||||
|
||||
@@ -32,23 +32,23 @@ def archive_url(archive: schemas.ArchiveCreate, email=Depends(get_token_or_user_
|
||||
def search_by_url(
|
||||
url: str, skip: int = 0, limit: int = 25,
|
||||
archived_after: datetime = None, archived_before: datetime = None,
|
||||
db: Session = Depends(get_db),
|
||||
db: Session = Depends(get_db_dependency),
|
||||
email=Depends(get_token_or_user_auth)):
|
||||
return crud.search_archives_by_url(db, url.strip(), email, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before)
|
||||
|
||||
|
||||
@url_router.get("/latest", response_model=list[schemas.Archive], summary="Fetch latest URL archives for the authenticated user.")
|
||||
def latest(skip: int = 0, limit: int = 25, db: Session = Depends(get_db), email=Depends(get_user_auth)):
|
||||
def latest(skip: int = 0, limit: int = 25, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
|
||||
return crud.search_archives_by_email(db, email, skip=skip, limit=limit)
|
||||
|
||||
|
||||
@url_router.get("/{id}", response_model=schemas.Archive, summary="Fetch a single URL archive by the associated id.")
|
||||
def lookup(id, db: Session = Depends(get_db), email=Depends(get_token_or_user_auth)):
|
||||
def lookup(id, db: Session = Depends(get_db_dependency), email=Depends(get_token_or_user_auth)):
|
||||
return crud.get_archive(db, id, email)
|
||||
|
||||
|
||||
@url_router.delete("/{id}", response_model=schemas.TaskDelete, summary="Delete a single URL archive by id.")
|
||||
def delete_task(id, db: Session = Depends(get_db), email=Depends(get_user_auth)):
|
||||
def delete_task(id, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
|
||||
logger.info(f"deleting url archive task {id} request by {email}")
|
||||
#TODO: use response model?
|
||||
return JSONResponse({
|
||||
|
||||
28
src/main.py
28
src/main.py
@@ -16,14 +16,16 @@ from worker import create_archive_task, create_sheet_task, celery, insert_result
|
||||
|
||||
from db import crud, models, schemas
|
||||
from security import get_user_auth, token_api_key_auth, get_token_or_user_auth
|
||||
from core.config import ALLOWED_ORIGINS, VERSION, SERVE_LOCAL_ARCHIVE, API_DESCRIPTION
|
||||
from db.database import get_db
|
||||
from core.config import VERSION, API_DESCRIPTION
|
||||
from db.database import get_db_dependency
|
||||
from core.events import lifespan
|
||||
from shared.settings import Settings
|
||||
|
||||
from auto_archiver import Metadata
|
||||
|
||||
from endpoints import default_router, url_router, sheet_router, task_router, interoperability_router
|
||||
|
||||
settings = Settings()
|
||||
|
||||
app = FastAPI(
|
||||
title="Auto-Archiver API",
|
||||
@@ -35,7 +37,7 @@ app = FastAPI(
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_origins=settings.ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
@@ -48,17 +50,15 @@ app.include_router(task_router)
|
||||
app.include_router(interoperability_router)
|
||||
|
||||
# prometheus exposed in /metrics with authentication
|
||||
Instrumentator(should_group_status_codes=False, excluded_handlers=["/metrics"]).instrument(app).expose(app, dependencies=[Depends(token_api_key_auth)])
|
||||
Instrumentator(should_group_status_codes=False, excluded_handlers=["/metrics", "/health"]).instrument(app).expose(app, dependencies=[Depends(token_api_key_auth)])
|
||||
|
||||
def setup_local_archive_serve():
|
||||
# if env SERVE_LOCAL_ARCHIVE is set it serves files from that dir, useful for development and using local_archive
|
||||
SERVE_LOCAL_ARCHIVE = os.environ.get("SERVE_LOCAL_ARCHIVE", "")
|
||||
local_dir = SERVE_LOCAL_ARCHIVE
|
||||
local_dir = settings.SERVE_LOCAL_ARCHIVE
|
||||
if not os.path.isdir(local_dir) and os.path.isdir(local_dir.replace("/app", ".")):
|
||||
local_dir = local_dir.replace("/app", ".")
|
||||
if len(SERVE_LOCAL_ARCHIVE) > 1 and os.path.isdir(local_dir):
|
||||
logger.warning(f"MOUNTing local archive {SERVE_LOCAL_ARCHIVE}")
|
||||
app.mount(SERVE_LOCAL_ARCHIVE, StaticFiles(directory=local_dir), name=SERVE_LOCAL_ARCHIVE)
|
||||
if len(settings.SERVE_LOCAL_ARCHIVE) > 1 and os.path.isdir(local_dir):
|
||||
logger.warning(f"MOUNTing local archive {settings.SERVE_LOCAL_ARCHIVE}")
|
||||
app.mount(settings.SERVE_LOCAL_ARCHIVE, StaticFiles(directory=local_dir), name=settings.SERVE_LOCAL_ARCHIVE)
|
||||
setup_local_archive_serve()
|
||||
|
||||
|
||||
@@ -68,12 +68,12 @@ app.middleware("http")(logging_middleware)
|
||||
|
||||
|
||||
@app.get("/tasks/search-url", response_model=list[schemas.Archive], deprecated=True) # DEPRECATED
|
||||
def search_by_url(url: str, skip: int = 0, limit: int = 100, archived_after: datetime = None, archived_before: datetime = None, db: Session = Depends(get_db), email=Depends(get_token_or_user_auth)):
|
||||
def search_by_url(url: str, skip: int = 0, limit: int = 100, archived_after: datetime = None, archived_before: datetime = None, db: Session = Depends(get_db_dependency), email=Depends(get_token_or_user_auth)):
|
||||
return crud.search_archives_by_url(db, url.strip(), email, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before)
|
||||
|
||||
|
||||
@app.get("/tasks/sync", response_model=list[schemas.Archive], deprecated=True) # DEPRECATED
|
||||
def search(skip: int = 0, limit: int = 100, db: Session = Depends(get_db), email=Depends(get_user_auth)):
|
||||
def search(skip: int = 0, limit: int = 100, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
|
||||
return crud.search_archives_by_email(db, email, skip=skip, limit=limit)
|
||||
|
||||
|
||||
@@ -90,7 +90,7 @@ def archive_tasks(archive: schemas.ArchiveCreate, email=Depends(get_token_or_use
|
||||
|
||||
|
||||
@app.get("/archive/{task_id}", deprecated=True) # DEPRECATED
|
||||
def lookup(task_id, db: Session = Depends(get_db), email=Depends(get_token_or_user_auth)):
|
||||
def lookup(task_id, db: Session = Depends(get_db_dependency), email=Depends(get_token_or_user_auth)):
|
||||
return crud.get_archive(db, task_id, email)
|
||||
|
||||
|
||||
@@ -123,7 +123,7 @@ def get_status(task_id, email=Depends(get_token_or_user_auth)):
|
||||
|
||||
|
||||
@app.delete("/tasks/{task_id}", deprecated=True) # DEPRECATED
|
||||
def delete_task(task_id, db: Session = Depends(get_db), email=Depends(get_user_auth)):
|
||||
def delete_task(task_id, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
|
||||
logger.info(f"deleting task {task_id} request by {email}")
|
||||
return JSONResponse({
|
||||
"id": task_id,
|
||||
|
||||
@@ -5,10 +5,12 @@ from sqlalchemy import pool
|
||||
|
||||
from alembic import context
|
||||
|
||||
from shared.settings import Settings
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
config.set_main_option('sqlalchemy.url', os.environ.get("DATABASE_PATH"))
|
||||
config.set_main_option('sqlalchemy.url', Settings().DATABASE_PATH)
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
|
||||
@@ -2,21 +2,18 @@ from loguru import logger
|
||||
import requests, os, secrets
|
||||
from fastapi import HTTPException, status, Depends
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from core.config import CHROME_APP_IDS, BLOCKED_EMAILS
|
||||
|
||||
# Configuration checks
|
||||
assert len(CHROME_APP_IDS) > 0, "CHROME_APP_IDS env variable not properly set, it's a csv"
|
||||
for app_id in CHROME_APP_IDS:
|
||||
assert len(app_id) > 10, f"CHROME_APP_IDS got invalid id: {app_id} env variable not set"
|
||||
|
||||
# Auth logic
|
||||
bearer_security = HTTPBearer()
|
||||
from shared.settings import Settings
|
||||
|
||||
ALLOW_ANY_EMAIL = "*"
|
||||
|
||||
settings = Settings()
|
||||
bearer_security = HTTPBearer()
|
||||
|
||||
|
||||
def secure_compare(token, api_key):
|
||||
return secrets.compare_digest(token.encode("utf8"), api_key.encode("utf8"))
|
||||
|
||||
|
||||
# Factory method to create an authentication dependency for a specific key
|
||||
def api_key_auth(api_key):
|
||||
|
||||
@@ -35,9 +32,10 @@ def api_key_auth(api_key):
|
||||
|
||||
return auth
|
||||
|
||||
|
||||
# --------------------- Token Auth for AA itself to query the API, AA setup tool and Prometheus
|
||||
API_BEARER_TOKEN = os.environ.get("API_BEARER_TOKEN", "") # min length is 20 chars
|
||||
token_api_key_auth = api_key_auth(API_BEARER_TOKEN)
|
||||
token_api_key_auth = api_key_auth(settings.API_BEARER_TOKEN)
|
||||
|
||||
|
||||
async def get_token_or_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
|
||||
# tries to use the static API_KEY and defaults to google JWT auth
|
||||
@@ -45,6 +43,7 @@ async def get_token_or_user_auth(credentials: HTTPAuthorizationCredentials = Dep
|
||||
if token_api_key_auth(access_token, auto_error=False): return ALLOW_ANY_EMAIL
|
||||
return await get_user_auth(credentials)
|
||||
|
||||
|
||||
async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
|
||||
# validates the Bearer token in the case that it requires it
|
||||
valid_user, info = authenticate_user(credentials.credentials)
|
||||
@@ -56,6 +55,7 @@ async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bear
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
def authenticate_user(access_token):
|
||||
# https://cloud.google.com/docs/authentication/token-types#access
|
||||
if type(access_token) != str or len(access_token) < 10: return False, "invalid access_token"
|
||||
@@ -63,9 +63,9 @@ def authenticate_user(access_token):
|
||||
if r.status_code != 200: return False, "error occurred"
|
||||
try:
|
||||
j = r.json()
|
||||
if j.get("azp") not in CHROME_APP_IDS and j.get("aud") not in CHROME_APP_IDS:
|
||||
if j.get("azp") not in settings.CHROME_APP_IDS and j.get("aud") not in settings.CHROME_APP_IDS:
|
||||
return False, f"token does not belong to valid APP_ID"
|
||||
if j.get("email") in BLOCKED_EMAILS:
|
||||
if j.get("email") in settings.BLOCKED_EMAILS:
|
||||
return False, f"email '{j.get('email')}' not allowed"
|
||||
if j.get("email_verified") != "true":
|
||||
return False, f"email '{j.get('email')}' not verified"
|
||||
@@ -75,4 +75,3 @@ def authenticate_user(access_token):
|
||||
except Exception as e:
|
||||
logger.warning(f"EXCEPTION occurred: {e}")
|
||||
return False, f"EXCEPTION occurred"
|
||||
|
||||
|
||||
31
src/shared/settings.py
Normal file
31
src/shared/settings.py
Normal file
@@ -0,0 +1,31 @@
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import ConfigDict
|
||||
from typing import Annotated, Set
|
||||
from annotated_types import Len
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = ConfigDict(extra='ignore', str_strip_whitespace=True)
|
||||
|
||||
# general
|
||||
SERVE_LOCAL_ARCHIVE: str = ""
|
||||
USER_GROUPS_FILENAME: str = "user-groups.yaml"
|
||||
|
||||
# database
|
||||
DATABASE_PATH: str
|
||||
DATABASE_QUERY_LIMIT: int = 100
|
||||
# redis
|
||||
CELERY_BROKER_URL: str = "redis://localhost:6379"
|
||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379"
|
||||
REDIS_EXCEPTIONS_CHANNEL: str = "exceptions-channel"
|
||||
|
||||
# observability
|
||||
REPEAT_COUNT_METRICS_SECONDS: int = 15
|
||||
|
||||
# security
|
||||
API_BEARER_TOKEN: Annotated[str, Len(min_length=20)]
|
||||
ALLOWED_ORIGINS: Annotated[set[str], Len(min_length=1)]
|
||||
CHROME_APP_IDS: Annotated[set[Annotated[str, Len(min_length=10)]], Len(min_length=1)]
|
||||
BLOCKED_EMAILS: Annotated[Set[str], Len(min_length=0)] = set()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from shared.settings import Settings
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_logger_add():
|
||||
@@ -8,6 +9,52 @@ def mock_logger_add():
|
||||
with patch('loguru.logger.add') as mock_add:
|
||||
yield mock_add # This makes the mock available to tests
|
||||
|
||||
os.environ["CHROME_APP_IDS"] = 'test_app_id_1,test_app_id_2'
|
||||
os.environ["DATABASE_PATH"] = "sqlite:////app/auto-archiver.test.db"
|
||||
os.environ["BLOCKED_EMAILS"] = "blocked@example.com"
|
||||
# @pytest.fixture(autouse=True)
|
||||
# def settings():
|
||||
# return Settings(_env_file=".env.test")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def settings():
|
||||
with patch('shared.settings.Settings', return_value=Settings(_env_file=".env.test")) as mock_settings:
|
||||
yield mock_settings
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_db(settings):
|
||||
from db.database import make_engine, make_session_local
|
||||
from db import models
|
||||
|
||||
engine = make_engine(settings.DATABASE_PATH)
|
||||
|
||||
if not os.path.exists(settings.DATABASE_PATH):
|
||||
open(settings.DATABASE_PATH, 'w').close()
|
||||
|
||||
models.Base.metadata.create_all(engine)
|
||||
|
||||
connection = engine.connect()
|
||||
yield connection
|
||||
connection.close()
|
||||
|
||||
models.Base.metadata.drop_all(bind=engine)
|
||||
os.remove(settings.DATABASE_PATH)
|
||||
|
||||
# @pytest.fixture()
|
||||
# def db_session(test_db):
|
||||
# session_local = make_session_local(test_db)
|
||||
# with session_local() as session:
|
||||
# yield session
|
||||
|
||||
# # create test data and insert it into the database
|
||||
# def create_test_data():
|
||||
# from db.database import SessionLocal
|
||||
# from db.models import Task
|
||||
|
||||
# db = SessionLocal()
|
||||
# task = Task(id="test-task-id", status="PENDING")
|
||||
# db.add(task)
|
||||
# db.commit()
|
||||
# db.refresh(task)
|
||||
# db.close()
|
||||
|
||||
# return task.id
|
||||
@@ -1,8 +1,7 @@
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import patch
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
|
||||
def setup_client():
|
||||
from main import app
|
||||
from security import get_token_or_user_auth
|
||||
@@ -10,6 +9,7 @@ def setup_client():
|
||||
app.dependency_overrides[get_token_or_user_auth] = mock_get_token_or_user_auth
|
||||
return TestClient(app), app
|
||||
|
||||
|
||||
@patch("endpoints.task.AsyncResult")
|
||||
def test_get_status_success(mock_async_result):
|
||||
client, app = setup_client()
|
||||
|
||||
@@ -1,22 +1,23 @@
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
def test_serve_local_archive_logic():
|
||||
os.environ["SERVE_LOCAL_ARCHIVE"] = "/app/local_archive_test"
|
||||
with patch("main.settings.SERVE_LOCAL_ARCHIVE", "/app/local_archive_test"):
|
||||
|
||||
# create a test file
|
||||
os.makedirs("local_archive_test", exist_ok=True)
|
||||
with open("local_archive_test/temp.txt", "w") as f:
|
||||
f.write("test")
|
||||
# create a test file
|
||||
os.makedirs("local_archive_test", exist_ok=True)
|
||||
with open("local_archive_test/temp.txt", "w") as f:
|
||||
f.write("test")
|
||||
|
||||
from main import app, setup_local_archive_serve
|
||||
setup_local_archive_serve()
|
||||
client = TestClient(app)
|
||||
from main import app, setup_local_archive_serve
|
||||
setup_local_archive_serve()
|
||||
client = TestClient(app)
|
||||
|
||||
r = client.get("/app/local_archive_test/temp.txt")
|
||||
assert r.status_code == 200
|
||||
assert r.text == "test"
|
||||
r = client.get("/app/local_archive_test/temp.txt")
|
||||
assert r.status_code == 200
|
||||
assert r.text == "test"
|
||||
|
||||
os.remove("local_archive_test/temp.txt")
|
||||
os.rmdir("local_archive_test")
|
||||
os.remove("local_archive_test/temp.txt")
|
||||
os.rmdir("local_archive_test")
|
||||
|
||||
@@ -6,10 +6,8 @@ from loguru import logger
|
||||
from prometheus_client import Counter, Gauge
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.config import REPEAT_COUNT_METRICS_SECONDS
|
||||
from db import crud
|
||||
from db.database import get_db
|
||||
from core.config import SQLALCHEMY_DATABASE_URL
|
||||
from worker import REDIS_EXCEPTIONS_CHANNEL, Rdis
|
||||
|
||||
|
||||
@@ -47,20 +45,20 @@ async def redis_subscribe_worker_exceptions():
|
||||
WORKER_EXCEPTION.labels(exception=data["exception"], task=data["task"]).inc()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def measure_regular_metrics():
|
||||
async def measure_regular_metrics(sqlite_db_url:str, repeat_in_seconds:int):
|
||||
_total, used, free = shutil.disk_usage("/")
|
||||
DISK_UTILIZATION.labels(type="used").set(used / (2**30))
|
||||
DISK_UTILIZATION.labels(type="free").set(free / (2**30))
|
||||
try:
|
||||
fs = os.stat(SQLALCHEMY_DATABASE_URL.replace("sqlite:///", ""))
|
||||
fs = os.stat(sqlite_db_url.replace("sqlite:///", ""))
|
||||
DISK_UTILIZATION.labels(type="database").set(fs.st_size / (2**30))
|
||||
except Exception as e: logger.error(e)
|
||||
|
||||
session: Session = next(get_db())
|
||||
count_archives = crud.count_archives(session)
|
||||
count_archive_urls = crud.count_archive_urls(session)
|
||||
DATABASE_METRICS.labels(query="count_archives", user="-").set(count_archives)
|
||||
DATABASE_METRICS.labels(query="count_archive_urls", user="-").set(count_archive_urls)
|
||||
with get_db as db:
|
||||
count_archives = crud.count_archives(db)
|
||||
count_archive_urls = crud.count_archive_urls(db)
|
||||
DATABASE_METRICS.labels(query="count_archives", user="-").set(count_archives)
|
||||
DATABASE_METRICS.labels(query="count_archive_urls", user="-").set(count_archive_urls)
|
||||
|
||||
for user in crud.count_by_user_since(session, REPEAT_COUNT_METRICS_SECONDS):
|
||||
DATABASE_METRICS.labels(query="count_by_user", user=user.author_id).set(user.total)
|
||||
for user in crud.count_by_user_since(db, repeat_in_seconds):
|
||||
DATABASE_METRICS.labels(query="count_by_user", user=user.author_id).set(user.total)
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
import os, traceback, yaml, datetime, sys
|
||||
import traceback, yaml, datetime
|
||||
from typing import List, Set
|
||||
|
||||
from celery import Celery
|
||||
@@ -9,29 +9,25 @@ from auto_archiver.core import Media
|
||||
from loguru import logger
|
||||
|
||||
from db import crud, schemas, models
|
||||
from db.database import SessionLocal
|
||||
from contextlib import contextmanager
|
||||
from db.database import get_db
|
||||
from shared.settings import Settings
|
||||
import json
|
||||
import redis
|
||||
from sqlalchemy import exc
|
||||
|
||||
settings = Settings()
|
||||
|
||||
celery = Celery(__name__)
|
||||
celery.conf.broker_url = os.environ.get("CELERY_BROKER_URL", "redis://localhost:6379")
|
||||
celery.conf.result_backend = os.environ.get("CELERY_RESULT_BACKEND", "redis://localhost:6379")
|
||||
USER_GROUPS_FILENAME = os.environ.get("USER_GROUPS_FILENAME", "user-groups.yaml")
|
||||
REDIS_EXCEPTIONS_CHANNEL = "exceptions-channel"
|
||||
celery.conf.broker_url = settings.CELERY_BROKER_URL
|
||||
celery.conf.result_backend = settings.CELERY_RESULT_BACKEND
|
||||
USER_GROUPS_FILENAME = settings.USER_GROUPS_FILENAME
|
||||
REDIS_EXCEPTIONS_CHANNEL = settings.REDIS_EXCEPTIONS_CHANNEL
|
||||
|
||||
Rdis = redis.Redis.from_url(celery.conf.broker_url)
|
||||
|
||||
@contextmanager
|
||||
def get_db():
|
||||
session = SessionLocal()
|
||||
try: yield session
|
||||
finally: session.close()
|
||||
|
||||
|
||||
@celery.task(name="create_archive_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 3})
|
||||
def create_archive_task(self, archive_json: str):
|
||||
archive = schemas.ArchiveCreate.parse_raw(archive_json)
|
||||
archive = schemas.ArchiveCreate.model_validate_json(archive_json)
|
||||
logger.info(f"Archiving {archive.url=} {archive.tags=} {archive.public=} {archive.group_id=} {archive.author_id=}")
|
||||
invalid = is_group_invalid_for_user(archive.public, archive.group_id, archive.author_id)
|
||||
if invalid:
|
||||
@@ -63,7 +59,7 @@ def create_archive_task(self, archive_json: str):
|
||||
|
||||
@celery.task(name="create_sheet_task", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={'max_retries': 0})
|
||||
def create_sheet_task(self, sheet_json: str):
|
||||
sheet = schemas.SubmitSheet.parse_raw(sheet_json)
|
||||
sheet = schemas.SubmitSheet.model_validate_json(sheet_json)
|
||||
sheet.tags.add("gsheet")
|
||||
logger.info(f"SHEET START {sheet=}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user