Refactor auth methods

De duplicate some common codepaths. Also, for routes accepting basic
authentication, allow bearer auth as an alternative. This allows
clients to switch to bearer auth opportunistically, but we won't
have to coordinate deployments.

Basic auth should be deprecated since we don't really use a
user/password auth scheme.
This commit is contained in:
Lilia Kai
2023-12-21 14:05:15 +01:00
parent b83e51de68
commit e37b848ef5
2 changed files with 37 additions and 31 deletions

View File

@@ -17,18 +17,19 @@ logger.info(f"{len(BLOCKED_EMAILS)=}")
basic_security = HTTPBasic()
bearer_security = HTTPBearer()
# --------------------- Bearer Auth
ALLOW_ANY_EMAIL = "*"
def secure_compare(token, api_key):
return secrets.compare_digest(token.encode("utf8"), api_key.encode("utf8"))
# --------------------- Bearer Auth
API_BEARER_TOKEN = os.environ.get("API_BEARER_TOKEN", "") # min length is 20 chars
async def get_bearer_auth_token_or_jwt(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
# tries to use the static API_KEY and defaults to google JWT auth
access_token = credentials.credentials
if len(API_BEARER_TOKEN) >= 20:
current_token_bytes = access_token.encode("utf8")
is_correct_token = secrets.compare_digest(current_token_bytes, API_BEARER_TOKEN.encode("utf8"))
if is_correct_token: return ALLOW_ANY_EMAIL # any email works
if len(API_BEARER_TOKEN) >= 20:
is_correct_token = secure_compare(access_token, API_BEARER_TOKEN)
if is_correct_token: return ALLOW_ANY_EMAIL
return await get_bearer_auth(credentials)
async def get_bearer_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
@@ -64,35 +65,38 @@ def authenticate_user(access_token):
logger.warning(f"EXCEPTION occurred: {e}")
return False, f"EXCEPTION occurred"
# Temporary method until all clients migrate from basic to bearer
async def bearer_or_basic_auth(bearer: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error = False)), basic: HTTPBasicCredentials = Depends(HTTPBasic(auto_error = False))):
if bearer: return bearer.credentials
if basic: return basic.password
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
)
# Factory method to create an authentication dependency for a specific key
def api_key_auth(api_key):
async def auth(challenge = Depends(bearer_or_basic_auth)):
assert len(api_key) >= 20, "Invalid API key, must be at least 20 chars"
is_correct = secure_compare(challenge, api_key)
if is_correct: return True
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Wrong auth credentials",
)
return auth
# --------------------- Basic Auth
SFP = os.environ.get("STATIC_FILE_PASSWORD", "") # min length is 20 chars
async def get_basic_auth(credentials: HTTPBasicCredentials = Depends(basic_security)):
# validates that the Basic token in the case that it requires it
assert len(SFP) >= 20, "Invalid STATIC_FILE_PASSWORD, must be at least 20 chars"
current_password_bytes = credentials.password.encode("utf8")
is_correct_password = secrets.compare_digest(current_password_bytes, SFP.encode("utf8"))
if is_correct_password: return True
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Wrong auth credentials",
headers={"WWW-Authenticate": "Basic"}
)
get_basic_auth = api_key_auth(SFP)
# --------------------- Server-side Auth
SERVICE_PASSWORD = os.environ.get("SERVICE_PASSWORD", "") # min length is 20 chars
get_server_auth = api_key_auth(SERVICE_PASSWORD)
async def get_server_auth(credentials: HTTPBasicCredentials = Depends(basic_security)):
# validates that the Basic token in the case that it requires it
assert len(SERVICE_PASSWORD) >= 20, "Invalid SERVICE_PASSWORD, must be at least 20 chars"
current_password_bytes = credentials.password.encode("utf8")
is_correct_password = secrets.compare_digest(current_password_bytes, SERVICE_PASSWORD.encode("utf8"))
if is_correct_password: return True
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Wrong auth credentials",
headers={"WWW-Authenticate": "Basic"}
)