diff --git a/src/main.py b/src/main.py index 6ec78bf..f311d91 100644 --- a/src/main.py +++ b/src/main.py @@ -1,14 +1,12 @@ from celery.result import AsyncResult -from fastapi import Body, FastAPI, Request, HTTPException, status, Depends +from fastapi import Body, FastAPI, Depends from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse, FileResponse from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware -from fastapi.security import HTTPBasic, HTTPBasicCredentials -# from fastapi.templating import Jinja2Templates -# from pydantic.json import pydantic_encoder +import alembic.config from dotenv import load_dotenv -import traceback, os, requests, re, secrets +import traceback, os from loguru import logger from worker import create_archive_task, celery @@ -16,16 +14,13 @@ from worker import create_archive_task, celery from db import crud, models, schemas from db.database import engine, SessionLocal from sqlalchemy.orm import Session +from security import get_bearer_auth, get_basic_auth load_dotenv() # Configuration -GOOGLE_CHROME_APP_ID = os.environ.get("GOOGLE_CHROME_APP_ID") -assert len(GOOGLE_CHROME_APP_ID)>10, "GOOGLE_CHROME_APP_ID env variable not set" -ALLOWED_EMAILS = set([e.strip().lower() for e in os.environ.get("ALLOWED_EMAILS", "").split(",")]) -assert len(GOOGLE_CHROME_APP_ID)>=1, "at least one ALLOWED_EMAILS is required from the env variable" ALLOWED_ORIGINS = os.environ.get("ALLOWED_ORIGINS", "chrome-extension://ondkcheoicfckabcnkdgbepofpjmjcmb,chrome-extension://ojcimmjndnlmmlgnjaeojoebaceokpdp").split(",") -VERSION = "0.1.8" +VERSION = "0.2.0" app = FastAPI() app.add_middleware( @@ -41,33 +36,33 @@ def get_db(): session = SessionLocal() try: yield session finally: session.close() + + +@app.get("/") +def home(): return JSONResponse({"version": VERSION}) + +# Bearer protected below - @app.get("/tasks/search-url", response_model=list[schemas.Task]) -def search(access_token:str, url:str, skip: int = 0, limit: int = 100, db: Session = Depends(get_db)): - validate_user_get_email(access_token) +def search(url:str, skip: int = 0, limit: int = 100, db: Session = Depends(get_db), email = Depends(get_bearer_auth)): return crud.search_tasks_by_url(db, url, skip=skip, limit=limit) - -@app.get("/tasks/search", response_model=list[schemas.Task]) -def search(access_token:str, skip: int = 0, limit: int = 100, db: Session = Depends(get_db)): - validate_user_get_email(access_token) - return crud.get_tasks(db, skip=skip, limit=limit) + +# @app.get("/tasks/search", response_model=list[schemas.Task]) +# def search(skip: int = 0, limit: int = 100, db: Session = Depends(get_db), email = Depends(get_bearer_auth)): +# return crud.get_tasks(db, skip=skip, limit=limit) @app.get("/tasks/sync", response_model=list[schemas.Task]) -def search(access_token:str, skip: int = 0, limit: int = 100, db: Session = Depends(get_db)): - email = validate_user_get_email(access_token) +def search(skip: int = 0, limit: int = 100, db: Session = Depends(get_db), email = Depends(get_bearer_auth)): return crud.search_tasks_by_email(db, email, skip=skip, limit=limit) @app.post("/tasks", status_code=201) -def run_task(payload = Body(...)): - email = validate_user_get_email(payload.get("access_token")) +def run_task(payload = Body(...), email = Depends(get_bearer_auth)): logger.info(f"new task for user {email}: {payload.get('url')}") task = create_archive_task.delay(url=payload.get('url'), email=email) return JSONResponse({"id": task.id}) @app.get("/tasks/{task_id}") -def get_status(task_id, access_token:str): - email = validate_user_get_email(access_token) +def get_status(task_id, email = Depends(get_bearer_auth)): logger.info(f"status check for user {email}") task_result = AsyncResult(task_id, app=celery) result = { @@ -77,7 +72,6 @@ def get_status(task_id, access_token:str): } try: json_result = jsonable_encoder(result, exclude_unset=True) - # json_result = jsonable_encoder(result, custom_encoder={"pydantic_encoder": pydantic_encoder}) # causes error return JSONResponse(json_result) except Exception as e: logger.error(e) @@ -89,74 +83,25 @@ def get_status(task_id, access_token:str): @app.delete("/tasks/{task_id}") -def get_status(task_id, access_token:str, db: Session = Depends(get_db)): - email = validate_user_get_email(access_token) +def get_status(task_id, db: Session = Depends(get_db), email = Depends(get_bearer_auth)): logger.info(f"deleting task {task_id} request by {email}") return JSONResponse({ "id": task_id, "deleted": crud.soft_delete_task(db, task_id, email) }) -# logic to allow access to 1 static file +# Basic protected logic to allow access to 1 static file SF = os.environ.get("STATIC_FILE", "") -SFP = os.environ.get("STATIC_FILE_PASSWORD", "") # min length is 20 chars -security = HTTPBasic() -if len(SF) > 1 and len(SFP) >= 20 and os.path.isfile(SF): +if len(SF) > 1 and os.path.isfile(SF): @app.get("/static-file") - def static_file(credentials: HTTPBasicCredentials = Depends(security)): - current_password_bytes = credentials.password.encode("utf8") - is_correct_password = secrets.compare_digest(current_password_bytes, SFP.encode("utf8")) - if is_correct_password: - return FileResponse(SF, filename=os.path.basename(SF)) - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Wrong static file access credentials", - headers={"WWW-Authenticate": "Basic"} - ) + def static_file(basic_auth = Depends(get_basic_auth)): + return FileResponse(SF, filename=os.path.basename(SF)) -@app.get("/") -def home(): - return JSONResponse({"status": "good", "version": VERSION}) +# on startup - -import alembic.config @app.on_event("startup") async def on_startup(): # # Not needed if you setup a migration system like Alembic # await create_db_and_tables()https://github.com/bellingcat/auto-archiver/tree/dockerize models.Base.metadata.create_all(bind=engine) alembic.config.main(argv=['--raiseerr', 'upgrade', 'head']) - -#### helper methods -def authenticate_user(access_token): - # https://cloud.google.com/docs/authentication/token-types#access - if type(access_token)!=str or len(access_token)<10: return False, "invalid access_token" - r = requests.get("https://oauth2.googleapis.com/tokeninfo", {"access_token":access_token}) - if r.status_code!=200: return False, "error occurred" - try: - j = r.json() - if j.get("azp") != GOOGLE_CHROME_APP_ID and j.get("aud")!=GOOGLE_CHROME_APP_ID: - return False, f"token does not belong to correct APP_ID" - # if j.get("email") not in ALLOWED_EMAILS: - if not custom_is_email_allowed(j.get("email"), any_bellingcat_email=True): - return False, f"email '{j.get('email')}' not allowed" - if j.get("email_verified") != "true": - return False, f"email '{j.get('email')}' not verified" - if int(j.get("expires_in", -1)) <= 0: - return False, "Token expired" - return True, j.get('email') - except Exception as e: - logger.warning(f"EXCEPTION occurred: {e}") - return False, f"EXCEPTION occurred" - -def custom_is_email_allowed(email, any_bellingcat_email=False): - return email.lower() in ALLOWED_EMAILS or (any_bellingcat_email and re.match(r'^[\w.]+@bellingcat\.com$', email)) - -def validate_user_get_email(access_token): - valid_user, info = authenticate_user(access_token) - if valid_user != True: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=info - ) - return info \ No newline at end of file diff --git a/src/security.py b/src/security.py new file mode 100644 index 0000000..f04c616 --- /dev/null +++ b/src/security.py @@ -0,0 +1,67 @@ +from loguru import logger +import requests, os, re, secrets +from fastapi import HTTPException, status, Depends +from fastapi.security import HTTPBasic, HTTPBasicCredentials, HTTPBearer, HTTPAuthorizationCredentials + + + +# Configuration +GOOGLE_CHROME_APP_ID = os.environ.get("GOOGLE_CHROME_APP_ID") +assert len(GOOGLE_CHROME_APP_ID)>10, "GOOGLE_CHROME_APP_ID env variable not set" +ALLOWED_EMAILS = set([e.strip().lower() for e in os.environ.get("ALLOWED_EMAILS", "").split(",")]) +assert len(GOOGLE_CHROME_APP_ID)>=1, "at least one ALLOWED_EMAILS is required from the env variable" + +basic_security = HTTPBasic() +bearer_security = HTTPBearer() + +#--------------------- Bearer Auth + +async def get_bearer_auth(credentials: HTTPAuthorizationCredentials = Depends(bearer_security)): + # validates the Bearer token in the case that it requires it + access_token = credentials.credentials + valid_user, info = authenticate_user(access_token) + if valid_user: return info + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=info, + headers={"WWW-Authenticate": "Bearer"}, + ) + +def authenticate_user(access_token): + # https://cloud.google.com/docs/authentication/token-types#access + if type(access_token)!=str or len(access_token)<10: return False, "invalid access_token" + r = requests.get("https://oauth2.googleapis.com/tokeninfo", {"access_token":access_token}) + if r.status_code!=200: return False, "error occurred" + try: + j = r.json() + if j.get("azp") != GOOGLE_CHROME_APP_ID and j.get("aud")!=GOOGLE_CHROME_APP_ID: + return False, f"token does not belong to correct APP_ID" + # if j.get("email") not in ALLOWED_EMAILS: + if not custom_is_email_allowed(j.get("email"), any_bellingcat_email=True): + return False, f"email '{j.get('email')}' not allowed" + if j.get("email_verified") != "true": + return False, f"email '{j.get('email')}' not verified" + if int(j.get("expires_in", -1)) <= 0: + return False, "Token expired" + return True, j.get('email') + except Exception as e: + logger.warning(f"EXCEPTION occurred: {e}") + return False, f"EXCEPTION occurred" + +def custom_is_email_allowed(email, any_bellingcat_email=False): + return email.lower() in ALLOWED_EMAILS or (any_bellingcat_email and re.match(r'^[\w.]+@bellingcat\.com$', email)) + + +#--------------------- Basic Auth +SFP = os.environ.get("STATIC_FILE_PASSWORD", "") # min length is 20 chars +async def get_basic_auth(credentials: HTTPBasicCredentials = Depends(basic_security)): + # validates that the Basic token in the case that it requires it + assert len(SFP) >= 20, "Invalid STATIC_FILE_PASSWORD, must be at least 20 chars" + current_password_bytes = credentials.password.encode("utf8") + is_correct_password = secrets.compare_digest(current_password_bytes, SFP.encode("utf8")) + if is_correct_password: return True + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Wrong static file access credentials", + headers={"WWW-Authenticate": "Basic"} + ) \ No newline at end of file