From e37b848ef51b526a7a59acef0d085619832953de Mon Sep 17 00:00:00 2001 From: Lilia Kai Date: Thu, 21 Dec 2023 14:05:15 +0100 Subject: [PATCH] Refactor auth methods De duplicate some common codepaths. Also, for routes accepting basic authentication, allow bearer auth as an alternative. This allows clients to switch to bearer auth opportunistically, but we won't have to coordinate deployments. Basic auth should be deprecated since we don't really use a user/password auth scheme. --- docker-compose.dev.yml | 4 ++- src/security.py | 64 ++++++++++++++++++++++-------------------- 2 files changed, 37 insertions(+), 31 deletions(-) diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index cb51d8d..1dc92a0 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -5,7 +5,9 @@ services: restart: "no" environment: - SERVE_LOCAL_ARCHIVE=/app/local_archive # See orchestration.yaml local_storage.save_to - - ALLOWED_ORIGINS=http://localhost:8004 + - ALLOWED_ORIGINS=http://localhost:8004,chrome-extension://ojcimmjndnlmmlgnjaeojoebaceokpdp + - SERVICE_PASSWORD=dev-service-password + - STATIC_FILE_PASSWORD=dev-static-file-password worker: restart: "no" diff --git a/src/security.py b/src/security.py index cfbd83d..5749130 100644 --- a/src/security.py +++ b/src/security.py @@ -17,18 +17,19 @@ logger.info(f"{len(BLOCKED_EMAILS)=}") basic_security = HTTPBasic() bearer_security = HTTPBearer() -# --------------------- Bearer Auth ALLOW_ANY_EMAIL = "*" +def secure_compare(token, api_key): + return secrets.compare_digest(token.encode("utf8"), api_key.encode("utf8")) +# --------------------- Bearer Auth API_BEARER_TOKEN = os.environ.get("API_BEARER_TOKEN", "") # min length is 20 chars async def get_bearer_auth_token_or_jwt(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)): # tries to use the static API_KEY and defaults to google JWT auth access_token = credentials.credentials - if len(API_BEARER_TOKEN) >= 20: - current_token_bytes = access_token.encode("utf8") - is_correct_token = secrets.compare_digest(current_token_bytes, API_BEARER_TOKEN.encode("utf8")) - if is_correct_token: return ALLOW_ANY_EMAIL # any email works + if len(API_BEARER_TOKEN) >= 20: + is_correct_token = secure_compare(access_token, API_BEARER_TOKEN) + if is_correct_token: return ALLOW_ANY_EMAIL return await get_bearer_auth(credentials) async def get_bearer_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)): @@ -64,35 +65,38 @@ def authenticate_user(access_token): logger.warning(f"EXCEPTION occurred: {e}") return False, f"EXCEPTION occurred" +# Temporary method until all clients migrate from basic to bearer +async def bearer_or_basic_auth(bearer: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error = False)), basic: HTTPBasicCredentials = Depends(HTTPBasic(auto_error = False))): + + if bearer: return bearer.credentials + if basic: return basic.password + + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + ) + +# Factory method to create an authentication dependency for a specific key +def api_key_auth(api_key): + + async def auth(challenge = Depends(bearer_or_basic_auth)): + assert len(api_key) >= 20, "Invalid API key, must be at least 20 chars" + + is_correct = secure_compare(challenge, api_key) + if is_correct: return True + + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Wrong auth credentials", + ) + + return auth # --------------------- Basic Auth SFP = os.environ.get("STATIC_FILE_PASSWORD", "") # min length is 20 chars - - -async def get_basic_auth(credentials: HTTPBasicCredentials = Depends(basic_security)): - # validates that the Basic token in the case that it requires it - assert len(SFP) >= 20, "Invalid STATIC_FILE_PASSWORD, must be at least 20 chars" - current_password_bytes = credentials.password.encode("utf8") - is_correct_password = secrets.compare_digest(current_password_bytes, SFP.encode("utf8")) - if is_correct_password: return True - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Wrong auth credentials", - headers={"WWW-Authenticate": "Basic"} - ) +get_basic_auth = api_key_auth(SFP) # --------------------- Server-side Auth SERVICE_PASSWORD = os.environ.get("SERVICE_PASSWORD", "") # min length is 20 chars +get_server_auth = api_key_auth(SERVICE_PASSWORD) - -async def get_server_auth(credentials: HTTPBasicCredentials = Depends(basic_security)): - # validates that the Basic token in the case that it requires it - assert len(SERVICE_PASSWORD) >= 20, "Invalid SERVICE_PASSWORD, must be at least 20 chars" - current_password_bytes = credentials.password.encode("utf8") - is_correct_password = secrets.compare_digest(current_password_bytes, SERVICE_PASSWORD.encode("utf8")) - if is_correct_password: return True - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Wrong auth credentials", - headers={"WWW-Authenticate": "Basic"} - )