mirror of
https://github.com/bellingcat/auto-archiver-api.git
synced 2026-06-08 03:28:35 +03:00
refactoring with app_factory
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -20,5 +20,6 @@ src/crawls
|
|||||||
.pytest_cache/*
|
.pytest_cache/*
|
||||||
htmlcov
|
htmlcov
|
||||||
local_archive
|
local_archive
|
||||||
|
local_archive_test
|
||||||
*db-wal
|
*db-wal
|
||||||
*db-shm
|
*db-shm
|
||||||
@@ -5,8 +5,8 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
- SERVE_LOCAL_ARCHIVE=/app/local_archive # See orchestration.yaml local_storage.save_to
|
- SERVE_LOCAL_ARCHIVE=/app/local_archive # See orchestration.yaml local_storage.save_to
|
||||||
- ALLOWED_ORIGINS=http://localhost:8004,chrome-extension://ojcimmjndnlmmlgnjaeojoebaceokpdp
|
- ALLOWED_ORIGINS=http://localhost:8004,chrome-extension://ojcimmjndnlmmlgnjaeojoebaceokpdp
|
||||||
- API_BEARER_TOKEN=dev-api-bearer-token
|
|
||||||
- USER_GROUPS_FILENAME=user-groups.dev.yaml
|
- USER_GROUPS_FILENAME=user-groups.dev.yaml
|
||||||
|
- DATABASE_PATH=sqlite:////app/auto-archiver.db
|
||||||
|
|
||||||
worker:
|
worker:
|
||||||
restart: "no"
|
restart: "no"
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ services:
|
|||||||
<<: *base-setup
|
<<: *base-setup
|
||||||
ports:
|
ports:
|
||||||
- "127.0.0.1:8004:8000"
|
- "127.0.0.1:8004:8000"
|
||||||
command: uvicorn main:app --host 0.0.0.0 --reload
|
command: uvicorn web:app --host 0.0.0.0 --reload
|
||||||
volumes:
|
volumes:
|
||||||
- ./src:/app
|
- ./src:/app
|
||||||
depends_on:
|
depends_on:
|
||||||
|
|||||||
@@ -3,5 +3,5 @@ ALLOWED_ORIGINS='["chrome-extension://example1","chrome-extension://example2","h
|
|||||||
BLOCKED_EMAILS='["blocked@example.com"]'
|
BLOCKED_EMAILS='["blocked@example.com"]'
|
||||||
|
|
||||||
|
|
||||||
DATABASE_PATH="sqlite:////app/auto-archiver.test.db"
|
DATABASE_PATH="sqlite:///auto-archiver.test.db"
|
||||||
API_BEARER_TOKEN=this_is_the_test_api_token
|
API_BEARER_TOKEN=this_is_the_test_api_token
|
||||||
168
src/main.py
168
src/main.py
@@ -1,168 +0,0 @@
|
|||||||
import traceback, os
|
|
||||||
from celery.result import AsyncResult
|
|
||||||
from fastapi import FastAPI, Depends, HTTPException
|
|
||||||
from fastapi.encoders import jsonable_encoder
|
|
||||||
from fastapi.responses import JSONResponse
|
|
||||||
from fastapi.staticfiles import StaticFiles
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from prometheus_fastapi_instrumentator import Instrumentator
|
|
||||||
from datetime import datetime
|
|
||||||
import sqlalchemy
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from core.logging import logging_middleware
|
|
||||||
from worker import create_archive_task, create_sheet_task, celery, insert_result_into_db
|
|
||||||
|
|
||||||
from db import crud, models, schemas
|
|
||||||
from security import get_user_auth, token_api_key_auth, get_token_or_user_auth
|
|
||||||
from core.config import VERSION, API_DESCRIPTION
|
|
||||||
from db.database import get_db_dependency
|
|
||||||
from core.events import lifespan
|
|
||||||
from shared.settings import Settings
|
|
||||||
|
|
||||||
from auto_archiver import Metadata
|
|
||||||
|
|
||||||
from endpoints import default_router, url_router, sheet_router, task_router, interoperability_router
|
|
||||||
|
|
||||||
settings = Settings()
|
|
||||||
|
|
||||||
app = FastAPI(
|
|
||||||
title="Auto-Archiver API",
|
|
||||||
description=API_DESCRIPTION,
|
|
||||||
version=VERSION,
|
|
||||||
contact={"name": "GitHub", "url": "https://github.com/bellingcat/auto-archiver-api"},
|
|
||||||
lifespan=lifespan
|
|
||||||
)
|
|
||||||
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=settings.ALLOWED_ORIGINS,
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
app.include_router(default_router)
|
|
||||||
app.include_router(url_router)
|
|
||||||
app.include_router(sheet_router)
|
|
||||||
app.include_router(task_router)
|
|
||||||
app.include_router(interoperability_router)
|
|
||||||
|
|
||||||
# prometheus exposed in /metrics with authentication
|
|
||||||
Instrumentator(should_group_status_codes=False, excluded_handlers=["/metrics", "/health"]).instrument(app).expose(app, dependencies=[Depends(token_api_key_auth)])
|
|
||||||
|
|
||||||
def setup_local_archive_serve():
|
|
||||||
local_dir = settings.SERVE_LOCAL_ARCHIVE
|
|
||||||
if not os.path.isdir(local_dir) and os.path.isdir(local_dir.replace("/app", ".")):
|
|
||||||
local_dir = local_dir.replace("/app", ".")
|
|
||||||
if len(settings.SERVE_LOCAL_ARCHIVE) > 1 and os.path.isdir(local_dir):
|
|
||||||
logger.warning(f"MOUNTing local archive {settings.SERVE_LOCAL_ARCHIVE}")
|
|
||||||
app.mount(settings.SERVE_LOCAL_ARCHIVE, StaticFiles(directory=local_dir), name=settings.SERVE_LOCAL_ARCHIVE)
|
|
||||||
setup_local_archive_serve()
|
|
||||||
|
|
||||||
|
|
||||||
app.middleware("http")(logging_middleware)
|
|
||||||
|
|
||||||
# -----Submit URL and manipulate tasks. Bearer protected below
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/tasks/search-url", response_model=list[schemas.Archive], deprecated=True) # DEPRECATED
|
|
||||||
def search_by_url(url: str, skip: int = 0, limit: int = 100, archived_after: datetime = None, archived_before: datetime = None, db: Session = Depends(get_db_dependency), email=Depends(get_token_or_user_auth)):
|
|
||||||
return crud.search_archives_by_url(db, url.strip(), email, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/tasks/sync", response_model=list[schemas.Archive], deprecated=True) # DEPRECATED
|
|
||||||
def search(skip: int = 0, limit: int = 100, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
|
|
||||||
return crud.search_archives_by_email(db, email, skip=skip, limit=limit)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/tasks", status_code=201, deprecated=True) # DEPRECATED
|
|
||||||
def archive_tasks(archive: schemas.ArchiveCreate, email=Depends(get_token_or_user_auth)):
|
|
||||||
archive.author_id = email
|
|
||||||
url = archive.url
|
|
||||||
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {url}")
|
|
||||||
if type(url) != str or len(url) <= 5:
|
|
||||||
raise HTTPException(status_code=422, detail=f"Invalid URL received: {url}")
|
|
||||||
logger.info("creating task")
|
|
||||||
task = create_archive_task.delay(archive.model_dump_json())
|
|
||||||
return JSONResponse({"id": task.id})
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/archive/{task_id}", deprecated=True) # DEPRECATED
|
|
||||||
def lookup(task_id, db: Session = Depends(get_db_dependency), email=Depends(get_token_or_user_auth)):
|
|
||||||
return crud.get_archive(db, task_id, email)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/tasks/{task_id}", deprecated=True) # DEPRECATED
|
|
||||||
def get_status(task_id, email=Depends(get_token_or_user_auth)):
|
|
||||||
logger.info(f"status check for user {email} task {task_id}")
|
|
||||||
task = AsyncResult(task_id, app=celery)
|
|
||||||
try:
|
|
||||||
if task.status == "FAILURE":
|
|
||||||
# *FAILURE* The task raised an exception, or has exceeded the retry limit.
|
|
||||||
# The :attr:`result` attribute then contains the exception raised by the task.
|
|
||||||
# https://docs.celeryq.dev/en/stable/_modules/celery/result.html#AsyncResult
|
|
||||||
raise task.result
|
|
||||||
|
|
||||||
response = {
|
|
||||||
"id": task_id,
|
|
||||||
"status": task.status,
|
|
||||||
"result": task.result
|
|
||||||
}
|
|
||||||
return JSONResponse(jsonable_encoder(response, exclude_unset=True))
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(e)
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
return JSONResponse({
|
|
||||||
"id": task_id,
|
|
||||||
"status": "FAILURE",
|
|
||||||
"result": {"error": str(e)}
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@app.delete("/tasks/{task_id}", deprecated=True) # DEPRECATED
|
|
||||||
def delete_task(task_id, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
|
|
||||||
logger.info(f"deleting task {task_id} request by {email}")
|
|
||||||
return JSONResponse({
|
|
||||||
"id": task_id,
|
|
||||||
"deleted": crud.soft_delete_task(db, task_id, email)
|
|
||||||
})
|
|
||||||
|
|
||||||
# ----- Google Sheets Logic
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/sheet", status_code=201, deprecated=True) # DEPRECATED
|
|
||||||
def archive_sheet(sheet: schemas.SubmitSheet, email=Depends(get_user_auth)):
|
|
||||||
logger.info(f"SHEET TASK for {sheet=}")
|
|
||||||
sheet.author_id = email
|
|
||||||
if not sheet.sheet_name and not sheet.sheet_id:
|
|
||||||
raise HTTPException(status_code=422, detail=f"sheet name or id is required")
|
|
||||||
task = create_sheet_task.delay(sheet.model_dump_json())
|
|
||||||
return JSONResponse({"id": task.id})
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/sheet_service", status_code=201, deprecated=True) # DEPRECATED
|
|
||||||
def archive_sheet_service(sheet: schemas.SubmitSheet, auth=Depends(token_api_key_auth)):
|
|
||||||
logger.info(f"SHEET TASK for {sheet=}")
|
|
||||||
sheet.author_id = sheet.author_id or "api-endpoint"
|
|
||||||
if not sheet.sheet_name and not sheet.sheet_id:
|
|
||||||
raise HTTPException(status_code=422, detail=f"sheet name or id is required")
|
|
||||||
task = create_sheet_task.delay(sheet.model_dump_json())
|
|
||||||
return JSONResponse({"id": task.id})
|
|
||||||
|
|
||||||
# ----- endpoint to submit data archived elsewhere
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/submit-archive", status_code=201, deprecated=True) # DEPRECATED
|
|
||||||
def submit_manual_archive(manual: schemas.SubmitManual, auth=Depends(token_api_key_auth)):
|
|
||||||
result = Metadata.from_json(manual.result)
|
|
||||||
logger.info(f"MANUAL SUBMIT {result.get_url()} {manual.author_id}")
|
|
||||||
manual.tags.add("manual")
|
|
||||||
try:
|
|
||||||
archive_id = insert_result_into_db(result, manual.tags, manual.public, manual.group_id, manual.author_id, models.generate_uuid())
|
|
||||||
except sqlalchemy.exc.IntegrityError as e:
|
|
||||||
logger.error(e)
|
|
||||||
raise HTTPException(status_code=422, detail=f"Cannot insert into DB due to integrity error")
|
|
||||||
return JSONResponse({"id": archive_id})
|
|
||||||
@@ -60,7 +60,7 @@ def authenticate_user(access_token):
|
|||||||
# https://cloud.google.com/docs/authentication/token-types#access
|
# https://cloud.google.com/docs/authentication/token-types#access
|
||||||
if type(access_token) != str or len(access_token) < 10: return False, "invalid access_token"
|
if type(access_token) != str or len(access_token) < 10: return False, "invalid access_token"
|
||||||
r = requests.get("https://oauth2.googleapis.com/tokeninfo", {"access_token": access_token})
|
r = requests.get("https://oauth2.googleapis.com/tokeninfo", {"access_token": access_token})
|
||||||
if r.status_code != 200: return False, "error occurred"
|
if r.status_code != 200: return False, "invalid token"
|
||||||
try:
|
try:
|
||||||
j = r.json()
|
j = r.json()
|
||||||
if j.get("azp") not in settings.CHROME_APP_IDS and j.get("aud") not in settings.CHROME_APP_IDS:
|
if j.get("azp") not in settings.CHROME_APP_IDS and j.get("aud") not in settings.CHROME_APP_IDS:
|
||||||
|
|||||||
@@ -28,4 +28,3 @@ class Settings(BaseSettings):
|
|||||||
ALLOWED_ORIGINS: Annotated[set[str], Len(min_length=1)]
|
ALLOWED_ORIGINS: Annotated[set[str], Len(min_length=1)]
|
||||||
CHROME_APP_IDS: Annotated[set[Annotated[str, Len(min_length=10)]], Len(min_length=1)]
|
CHROME_APP_IDS: Annotated[set[Annotated[str, Len(min_length=10)]], Len(min_length=1)]
|
||||||
BLOCKED_EMAILS: Annotated[Set[str], Len(min_length=0)] = set()
|
BLOCKED_EMAILS: Annotated[Set[str], Len(min_length=0)] = set()
|
||||||
|
|
||||||
|
|||||||
@@ -1,34 +1,38 @@
|
|||||||
import os
|
import os
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
from shared.settings import Settings
|
from shared.settings import Settings
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def mock_logger_add():
|
def mock_logger_add():
|
||||||
"""Fixture to mock loguru.logger.add for all tests."""
|
"""Fixture to mock loguru.logger.add for all tests."""
|
||||||
with patch('loguru.logger.add') as mock_add:
|
with patch('loguru.logger.add') as mock_add:
|
||||||
yield mock_add # This makes the mock available to tests
|
yield mock_add # This makes the mock available to tests
|
||||||
|
|
||||||
# @pytest.fixture(autouse=True)
|
|
||||||
# def settings():
|
@pytest.fixture()
|
||||||
# return Settings(_env_file=".env.test")
|
def settings():
|
||||||
|
return Settings(_env_file=".env.test")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def settings():
|
def mock_settings():
|
||||||
with patch('shared.settings.Settings', return_value=Settings(_env_file=".env.test")) as mock_settings:
|
with patch('shared.settings.Settings', return_value=Settings(_env_file=".env.test")) as mock_settings:
|
||||||
yield mock_settings
|
yield mock_settings
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def test_db(settings):
|
def test_db(settings: Settings):
|
||||||
from db.database import make_engine, make_session_local
|
from db.database import make_engine
|
||||||
from db import models
|
from db import models
|
||||||
|
|
||||||
engine = make_engine(settings.DATABASE_PATH)
|
engine = make_engine(settings.DATABASE_PATH)
|
||||||
|
|
||||||
if not os.path.exists(settings.DATABASE_PATH):
|
fs = settings.DATABASE_PATH.replace("sqlite:///", "")
|
||||||
open(settings.DATABASE_PATH, 'w').close()
|
if not os.path.exists(fs):
|
||||||
|
open(fs, 'w').close()
|
||||||
|
|
||||||
models.Base.metadata.create_all(engine)
|
models.Base.metadata.create_all(engine)
|
||||||
|
|
||||||
@@ -37,13 +41,35 @@ def test_db(settings):
|
|||||||
connection.close()
|
connection.close()
|
||||||
|
|
||||||
models.Base.metadata.drop_all(bind=engine)
|
models.Base.metadata.drop_all(bind=engine)
|
||||||
os.remove(settings.DATABASE_PATH)
|
for suffix in ["", "-wal", "-shm"]:
|
||||||
|
new_fs = fs + suffix
|
||||||
|
if os.path.exists(new_fs):
|
||||||
|
os.remove(new_fs)
|
||||||
|
|
||||||
# @pytest.fixture()
|
|
||||||
# def db_session(test_db):
|
@pytest.fixture()
|
||||||
# session_local = make_session_local(test_db)
|
def db_session(test_db):
|
||||||
# with session_local() as session:
|
from db.database import make_session_local
|
||||||
# yield session
|
session_local = make_session_local(test_db)
|
||||||
|
with session_local() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def app(db_session, settings):
|
||||||
|
from web.main import app_factory
|
||||||
|
app = app_factory()
|
||||||
|
from security import get_token_or_user_auth
|
||||||
|
app.dependency_overrides[get_token_or_user_auth] = lambda: "example@email.com"
|
||||||
|
# app.dependency_overrides[settings] = lambda: settings
|
||||||
|
# app.dependency_overrides[get_session] = lambda: db_session
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def client(app):
|
||||||
|
client = TestClient(app)
|
||||||
|
return client
|
||||||
|
|
||||||
# # create test data and insert it into the database
|
# # create test data and insert it into the database
|
||||||
# def create_test_data():
|
# def create_test_data():
|
||||||
@@ -57,4 +83,4 @@ def test_db(settings):
|
|||||||
# db.refresh(task)
|
# db.refresh(task)
|
||||||
# db.close()
|
# db.close()
|
||||||
|
|
||||||
# return task.id
|
# return task.id
|
||||||
|
|||||||
@@ -3,10 +3,7 @@ from fastapi.testclient import TestClient
|
|||||||
from core.config import VERSION
|
from core.config import VERSION
|
||||||
|
|
||||||
|
|
||||||
def test_endpoint_home():
|
def test_endpoint_home(client):
|
||||||
from main import app
|
|
||||||
client = TestClient(app)
|
|
||||||
|
|
||||||
r = client.get("/")
|
r = client.get("/")
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
j = r.json()
|
j = r.json()
|
||||||
@@ -18,10 +15,7 @@ def test_endpoint_home():
|
|||||||
@patch("endpoints.default.bearer_security", new_callable=AsyncMock)
|
@patch("endpoints.default.bearer_security", new_callable=AsyncMock)
|
||||||
@patch("endpoints.default.get_user_auth", new_callable=AsyncMock, return_value="test@example.com")
|
@patch("endpoints.default.get_user_auth", new_callable=AsyncMock, return_value="test@example.com")
|
||||||
@patch("endpoints.default.crud.get_user_groups", return_value=["group1", "group2"])
|
@patch("endpoints.default.crud.get_user_groups", return_value=["group1", "group2"])
|
||||||
def test_endpoint_home_with_groups(m1, m2, m3):
|
def test_endpoint_home_with_groups(m1, m2, m3, client):
|
||||||
from main import app
|
|
||||||
client = TestClient(app)
|
|
||||||
|
|
||||||
r = client.get("/")
|
r = client.get("/")
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
j = r.json()
|
j = r.json()
|
||||||
@@ -31,30 +25,24 @@ def test_endpoint_home_with_groups(m1, m2, m3):
|
|||||||
assert j["groups"] == ["group1", "group2"]
|
assert j["groups"] == ["group1", "group2"]
|
||||||
|
|
||||||
|
|
||||||
def test_endpoint_health():
|
def test_endpoint_health(client):
|
||||||
from main import app
|
|
||||||
client = TestClient(app)
|
|
||||||
|
|
||||||
r = client.get("/health")
|
r = client.get("/health")
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
assert r.json() == {"status": "ok"}
|
assert r.json() == {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
def test_endpoint_groups_403():
|
def test_endpoint_groups_403(client):
|
||||||
from main import app
|
|
||||||
client = TestClient(app)
|
|
||||||
r = client.get("/groups")
|
r = client.get("/groups")
|
||||||
assert r.status_code == 403
|
assert r.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
@patch("endpoints.default.crud.get_user_groups", return_value=["group1", "group2"])
|
@patch("endpoints.default.crud.get_user_groups", return_value=["group1", "group2"])
|
||||||
def test_endpoint_groups(m1):
|
def test_endpoint_groups(m1, app):
|
||||||
async def mock_get_user_auth(): return True
|
async def mock_get_user_auth(): return True
|
||||||
from main import app
|
|
||||||
from security import get_user_auth
|
from security import get_user_auth
|
||||||
app.dependency_overrides[get_user_auth] = mock_get_user_auth
|
app.dependency_overrides[get_user_auth] = mock_get_user_auth
|
||||||
|
|
||||||
client = TestClient(app)
|
client = TestClient(app)
|
||||||
|
|
||||||
r = client.get("/groups")
|
r = client.get("/groups")
|
||||||
|
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
@@ -63,18 +51,12 @@ def test_endpoint_groups(m1):
|
|||||||
app.dependency_overrides = {}
|
app.dependency_overrides = {}
|
||||||
|
|
||||||
|
|
||||||
def test_no_serve_local_archive_by_default():
|
def test_no_serve_local_archive_by_default(client):
|
||||||
from main import app
|
|
||||||
client = TestClient(app)
|
|
||||||
|
|
||||||
r = client.get("/app/local_archive_test/temp.txt")
|
r = client.get("/app/local_archive_test/temp.txt")
|
||||||
assert r.status_code == 404
|
assert r.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
def test_favicon():
|
def test_favicon(client):
|
||||||
from main import app
|
|
||||||
client = TestClient(app)
|
|
||||||
|
|
||||||
r = client.get("/favicon.ico")
|
r = client.get("/favicon.ico")
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
assert r.headers["content-type"] == "image/vnd.microsoft.icon"
|
assert r.headers["content-type"] == "image/vnd.microsoft.icon"
|
||||||
|
|||||||
@@ -2,18 +2,8 @@ from unittest.mock import patch
|
|||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
|
||||||
def setup_client():
|
|
||||||
from main import app
|
|
||||||
from security import get_token_or_user_auth
|
|
||||||
async def mock_get_token_or_user_auth(): return "example@email.com"
|
|
||||||
app.dependency_overrides[get_token_or_user_auth] = mock_get_token_or_user_auth
|
|
||||||
return TestClient(app), app
|
|
||||||
|
|
||||||
|
|
||||||
@patch("endpoints.task.AsyncResult")
|
@patch("endpoints.task.AsyncResult")
|
||||||
def test_get_status_success(mock_async_result):
|
def test_get_status_success(mock_async_result, client):
|
||||||
client, app = setup_client()
|
|
||||||
|
|
||||||
mock_async_result.return_value.status = "SUCCESS"
|
mock_async_result.return_value.status = "SUCCESS"
|
||||||
mock_async_result.return_value.result = {"data": "some result"}
|
mock_async_result.return_value.result = {"data": "some result"}
|
||||||
|
|
||||||
@@ -25,12 +15,10 @@ def test_get_status_success(mock_async_result):
|
|||||||
"status": "SUCCESS",
|
"status": "SUCCESS",
|
||||||
"result": {"data": "some result"}
|
"result": {"data": "some result"}
|
||||||
}
|
}
|
||||||
app.dependency_overrides = {}
|
|
||||||
|
|
||||||
|
|
||||||
@patch("endpoints.task.AsyncResult")
|
@patch("endpoints.task.AsyncResult")
|
||||||
def test_get_status_failure(mock_async_result):
|
def test_get_status_failure(mock_async_result, client):
|
||||||
client, app = setup_client()
|
|
||||||
|
|
||||||
mock_async_result.return_value.status = "FAILURE"
|
mock_async_result.return_value.status = "FAILURE"
|
||||||
mock_async_result.return_value.result = Exception("Some error")
|
mock_async_result.return_value.result = Exception("Some error")
|
||||||
@@ -43,13 +31,10 @@ def test_get_status_failure(mock_async_result):
|
|||||||
"status": "FAILURE",
|
"status": "FAILURE",
|
||||||
"result": {"error": "Some error"}
|
"result": {"error": "Some error"}
|
||||||
}
|
}
|
||||||
app.dependency_overrides = {}
|
|
||||||
|
|
||||||
|
|
||||||
@patch("endpoints.task.AsyncResult")
|
@patch("endpoints.task.AsyncResult")
|
||||||
def test_get_status_pending(mock_async_result):
|
def test_get_status_pending(mock_async_result, client):
|
||||||
client, app = setup_client()
|
|
||||||
|
|
||||||
mock_async_result.return_value.status = "PENDING"
|
mock_async_result.return_value.status = "PENDING"
|
||||||
mock_async_result.return_value.result = None
|
mock_async_result.return_value.result = None
|
||||||
|
|
||||||
@@ -61,4 +46,3 @@ def test_get_status_pending(mock_async_result):
|
|||||||
"status": "PENDING",
|
"status": "PENDING",
|
||||||
"result": None
|
"result": None
|
||||||
}
|
}
|
||||||
app.dependency_overrides = {}
|
|
||||||
|
|||||||
6
src/tests/endpoints/test_url.py
Normal file
6
src/tests/endpoints/test_url.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
|
||||||
|
# def test_archive_url(client):
|
||||||
|
# response = client.get("/archive/url")
|
||||||
|
# assert response.status_code == 200
|
||||||
|
# assert response.json() == {"message": "Archive URL"}
|
||||||
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
import os
|
|
||||||
from unittest.mock import patch
|
|
||||||
from fastapi.testclient import TestClient
|
|
||||||
|
|
||||||
|
|
||||||
def test_serve_local_archive_logic():
|
|
||||||
with patch("main.settings.SERVE_LOCAL_ARCHIVE", "/app/local_archive_test"):
|
|
||||||
|
|
||||||
# create a test file
|
|
||||||
os.makedirs("local_archive_test", exist_ok=True)
|
|
||||||
with open("local_archive_test/temp.txt", "w") as f:
|
|
||||||
f.write("test")
|
|
||||||
|
|
||||||
from main import app, setup_local_archive_serve
|
|
||||||
setup_local_archive_serve()
|
|
||||||
client = TestClient(app)
|
|
||||||
|
|
||||||
r = client.get("/app/local_archive_test/temp.txt")
|
|
||||||
assert r.status_code == 200
|
|
||||||
assert r.text == "test"
|
|
||||||
|
|
||||||
os.remove("local_archive_test/temp.txt")
|
|
||||||
os.rmdir("local_archive_test")
|
|
||||||
28
src/tests/web/test_main.py
Normal file
28
src/tests/web/test_main.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from shared.settings import Settings
|
||||||
|
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
def test_serve_local_archive_logic(settings: Settings):
|
||||||
|
# create a test file first
|
||||||
|
os.makedirs("local_archive_test", exist_ok=True)
|
||||||
|
with open("local_archive_test/temp.txt", "w") as f:
|
||||||
|
f.write("test")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# modify the settings
|
||||||
|
settings.SERVE_LOCAL_ARCHIVE = "/app/local_archive_test"
|
||||||
|
from web.main import app_factory
|
||||||
|
app = app_factory(settings)
|
||||||
|
|
||||||
|
# test
|
||||||
|
client = TestClient(app)
|
||||||
|
r = client.get("/app/local_archive_test/temp.txt")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.text == "test"
|
||||||
|
finally:
|
||||||
|
# cleanup
|
||||||
|
shutil.rmtree("local_archive_test")
|
||||||
4
src/web/__init__.py
Normal file
4
src/web/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from web.main import app_factory
|
||||||
|
|
||||||
|
|
||||||
|
app = app_factory
|
||||||
170
src/web/main.py
Normal file
170
src/web/main.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
import traceback, os
|
||||||
|
from celery.result import AsyncResult
|
||||||
|
from fastapi import FastAPI, Depends, HTTPException
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from prometheus_fastapi_instrumentator import Instrumentator
|
||||||
|
from datetime import datetime
|
||||||
|
import sqlalchemy
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from core.logging import logging_middleware
|
||||||
|
from worker import create_archive_task, create_sheet_task, celery, insert_result_into_db
|
||||||
|
|
||||||
|
from db import crud, models, schemas
|
||||||
|
from security import get_user_auth, token_api_key_auth, get_token_or_user_auth
|
||||||
|
from core.config import VERSION, API_DESCRIPTION
|
||||||
|
from db.database import get_db_dependency
|
||||||
|
from core.events import lifespan
|
||||||
|
from shared.settings import Settings
|
||||||
|
|
||||||
|
from auto_archiver import Metadata
|
||||||
|
|
||||||
|
from endpoints import default_router, url_router, sheet_router, task_router, interoperability_router
|
||||||
|
|
||||||
|
|
||||||
|
def app_factory(settings = Settings()):
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="Auto-Archiver API",
|
||||||
|
description=API_DESCRIPTION,
|
||||||
|
version=VERSION,
|
||||||
|
contact={"name": "GitHub", "url": "https://github.com/bellingcat/auto-archiver-api"},
|
||||||
|
lifespan=lifespan
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=settings.ALLOWED_ORIGINS,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
app.include_router(default_router)
|
||||||
|
app.include_router(url_router)
|
||||||
|
app.include_router(sheet_router)
|
||||||
|
app.include_router(task_router)
|
||||||
|
app.include_router(interoperability_router)
|
||||||
|
|
||||||
|
# prometheus exposed in /metrics with authentication
|
||||||
|
Instrumentator(should_group_status_codes=False, excluded_handlers=["/metrics", "/health"]).instrument(app).expose(app, dependencies=[Depends(token_api_key_auth)])
|
||||||
|
|
||||||
|
local_dir = settings.SERVE_LOCAL_ARCHIVE
|
||||||
|
if not os.path.isdir(local_dir) and os.path.isdir(local_dir.replace("/app", ".")):
|
||||||
|
local_dir = local_dir.replace("/app", ".")
|
||||||
|
if len(settings.SERVE_LOCAL_ARCHIVE) > 1 and os.path.isdir(local_dir):
|
||||||
|
logger.warning(f"MOUNTing local archive {settings.SERVE_LOCAL_ARCHIVE}")
|
||||||
|
app.mount(settings.SERVE_LOCAL_ARCHIVE, StaticFiles(directory=local_dir), name=settings.SERVE_LOCAL_ARCHIVE)
|
||||||
|
|
||||||
|
|
||||||
|
app.middleware("http")(logging_middleware)
|
||||||
|
|
||||||
|
# -----Submit URL and manipulate tasks. Bearer protected below
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/tasks/search-url", response_model=list[schemas.Archive], deprecated=True) # DEPRECATED
|
||||||
|
def search_by_url(url: str, skip: int = 0, limit: int = 100, archived_after: datetime = None, archived_before: datetime = None, db: Session = Depends(get_db_dependency), email=Depends(get_token_or_user_auth)):
|
||||||
|
return crud.search_archives_by_url(db, url.strip(), email, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/tasks/sync", response_model=list[schemas.Archive], deprecated=True) # DEPRECATED
|
||||||
|
def search(skip: int = 0, limit: int = 100, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
|
||||||
|
return crud.search_archives_by_email(db, email, skip=skip, limit=limit)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/tasks", status_code=201, deprecated=True) # DEPRECATED
|
||||||
|
def archive_tasks(archive: schemas.ArchiveCreate, email=Depends(get_token_or_user_auth)):
|
||||||
|
archive.author_id = email
|
||||||
|
url = archive.url
|
||||||
|
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {url}")
|
||||||
|
if type(url) != str or len(url) <= 5:
|
||||||
|
raise HTTPException(status_code=422, detail=f"Invalid URL received: {url}")
|
||||||
|
logger.info("creating task")
|
||||||
|
task = create_archive_task.delay(archive.model_dump_json())
|
||||||
|
return JSONResponse({"id": task.id})
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/archive/{task_id}", deprecated=True) # DEPRECATED
|
||||||
|
def lookup(task_id, db: Session = Depends(get_db_dependency), email=Depends(get_token_or_user_auth)):
|
||||||
|
return crud.get_archive(db, task_id, email)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/tasks/{task_id}", deprecated=True) # DEPRECATED
|
||||||
|
def get_status(task_id, email=Depends(get_token_or_user_auth)):
|
||||||
|
logger.info(f"status check for user {email} task {task_id}")
|
||||||
|
task = AsyncResult(task_id, app=celery)
|
||||||
|
try:
|
||||||
|
if task.status == "FAILURE":
|
||||||
|
# *FAILURE* The task raised an exception, or has exceeded the retry limit.
|
||||||
|
# The :attr:`result` attribute then contains the exception raised by the task.
|
||||||
|
# https://docs.celeryq.dev/en/stable/_modules/celery/result.html#AsyncResult
|
||||||
|
raise task.result
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"id": task_id,
|
||||||
|
"status": task.status,
|
||||||
|
"result": task.result
|
||||||
|
}
|
||||||
|
return JSONResponse(jsonable_encoder(response, exclude_unset=True))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return JSONResponse({
|
||||||
|
"id": task_id,
|
||||||
|
"status": "FAILURE",
|
||||||
|
"result": {"error": str(e)}
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@app.delete("/tasks/{task_id}", deprecated=True) # DEPRECATED
|
||||||
|
def delete_task(task_id, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
|
||||||
|
logger.info(f"deleting task {task_id} request by {email}")
|
||||||
|
return JSONResponse({
|
||||||
|
"id": task_id,
|
||||||
|
"deleted": crud.soft_delete_task(db, task_id, email)
|
||||||
|
})
|
||||||
|
|
||||||
|
# ----- Google Sheets Logic
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/sheet", status_code=201, deprecated=True) # DEPRECATED
|
||||||
|
def archive_sheet(sheet: schemas.SubmitSheet, email=Depends(get_user_auth)):
|
||||||
|
logger.info(f"SHEET TASK for {sheet=}")
|
||||||
|
sheet.author_id = email
|
||||||
|
if not sheet.sheet_name and not sheet.sheet_id:
|
||||||
|
raise HTTPException(status_code=422, detail=f"sheet name or id is required")
|
||||||
|
task = create_sheet_task.delay(sheet.model_dump_json())
|
||||||
|
return JSONResponse({"id": task.id})
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/sheet_service", status_code=201, deprecated=True) # DEPRECATED
|
||||||
|
def archive_sheet_service(sheet: schemas.SubmitSheet, auth=Depends(token_api_key_auth)):
|
||||||
|
logger.info(f"SHEET TASK for {sheet=}")
|
||||||
|
sheet.author_id = sheet.author_id or "api-endpoint"
|
||||||
|
if not sheet.sheet_name and not sheet.sheet_id:
|
||||||
|
raise HTTPException(status_code=422, detail=f"sheet name or id is required")
|
||||||
|
task = create_sheet_task.delay(sheet.model_dump_json())
|
||||||
|
return JSONResponse({"id": task.id})
|
||||||
|
|
||||||
|
# ----- endpoint to submit data archived elsewhere
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/submit-archive", status_code=201, deprecated=True) # DEPRECATED
|
||||||
|
def submit_manual_archive(manual: schemas.SubmitManual, auth=Depends(token_api_key_auth)):
|
||||||
|
result = Metadata.from_json(manual.result)
|
||||||
|
logger.info(f"MANUAL SUBMIT {result.get_url()} {manual.author_id}")
|
||||||
|
manual.tags.add("manual")
|
||||||
|
try:
|
||||||
|
archive_id = insert_result_into_db(result, manual.tags, manual.public, manual.group_id, manual.author_id, models.generate_uuid())
|
||||||
|
except sqlalchemy.exc.IntegrityError as e:
|
||||||
|
logger.error(e)
|
||||||
|
raise HTTPException(status_code=422, detail=f"Cannot insert into DB due to integrity error")
|
||||||
|
return JSONResponse({"id": archive_id})
|
||||||
|
|
||||||
|
return app
|
||||||
Reference in New Issue
Block a user