mirror of
https://github.com/bellingcat/auto-archiver-api.git
synced 2026-06-12 21:48:35 +03:00
Merge pull request #39 from bellingcat/auth
Remove basic auth and rename methods
This commit is contained in:
@@ -8,6 +8,7 @@ services:
|
|||||||
- ALLOWED_ORIGINS=http://localhost:8004,chrome-extension://ojcimmjndnlmmlgnjaeojoebaceokpdp
|
- ALLOWED_ORIGINS=http://localhost:8004,chrome-extension://ojcimmjndnlmmlgnjaeojoebaceokpdp
|
||||||
- SERVICE_PASSWORD=dev-service-password
|
- SERVICE_PASSWORD=dev-service-password
|
||||||
- STATIC_FILE_PASSWORD=dev-static-file-password
|
- STATIC_FILE_PASSWORD=dev-static-file-password
|
||||||
|
- API_BEARER_TOKEN=dev-api-bearer-token
|
||||||
|
|
||||||
worker:
|
worker:
|
||||||
restart: "no"
|
restart: "no"
|
||||||
|
|||||||
26
src/main.py
26
src/main.py
@@ -19,7 +19,7 @@ from worker import create_archive_task, create_sheet_task, celery, insert_result
|
|||||||
from db import crud, models, schemas
|
from db import crud, models, schemas
|
||||||
from db.database import engine, SessionLocal
|
from db.database import engine, SessionLocal
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from security import get_bearer_auth, get_basic_auth, get_server_auth, bearer_security, get_bearer_auth_token_or_jwt
|
from security import get_user_auth, static_api_key_auth, service_api_key_auth, bearer_security, get_token_or_user_auth
|
||||||
from auto_archiver import Metadata
|
from auto_archiver import Metadata
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
@@ -46,7 +46,7 @@ EXCEPTION_COUNTER = Counter(
|
|||||||
labelnames=("types",)
|
labelnames=("types",)
|
||||||
)
|
)
|
||||||
# prometheus exposed in /metrics with authentication
|
# prometheus exposed in /metrics with authentication
|
||||||
Instrumentator(should_group_status_codes=False, excluded_handlers=["/metrics"]).instrument(app).expose(app, dependencies=[Depends(get_server_auth)])
|
Instrumentator(should_group_status_codes=False, excluded_handlers=["/metrics"]).instrument(app).expose(app, dependencies=[Depends(service_api_key_auth)])
|
||||||
|
|
||||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||||
|
|
||||||
@@ -78,7 +78,7 @@ async def home(request: Request):
|
|||||||
status = {"version": VERSION, "breakingChanges": BREAKING_CHANGES}
|
status = {"version": VERSION, "breakingChanges": BREAKING_CHANGES}
|
||||||
try:
|
try:
|
||||||
# if authenticated will load available groups
|
# if authenticated will load available groups
|
||||||
email = await get_bearer_auth(await bearer_security(request))
|
email = await get_user_auth(await bearer_security(request))
|
||||||
db: Session = next(get_db())
|
db: Session = next(get_db())
|
||||||
status["groups"] = crud.get_user_groups(db, email)
|
status["groups"] = crud.get_user_groups(db, email)
|
||||||
except HTTPException: pass
|
except HTTPException: pass
|
||||||
@@ -89,19 +89,19 @@ async def home(request: Request):
|
|||||||
#-----Submit URL and manipulate tasks. Bearer protected below
|
#-----Submit URL and manipulate tasks. Bearer protected below
|
||||||
|
|
||||||
@app.get("/groups", response_model=list[str])
|
@app.get("/groups", response_model=list[str])
|
||||||
def get_user_groups(db: Session = Depends(get_db), email = Depends(get_bearer_auth)):
|
def get_user_groups(db: Session = Depends(get_db), email = Depends(get_user_auth)):
|
||||||
return crud.get_user_groups(db, email)
|
return crud.get_user_groups(db, email)
|
||||||
|
|
||||||
@app.get("/tasks/search-url", response_model=list[schemas.Archive])
|
@app.get("/tasks/search-url", response_model=list[schemas.Archive])
|
||||||
def search_by_url(url:str, skip: int = 0, limit: int = 100, archived_after:datetime=None, archived_before:datetime=None, db: Session = Depends(get_db), email = Depends(get_bearer_auth_token_or_jwt)):
|
def search_by_url(url:str, skip: int = 0, limit: int = 100, archived_after:datetime=None, archived_before:datetime=None, db: Session = Depends(get_db), email = Depends(get_token_or_user_auth)):
|
||||||
return crud.search_tasks_by_url(db, url.strip(), email, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before)
|
return crud.search_tasks_by_url(db, url.strip(), email, skip=skip, limit=limit, archived_after=archived_after, archived_before=archived_before)
|
||||||
|
|
||||||
@app.get("/tasks/sync", response_model=list[schemas.Archive])
|
@app.get("/tasks/sync", response_model=list[schemas.Archive])
|
||||||
def search(skip: int = 0, limit: int = 100, db: Session = Depends(get_db), email = Depends(get_bearer_auth)):
|
def search(skip: int = 0, limit: int = 100, db: Session = Depends(get_db), email = Depends(get_user_auth)):
|
||||||
return crud.search_tasks_by_email(db, email, skip=skip, limit=limit)
|
return crud.search_tasks_by_email(db, email, skip=skip, limit=limit)
|
||||||
|
|
||||||
@app.post("/tasks", status_code=201)
|
@app.post("/tasks", status_code=201)
|
||||||
def archive_tasks(archive:schemas.ArchiveCreate, email = Depends(get_bearer_auth_token_or_jwt)):
|
def archive_tasks(archive:schemas.ArchiveCreate, email = Depends(get_token_or_user_auth)):
|
||||||
archive.author_id = email
|
archive.author_id = email
|
||||||
url = archive.url
|
url = archive.url
|
||||||
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {url}")
|
logger.info(f"new {archive.public=} task for {email=} and {archive.group_id=}: {url}")
|
||||||
@@ -112,11 +112,11 @@ def archive_tasks(archive:schemas.ArchiveCreate, email = Depends(get_bearer_auth
|
|||||||
return JSONResponse({"id": task.id})
|
return JSONResponse({"id": task.id})
|
||||||
|
|
||||||
@app.get("/archive/{task_id}")
|
@app.get("/archive/{task_id}")
|
||||||
def lookup(task_id, db: Session = Depends(get_db), email = Depends(get_bearer_auth_token_or_jwt)):
|
def lookup(task_id, db: Session = Depends(get_db), email = Depends(get_token_or_user_auth)):
|
||||||
return crud.get_task(db, task_id, email)
|
return crud.get_task(db, task_id, email)
|
||||||
|
|
||||||
@app.get("/tasks/{task_id}")
|
@app.get("/tasks/{task_id}")
|
||||||
def get_status(task_id, email = Depends(get_bearer_auth)):
|
def get_status(task_id, email = Depends(get_user_auth)):
|
||||||
logger.info(f"status check for user {email} task {task_id}")
|
logger.info(f"status check for user {email} task {task_id}")
|
||||||
task = AsyncResult(task_id, app=celery)
|
task = AsyncResult(task_id, app=celery)
|
||||||
try:
|
try:
|
||||||
@@ -143,7 +143,7 @@ def get_status(task_id, email = Depends(get_bearer_auth)):
|
|||||||
})
|
})
|
||||||
|
|
||||||
@app.delete("/tasks/{task_id}")
|
@app.delete("/tasks/{task_id}")
|
||||||
def delete_task(task_id, db: Session = Depends(get_db), email = Depends(get_bearer_auth)):
|
def delete_task(task_id, db: Session = Depends(get_db), email = Depends(get_user_auth)):
|
||||||
logger.info(f"deleting task {task_id} request by {email}")
|
logger.info(f"deleting task {task_id} request by {email}")
|
||||||
return JSONResponse({
|
return JSONResponse({
|
||||||
"id": task_id,
|
"id": task_id,
|
||||||
@@ -152,7 +152,7 @@ def delete_task(task_id, db: Session = Depends(get_db), email = Depends(get_bear
|
|||||||
|
|
||||||
#----- Google Sheets Logic
|
#----- Google Sheets Logic
|
||||||
@app.post("/sheet", status_code=201)
|
@app.post("/sheet", status_code=201)
|
||||||
def archive_sheet(sheet:schemas.SubmitSheet, email = Depends(get_bearer_auth)):
|
def archive_sheet(sheet:schemas.SubmitSheet, email = Depends(get_user_auth)):
|
||||||
logger.info(f"SHEET TASK for {sheet=}")
|
logger.info(f"SHEET TASK for {sheet=}")
|
||||||
sheet.author_id = email
|
sheet.author_id = email
|
||||||
if not sheet.sheet_name and not sheet.sheet_id:
|
if not sheet.sheet_name and not sheet.sheet_id:
|
||||||
@@ -161,7 +161,7 @@ def archive_sheet(sheet:schemas.SubmitSheet, email = Depends(get_bearer_auth)):
|
|||||||
return JSONResponse({"id": task.id})
|
return JSONResponse({"id": task.id})
|
||||||
|
|
||||||
@app.post("/sheet_service", status_code=201)
|
@app.post("/sheet_service", status_code=201)
|
||||||
def archive_sheet_service(sheet:schemas.SubmitSheet, basic_auth = Depends(get_server_auth)):
|
def archive_sheet_service(sheet:schemas.SubmitSheet, auth = Depends(service_api_key_auth)):
|
||||||
logger.info(f"SHEET TASK for {sheet=}")
|
logger.info(f"SHEET TASK for {sheet=}")
|
||||||
sheet.author_id = sheet.author_id or "api-endpoint"
|
sheet.author_id = sheet.author_id or "api-endpoint"
|
||||||
if not sheet.sheet_name and not sheet.sheet_id:
|
if not sheet.sheet_name and not sheet.sheet_id:
|
||||||
@@ -171,7 +171,7 @@ def archive_sheet_service(sheet:schemas.SubmitSheet, basic_auth = Depends(get_se
|
|||||||
|
|
||||||
#----- endpoint to submit data archived elsewhere
|
#----- endpoint to submit data archived elsewhere
|
||||||
@app.post("/submit-archive", status_code=201)
|
@app.post("/submit-archive", status_code=201)
|
||||||
def submit_manual_archive(manual:schemas.SubmitManual, basic_auth = Depends(get_basic_auth)):
|
def submit_manual_archive(manual:schemas.SubmitManual, auth = Depends(static_api_key_auth)):
|
||||||
result = Metadata.from_json(manual.result)
|
result = Metadata.from_json(manual.result)
|
||||||
logger.info(f"MANUAL SUBMIT {result.get_url()} {manual.author_id}")
|
logger.info(f"MANUAL SUBMIT {result.get_url()} {manual.author_id}")
|
||||||
manual.tags.add("manual")
|
manual.tags.add("manual")
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ logger.info(f"{CHROME_APP_IDS=}")
|
|||||||
BLOCKED_EMAILS = set([e.strip().lower() for e in os.environ.get("BLOCKED_EMAILS", "").split(",")])
|
BLOCKED_EMAILS = set([e.strip().lower() for e in os.environ.get("BLOCKED_EMAILS", "").split(",")])
|
||||||
logger.info(f"{len(BLOCKED_EMAILS)=}")
|
logger.info(f"{len(BLOCKED_EMAILS)=}")
|
||||||
|
|
||||||
basic_security = HTTPBasic()
|
|
||||||
bearer_security = HTTPBearer()
|
bearer_security = HTTPBearer()
|
||||||
|
|
||||||
ALLOW_ANY_EMAIL = "*"
|
ALLOW_ANY_EMAIL = "*"
|
||||||
@@ -22,20 +21,45 @@ ALLOW_ANY_EMAIL = "*"
|
|||||||
def secure_compare(token, api_key):
|
def secure_compare(token, api_key):
|
||||||
return secrets.compare_digest(token.encode("utf8"), api_key.encode("utf8"))
|
return secrets.compare_digest(token.encode("utf8"), api_key.encode("utf8"))
|
||||||
|
|
||||||
# --------------------- Bearer Auth
|
# Factory method to create an authentication dependency for a specific key
|
||||||
|
def api_key_auth(api_key):
|
||||||
|
|
||||||
|
async def auth(bearer: HTTPAuthorizationCredentials = Depends(bearer_security), auto_error=True):
|
||||||
|
assert len(api_key) >= 20, "Invalid API key, must be at least 20 chars"
|
||||||
|
|
||||||
|
is_correct = secure_compare(bearer.credentials, api_key)
|
||||||
|
if is_correct: return True
|
||||||
|
|
||||||
|
if auto_error:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Wrong auth credentials",
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return auth
|
||||||
|
|
||||||
|
# --------------------- Static Auth for local AA deployments to add archives to the API
|
||||||
|
SFP = os.environ.get("STATIC_FILE_PASSWORD", "") # min length is 20 chars
|
||||||
|
static_api_key_auth = api_key_auth(SFP)
|
||||||
|
|
||||||
|
# --------------------- Service Auth for the AA setup tool and Prometheus
|
||||||
|
SERVICE_PASSWORD = os.environ.get("SERVICE_PASSWORD", "") # min length is 20 chars
|
||||||
|
service_api_key_auth = api_key_auth(SERVICE_PASSWORD)
|
||||||
|
|
||||||
|
# --------------------- Token Auth for AA itself to query the API
|
||||||
API_BEARER_TOKEN = os.environ.get("API_BEARER_TOKEN", "") # min length is 20 chars
|
API_BEARER_TOKEN = os.environ.get("API_BEARER_TOKEN", "") # min length is 20 chars
|
||||||
async def get_bearer_auth_token_or_jwt(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
|
token_api_key_auth = api_key_auth(API_BEARER_TOKEN)
|
||||||
|
|
||||||
|
async def get_token_or_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
|
||||||
# tries to use the static API_KEY and defaults to google JWT auth
|
# tries to use the static API_KEY and defaults to google JWT auth
|
||||||
access_token = credentials.credentials
|
access_token = credentials.credentials
|
||||||
if len(API_BEARER_TOKEN) >= 20:
|
if token_api_key_auth(access_token, auto_error=False): return ALLOW_ANY_EMAIL
|
||||||
is_correct_token = secure_compare(access_token, API_BEARER_TOKEN)
|
return await get_user_auth(credentials)
|
||||||
if is_correct_token: return ALLOW_ANY_EMAIL
|
|
||||||
return await get_bearer_auth(credentials)
|
|
||||||
|
|
||||||
async def get_bearer_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
|
async def get_user_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)):
|
||||||
# validates the Bearer token in the case that it requires it
|
# validates the Bearer token in the case that it requires it
|
||||||
access_token = credentials.credentials
|
valid_user, info = authenticate_user(credentials.credentials)
|
||||||
valid_user, info = authenticate_user(access_token)
|
|
||||||
if valid_user: return info
|
if valid_user: return info
|
||||||
logger.debug(f"TOKEN FAILURE: {valid_user=} {info=}")
|
logger.debug(f"TOKEN FAILURE: {valid_user=} {info=}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -44,7 +68,6 @@ async def get_bearer_auth(credentials: HTTPAuthorizationCredentials = Depends(be
|
|||||||
headers={"WWW-Authenticate": "Bearer"},
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def authenticate_user(access_token):
|
def authenticate_user(access_token):
|
||||||
# https://cloud.google.com/docs/authentication/token-types#access
|
# https://cloud.google.com/docs/authentication/token-types#access
|
||||||
if type(access_token) != str or len(access_token) < 10: return False, "invalid access_token"
|
if type(access_token) != str or len(access_token) < 10: return False, "invalid access_token"
|
||||||
@@ -65,38 +88,3 @@ def authenticate_user(access_token):
|
|||||||
logger.warning(f"EXCEPTION occurred: {e}")
|
logger.warning(f"EXCEPTION occurred: {e}")
|
||||||
return False, f"EXCEPTION occurred"
|
return False, f"EXCEPTION occurred"
|
||||||
|
|
||||||
# Temporary method until all clients migrate from basic to bearer
|
|
||||||
async def bearer_or_basic_auth(bearer: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error = False)), basic: HTTPBasicCredentials = Depends(HTTPBasic(auto_error = False))):
|
|
||||||
|
|
||||||
if bearer: return bearer.credentials
|
|
||||||
if basic: return basic.password
|
|
||||||
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Not authenticated",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Factory method to create an authentication dependency for a specific key
|
|
||||||
def api_key_auth(api_key):
|
|
||||||
|
|
||||||
async def auth(challenge = Depends(bearer_or_basic_auth)):
|
|
||||||
assert len(api_key) >= 20, "Invalid API key, must be at least 20 chars"
|
|
||||||
|
|
||||||
is_correct = secure_compare(challenge, api_key)
|
|
||||||
if is_correct: return True
|
|
||||||
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Wrong auth credentials",
|
|
||||||
)
|
|
||||||
|
|
||||||
return auth
|
|
||||||
|
|
||||||
# --------------------- Basic Auth
|
|
||||||
SFP = os.environ.get("STATIC_FILE_PASSWORD", "") # min length is 20 chars
|
|
||||||
get_basic_auth = api_key_auth(SFP)
|
|
||||||
|
|
||||||
# --------------------- Server-side Auth
|
|
||||||
SERVICE_PASSWORD = os.environ.get("SERVICE_PASSWORD", "") # min length is 20 chars
|
|
||||||
get_server_auth = api_key_auth(SERVICE_PASSWORD)
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user