diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index 1dc92a0..b79290b 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -8,6 +8,7 @@ services: - ALLOWED_ORIGINS=http://localhost:8004,chrome-extension://ojcimmjndnlmmlgnjaeojoebaceokpdp - SERVICE_PASSWORD=dev-service-password - STATIC_FILE_PASSWORD=dev-static-file-password + - API_BEARER_TOKEN=dev-api-bearer-token worker: restart: "no" diff --git a/src/main.py b/src/main.py index 4163ecf..027985f 100644 --- a/src/main.py +++ b/src/main.py @@ -19,7 +19,7 @@ from worker import create_archive_task, create_sheet_task, celery, insert_result from db import crud, models, schemas from db.database import engine, SessionLocal from sqlalchemy.orm import Session -from security import get_bearer_auth, get_basic_auth, get_server_auth, bearer_security, get_bearer_auth_token_or_jwt +from security import get_user_auth, static_api_key_auth, service_api_key_auth, bearer_security, get_token_or_user_auth from auto_archiver import Metadata load_dotenv() @@ -46,7 +46,7 @@ EXCEPTION_COUNTER = Counter( labelnames=("types",) ) # prometheus exposed in /metrics with authentication -Instrumentator(should_group_status_codes=False, excluded_handlers=["/metrics"]).instrument(app).expose(app, dependencies=[Depends(get_server_auth)]) +Instrumentator(should_group_status_codes=False, excluded_handlers=["/metrics"]).instrument(app).expose(app, dependencies=[Depends(service_api_key_auth)]) app.mount("/static", StaticFiles(directory="static"), name="static") @@ -78,7 +78,7 @@ async def home(request: Request): status = {"version": VERSION, "breakingChanges": BREAKING_CHANGES} try: # if authenticated will load available groups - email = await get_bearer_auth(await bearer_security(request)) + email = await get_user_auth(await bearer_security(request)) db: Session = next(get_db()) status["groups"] = crud.get_user_groups(db, email) except HTTPException: pass @@ -89,19 +89,19 @@ async def home(request: Request): #-----Submit URL and manipulate tasks. Bearer protected below @app.get("/groups", response_model=list[str]) -def get_user_groups(db: Session = Depends(get_db), email = Depends(get_bearer_auth)): +def get_user_groups(db: Session = Depends(get_db), email = Depends(get_user_auth)): return crud.get_user_groups(db, email) @app.get("/tasks/search-url", response_model=list[schemas.Archive]) -def search_by_url(url:str, skip: int = 0, limit: int = 100, archived_after:datetime=None, archived_before:datetime=None, db: Session = Depends(get_db), email = Depends(get_bearer_auth_token_or_jwt)): +def search_by_url(url:str, skip: int = 0, limit: int = 100, archived_after:datetime=None, archived_before:datetime=None, db: Session = Depends(get_db), email = Depends(get_token_or_user_auth)): return crud.search_tasks_by_url(db, url.strip(), email, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before) @app.get("/tasks/sync", response_model=list[schemas.Archive]) -def search(skip: int = 0, limit: int = 100, db: Session = Depends(get_db), email = Depends(get_bearer_auth)): +def search(skip: int = 0, limit: int = 100, db: Session = Depends(get_db), email = Depends(get_user_auth)): return crud.search_tasks_by_email(db, email, skip=skip, limit=limit) @app.post("/tasks", status_code=201) -def archive_tasks(archive:schemas.ArchiveCreate, email = Depends(get_bearer_auth_token_or_jwt)): +def archive_tasks(archive:schemas.ArchiveCreate, email = Depends(get_token_or_user_auth)): archive.author_id = email url = archive.url logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {url}") @@ -112,11 +112,11 @@ def archive_tasks(archive:schemas.ArchiveCreate, email = Depends(get_bearer_auth return JSONResponse({"id": task.id}) @app.get("/archive/{task_id}") -def lookup(task_id, db: Session = Depends(get_db), email = Depends(get_bearer_auth_token_or_jwt)): +def lookup(task_id, db: Session = Depends(get_db), email = Depends(get_token_or_user_auth)): return crud.get_task(db, task_id, email) @app.get("/tasks/{task_id}") -def get_status(task_id, email = Depends(get_bearer_auth)): +def get_status(task_id, email = Depends(get_user_auth)): logger.info(f"status check for user {email} task {task_id}") task = AsyncResult(task_id, app=celery) try: @@ -143,7 +143,7 @@ def get_status(task_id, email = Depends(get_bearer_auth)): }) @app.delete("/tasks/{task_id}") -def delete_task(task_id, db: Session = Depends(get_db), email = Depends(get_bearer_auth)): +def delete_task(task_id, db: Session = Depends(get_db), email = Depends(get_user_auth)): logger.info(f"deleting task {task_id} request by {email}") return JSONResponse({ "id": task_id, @@ -152,7 +152,7 @@ def delete_task(task_id, db: Session = Depends(get_db), email = Depends(get_bear #----- Google Sheets Logic @app.post("/sheet", status_code=201) -def archive_sheet(sheet:schemas.SubmitSheet, email = Depends(get_bearer_auth)): +def archive_sheet(sheet:schemas.SubmitSheet, email = Depends(get_user_auth)): logger.info(f"SHEET TASK for {sheet=}") sheet.author_id = email if not sheet.sheet_name and not sheet.sheet_id: @@ -161,7 +161,7 @@ def archive_sheet(sheet:schemas.SubmitSheet, email = Depends(get_bearer_auth)): return JSONResponse({"id": task.id}) @app.post("/sheet_service", status_code=201) -def archive_sheet_service(sheet:schemas.SubmitSheet, basic_auth = Depends(get_server_auth)): +def archive_sheet_service(sheet:schemas.SubmitSheet, auth = Depends(service_api_key_auth)): logger.info(f"SHEET TASK for {sheet=}") sheet.author_id = sheet.author_id or "api-endpoint" if not sheet.sheet_name and not sheet.sheet_id: @@ -171,7 +171,7 @@ def archive_sheet_service(sheet:schemas.SubmitSheet, basic_auth = Depends(get_se #----- endpoint to submit data archived elsewhere @app.post("/submit-archive", status_code=201) -def submit_manual_archive(manual:schemas.SubmitManual, basic_auth = Depends(get_basic_auth)): +def submit_manual_archive(manual:schemas.SubmitManual, auth = Depends(static_api_key_auth)): result = Metadata.from_json(manual.result) logger.info(f"MANUAL SUBMIT {result.get_url()} {manual.author_id}") manual.tags.add("manual") diff --git a/src/security.py b/src/security.py index 5749130..061ddfd 100644 --- a/src/security.py +++ b/src/security.py @@ -14,7 +14,6 @@ logger.info(f"{CHROME_APP_IDS=}") BLOCKED_EMAILS = set([e.strip().lower() for e in os.environ.get("BLOCKED_EMAILS", "").split(",")]) logger.info(f"{len(BLOCKED_EMAILS)=}") -basic_security = HTTPBasic() bearer_security = HTTPBearer() ALLOW_ANY_EMAIL = "*" @@ -22,20 +21,45 @@ ALLOW_ANY_EMAIL = "*" def secure_compare(token, api_key): return secrets.compare_digest(token.encode("utf8"), api_key.encode("utf8")) -# --------------------- Bearer Auth +# Factory method to create an authentication dependency for a specific key +def api_key_auth(api_key): + + async def auth(bearer: HTTPAuthorizationCredentials = Depends(bearer_security), auto_error=True): + assert len(api_key) >= 20, "Invalid API key, must be at least 20 chars" + + is_correct = secure_compare(bearer.credentials, api_key) + if is_correct: return True + + if auto_error: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Wrong auth credentials", + ) + return False + + return auth + +# --------------------- Static Auth for local AA deployments to add archives to the API +SFP = os.environ.get("STATIC_FILE_PASSWORD", "") # min length is 20 chars +static_api_key_auth = api_key_auth(SFP) + +# --------------------- Service Auth for the AA setup tool and Prometheus +SERVICE_PASSWORD = os.environ.get("SERVICE_PASSWORD", "") # min length is 20 chars +service_api_key_auth = api_key_auth(SERVICE_PASSWORD) + +# --------------------- Token Auth for AA itself to query the API API_BEARER_TOKEN = os.environ.get("API_BEARER_TOKEN", "") # min length is 20 chars -async def get_bearer_auth_token_or_jwt(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)): +token_api_key_auth = api_key_auth(API_BEARER_TOKEN) + +async def get_token_or_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)): # tries to use the static API_KEY and defaults to google JWT auth access_token = credentials.credentials - if len(API_BEARER_TOKEN) >= 20: - is_correct_token = secure_compare(access_token, API_BEARER_TOKEN) - if is_correct_token: return ALLOW_ANY_EMAIL - return await get_bearer_auth(credentials) + if token_api_key_auth(access_token, auto_error=False): return ALLOW_ANY_EMAIL + return await get_user_auth(credentials) -async def get_bearer_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)): +async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)): # validates the Bearer token in the case that it requires it - access_token = credentials.credentials - valid_user, info = authenticate_user(access_token) + valid_user, info = authenticate_user(credentials.credentials) if valid_user: return info logger.debug(f"TOKEN FAILURE: {valid_user=} {info=}") raise HTTPException( @@ -44,7 +68,6 @@ async def get_bearer_auth(credentials: HTTPAuthorizationCredentials = Depends(be headers={"WWW-Authenticate": "Bearer"}, ) - def authenticate_user(access_token): # https://cloud.google.com/docs/authentication/token-types#access if type(access_token) != str or len(access_token) < 10: return False, "invalid access_token" @@ -65,38 +88,3 @@ def authenticate_user(access_token): logger.warning(f"EXCEPTION occurred: {e}") return False, f"EXCEPTION occurred" -# Temporary method until all clients migrate from basic to bearer -async def bearer_or_basic_auth(bearer: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error = False)), basic: HTTPBasicCredentials = Depends(HTTPBasic(auto_error = False))): - - if bearer: return bearer.credentials - if basic: return basic.password - - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Not authenticated", - ) - -# Factory method to create an authentication dependency for a specific key -def api_key_auth(api_key): - - async def auth(challenge = Depends(bearer_or_basic_auth)): - assert len(api_key) >= 20, "Invalid API key, must be at least 20 chars" - - is_correct = secure_compare(challenge, api_key) - if is_correct: return True - - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Wrong auth credentials", - ) - - return auth - -# --------------------- Basic Auth -SFP = os.environ.get("STATIC_FILE_PASSWORD", "") # min length is 20 chars -get_basic_auth = api_key_auth(SFP) - -# --------------------- Server-side Auth -SERVICE_PASSWORD = os.environ.get("SERVICE_PASSWORD", "") # min length is 20 chars -get_server_auth = api_key_auth(SERVICE_PASSWORD) -