diff --git a/src/db/crud.py b/src/db/crud.py index 97dacf3..ddd0412 100644 --- a/src/db/crud.py +++ b/src/db/crud.py @@ -104,6 +104,9 @@ def create_tag(db: Session, tag: str): db.refresh(db_tag) return db_tag +def is_active_user(db: Session, email: str) -> bool: + email = email.lower() + return len(email) and db.query(models.User).filter(models.User.email == email).count() > 0 def is_user_in_group(db: Session, group_name: str, email: str) -> models.Group: if email == ALLOW_ANY_EMAIL: return True diff --git a/src/db/schemas.py b/src/db/schemas.py index 2f04462..aa9abd7 100644 --- a/src/db/schemas.py +++ b/src/db/schemas.py @@ -53,4 +53,7 @@ class TaskResult(Task): result: str class TaskDelete(Task): - deleted: bool \ No newline at end of file + deleted: bool + +class ActiveUser(BaseModel): + active: bool \ No newline at end of file diff --git a/src/endpoints/default.py b/src/endpoints/default.py index 269fa3f..a294fda 100644 --- a/src/endpoints/default.py +++ b/src/endpoints/default.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import Session from core.config import VERSION, BREAKING_CHANGES from core.logging import log_error -from db import crud +from db import crud, schemas from db.database import get_db_dependency, get_db from web.security import get_user_auth, bearer_security @@ -30,6 +30,11 @@ async def health(): return JSONResponse({"status": "ok"}) +@default_router.get("/user/active", summary="Check if the user is active and can use the tool.", response_model=schemas.ActiveUser) +async def active(db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)): + return {"active": crud.is_active_user(db, email)} + + @default_router.get("/groups", response_model=list[str]) def get_user_groups(db: Session = Depends(get_db_dependency), email=Depends(get_user_auth)): return crud.get_user_groups(db, email) diff --git a/src/tests/db/test_crud.py b/src/tests/db/test_crud.py index daad766..edd9abf 100644 --- a/src/tests/db/test_crud.py +++ b/src/tests/db/test_crud.py @@ -293,6 +293,15 @@ def test_create_tag(db_session): assert second_tag.id == "tag-102" assert db_session.query(models.Tag).count() == 2 +def test_is_active_user(test_data, db_session): + from db import crud + + assert crud.is_active_user(db_session, "") == False + assert crud.is_active_user(db_session, "example.com") == False + assert crud.is_active_user(db_session, "unknown@example.com") == False + assert crud.is_active_user(db_session, "rick@example.com") == True + assert crud.is_active_user(db_session, "RICK@example.com") == True + def test_is_user_in_group(test_data, db_session): from db import crud diff --git a/src/tests/endpoints/test_default.py b/src/tests/endpoints/test_default.py index b840260..b5e2e77 100644 --- a/src/tests/endpoints/test_default.py +++ b/src/tests/endpoints/test_default.py @@ -2,6 +2,8 @@ from unittest.mock import AsyncMock, patch from fastapi.testclient import TestClient import pytest from core.config import VERSION +from tests.db.test_crud import test_data + def test_endpoint_home(client_with_auth): @@ -44,6 +46,24 @@ def test_endpoint_health(client_with_auth): assert r.json() == {"status": "ok"} +def test_endpoint_active_no_auth(client, test_no_auth): + test_no_auth(client.get, "/user/active") + + +def test_endpoint_active_true_user(client_with_auth): + r = client_with_auth.get("/user/active") + assert r.status_code == 200 + assert r.json() == {"active": True} + +def test_endpoint_active_true_user(client_with_auth, db_session): + from db import models + db_session.query(models.User).delete() + db_session.commit() + r = client_with_auth.get("/user/active") + assert r.status_code == 200 + assert r.json() == {"active": False} + + def test_endpoint_groups_no_auth(client, test_no_auth): test_no_auth(client.get, "/groups") @@ -79,9 +99,6 @@ def test_favicon(client_with_auth): assert r.headers["content-type"] == "image/vnd.microsoft.icon" -from tests.db.test_crud import test_data - - @pytest.mark.asyncio async def test_prometheus_metrics(test_data, client_with_auth, get_settings): # before metrics calculation @@ -117,4 +134,4 @@ async def test_prometheus_metrics(test_data, client_with_auth, get_settings): assert 'database_metrics{query="count_users"} 4.0' in r3.text assert 'database_metrics_counter_total{query="count_by_user",user="rick@example.com"} 34.0' in r3.text assert 'database_metrics_counter_total{query="count_by_user",user="morty@example.com"} 33.0' in r3.text - assert 'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0' in r3.text \ No newline at end of file + assert 'database_metrics_counter_total{query="count_by_user",user="jerry@example.com"} 33.0' in r3.text