mirror of
https://github.com/bellingcat/auto-archiver-api.git
synced 2026-06-08 03:28:35 +03:00
Merge pull request #39 from bellingcat/auth
Remove basic auth and rename methods
This commit is contained in:
@@ -8,6 +8,7 @@ services:
|
||||
- ALLOWED_ORIGINS=http://localhost:8004,chrome-extension://ojcimmjndnlmmlgnjaeojoebaceokpdp
|
||||
- SERVICE_PASSWORD=dev-service-password
|
||||
- STATIC_FILE_PASSWORD=dev-static-file-password
|
||||
- API_BEARER_TOKEN=dev-api-bearer-token
|
||||
|
||||
worker:
|
||||
restart: "no"
|
||||
|
||||
26
src/main.py
26
src/main.py
@@ -19,7 +19,7 @@ from worker import create_archive_task, create_sheet_task, celery, insert_result
|
||||
from db import crud, models, schemas
|
||||
from db.database import engine, SessionLocal
|
||||
from sqlalchemy.orm import Session
|
||||
from security import get_bearer_auth, get_basic_auth, get_server_auth, bearer_security, get_bearer_auth_token_or_jwt
|
||||
from security import get_user_auth, static_api_key_auth, service_api_key_auth, bearer_security, get_token_or_user_auth
|
||||
from auto_archiver import Metadata
|
||||
|
||||
load_dotenv()
|
||||
@@ -46,7 +46,7 @@ EXCEPTION_COUNTER = Counter(
|
||||
labelnames=("types",)
|
||||
)
|
||||
# prometheus exposed in /metrics with authentication
|
||||
Instrumentator(should_group_status_codes=False, excluded_handlers=["/metrics"]).instrument(app).expose(app, dependencies=[Depends(get_server_auth)])
|
||||
Instrumentator(should_group_status_codes=False, excluded_handlers=["/metrics"]).instrument(app).expose(app, dependencies=[Depends(service_api_key_auth)])
|
||||
|
||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||
|
||||
@@ -78,7 +78,7 @@ async def home(request: Request):
|
||||
status = {"version": VERSION, "breakingChanges": BREAKING_CHANGES}
|
||||
try:
|
||||
# if authenticated will load available groups
|
||||
email = await get_bearer_auth(await bearer_security(request))
|
||||
email = await get_user_auth(await bearer_security(request))
|
||||
db: Session = next(get_db())
|
||||
status["groups"] = crud.get_user_groups(db, email)
|
||||
except HTTPException: pass
|
||||
@@ -89,19 +89,19 @@ async def home(request: Request):
|
||||
#-----Submit URL and manipulate tasks. Bearer protected below
|
||||
|
||||
@app.get("/groups", response_model=list[str])
|
||||
def get_user_groups(db: Session = Depends(get_db), email = Depends(get_bearer_auth)):
|
||||
def get_user_groups(db: Session = Depends(get_db), email = Depends(get_user_auth)):
|
||||
return crud.get_user_groups(db, email)
|
||||
|
||||
@app.get("/tasks/search-url", response_model=list[schemas.Archive])
|
||||
def search_by_url(url:str, skip: int = 0, limit: int = 100, archived_after:datetime=None, archived_before:datetime=None, db: Session = Depends(get_db), email = Depends(get_bearer_auth_token_or_jwt)):
|
||||
def search_by_url(url:str, skip: int = 0, limit: int = 100, archived_after:datetime=None, archived_before:datetime=None, db: Session = Depends(get_db), email = Depends(get_token_or_user_auth)):
|
||||
return crud.search_tasks_by_url(db, url.strip(), email, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before)
|
||||
|
||||
@app.get("/tasks/sync", response_model=list[schemas.Archive])
|
||||
def search(skip: int = 0, limit: int = 100, db: Session = Depends(get_db), email = Depends(get_bearer_auth)):
|
||||
def search(skip: int = 0, limit: int = 100, db: Session = Depends(get_db), email = Depends(get_user_auth)):
|
||||
return crud.search_tasks_by_email(db, email, skip=skip, limit=limit)
|
||||
|
||||
@app.post("/tasks", status_code=201)
|
||||
def archive_tasks(archive:schemas.ArchiveCreate, email = Depends(get_bearer_auth_token_or_jwt)):
|
||||
def archive_tasks(archive:schemas.ArchiveCreate, email = Depends(get_token_or_user_auth)):
|
||||
archive.author_id = email
|
||||
url = archive.url
|
||||
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {url}")
|
||||
@@ -112,11 +112,11 @@ def archive_tasks(archive:schemas.ArchiveCreate, email = Depends(get_bearer_auth
|
||||
return JSONResponse({"id": task.id})
|
||||
|
||||
@app.get("/archive/{task_id}")
|
||||
def lookup(task_id, db: Session = Depends(get_db), email = Depends(get_bearer_auth_token_or_jwt)):
|
||||
def lookup(task_id, db: Session = Depends(get_db), email = Depends(get_token_or_user_auth)):
|
||||
return crud.get_task(db, task_id, email)
|
||||
|
||||
@app.get("/tasks/{task_id}")
|
||||
def get_status(task_id, email = Depends(get_bearer_auth)):
|
||||
def get_status(task_id, email = Depends(get_user_auth)):
|
||||
logger.info(f"status check for user {email} task {task_id}")
|
||||
task = AsyncResult(task_id, app=celery)
|
||||
try:
|
||||
@@ -143,7 +143,7 @@ def get_status(task_id, email = Depends(get_bearer_auth)):
|
||||
})
|
||||
|
||||
@app.delete("/tasks/{task_id}")
|
||||
def delete_task(task_id, db: Session = Depends(get_db), email = Depends(get_bearer_auth)):
|
||||
def delete_task(task_id, db: Session = Depends(get_db), email = Depends(get_user_auth)):
|
||||
logger.info(f"deleting task {task_id} request by {email}")
|
||||
return JSONResponse({
|
||||
"id": task_id,
|
||||
@@ -152,7 +152,7 @@ def delete_task(task_id, db: Session = Depends(get_db), email = Depends(get_bear
|
||||
|
||||
#----- Google Sheets Logic
|
||||
@app.post("/sheet", status_code=201)
|
||||
def archive_sheet(sheet:schemas.SubmitSheet, email = Depends(get_bearer_auth)):
|
||||
def archive_sheet(sheet:schemas.SubmitSheet, email = Depends(get_user_auth)):
|
||||
logger.info(f"SHEET TASK for {sheet=}")
|
||||
sheet.author_id = email
|
||||
if not sheet.sheet_name and not sheet.sheet_id:
|
||||
@@ -161,7 +161,7 @@ def archive_sheet(sheet:schemas.SubmitSheet, email = Depends(get_bearer_auth)):
|
||||
return JSONResponse({"id": task.id})
|
||||
|
||||
@app.post("/sheet_service", status_code=201)
|
||||
def archive_sheet_service(sheet:schemas.SubmitSheet, basic_auth = Depends(get_server_auth)):
|
||||
def archive_sheet_service(sheet:schemas.SubmitSheet, auth = Depends(service_api_key_auth)):
|
||||
logger.info(f"SHEET TASK for {sheet=}")
|
||||
sheet.author_id = sheet.author_id or "api-endpoint"
|
||||
if not sheet.sheet_name and not sheet.sheet_id:
|
||||
@@ -171,7 +171,7 @@ def archive_sheet_service(sheet:schemas.SubmitSheet, basic_auth = Depends(get_se
|
||||
|
||||
#----- endpoint to submit data archived elsewhere
|
||||
@app.post("/submit-archive", status_code=201)
|
||||
def submit_manual_archive(manual:schemas.SubmitManual, basic_auth = Depends(get_basic_auth)):
|
||||
def submit_manual_archive(manual:schemas.SubmitManual, auth = Depends(static_api_key_auth)):
|
||||
result = Metadata.from_json(manual.result)
|
||||
logger.info(f"MANUAL SUBMIT {result.get_url()} {manual.author_id}")
|
||||
manual.tags.add("manual")
|
||||
|
||||
@@ -14,7 +14,6 @@ logger.info(f"{CHROME_APP_IDS=}")
|
||||
BLOCKED_EMAILS = set([e.strip().lower() for e in os.environ.get("BLOCKED_EMAILS", "").split(",")])
|
||||
logger.info(f"{len(BLOCKED_EMAILS)=}")
|
||||
|
||||
basic_security = HTTPBasic()
|
||||
bearer_security = HTTPBearer()
|
||||
|
||||
ALLOW_ANY_EMAIL = "*"
|
||||
@@ -22,20 +21,45 @@ ALLOW_ANY_EMAIL = "*"
|
||||
def secure_compare(token, api_key):
|
||||
return secrets.compare_digest(token.encode("utf8"), api_key.encode("utf8"))
|
||||
|
||||
# --------------------- Bearer Auth
|
||||
# Factory method to create an authentication dependency for a specific key
|
||||
def api_key_auth(api_key):
|
||||
|
||||
async def auth(bearer: HTTPAuthorizationCredentials = Depends(bearer_security), auto_error=True):
|
||||
assert len(api_key) >= 20, "Invalid API key, must be at least 20 chars"
|
||||
|
||||
is_correct = secure_compare(bearer.credentials, api_key)
|
||||
if is_correct: return True
|
||||
|
||||
if auto_error:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Wrong auth credentials",
|
||||
)
|
||||
return False
|
||||
|
||||
return auth
|
||||
|
||||
# --------------------- Static Auth for local AA deployments to add archives to the API
|
||||
SFP = os.environ.get("STATIC_FILE_PASSWORD", "") # min length is 20 chars
|
||||
static_api_key_auth = api_key_auth(SFP)
|
||||
|
||||
# --------------------- Service Auth for the AA setup tool and Prometheus
|
||||
SERVICE_PASSWORD = os.environ.get("SERVICE_PASSWORD", "") # min length is 20 chars
|
||||
service_api_key_auth = api_key_auth(SERVICE_PASSWORD)
|
||||
|
||||
# --------------------- Token Auth for AA itself to query the API
|
||||
API_BEARER_TOKEN = os.environ.get("API_BEARER_TOKEN", "") # min length is 20 chars
|
||||
async def get_bearer_auth_token_or_jwt(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
|
||||
token_api_key_auth = api_key_auth(API_BEARER_TOKEN)
|
||||
|
||||
async def get_token_or_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
|
||||
# tries to use the static API_KEY and defaults to google JWT auth
|
||||
access_token = credentials.credentials
|
||||
if len(API_BEARER_TOKEN) >= 20:
|
||||
is_correct_token = secure_compare(access_token, API_BEARER_TOKEN)
|
||||
if is_correct_token: return ALLOW_ANY_EMAIL
|
||||
return await get_bearer_auth(credentials)
|
||||
if token_api_key_auth(access_token, auto_error=False): return ALLOW_ANY_EMAIL
|
||||
return await get_user_auth(credentials)
|
||||
|
||||
async def get_bearer_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
|
||||
async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
|
||||
# validates the Bearer token in the case that it requires it
|
||||
access_token = credentials.credentials
|
||||
valid_user, info = authenticate_user(access_token)
|
||||
valid_user, info = authenticate_user(credentials.credentials)
|
||||
if valid_user: return info
|
||||
logger.debug(f"TOKEN FAILURE: {valid_user=} {info=}")
|
||||
raise HTTPException(
|
||||
@@ -44,7 +68,6 @@ async def get_bearer_auth(credentials: HTTPAuthorizationCredentials = Depends(be
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
def authenticate_user(access_token):
|
||||
# https://cloud.google.com/docs/authentication/token-types#access
|
||||
if type(access_token) != str or len(access_token) < 10: return False, "invalid access_token"
|
||||
@@ -65,38 +88,3 @@ def authenticate_user(access_token):
|
||||
logger.warning(f"EXCEPTION occurred: {e}")
|
||||
return False, f"EXCEPTION occurred"
|
||||
|
||||
# Temporary method until all clients migrate from basic to bearer
|
||||
async def bearer_or_basic_auth(bearer: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error = False)), basic: HTTPBasicCredentials = Depends(HTTPBasic(auto_error = False))):
|
||||
|
||||
if bearer: return bearer.credentials
|
||||
if basic: return basic.password
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
)
|
||||
|
||||
# Factory method to create an authentication dependency for a specific key
|
||||
def api_key_auth(api_key):
|
||||
|
||||
async def auth(challenge = Depends(bearer_or_basic_auth)):
|
||||
assert len(api_key) >= 20, "Invalid API key, must be at least 20 chars"
|
||||
|
||||
is_correct = secure_compare(challenge, api_key)
|
||||
if is_correct: return True
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Wrong auth credentials",
|
||||
)
|
||||
|
||||
return auth
|
||||
|
||||
# --------------------- Basic Auth
|
||||
SFP = os.environ.get("STATIC_FILE_PASSWORD", "") # min length is 20 chars
|
||||
get_basic_auth = api_key_auth(SFP)
|
||||
|
||||
# --------------------- Server-side Auth
|
||||
SERVICE_PASSWORD = os.environ.get("SERVICE_PASSWORD", "") # min length is 20 chars
|
||||
get_server_auth = api_key_auth(SERVICE_PASSWORD)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user