mirror of
https://github.com/bellingcat/auto-archiver-api.git
synced 2026-06-08 03:28:35 +03:00
auth from extension
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -8,7 +8,7 @@ __pycache
|
||||
.pytest_cach
|
||||
logs
|
||||
.env
|
||||
.db
|
||||
*.db
|
||||
#temp
|
||||
static
|
||||
templates
|
||||
|
||||
@@ -13,12 +13,11 @@ pytest = "==6.2.4"
|
||||
redis = "==3.5.3"
|
||||
requests = "==2.25.1"
|
||||
uvicorn = "==0.13.4"
|
||||
fastapi-users = {extras = ["sqlalchemy"], version = "*"}
|
||||
aiosqlite = "*"
|
||||
httpx-oauth = "*"
|
||||
python-dotenv = "*"
|
||||
auto-archiver = {editable = true, path = "./../../auto-archiver"}
|
||||
loguru = "*"
|
||||
sqlalchemy = "*"
|
||||
|
||||
[dev-packages]
|
||||
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
from typing import AsyncGenerator, List
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi_users.db import (
|
||||
SQLAlchemyBaseOAuthAccountTableUUID,
|
||||
SQLAlchemyBaseUserTableUUID,
|
||||
SQLAlchemyUserDatabase,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, relationship
|
||||
|
||||
DATABASE_URL = "sqlite+aiosqlite:///./users.db"
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
|
||||
pass
|
||||
|
||||
|
||||
class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
oauth_accounts: Mapped[List[OAuthAccount]] = relationship(
|
||||
"OAuthAccount", lazy="joined"
|
||||
)
|
||||
|
||||
|
||||
engine = create_async_engine(DATABASE_URL)
|
||||
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
async def create_db_and_tables():
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
|
||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with async_session_maker() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
|
||||
yield SQLAlchemyUserDatabase(session, User, OAuthAccount)
|
||||
@@ -1,15 +0,0 @@
|
||||
# import uuid
|
||||
|
||||
# from fastapi_users import schemas
|
||||
|
||||
|
||||
# class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
# pass
|
||||
|
||||
|
||||
# class UserCreate(schemas.BaseUserCreate):
|
||||
# pass
|
||||
|
||||
|
||||
# class UserUpdate(schemas.BaseUserUpdate):
|
||||
# pass
|
||||
@@ -1,62 +0,0 @@
|
||||
import os
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends, Request
|
||||
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin
|
||||
from fastapi_users.authentication import (
|
||||
AuthenticationBackend,
|
||||
BearerTransport,
|
||||
JWTStrategy,
|
||||
)
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
|
||||
from .db import User, get_user_db
|
||||
|
||||
SECRET_KEY = os.getenv("SECRET_KEY", os.urandom(24))
|
||||
|
||||
google_oauth_client = GoogleOAuth2(
|
||||
os.getenv("GOOGLE_CLIENT_ID", ""),
|
||||
os.getenv("GOOGLE_CLIENT_SECRET", ""),
|
||||
)
|
||||
|
||||
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = SECRET_KEY
|
||||
verification_token_secret = SECRET_KEY
|
||||
|
||||
async def on_after_register(self, user: User, request: Optional[Request] = None):
|
||||
print(f"User {user.id} has registered.")
|
||||
|
||||
async def on_after_forgot_password(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
):
|
||||
print(f"User {user.id} has forgot their password. Reset token: {token}")
|
||||
|
||||
async def on_after_request_verify(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
):
|
||||
print(f"Verification requested for user {user.id}. Verification token: {token}")
|
||||
|
||||
|
||||
async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)):
|
||||
yield UserManager(user_db)
|
||||
|
||||
|
||||
bearer_transport = BearerTransport(tokenUrl="auth/jwt/login")
|
||||
|
||||
|
||||
def get_jwt_strategy() -> JWTStrategy:
|
||||
return JWTStrategy(secret=SECRET_KEY, lifetime_seconds=3600)
|
||||
|
||||
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="jwt",
|
||||
transport=bearer_transport,
|
||||
get_strategy=get_jwt_strategy,
|
||||
)
|
||||
|
||||
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
|
||||
|
||||
current_active_user = fastapi_users.current_user(active=True)
|
||||
1
src/db/__init__.py
Normal file
1
src/db/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# https://fastapi.tiangolo.com/tutorial/sql-databases/#review-all-the-files
|
||||
24
src/db/crud.py
Normal file
24
src/db/crud.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from loguru import logger
|
||||
|
||||
from . import models, schemas
|
||||
|
||||
|
||||
def get_task(db: Session, task_id: str):
|
||||
return db.query(models.Task).filter(models.Task.id == task_id).first()
|
||||
|
||||
|
||||
# def get_user_by_email(db: Session, email: str):
|
||||
# return db.query(models.User).filter(models.User.email == email).first()
|
||||
|
||||
|
||||
def get_tasks(db: Session, skip: int = 0, limit: int = 100):
|
||||
return db.query(models.Task).offset(skip).limit(limit).all()
|
||||
|
||||
|
||||
def create_task(db: Session, task: schemas.TaskCreate):
|
||||
db_task = models.Task(id=task.id, url=task.url, author=task.author, result=task.result)
|
||||
db.add(db_task)
|
||||
db.commit()
|
||||
db.refresh(db_task)
|
||||
return db_task
|
||||
14
src/db/database.py
Normal file
14
src/db/database.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
SQLALCHEMY_DATABASE_URL = "sqlite:///./auto-archiver.db"
|
||||
# SQLALCHEMY_DATABASE_URL = "postgresql://user:password@postgresserver/db"
|
||||
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
||||
)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
17
src/db/models.py
Normal file
17
src/db/models.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, JSON, TIMESTAMP, DateTime
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
from .database import Base
|
||||
|
||||
|
||||
class Task(Base):
|
||||
__tablename__ = "tasks"
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
url = Column(String, index=True)
|
||||
author = Column(String, index=True)
|
||||
result = Column(JSON, default=None)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
# updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
# items = relationship("Item", back_populates="owner")
|
||||
15
src/db/schemas.py
Normal file
15
src/db/schemas.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
|
||||
class TaskCreate(BaseModel):
|
||||
id: str
|
||||
url: str
|
||||
author: str
|
||||
result: dict
|
||||
|
||||
|
||||
class Task(TaskCreate):
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
155
src/main.py
155
src/main.py
@@ -1,73 +1,56 @@
|
||||
from celery.result import AsyncResult
|
||||
from fastapi import Body, FastAPI, Form, Request, Depends, HTTPException
|
||||
from fastapi import Body, FastAPI, Request, HTTPException, status, Depends
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from pydantic.json import pydantic_encoder
|
||||
# from pydantic.json import pydantic_encoder
|
||||
from dotenv import load_dotenv
|
||||
import json
|
||||
import sys, traceback
|
||||
import traceback, os, requests
|
||||
from loguru import logger
|
||||
# import os, requests
|
||||
|
||||
# from google.oauth2.credentials import Credentials
|
||||
# from google_auth_oauthlib.flow import Flow
|
||||
# from google_auth_oauthlib.flow import InstalledAppFlow
|
||||
# from oauthlib.oauth2 import WebApplicationClient
|
||||
from typing import Optional
|
||||
from worker import create_archive_task, celery
|
||||
|
||||
from fastapi import Depends
|
||||
# from fastapi.security import OAuth2PasswordRequestForm
|
||||
# from fastapi.security import OAuth2AuthorizationCodeBearer
|
||||
from db import crud, models, schemas
|
||||
from db.database import engine, SessionLocal
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from auth.db import User, create_db_and_tables
|
||||
# from app.schemas import UserCreate, UserRead, UserUpdate
|
||||
from auth.users import (
|
||||
SECRET_KEY,
|
||||
auth_backend,
|
||||
current_active_user,
|
||||
fastapi_users,
|
||||
google_oauth_client,
|
||||
)
|
||||
|
||||
from worker import create_task, create_archive_task, celery
|
||||
# models.Base.metadata.create_all(bind=engine)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
# Configuration
|
||||
# GOOGLE_CLIENT_ID = os.environ.get("GOOGLE_CLIENT_ID", None)
|
||||
# GOOGLE_CLIENT_SECRET = os.environ.get("GOOGLE_CLIENT_SECRET", None)
|
||||
# GOOGLE_DISCOVERY_URL = ("https://accounts.google.com/.well-known/openid-configuration")
|
||||
# GOOGLE_LOGIN_CALLBACK = os.environ.get("GOOGLE_LOGIN_CALLBACK", "http://localhost:5000/login/callback")
|
||||
# SECRET_KEY = os.environ.get("SECRET_KEY", os.urandom(24))
|
||||
GOOGLE_CHROME_APP_ID = os.environ.get("GOOGLE_CHROME_APP_ID")
|
||||
assert len(GOOGLE_CHROME_APP_ID)>10, "GOOGLE_CHROME_APP_ID env variable not set"
|
||||
ALLOWED_EMAILS = set(os.environ.get("ALLOWED_EMAILS", "").split(","))
|
||||
assert len(GOOGLE_CHROME_APP_ID)>=1, "at least one ALLOWED_EMAILS is required from the env variable"
|
||||
|
||||
# Authentication logic for OAUTH2
|
||||
app.include_router(
|
||||
fastapi_users.get_oauth_router(google_oauth_client, auth_backend, SECRET_KEY),
|
||||
prefix="/auth/google",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
@app.get("/authenticated-route")
|
||||
async def authenticated_route(user: User = Depends(current_active_user)):
|
||||
return {"message": f"Hello {user.email}!"}
|
||||
app = FastAPI()
|
||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||
|
||||
# protected version
|
||||
@app.post("/tasks-auth", status_code=201)
|
||||
def run_task(payload = Body(...), user: User = Depends(current_active_user)):
|
||||
logger.info(f"new task for user {user.email}: {payload['url']}")
|
||||
# task_type = payload["type"]
|
||||
# task = create_task.delay(int(task_type))
|
||||
def get_db():
|
||||
session = SessionLocal()
|
||||
try: yield session
|
||||
finally: session.close()
|
||||
|
||||
|
||||
@app.get("/tasks/search", response_model=list[schemas.Task])
|
||||
def search(access_token:str, skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
|
||||
validate_user_get_email(access_token)
|
||||
return crud.get_tasks(db, skip=skip, limit=limit)
|
||||
|
||||
@app.post("/tasks", status_code=201)
|
||||
def run_task(payload = Body(...)):
|
||||
email = validate_user_get_email(payload["access_token"])
|
||||
logger.info(f"new task for user {email}: {payload['url']}")
|
||||
task = create_archive_task.delay(payload["url"])
|
||||
return JSONResponse({"task_id": task.id})
|
||||
|
||||
@app.get("/tasks-auth/{task_id}")
|
||||
def get_status(task_id, user: User = Depends(current_active_user)):
|
||||
logger.info(f"status check for user {user.email}")
|
||||
@app.get("/tasks/{task_id}")
|
||||
def get_status(task_id, access_token:str):
|
||||
email = validate_user_get_email(access_token)
|
||||
logger.info(f"status check for user {email}")
|
||||
task_result = AsyncResult(task_id, app=celery)
|
||||
result = {
|
||||
"task_id": task_id,
|
||||
@@ -87,45 +70,45 @@ def get_status(task_id, user: User = Depends(current_active_user)):
|
||||
})
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def on_startup():
|
||||
# Not needed if you setup a migration system like Alembic
|
||||
await create_db_and_tables()
|
||||
|
||||
####
|
||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||
templates = Jinja2Templates(directory="templates")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
def home(request: Request):
|
||||
return templates.TemplateResponse("home.html", context={"request": request})
|
||||
def home():
|
||||
return JSONResponse({"message": "Hello"})
|
||||
|
||||
@app.on_event("startup")
|
||||
async def on_startup():
|
||||
# # Not needed if you setup a migration system like Alembic
|
||||
# await create_db_and_tables()
|
||||
models.Base.metadata.create_all(bind=engine)
|
||||
|
||||
#TODO: deprecate
|
||||
# @app.post("/tasks", status_code=201)
|
||||
# def run_task(payload = Body(...)):
|
||||
# # task_type = payload["type"]
|
||||
# # task = create_task.delay(int(task_type))
|
||||
# task = create_archive_task.delay(payload["url"])
|
||||
# return JSONResponse({"task_id": task.id})
|
||||
#### helper methods
|
||||
def authenticate_user(access_token):
|
||||
# https://cloud.google.com/docs/authentication/token-types#access
|
||||
if type(access_token)!=str or len(access_token)<10: return False, "invalid access_token"
|
||||
r = requests.get("https://oauth2.googleapis.com/tokeninfo", {"access_token":access_token})
|
||||
if r.status_code!=200: return False, "error occurred"
|
||||
try:
|
||||
j = r.json()
|
||||
if j.get("azp") != GOOGLE_CHROME_APP_ID and j.get("aud")!=GOOGLE_CHROME_APP_ID:
|
||||
return False, f"token does not belong to correct APP_ID"
|
||||
if j.get("email") not in ALLOWED_EMAILS:
|
||||
return False, f"email '{j.get('email')}' not in ALLOWED"
|
||||
if j.get("email_verified") != "true":
|
||||
return False, f"email '{j.get('email')}' not verified"
|
||||
if int(j.get("expires_in", -1)) <= 0:
|
||||
return False, "Token expired"
|
||||
return True, j.get('email')
|
||||
except Exception as e:
|
||||
logger.warning(f"EXCEPTION occurred: {e}")
|
||||
return False, f"EXCEPTION occurred"
|
||||
|
||||
# @app.get("/tasks/{task_id}")
|
||||
# def get_status(task_id):
|
||||
# task_result = AsyncResult(task_id, app=celery)
|
||||
# result = {
|
||||
# "task_id": task_id,
|
||||
# "task_status": task_result.status,
|
||||
# "task_result": task_result.result
|
||||
# }
|
||||
# try:
|
||||
# json_result = jsonable_encoder(result)
|
||||
# # json_result = jsonable_encoder(result, custom_encoder=pydantic_encoder) # causes error
|
||||
# return JSONResponse(json_result)
|
||||
# except Exception as e:
|
||||
# logger.error(e)
|
||||
# logger.error(traceback.format_exc())
|
||||
# return JSONResponse({
|
||||
# "task_id": task_id,
|
||||
# "task_status": "FAILURE",
|
||||
# })
|
||||
def validate_user_get_email(access_token):
|
||||
valid_user, info = authenticate_user(access_token)
|
||||
if valid_user != True:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=info
|
||||
)
|
||||
return info
|
||||
@@ -1,31 +1,39 @@
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
from celery import Celery
|
||||
from dataclasses import asdict
|
||||
from auto_archiver import Config, ArchivingOrchestrator, Metadata
|
||||
from loguru import logger
|
||||
|
||||
from db import crud, models, schemas
|
||||
from db.database import engine, SessionLocal
|
||||
from contextlib import contextmanager
|
||||
import json
|
||||
|
||||
# models.Base.metadata.create_all(bind=engine)
|
||||
|
||||
celery = Celery(__name__)
|
||||
celery.conf.broker_url = os.environ.get("CELERY_BROKER_URL", "redis://localhost:6379")
|
||||
celery.conf.result_backend = os.environ.get("CELERY_RESULT_BACKEND", "redis://localhost:6379")
|
||||
|
||||
@celery.task(name="create_task")
|
||||
def create_task(task_type):
|
||||
print("DEV MODE")
|
||||
time.sleep(int(task_type) * 10)
|
||||
return True
|
||||
@contextmanager
|
||||
def get_db():
|
||||
session = SessionLocal()
|
||||
try: yield session
|
||||
finally: session.close()
|
||||
|
||||
|
||||
# from configs.v2config import ConfigV2
|
||||
# from auto_archiver import ArchivingOrchestrator
|
||||
config = Config()
|
||||
config.parse(use_cli=False, yaml_config_filename="orchestration.yaml")
|
||||
config.parse(use_cli=False, yaml_config_filename="secrets/orchestration.yaml")
|
||||
orchestrator = None
|
||||
|
||||
@celery.task(name="create_archive_task")
|
||||
def create_archive_task(url: str , user_email:str=""):
|
||||
@celery.task(name="create_archive_task", bind=True)
|
||||
def create_archive_task(self, url: str , user_email:str=""):
|
||||
global orchestrator
|
||||
if not orchestrator: orchestrator = ArchivingOrchestrator(config)
|
||||
return orchestrator.feed_item(Metadata().set_url(url)).to_json()
|
||||
#TODO: associate user with url (?)
|
||||
result = orchestrator.feed_item(Metadata().set_url(url)).to_json()
|
||||
# result = orchestrator.feed_item(Metadata().set_url(url))
|
||||
with get_db() as session:
|
||||
db_task = crud.create_task(session, task=schemas.TaskCreate(id=self.request.id, url=url, author=user_email, result=json.loads(result)))
|
||||
logger.debug(f"Added {db_task.id=} to database on {db_task.created_at}")
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user