From 2a4c6444b0a5e386ff6a599dd83a065b97cc6095 Mon Sep 17 00:00:00 2001 From: msramalho <19508417+msramalho@users.noreply.github.com> Date: Tue, 29 Oct 2024 17:24:03 +0000 Subject: [PATCH] fixes bad input for limit --- src/db/crud.py | 6 ++++-- src/tests/db/test_crud.py | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/db/crud.py b/src/db/crud.py index c3c069f..4d03318 100644 --- a/src/db/crud.py +++ b/src/db/crud.py @@ -15,6 +15,8 @@ DATABASE_QUERY_LIMIT = get_settings().DATABASE_QUERY_LIMIT # --------------- TASK = Archive +def get_limit(user_limit:int): + return max(1, min(user_limit, DATABASE_QUERY_LIMIT)) def get_archive(db: Session, id: str, email: str): email = email.lower() @@ -40,12 +42,12 @@ def search_archives_by_url(db: Session, url: str, email: str, skip: int = 0, lim query = query.filter(models.Archive.created_at > archived_after) if archived_before: query = query.filter(models.Archive.created_at < archived_before) - return query.order_by(models.Archive.created_at.desc()).offset(skip).limit(min(limit, DATABASE_QUERY_LIMIT)).all() + return query.order_by(models.Archive.created_at.desc()).offset(skip).limit(get_limit(limit)).all() def search_archives_by_email(db: Session, email: str, skip: int = 0, limit: int = 100): email = email.lower() - return base_query(db).filter(models.Archive.author_id == email).order_by(models.Archive.created_at.desc()).offset(skip).limit(min(limit, DATABASE_QUERY_LIMIT)).all() + return base_query(db).filter(models.Archive.author_id == email).order_by(models.Archive.created_at.desc()).offset(skip).limit(get_limit(limit)).all() def create_task(db: Session, task: schemas.ArchiveCreate, tags: list[models.Tag], urls: list[models.ArchiveUrl]): diff --git a/src/tests/db/test_crud.py b/src/tests/db/test_crud.py index 3913548..f293cf8 100644 --- a/src/tests/db/test_crud.py +++ b/src/tests/db/test_crud.py @@ -127,6 +127,7 @@ def test_search_archives_by_url(test_data, db_session): # limit assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, limit=10)) == 10 + assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, limit=-1)) == 1 # skip assert len(crud.search_archives_by_url(db_session, "https://example", ALLOW_ANY_EMAIL, skip=10)) == 90