Merge pull request #39 from bellingcat/auth

Remove basic auth and rename methods
This commit is contained in:
Kai
2024-02-15 10:19:35 -10:00
committed by GitHub
3 changed files with 48 additions and 59 deletions

View File

@@ -8,6 +8,7 @@ services:
- ALLOWED_ORIGINS=http://localhost:8004,chrome-extension://ojcimmjndnlmmlgnjaeojoebaceokpdp
- SERVICE_PASSWORD=dev-service-password
- STATIC_FILE_PASSWORD=dev-static-file-password
- API_BEARER_TOKEN=dev-api-bearer-token
worker:
restart: "no"

View File

@@ -19,7 +19,7 @@ from worker import create_archive_task, create_sheet_task, celery, insert_result
from db import crud, models, schemas
from db.database import engine, SessionLocal
from sqlalchemy.orm import Session
from security import get_bearer_auth, get_basic_auth, get_server_auth, bearer_security, get_bearer_auth_token_or_jwt
from security import get_user_auth, static_api_key_auth, service_api_key_auth, bearer_security, get_token_or_user_auth
from auto_archiver import Metadata
load_dotenv()
@@ -46,7 +46,7 @@ EXCEPTION_COUNTER = Counter(
labelnames=("types",)
)
# prometheus exposed in /metrics with authentication
Instrumentator(should_group_status_codes=False, excluded_handlers=["/metrics"]).instrument(app).expose(app, dependencies=[Depends(get_server_auth)])
Instrumentator(should_group_status_codes=False, excluded_handlers=["/metrics"]).instrument(app).expose(app, dependencies=[Depends(service_api_key_auth)])
app.mount("/static", StaticFiles(directory="static"), name="static")
@@ -78,7 +78,7 @@ async def home(request: Request):
status = {"version": VERSION, "breakingChanges": BREAKING_CHANGES}
try:
# if authenticated will load available groups
email = await get_bearer_auth(await bearer_security(request))
email = await get_user_auth(await bearer_security(request))
db: Session = next(get_db())
status["groups"] = crud.get_user_groups(db, email)
except HTTPException: pass
@@ -89,19 +89,19 @@ async def home(request: Request):
#-----Submit URL and manipulate tasks. Bearer protected below
@app.get("/groups", response_model=list[str])
def get_user_groups(db: Session = Depends(get_db), email = Depends(get_bearer_auth)):
def get_user_groups(db: Session = Depends(get_db), email = Depends(get_user_auth)):
return crud.get_user_groups(db, email)
@app.get("/tasks/search-url", response_model=list[schemas.Archive])
def search_by_url(url:str, skip: int = 0, limit: int = 100, archived_after:datetime=None, archived_before:datetime=None, db: Session = Depends(get_db), email = Depends(get_bearer_auth_token_or_jwt)):
def search_by_url(url:str, skip: int = 0, limit: int = 100, archived_after:datetime=None, archived_before:datetime=None, db: Session = Depends(get_db), email = Depends(get_token_or_user_auth)):
return crud.search_tasks_by_url(db, url.strip(), email, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before)
@app.get("/tasks/sync", response_model=list[schemas.Archive])
def search(skip: int = 0, limit: int = 100, db: Session = Depends(get_db), email = Depends(get_bearer_auth)):
def search(skip: int = 0, limit: int = 100, db: Session = Depends(get_db), email = Depends(get_user_auth)):
return crud.search_tasks_by_email(db, email, skip=skip, limit=limit)
@app.post("/tasks", status_code=201)
def archive_tasks(archive:schemas.ArchiveCreate, email = Depends(get_bearer_auth_token_or_jwt)):
def archive_tasks(archive:schemas.ArchiveCreate, email = Depends(get_token_or_user_auth)):
archive.author_id = email
url = archive.url
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {url}")
@@ -112,11 +112,11 @@ def archive_tasks(archive:schemas.ArchiveCreate, email = Depends(get_bearer_auth
return JSONResponse({"id": task.id})
@app.get("/archive/{task_id}")
def lookup(task_id, db: Session = Depends(get_db), email = Depends(get_bearer_auth_token_or_jwt)):
def lookup(task_id, db: Session = Depends(get_db), email = Depends(get_token_or_user_auth)):
return crud.get_task(db, task_id, email)
@app.get("/tasks/{task_id}")
def get_status(task_id, email = Depends(get_bearer_auth)):
def get_status(task_id, email = Depends(get_user_auth)):
logger.info(f"status check for user {email} task {task_id}")
task = AsyncResult(task_id, app=celery)
try:
@@ -143,7 +143,7 @@ def get_status(task_id, email = Depends(get_bearer_auth)):
})
@app.delete("/tasks/{task_id}")
def delete_task(task_id, db: Session = Depends(get_db), email = Depends(get_bearer_auth)):
def delete_task(task_id, db: Session = Depends(get_db), email = Depends(get_user_auth)):
logger.info(f"deleting task {task_id} request by {email}")
return JSONResponse({
"id": task_id,
@@ -152,7 +152,7 @@ def delete_task(task_id, db: Session = Depends(get_db), email = Depends(get_bear
#----- Google Sheets Logic
@app.post("/sheet", status_code=201)
def archive_sheet(sheet:schemas.SubmitSheet, email = Depends(get_bearer_auth)):
def archive_sheet(sheet:schemas.SubmitSheet, email = Depends(get_user_auth)):
logger.info(f"SHEET TASK for {sheet=}")
sheet.author_id = email
if not sheet.sheet_name and not sheet.sheet_id:
@@ -161,7 +161,7 @@ def archive_sheet(sheet:schemas.SubmitSheet, email = Depends(get_bearer_auth)):
return JSONResponse({"id": task.id})
@app.post("/sheet_service", status_code=201)
def archive_sheet_service(sheet:schemas.SubmitSheet, basic_auth = Depends(get_server_auth)):
def archive_sheet_service(sheet:schemas.SubmitSheet, auth = Depends(service_api_key_auth)):
logger.info(f"SHEET TASK for {sheet=}")
sheet.author_id = sheet.author_id or "api-endpoint"
if not sheet.sheet_name and not sheet.sheet_id:
@@ -171,7 +171,7 @@ def archive_sheet_service(sheet:schemas.SubmitSheet, basic_auth = Depends(get_se
#----- endpoint to submit data archived elsewhere
@app.post("/submit-archive", status_code=201)
def submit_manual_archive(manual:schemas.SubmitManual, basic_auth = Depends(get_basic_auth)):
def submit_manual_archive(manual:schemas.SubmitManual, auth = Depends(static_api_key_auth)):
result = Metadata.from_json(manual.result)
logger.info(f"MANUAL SUBMIT {result.get_url()} {manual.author_id}")
manual.tags.add("manual")

View File

@@ -14,7 +14,6 @@ logger.info(f"{CHROME_APP_IDS=}")
BLOCKED_EMAILS = set([e.strip().lower() for e in os.environ.get("BLOCKED_EMAILS", "").split(",")])
logger.info(f"{len(BLOCKED_EMAILS)=}")
basic_security = HTTPBasic()
bearer_security = HTTPBearer()
ALLOW_ANY_EMAIL = "*"
@@ -22,20 +21,45 @@ ALLOW_ANY_EMAIL = "*"
def secure_compare(token, api_key):
return secrets.compare_digest(token.encode("utf8"), api_key.encode("utf8"))
# --------------------- Bearer Auth
# Factory method to create an authentication dependency for a specific key
def api_key_auth(api_key):
async def auth(bearer: HTTPAuthorizationCredentials = Depends(bearer_security), auto_error=True):
assert len(api_key) >= 20, "Invalid API key, must be at least 20 chars"
is_correct = secure_compare(bearer.credentials, api_key)
if is_correct: return True
if auto_error:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Wrong auth credentials",
)
return False
return auth
# --------------------- Static Auth for local AA deployments to add archives to the API
SFP = os.environ.get("STATIC_FILE_PASSWORD", "") # min length is 20 chars
static_api_key_auth = api_key_auth(SFP)
# --------------------- Service Auth for the AA setup tool and Prometheus
SERVICE_PASSWORD = os.environ.get("SERVICE_PASSWORD", "") # min length is 20 chars
service_api_key_auth = api_key_auth(SERVICE_PASSWORD)
# --------------------- Token Auth for AA itself to query the API
API_BEARER_TOKEN = os.environ.get("API_BEARER_TOKEN", "") # min length is 20 chars
async def get_bearer_auth_token_or_jwt(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
token_api_key_auth = api_key_auth(API_BEARER_TOKEN)
async def get_token_or_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
# tries to use the static API_KEY and defaults to google JWT auth
access_token = credentials.credentials
if len(API_BEARER_TOKEN) >= 20:
is_correct_token = secure_compare(access_token, API_BEARER_TOKEN)
if is_correct_token: return ALLOW_ANY_EMAIL
return await get_bearer_auth(credentials)
if token_api_key_auth(access_token, auto_error=False): return ALLOW_ANY_EMAIL
return await get_user_auth(credentials)
async def get_bearer_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
# validates the Bearer token in the case that it requires it
access_token = credentials.credentials
valid_user, info = authenticate_user(access_token)
valid_user, info = authenticate_user(credentials.credentials)
if valid_user: return info
logger.debug(f"TOKEN FAILURE: {valid_user=} {info=}")
raise HTTPException(
@@ -44,7 +68,6 @@ async def get_bearer_auth(credentials: HTTPAuthorizationCredentials = Depends(be
headers={"WWW-Authenticate": "Bearer"},
)
def authenticate_user(access_token):
# https://cloud.google.com/docs/authentication/token-types#access
if type(access_token) != str or len(access_token) < 10: return False, "invalid access_token"
@@ -65,38 +88,3 @@ def authenticate_user(access_token):
logger.warning(f"EXCEPTION occurred: {e}")
return False, f"EXCEPTION occurred"
# Temporary method until all clients migrate from basic to bearer
async def bearer_or_basic_auth(bearer: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error = False)), basic: HTTPBasicCredentials = Depends(HTTPBasic(auto_error = False))):
if bearer: return bearer.credentials
if basic: return basic.password
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
)
# Factory method to create an authentication dependency for a specific key
def api_key_auth(api_key):
async def auth(challenge = Depends(bearer_or_basic_auth)):
assert len(api_key) >= 20, "Invalid API key, must be at least 20 chars"
is_correct = secure_compare(challenge, api_key)
if is_correct: return True
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Wrong auth credentials",
)
return auth
# --------------------- Basic Auth
SFP = os.environ.get("STATIC_FILE_PASSWORD", "") # min length is 20 chars
get_basic_auth = api_key_auth(SFP)
# --------------------- Server-side Auth
SERVICE_PASSWORD = os.environ.get("SERVICE_PASSWORD", "") # min length is 20 chars
get_server_auth = api_key_auth(SERVICE_PASSWORD)