From 56a81f6ec0809bbaedec24b502a8d0e078245434 Mon Sep 17 00:00:00 2001 From: msramalho <19508417+msramalho@users.noreply.github.com> Date: Fri, 18 Oct 2024 19:16:05 +0100 Subject: [PATCH] adds new tests and improvements --- src/db/crud.py | 6 +- src/endpoints/url.py | 8 +- src/tests/conftest.py | 30 +++-- src/tests/endpoints/test_default.py | 44 +++++--- src/tests/endpoints/test_task.py | 12 +- src/tests/endpoints/test_url.py | 165 +++++++++++++++++++++++++++- 6 files changed, 217 insertions(+), 48 deletions(-) diff --git a/src/db/crud.py b/src/db/crud.py index a62113f..097e556 100644 --- a/src/db/crud.py +++ b/src/db/crud.py @@ -16,9 +16,9 @@ DATABASE_QUERY_LIMIT = Settings().DATABASE_QUERY_LIMIT # --------------- TASK = Archive -def get_archive(db: Session, task_id: str, email: str): +def get_archive(db: Session, id: str, email: str): email = email.lower() - query = base_query(db).filter(models.Archive.id == task_id) + query = base_query(db).filter(models.Archive.id == id) if email != ALLOW_ANY_EMAIL: groups = get_user_groups(db, email) query = query.filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups))) @@ -45,7 +45,7 @@ def search_archives_by_url(db: Session, url: str, email: str, skip: int = 0, lim def search_archives_by_email(db: Session, email: str, skip: int = 0, limit: int = 100): email = email.lower() - return base_query(db).filter(models.Archive.author.has(email=email)).offset(skip).limit(min(limit, DATABASE_QUERY_LIMIT)).all() + return base_query(db).filter(models.Archive.author.has(email=email)).order_by(models.Archive.created_at.desc()).offset(skip).limit(min(limit, DATABASE_QUERY_LIMIT)).all() def create_task(db: Session, task: schemas.ArchiveCreate, tags: list[models.Tag], urls: list[models.ArchiveUrl]): diff --git a/src/endpoints/url.py b/src/endpoints/url.py index dc1ca68..0771ed6 100644 --- a/src/endpoints/url.py +++ b/src/endpoints/url.py @@ -25,7 +25,7 @@ def archive_url(archive: schemas.ArchiveCreate, email=Depends(get_token_or_user_ logger.info("creating task") task = create_archive_task.delay(archive.model_dump_json()) task_response = schemas.Task(id=task.id) - return JSONResponse(task_response.model_dump()) + return JSONResponse(task_response.model_dump(), status_code=201) @url_router.get("/search", response_model=list[schemas.Archive], summary="Search for archive entries by URL.") @@ -44,13 +44,15 @@ def latest(skip: int = 0, limit: int = 25, db: Session = Depends(get_db_dependen @url_router.get("/{id}", response_model=schemas.Archive, summary="Fetch a single URL archive by the associated id.") def lookup(id, db: Session = Depends(get_db_dependency), email=Depends(get_token_or_user_auth)): - return crud.get_archive(db, id, email) + archive = crud.get_archive(db, id, email) + if archive is None: + raise HTTPException(status_code=404, detail="Archive not found") + return archive @url_router.delete("/{id}", response_model=schemas.TaskDelete, summary="Delete a single URL archive by id.") def delete_task(id, db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)): logger.info(f"deleting url archive task {id} request by {email}") - #TODO: use response model? return JSONResponse({ "id": id, "deleted": crud.soft_delete_task(db, id, email) diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 6a417de..a63b5ba 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -43,7 +43,7 @@ def test_db(settings: Settings): models.Base.metadata.drop_all(bind=engine) for suffix in ["", "-wal", "-shm"]: new_fs = fs + suffix - if os.path.exists(new_fs): + if os.path.exists(new_fs): os.remove(new_fs) @@ -56,13 +56,9 @@ def db_session(test_db): @pytest.fixture() -def app(db_session, settings): +def app(db_session): from web.main import app_factory app = app_factory() - from security import get_token_or_user_auth - app.dependency_overrides[get_token_or_user_auth] = lambda: "example@email.com" - # app.dependency_overrides[settings] = lambda: settings - # app.dependency_overrides[get_session] = lambda: db_session return app @@ -71,16 +67,16 @@ def client(app): client = TestClient(app) return client -# # create test data and insert it into the database -# def create_test_data(): -# from db.database import SessionLocal -# from db.models import Task -# db = SessionLocal() -# task = Task(id="test-task-id", status="PENDING") -# db.add(task) -# db.commit() -# db.refresh(task) -# db.close() +@pytest.fixture() +def app_with_auth(app): + from security import get_token_or_user_auth, get_user_auth + app.dependency_overrides[get_token_or_user_auth] = lambda: "rick@example.com" + app.dependency_overrides[get_user_auth] = lambda: "morty@example.com" + return app -# return task.id + +@pytest.fixture() +def client_with_auth(app_with_auth): + client = TestClient(app_with_auth) + return client diff --git a/src/tests/endpoints/test_default.py b/src/tests/endpoints/test_default.py index 4e049c6..0d46a99 100644 --- a/src/tests/endpoints/test_default.py +++ b/src/tests/endpoints/test_default.py @@ -3,8 +3,8 @@ from fastapi.testclient import TestClient from core.config import VERSION -def test_endpoint_home(client): - r = client.get("/") +def test_endpoint_home(client_with_auth): + r = client_with_auth.get("/") assert r.status_code == 200 j = r.json() assert "version" in j and j["version"] == VERSION @@ -15,8 +15,8 @@ def test_endpoint_home(client): @patch("endpoints.default.bearer_security", new_callable=AsyncMock) @patch("endpoints.default.get_user_auth", new_callable=AsyncMock, return_value="test@example.com") @patch("endpoints.default.crud.get_user_groups", return_value=["group1", "group2"]) -def test_endpoint_home_with_groups(m1, m2, m3, client): - r = client.get("/") +def test_endpoint_home_with_groups(m1, m2, m3, client_with_auth): + r = client_with_auth.get("/") assert r.status_code == 200 j = r.json() assert "version" in j and j["version"] == VERSION @@ -24,9 +24,20 @@ def test_endpoint_home_with_groups(m1, m2, m3, client): assert "groups" in j assert j["groups"] == ["group1", "group2"] +@patch("endpoints.default.bearer_security", new_callable=AsyncMock) +@patch("endpoints.default.get_user_auth", new_callable=AsyncMock, return_value="test@example.com") +@patch("endpoints.default.crud.get_user_groups", side_effect=Exception('mocked error')) +def test_endpoint_home_with_groups_exception(m1, m2, m3, client_with_auth): # mocks call that triggers an internal error + r = client_with_auth.get("/") + assert r.status_code == 200 + j = r.json() + assert "version" in j and j["version"] == VERSION + assert "breakingChanges" in j + assert "groups" not in j -def test_endpoint_health(client): - r = client.get("/health") + +def test_endpoint_health(client_with_auth): + r = client_with_auth.get("/health") assert r.status_code == 200 assert r.json() == {"status": "ok"} @@ -36,27 +47,30 @@ def test_endpoint_groups_403(client): assert r.status_code == 403 +def test_endpoint_groups_empty(client_with_auth): + r = client_with_auth.get("/groups") + assert r.status_code == 200 + assert r.json() == [] + + @patch("endpoints.default.crud.get_user_groups", return_value=["group1", "group2"]) def test_endpoint_groups(m1, app): - async def mock_get_user_auth(): return True from security import get_user_auth - app.dependency_overrides[get_user_auth] = mock_get_user_auth + app.dependency_overrides[get_user_auth] = lambda: True client = TestClient(app) r = client.get("/groups") assert r.status_code == 200 - j = r.json() - assert j == ["group1", "group2"] - app.dependency_overrides = {} + assert r.json() == ["group1", "group2"] -def test_no_serve_local_archive_by_default(client): - r = client.get("/app/local_archive_test/temp.txt") +def test_no_serve_local_archive_by_default(client_with_auth): + r = client_with_auth.get("/app/local_archive_test/temp.txt") assert r.status_code == 404 -def test_favicon(client): - r = client.get("/favicon.ico") +def test_favicon(client_with_auth): + r = client_with_auth.get("/favicon.ico") assert r.status_code == 200 assert r.headers["content-type"] == "image/vnd.microsoft.icon" diff --git a/src/tests/endpoints/test_task.py b/src/tests/endpoints/test_task.py index e6ff8be..36285f5 100644 --- a/src/tests/endpoints/test_task.py +++ b/src/tests/endpoints/test_task.py @@ -3,11 +3,11 @@ from fastapi.testclient import TestClient @patch("endpoints.task.AsyncResult") -def test_get_status_success(mock_async_result, client): +def test_get_status_success(mock_async_result, client_with_auth): mock_async_result.return_value.status = "SUCCESS" mock_async_result.return_value.result = {"data": "some result"} - response = client.get("/task/test-task-id") + response = client_with_auth.get("/task/test-task-id") assert response.status_code == 200 assert response.json() == { @@ -18,12 +18,12 @@ def test_get_status_success(mock_async_result, client): @patch("endpoints.task.AsyncResult") -def test_get_status_failure(mock_async_result, client): +def test_get_status_failure(mock_async_result, client_with_auth): mock_async_result.return_value.status = "FAILURE" mock_async_result.return_value.result = Exception("Some error") - response = client.get("/task/test-task-id") + response = client_with_auth.get("/task/test-task-id") assert response.status_code == 200 assert response.json() == { @@ -34,11 +34,11 @@ def test_get_status_failure(mock_async_result, client): @patch("endpoints.task.AsyncResult") -def test_get_status_pending(mock_async_result, client): +def test_get_status_pending(mock_async_result, client_with_auth): mock_async_result.return_value.status = "PENDING" mock_async_result.return_value.result = None - response = client.get("/task/test-task-id") + response = client_with_auth.get("/task/test-task-id") assert response.status_code == 200 assert response.json() == { diff --git a/src/tests/endpoints/test_url.py b/src/tests/endpoints/test_url.py index 911554a..e5f0afa 100644 --- a/src/tests/endpoints/test_url.py +++ b/src/tests/endpoints/test_url.py @@ -1,6 +1,163 @@ -# def test_archive_url(client): -# response = client.get("/archive/url") -# assert response.status_code == 200 -# assert response.json() == {"message": "Archive URL"} +import json +import time +from unittest.mock import patch + +from db.schemas import ArchiveCreate, TaskResult + +NO_AUTH = {'detail': 'Not authenticated'} + + +def test_archive_url_unauthenticated(client): + response = client.post("/url/archive") + assert response.status_code == 403 + assert response.json() == NO_AUTH + + # this will call archive/{id} + response = client.get("/url/archive") + assert response.status_code == 403 + assert response.json() == NO_AUTH + + +@patch("worker.create_archive_task.delay", return_value=TaskResult(id="123-456-789", status="PENDING", result="")) +def test_archive_url(m1, client_with_auth): + response = client_with_auth.post("/url/archive", json={"url": "bad"}) + assert response.status_code == 422 + assert response.json() == {'detail': 'Invalid URL received: bad'} + m1.assert_not_called() + + response = client_with_auth.post("/url/archive", json={"url": "https://example.com"}) + assert response.status_code == 201 + assert response.json() == {'id': '123-456-789'} + + m1.assert_called_once() + called_val = m1.call_args.args[0] + assert json.loads(called_val) == {"id": None, "url": "https://example.com", "result": None, "public": True, "author_id": "rick@example.com", "group_id": None, "tags": [], "rearchive": True} + + +def test_search_by_url_unauthenticated(client): + response = client.get("/url/search") + assert response.status_code == 403 + assert response.json() == NO_AUTH + + +def test_search_by_url(client_with_auth, db_session): + # tests the search endpoint, including through some db data for the endpoint params + response = client_with_auth.get("/url/search") + assert response.status_code == 422 + assert response.json()["detail"][0]["msg"] == "Field required" + + response = client_with_auth.get("/url/search?url=https://example.com") + assert response.status_code == 200 + assert response.json() == [] + + from db import crud + for i in range(11): + crud.create_task(db_session, ArchiveCreate(id=f"url-456-{i}", url="https://example.com" if i < 10 else "https://something-else.com", result={}, public=True, author_id="rick@example.com", group_id=None), [], []) + #NB: this insertion is too fast for the ordering to be correct as they are within the same second + + response = client_with_auth.get("/url/search?url=https://example.com") + assert response.status_code == 200 + assert len(j := response.json()) == 10 + assert "url-456-0" in [i["id"] for i in j] + assert "url-456-9" in [i["id"] for i in j] + assert "url-456-10" not in [i["id"] for i in j] + + response = client_with_auth.get("/url/search?url=https://example.com&limit=5") + assert response.status_code == 200 + assert len(response.json()) == 5 + + response = client_with_auth.get("/url/search?url=https://example.com&skip=5&limit=2") + assert response.status_code == 200 + assert len(response.json()) == 2 + + response = client_with_auth.get("/url/search?url=https://example.com&archived_before=2010-01-01") + assert response.status_code == 200 + assert len(response.json()) == 0 + + response = client_with_auth.get("/url/search?url=https://example.com&archived_affter=2010-01-01") + assert response.status_code == 200 + assert len(response.json()) == 10 + + +def test_latest_unauthenticated(client): + response = client.get("/url/latest") + assert response.status_code == 403 + assert response.json() == NO_AUTH + + +def test_latest(client_with_auth, db_session): + response = client_with_auth.get("/url/latest") + assert response.status_code == 200 + assert response.json() == [] + + from db import crud + for i in range(11): + crud.create_task(db_session, ArchiveCreate(id=f"latest-456-{i}", url="https://example.com", result={}, public=True, author_id="morty@example.com" if i < 10 else "rick@example.com", group_id=None), [], []) + #NB: this insertion is too fast for the ordering to be correct as they are within the same second + + # user must exist for /latest to work + crud.get_user(db_session, "morty@example.com") + + response = client_with_auth.get("/url/latest") + assert response.status_code == 200 + assert len(j := response.json()) == 10 + assert "latest-456-0" in [i["id"] for i in j] + assert "latest-456-9" in [i["id"] for i in j] + assert "latest-456-10" not in [i["id"] for i in j] + + response = client_with_auth.get("/url/latest?limit=5") + assert response.status_code == 200 + assert len(response.json()) == 5 + + response = client_with_auth.get("/url/latest?skip=5&limit=2") + assert response.status_code == 200 + assert len(response.json()) == 2 + + +def test_lookup_unauthenticated(client): + response = client.get("/url/123-456-789") + assert response.status_code == 403 + assert response.json() == NO_AUTH + + +def test_lookup(client_with_auth, db_session): + response = client_with_auth.get("/url/lookup-123-456-789") + assert response.status_code == 404 + assert response.json() == {"detail": "Archive not found"} + + from db import crud + crud.create_task(db_session, ArchiveCreate(id="lookup-123-456-789", url="https://example.com", result={}, public=True, author_id="rick@example.com", group_id=None), [], []) + + response = client_with_auth.get("/url/lookup-123-456-789") + assert response.status_code == 200 + j = response.json() + assert j["id"] == "lookup-123-456-789" + assert j["url"] == "https://example.com" + assert j["result"] == {} + assert j["public"] == True + assert j["author_id"] == "rick@example.com" + assert j["group_id"] == None + assert j["tags"] == [] + assert j["updated_at"] == None + assert j["rearchive"] == True + + +def test_delete_task_unauthenticated(client): + response = client.delete("/url/123-456-789") + assert response.status_code == 403 + assert response.json() == NO_AUTH + + +def test_delete_task(client_with_auth, db_session): + response = client_with_auth.delete("/url/delete-123-456-789") + assert response.status_code == 200 + assert response.json() == {"id": "delete-123-456-789", "deleted": False} + + from db import crud + crud.create_task(db_session, ArchiveCreate(id="delete-123-456-789", url="https://example.com", result={}, public=True, author_id="morty@example.com", group_id=None), [], []) + + response = client_with_auth.delete("/url/delete-123-456-789") + assert response.status_code == 200 + assert response.json() == {"id": "delete-123-456-789", "deleted": True}