mirror of
https://github.com/bellingcat/auto-archiver-api.git
synced 2026-06-11 04:58:33 +03:00
adds new tests and improvements
This commit is contained in:
@@ -16,9 +16,9 @@ DATABASE_QUERY_LIMIT = Settings().DATABASE_QUERY_LIMIT
|
||||
# --------------- TASK = Archive
|
||||
|
||||
|
||||
def get_archive(db: Session, task_id: str, email: str):
|
||||
def get_archive(db: Session, id: str, email: str):
|
||||
email = email.lower()
|
||||
query = base_query(db).filter(models.Archive.id == task_id)
|
||||
query = base_query(db).filter(models.Archive.id == id)
|
||||
if email != ALLOW_ANY_EMAIL:
|
||||
groups = get_user_groups(db, email)
|
||||
query = query.filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups)))
|
||||
@@ -45,7 +45,7 @@ def search_archives_by_url(db: Session, url: str, email: str, skip: int = 0, lim
|
||||
|
||||
def search_archives_by_email(db: Session, email: str, skip: int = 0, limit: int = 100):
|
||||
email = email.lower()
|
||||
return base_query(db).filter(models.Archive.author.has(email=email)).offset(skip).limit(min(limit, DATABASE_QUERY_LIMIT)).all()
|
||||
return base_query(db).filter(models.Archive.author.has(email=email)).order_by(models.Archive.created_at.desc()).offset(skip).limit(min(limit, DATABASE_QUERY_LIMIT)).all()
|
||||
|
||||
|
||||
def create_task(db: Session, task: schemas.ArchiveCreate, tags: list[models.Tag], urls: list[models.ArchiveUrl]):
|
||||
|
||||
@@ -25,7 +25,7 @@ def archive_url(archive: schemas.ArchiveCreate, email=Depends(get_token_or_user_
|
||||
logger.info("creating task")
|
||||
task = create_archive_task.delay(archive.model_dump_json())
|
||||
task_response = schemas.Task(id=task.id)
|
||||
return JSONResponse(task_response.model_dump())
|
||||
return JSONResponse(task_response.model_dump(), status_code=201)
|
||||
|
||||
|
||||
@url_router.get("/search", response_model=list[schemas.Archive], summary="Search for archive entries by URL.")
|
||||
@@ -44,13 +44,15 @@ def latest(skip: int = 0, limit: int = 25, db: Session = Depends(get_db_dependen
|
||||
|
||||
@url_router.get("/{id}", response_model=schemas.Archive, summary="Fetch a single URL archive by the associated id.")
|
||||
def lookup(id, db: Session = Depends(get_db_dependency), email=Depends(get_token_or_user_auth)):
|
||||
return crud.get_archive(db, id, email)
|
||||
archive = crud.get_archive(db, id, email)
|
||||
if archive is None:
|
||||
raise HTTPException(status_code=404, detail="Archive not found")
|
||||
return archive
|
||||
|
||||
|
||||
@url_router.delete("/{id}", response_model=schemas.TaskDelete, summary="Delete a single URL archive by id.")
|
||||
def delete_task(id, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)):
|
||||
logger.info(f"deleting url archive task {id} request by {email}")
|
||||
#TODO: use response model?
|
||||
return JSONResponse({
|
||||
"id": id,
|
||||
"deleted": crud.soft_delete_task(db, id, email)
|
||||
|
||||
@@ -43,7 +43,7 @@ def test_db(settings: Settings):
|
||||
models.Base.metadata.drop_all(bind=engine)
|
||||
for suffix in ["", "-wal", "-shm"]:
|
||||
new_fs = fs + suffix
|
||||
if os.path.exists(new_fs):
|
||||
if os.path.exists(new_fs):
|
||||
os.remove(new_fs)
|
||||
|
||||
|
||||
@@ -56,13 +56,9 @@ def db_session(test_db):
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def app(db_session, settings):
|
||||
def app(db_session):
|
||||
from web.main import app_factory
|
||||
app = app_factory()
|
||||
from security import get_token_or_user_auth
|
||||
app.dependency_overrides[get_token_or_user_auth] = lambda: "example@email.com"
|
||||
# app.dependency_overrides[settings] = lambda: settings
|
||||
# app.dependency_overrides[get_session] = lambda: db_session
|
||||
return app
|
||||
|
||||
|
||||
@@ -71,16 +67,16 @@ def client(app):
|
||||
client = TestClient(app)
|
||||
return client
|
||||
|
||||
# # create test data and insert it into the database
|
||||
# def create_test_data():
|
||||
# from db.database import SessionLocal
|
||||
# from db.models import Task
|
||||
|
||||
# db = SessionLocal()
|
||||
# task = Task(id="test-task-id", status="PENDING")
|
||||
# db.add(task)
|
||||
# db.commit()
|
||||
# db.refresh(task)
|
||||
# db.close()
|
||||
@pytest.fixture()
|
||||
def app_with_auth(app):
|
||||
from security import get_token_or_user_auth, get_user_auth
|
||||
app.dependency_overrides[get_token_or_user_auth] = lambda: "rick@example.com"
|
||||
app.dependency_overrides[get_user_auth] = lambda: "morty@example.com"
|
||||
return app
|
||||
|
||||
# return task.id
|
||||
|
||||
@pytest.fixture()
|
||||
def client_with_auth(app_with_auth):
|
||||
client = TestClient(app_with_auth)
|
||||
return client
|
||||
|
||||
@@ -3,8 +3,8 @@ from fastapi.testclient import TestClient
|
||||
from core.config import VERSION
|
||||
|
||||
|
||||
def test_endpoint_home(client):
|
||||
r = client.get("/")
|
||||
def test_endpoint_home(client_with_auth):
|
||||
r = client_with_auth.get("/")
|
||||
assert r.status_code == 200
|
||||
j = r.json()
|
||||
assert "version" in j and j["version"] == VERSION
|
||||
@@ -15,8 +15,8 @@ def test_endpoint_home(client):
|
||||
@patch("endpoints.default.bearer_security", new_callable=AsyncMock)
|
||||
@patch("endpoints.default.get_user_auth", new_callable=AsyncMock, return_value="test@example.com")
|
||||
@patch("endpoints.default.crud.get_user_groups", return_value=["group1", "group2"])
|
||||
def test_endpoint_home_with_groups(m1, m2, m3, client):
|
||||
r = client.get("/")
|
||||
def test_endpoint_home_with_groups(m1, m2, m3, client_with_auth):
|
||||
r = client_with_auth.get("/")
|
||||
assert r.status_code == 200
|
||||
j = r.json()
|
||||
assert "version" in j and j["version"] == VERSION
|
||||
@@ -24,9 +24,20 @@ def test_endpoint_home_with_groups(m1, m2, m3, client):
|
||||
assert "groups" in j
|
||||
assert j["groups"] == ["group1", "group2"]
|
||||
|
||||
@patch("endpoints.default.bearer_security", new_callable=AsyncMock)
|
||||
@patch("endpoints.default.get_user_auth", new_callable=AsyncMock, return_value="test@example.com")
|
||||
@patch("endpoints.default.crud.get_user_groups", side_effect=Exception('mocked error'))
|
||||
def test_endpoint_home_with_groups_exception(m1, m2, m3, client_with_auth): # mocks call that triggers an internal error
|
||||
r = client_with_auth.get("/")
|
||||
assert r.status_code == 200
|
||||
j = r.json()
|
||||
assert "version" in j and j["version"] == VERSION
|
||||
assert "breakingChanges" in j
|
||||
assert "groups" not in j
|
||||
|
||||
def test_endpoint_health(client):
|
||||
r = client.get("/health")
|
||||
|
||||
def test_endpoint_health(client_with_auth):
|
||||
r = client_with_auth.get("/health")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"status": "ok"}
|
||||
|
||||
@@ -36,27 +47,30 @@ def test_endpoint_groups_403(client):
|
||||
assert r.status_code == 403
|
||||
|
||||
|
||||
def test_endpoint_groups_empty(client_with_auth):
|
||||
r = client_with_auth.get("/groups")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == []
|
||||
|
||||
|
||||
@patch("endpoints.default.crud.get_user_groups", return_value=["group1", "group2"])
|
||||
def test_endpoint_groups(m1, app):
|
||||
async def mock_get_user_auth(): return True
|
||||
from security import get_user_auth
|
||||
app.dependency_overrides[get_user_auth] = mock_get_user_auth
|
||||
app.dependency_overrides[get_user_auth] = lambda: True
|
||||
client = TestClient(app)
|
||||
|
||||
r = client.get("/groups")
|
||||
|
||||
assert r.status_code == 200
|
||||
j = r.json()
|
||||
assert j == ["group1", "group2"]
|
||||
app.dependency_overrides = {}
|
||||
assert r.json() == ["group1", "group2"]
|
||||
|
||||
|
||||
def test_no_serve_local_archive_by_default(client):
|
||||
r = client.get("/app/local_archive_test/temp.txt")
|
||||
def test_no_serve_local_archive_by_default(client_with_auth):
|
||||
r = client_with_auth.get("/app/local_archive_test/temp.txt")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_favicon(client):
|
||||
r = client.get("/favicon.ico")
|
||||
def test_favicon(client_with_auth):
|
||||
r = client_with_auth.get("/favicon.ico")
|
||||
assert r.status_code == 200
|
||||
assert r.headers["content-type"] == "image/vnd.microsoft.icon"
|
||||
|
||||
@@ -3,11 +3,11 @@ from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
@patch("endpoints.task.AsyncResult")
|
||||
def test_get_status_success(mock_async_result, client):
|
||||
def test_get_status_success(mock_async_result, client_with_auth):
|
||||
mock_async_result.return_value.status = "SUCCESS"
|
||||
mock_async_result.return_value.result = {"data": "some result"}
|
||||
|
||||
response = client.get("/task/test-task-id")
|
||||
response = client_with_auth.get("/task/test-task-id")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
@@ -18,12 +18,12 @@ def test_get_status_success(mock_async_result, client):
|
||||
|
||||
|
||||
@patch("endpoints.task.AsyncResult")
|
||||
def test_get_status_failure(mock_async_result, client):
|
||||
def test_get_status_failure(mock_async_result, client_with_auth):
|
||||
|
||||
mock_async_result.return_value.status = "FAILURE"
|
||||
mock_async_result.return_value.result = Exception("Some error")
|
||||
|
||||
response = client.get("/task/test-task-id")
|
||||
response = client_with_auth.get("/task/test-task-id")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
@@ -34,11 +34,11 @@ def test_get_status_failure(mock_async_result, client):
|
||||
|
||||
|
||||
@patch("endpoints.task.AsyncResult")
|
||||
def test_get_status_pending(mock_async_result, client):
|
||||
def test_get_status_pending(mock_async_result, client_with_auth):
|
||||
mock_async_result.return_value.status = "PENDING"
|
||||
mock_async_result.return_value.result = None
|
||||
|
||||
response = client.get("/task/test-task-id")
|
||||
response = client_with_auth.get("/task/test-task-id")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
|
||||
@@ -1,6 +1,163 @@
|
||||
|
||||
# def test_archive_url(client):
|
||||
# response = client.get("/archive/url")
|
||||
# assert response.status_code == 200
|
||||
# assert response.json() == {"message": "Archive URL"}
|
||||
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
from db.schemas import ArchiveCreate, TaskResult
|
||||
|
||||
NO_AUTH = {'detail': 'Not authenticated'}
|
||||
|
||||
|
||||
def test_archive_url_unauthenticated(client):
|
||||
response = client.post("/url/archive")
|
||||
assert response.status_code == 403
|
||||
assert response.json() == NO_AUTH
|
||||
|
||||
# this will call archive/{id}
|
||||
response = client.get("/url/archive")
|
||||
assert response.status_code == 403
|
||||
assert response.json() == NO_AUTH
|
||||
|
||||
|
||||
@patch("worker.create_archive_task.delay", return_value=TaskResult(id="123-456-789", status="PENDING", result=""))
|
||||
def test_archive_url(m1, client_with_auth):
|
||||
response = client_with_auth.post("/url/archive", json={"url": "bad"})
|
||||
assert response.status_code == 422
|
||||
assert response.json() == {'detail': 'Invalid URL received: bad'}
|
||||
m1.assert_not_called()
|
||||
|
||||
response = client_with_auth.post("/url/archive", json={"url": "https://example.com"})
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {'id': '123-456-789'}
|
||||
|
||||
m1.assert_called_once()
|
||||
called_val = m1.call_args.args[0]
|
||||
assert json.loads(called_val) == {"id": None, "url": "https://example.com", "result": None, "public": True, "author_id": "rick@example.com", "group_id": None, "tags": [], "rearchive": True}
|
||||
|
||||
|
||||
def test_search_by_url_unauthenticated(client):
|
||||
response = client.get("/url/search")
|
||||
assert response.status_code == 403
|
||||
assert response.json() == NO_AUTH
|
||||
|
||||
|
||||
def test_search_by_url(client_with_auth, db_session):
|
||||
# tests the search endpoint, including through some db data for the endpoint params
|
||||
response = client_with_auth.get("/url/search")
|
||||
assert response.status_code == 422
|
||||
assert response.json()["detail"][0]["msg"] == "Field required"
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
from db import crud
|
||||
for i in range(11):
|
||||
crud.create_task(db_session, ArchiveCreate(id=f"url-456-{i}", url="https://example.com" if i < 10 else "https://something-else.com", result={}, public=True, author_id="rick@example.com", group_id=None), [], [])
|
||||
#NB: this insertion is too fast for the ordering to be correct as they are within the same second
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com")
|
||||
assert response.status_code == 200
|
||||
assert len(j := response.json()) == 10
|
||||
assert "url-456-0" in [i["id"] for i in j]
|
||||
assert "url-456-9" in [i["id"] for i in j]
|
||||
assert "url-456-10" not in [i["id"] for i in j]
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com&limit=5")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 5
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com&skip=5&limit=2")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 2
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com&archived_before=2010-01-01")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 0
|
||||
|
||||
response = client_with_auth.get("/url/search?url=https://example.com&archived_affter=2010-01-01")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 10
|
||||
|
||||
|
||||
def test_latest_unauthenticated(client):
|
||||
response = client.get("/url/latest")
|
||||
assert response.status_code == 403
|
||||
assert response.json() == NO_AUTH
|
||||
|
||||
|
||||
def test_latest(client_with_auth, db_session):
|
||||
response = client_with_auth.get("/url/latest")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
from db import crud
|
||||
for i in range(11):
|
||||
crud.create_task(db_session, ArchiveCreate(id=f"latest-456-{i}", url="https://example.com", result={}, public=True, author_id="morty@example.com" if i < 10 else "rick@example.com", group_id=None), [], [])
|
||||
#NB: this insertion is too fast for the ordering to be correct as they are within the same second
|
||||
|
||||
# user must exist for /latest to work
|
||||
crud.get_user(db_session, "morty@example.com")
|
||||
|
||||
response = client_with_auth.get("/url/latest")
|
||||
assert response.status_code == 200
|
||||
assert len(j := response.json()) == 10
|
||||
assert "latest-456-0" in [i["id"] for i in j]
|
||||
assert "latest-456-9" in [i["id"] for i in j]
|
||||
assert "latest-456-10" not in [i["id"] for i in j]
|
||||
|
||||
response = client_with_auth.get("/url/latest?limit=5")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 5
|
||||
|
||||
response = client_with_auth.get("/url/latest?skip=5&limit=2")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 2
|
||||
|
||||
|
||||
def test_lookup_unauthenticated(client):
|
||||
response = client.get("/url/123-456-789")
|
||||
assert response.status_code == 403
|
||||
assert response.json() == NO_AUTH
|
||||
|
||||
|
||||
def test_lookup(client_with_auth, db_session):
|
||||
response = client_with_auth.get("/url/lookup-123-456-789")
|
||||
assert response.status_code == 404
|
||||
assert response.json() == {"detail": "Archive not found"}
|
||||
|
||||
from db import crud
|
||||
crud.create_task(db_session, ArchiveCreate(id="lookup-123-456-789", url="https://example.com", result={}, public=True, author_id="rick@example.com", group_id=None), [], [])
|
||||
|
||||
response = client_with_auth.get("/url/lookup-123-456-789")
|
||||
assert response.status_code == 200
|
||||
j = response.json()
|
||||
assert j["id"] == "lookup-123-456-789"
|
||||
assert j["url"] == "https://example.com"
|
||||
assert j["result"] == {}
|
||||
assert j["public"] == True
|
||||
assert j["author_id"] == "rick@example.com"
|
||||
assert j["group_id"] == None
|
||||
assert j["tags"] == []
|
||||
assert j["updated_at"] == None
|
||||
assert j["rearchive"] == True
|
||||
|
||||
|
||||
def test_delete_task_unauthenticated(client):
|
||||
response = client.delete("/url/123-456-789")
|
||||
assert response.status_code == 403
|
||||
assert response.json() == NO_AUTH
|
||||
|
||||
|
||||
def test_delete_task(client_with_auth, db_session):
|
||||
response = client_with_auth.delete("/url/delete-123-456-789")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"id": "delete-123-456-789", "deleted": False}
|
||||
|
||||
from db import crud
|
||||
crud.create_task(db_session, ArchiveCreate(id="delete-123-456-789", url="https://example.com", result={}, public=True, author_id="morty@example.com", group_id=None), [], [])
|
||||
|
||||
response = client_with_auth.delete("/url/delete-123-456-789")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"id": "delete-123-456-789", "deleted": True}
|
||||
|
||||
Reference in New Issue
Block a user