allows search to happen with API_TOKEN

This commit is contained in:
msramalho
2023-09-20 11:30:57 +01:00
parent f7160aad91
commit c6cd027e13
4 changed files with 27 additions and 6 deletions

View File

@@ -2,3 +2,8 @@ DATABASE_PATH="sqlite:///./auto-archiver.db"
USER_GROUPS_FILENAME=user-groups.yaml USER_GROUPS_FILENAME=user-groups.yaml
CHROME_APP_IDS=000000000000000000000000000000000000000000000.apps.googleusercontent.com,000000000000000000000000000000000000000000001.apps.googleusercontent.com CHROME_APP_IDS=000000000000000000000000000000000000000000000.apps.googleusercontent.com,000000000000000000000000000000000000000000001.apps.googleusercontent.com
#ALLOWED_ORIGINS="http://localhost:8004" # dev only #ALLOWED_ORIGINS="http://localhost:8004" # dev only
STATIC_FILE="/app/your-file.txt"
STATIC_FILE_PASSWORD=TODO
API_BEARER_TOKEN=TODO

View File

@@ -3,6 +3,8 @@ from sqlalchemy.orm import Session, load_only
from sqlalchemy import Column, or_ from sqlalchemy import Column, or_
from loguru import logger from loguru import logger
from datetime import datetime from datetime import datetime
from security import ALLOW_ANY_EMAIL
from . import models, schemas from . import models, schemas
import yaml, os import yaml, os
@@ -21,9 +23,13 @@ def get_tasks(db: Session, skip: int = 0, limit: int = 100):
def search_tasks_by_url(db: Session, url: str, email: str, skip: int = 0, limit: int = 100, archived_after: datetime = None, archived_before: datetime = None): def search_tasks_by_url(db: Session, url: str, email: str, skip: int = 0, limit: int = 100, archived_after: datetime = None, archived_before: datetime = None):
email = email.lower() # searches for partial URLs, if email is * no ownership filtering happens
groups = get_user_groups(db, email) query = base_query(db)
query = base_query(db).filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups))).filter(models.Archive.url.like(f'%{url}%')) if email != ALLOW_ANY_EMAIL:
email = email.lower()
groups = get_user_groups(db, email)
query = query.filter(or_(models.Archive.public == True, models.Archive.author_id == email, models.Archive.group_id.in_(groups)))
query = query.filter(models.Archive.url.like(f'%{url}%'))
if archived_after: if archived_after:
query = query.filter(models.Archive.created_at >= archived_after) query = query.filter(models.Archive.created_at >= archived_after)
if archived_before: if archived_before:

View File

@@ -18,7 +18,7 @@ from worker import create_archive_task, create_sheet_task, celery, insert_result
from db import crud, models, schemas from db import crud, models, schemas
from db.database import engine, SessionLocal from db.database import engine, SessionLocal
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from security import get_bearer_auth, get_basic_auth, get_server_auth, bearer_security from security import get_bearer_auth, get_basic_auth, get_server_auth, bearer_security, get_bearer_auth_token_or_jwt
from auto_archiver import Metadata from auto_archiver import Metadata
load_dotenv() load_dotenv()
@@ -82,8 +82,7 @@ def get_user_groups(db: Session = Depends(get_db), email = Depends(get_bearer_au
return crud.get_user_groups(db, email) return crud.get_user_groups(db, email)
@app.get("/tasks/search-url", response_model=list[schemas.Archive]) @app.get("/tasks/search-url", response_model=list[schemas.Archive])
def search_by_url(url:str, skip: int = 0, limit: int = 100, archived_after:datetime=None, archived_before:datetime=None, db: Session = Depends(get_db), email = Depends(get_bearer_auth)): def search_by_url(url:str, skip: int = 0, limit: int = 100, archived_after:datetime=None, archived_before:datetime=None, db: Session = Depends(get_db), email = Depends(get_bearer_auth_token_or_jwt)):
#TODO: test strip
return crud.search_tasks_by_url(db, url.strip(), email, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before) return crud.search_tasks_by_url(db, url.strip(), email, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before)
@app.get("/tasks/sync", response_model=list[schemas.Archive]) @app.get("/tasks/sync", response_model=list[schemas.Archive])

View File

@@ -18,8 +18,19 @@ basic_security = HTTPBasic()
bearer_security = HTTPBearer() bearer_security = HTTPBearer()
# --------------------- Bearer Auth # --------------------- Bearer Auth
ALLOW_ANY_EMAIL = "*"
API_BEARER_TOKEN = os.environ.get("API_BEARER_TOKEN", "") # min length is 20 chars
async def get_bearer_auth_token_or_jwt(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
# tries to use the static API_KEY and defaults to google JWT auth
access_token = credentials.credentials
if len(API_BEARER_TOKEN) >= 20:
current_token_bytes = access_token.encode("utf8")
is_correct_token = secrets.compare_digest(current_token_bytes, API_BEARER_TOKEN.encode("utf8"))
if is_correct_token: return ALLOW_ANY_EMAIL # any email works
return await get_bearer_auth(credentials)
async def get_bearer_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)): async def get_bearer_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
# validates the Bearer token in the case that it requires it # validates the Bearer token in the case that it requires it
access_token = credentials.credentials access_token = credentials.credentials