auth from extension

This commit is contained in:
msramalho
2023-02-24 15:39:11 +01:00
parent 75e98f5f31
commit f6b116554f
14 changed files with 164 additions and 224 deletions

View File

@@ -1,73 +1,56 @@
from celery.result import AsyncResult
from fastapi import Body, FastAPI, Form, Request, Depends, HTTPException
from fastapi import Body, FastAPI, Request, HTTPException, status, Depends
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse, RedirectResponse
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from pydantic.json import pydantic_encoder
# from pydantic.json import pydantic_encoder
from dotenv import load_dotenv
import json
import sys, traceback
import traceback, os, requests
from loguru import logger
# import os, requests
# from google.oauth2.credentials import Credentials
# from google_auth_oauthlib.flow import Flow
# from google_auth_oauthlib.flow import InstalledAppFlow
# from oauthlib.oauth2 import WebApplicationClient
from typing import Optional
from worker import create_archive_task, celery
from fastapi import Depends
# from fastapi.security import OAuth2PasswordRequestForm
# from fastapi.security import OAuth2AuthorizationCodeBearer
from db import crud, models, schemas
from db.database import engine, SessionLocal
from sqlalchemy.orm import Session
from auth.db import User, create_db_and_tables
# from app.schemas import UserCreate, UserRead, UserUpdate
from auth.users import (
SECRET_KEY,
auth_backend,
current_active_user,
fastapi_users,
google_oauth_client,
)
from worker import create_task, create_archive_task, celery
# models.Base.metadata.create_all(bind=engine)
load_dotenv()
app = FastAPI()
# Configuration
# GOOGLE_CLIENT_ID = os.environ.get("GOOGLE_CLIENT_ID", None)
# GOOGLE_CLIENT_SECRET = os.environ.get("GOOGLE_CLIENT_SECRET", None)
# GOOGLE_DISCOVERY_URL = ("https://accounts.google.com/.well-known/openid-configuration")
# GOOGLE_LOGIN_CALLBACK = os.environ.get("GOOGLE_LOGIN_CALLBACK", "http://localhost:5000/login/callback")
# SECRET_KEY = os.environ.get("SECRET_KEY", os.urandom(24))
GOOGLE_CHROME_APP_ID = os.environ.get("GOOGLE_CHROME_APP_ID")
assert len(GOOGLE_CHROME_APP_ID)>10, "GOOGLE_CHROME_APP_ID env variable not set"
ALLOWED_EMAILS = set(os.environ.get("ALLOWED_EMAILS", "").split(","))
assert len(GOOGLE_CHROME_APP_ID)>=1, "at least one ALLOWED_EMAILS is required from the env variable"
# Authentication logic for OAUTH2
app.include_router(
fastapi_users.get_oauth_router(google_oauth_client, auth_backend, SECRET_KEY),
prefix="/auth/google",
tags=["auth"],
)
@app.get("/authenticated-route")
async def authenticated_route(user: User = Depends(current_active_user)):
return {"message": f"Hello {user.email}!"}
app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
# protected version
@app.post("/tasks-auth", status_code=201)
def run_task(payload = Body(...), user: User = Depends(current_active_user)):
logger.info(f"new task for user {user.email}: {payload['url']}")
# task_type = payload["type"]
# task = create_task.delay(int(task_type))
def get_db():
session = SessionLocal()
try: yield session
finally: session.close()
@app.get("/tasks/search", response_model=list[schemas.Task])
def search(access_token:str, skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
validate_user_get_email(access_token)
return crud.get_tasks(db, skip=skip, limit=limit)
@app.post("/tasks", status_code=201)
def run_task(payload = Body(...)):
email = validate_user_get_email(payload["access_token"])
logger.info(f"new task for user {email}: {payload['url']}")
task = create_archive_task.delay(payload["url"])
return JSONResponse({"task_id": task.id})
@app.get("/tasks-auth/{task_id}")
def get_status(task_id, user: User = Depends(current_active_user)):
logger.info(f"status check for user {user.email}")
@app.get("/tasks/{task_id}")
def get_status(task_id, access_token:str):
email = validate_user_get_email(access_token)
logger.info(f"status check for user {email}")
task_result = AsyncResult(task_id, app=celery)
result = {
"task_id": task_id,
@@ -87,45 +70,45 @@ def get_status(task_id, user: User = Depends(current_active_user)):
})
@app.on_event("startup")
async def on_startup():
# Not needed if you setup a migration system like Alembic
await create_db_and_tables()
####
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
@app.get("/")
def home(request: Request):
return templates.TemplateResponse("home.html", context={"request": request})
def home():
return JSONResponse({"message": "Hello"})
@app.on_event("startup")
async def on_startup():
# # Not needed if you setup a migration system like Alembic
# await create_db_and_tables()
models.Base.metadata.create_all(bind=engine)
#TODO: deprecate
# @app.post("/tasks", status_code=201)
# def run_task(payload = Body(...)):
# # task_type = payload["type"]
# # task = create_task.delay(int(task_type))
# task = create_archive_task.delay(payload["url"])
# return JSONResponse({"task_id": task.id})
#### helper methods
def authenticate_user(access_token):
# https://cloud.google.com/docs/authentication/token-types#access
if type(access_token)!=str or len(access_token)<10: return False, "invalid access_token"
r = requests.get("https://oauth2.googleapis.com/tokeninfo", {"access_token":access_token})
if r.status_code!=200: return False, "error occurred"
try:
j = r.json()
if j.get("azp") != GOOGLE_CHROME_APP_ID and j.get("aud")!=GOOGLE_CHROME_APP_ID:
return False, f"token does not belong to correct APP_ID"
if j.get("email") not in ALLOWED_EMAILS:
return False, f"email '{j.get('email')}' not in ALLOWED"
if j.get("email_verified") != "true":
return False, f"email '{j.get('email')}' not verified"
if int(j.get("expires_in", -1)) <= 0:
return False, "Token expired"
return True, j.get('email')
except Exception as e:
logger.warning(f"EXCEPTION occurred: {e}")
return False, f"EXCEPTION occurred"
# @app.get("/tasks/{task_id}")
# def get_status(task_id):
# task_result = AsyncResult(task_id, app=celery)
# result = {
# "task_id": task_id,
# "task_status": task_result.status,
# "task_result": task_result.result
# }
# try:
# json_result = jsonable_encoder(result)
# # json_result = jsonable_encoder(result, custom_encoder=pydantic_encoder) # causes error
# return JSONResponse(json_result)
# except Exception as e:
# logger.error(e)
# logger.error(traceback.format_exc())
# return JSONResponse({
# "task_id": task_id,
# "task_status": "FAILURE",
# })
def validate_user_get_email(access_token):
valid_user, info = authenticate_user(access_token)
if valid_user != True:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=info
)
return info