diff --git a/.gitignore b/.gitignore index b84dba5..2e8319b 100644 --- a/.gitignore +++ b/.gitignore @@ -8,7 +8,7 @@ __pycache .pytest_cach logs .env -.db +*.db #temp static templates diff --git a/src/Pipfile b/src/Pipfile index 30f4144..371d7bb 100644 --- a/src/Pipfile +++ b/src/Pipfile @@ -13,12 +13,11 @@ pytest = "==6.2.4" redis = "==3.5.3" requests = "==2.25.1" uvicorn = "==0.13.4" -fastapi-users = {extras = ["sqlalchemy"], version = "*"} aiosqlite = "*" -httpx-oauth = "*" python-dotenv = "*" auto-archiver = {editable = true, path = "./../../auto-archiver"} loguru = "*" +sqlalchemy = "*" [dev-packages] diff --git a/src/auth/__init__.py b/src/auth/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/auth/db.py b/src/auth/db.py deleted file mode 100644 index 93b0461..0000000 --- a/src/auth/db.py +++ /dev/null @@ -1,44 +0,0 @@ -from typing import AsyncGenerator, List - -from fastapi import Depends -from fastapi_users.db import ( - SQLAlchemyBaseOAuthAccountTableUUID, - SQLAlchemyBaseUserTableUUID, - SQLAlchemyUserDatabase, -) -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine -from sqlalchemy.orm import DeclarativeBase, Mapped, relationship - -DATABASE_URL = "sqlite+aiosqlite:///./users.db" - - -class Base(DeclarativeBase): - pass - - -class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base): - pass - - -class User(SQLAlchemyBaseUserTableUUID, Base): - oauth_accounts: Mapped[List[OAuthAccount]] = relationship( - "OAuthAccount", lazy="joined" - ) - - -engine = create_async_engine(DATABASE_URL) -async_session_maker = async_sessionmaker(engine, expire_on_commit=False) - - -async def create_db_and_tables(): - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - - -async def get_async_session() -> AsyncGenerator[AsyncSession, None]: - async with async_session_maker() as session: - yield session - - -async def get_user_db(session: AsyncSession = Depends(get_async_session)): - yield SQLAlchemyUserDatabase(session, User, OAuthAccount) \ No newline at end of file diff --git a/src/auth/schemas.py b/src/auth/schemas.py deleted file mode 100644 index 852355a..0000000 --- a/src/auth/schemas.py +++ /dev/null @@ -1,15 +0,0 @@ -# import uuid - -# from fastapi_users import schemas - - -# class UserRead(schemas.BaseUser[uuid.UUID]): -# pass - - -# class UserCreate(schemas.BaseUserCreate): -# pass - - -# class UserUpdate(schemas.BaseUserUpdate): -# pass \ No newline at end of file diff --git a/src/auth/users.py b/src/auth/users.py deleted file mode 100644 index 0969ec3..0000000 --- a/src/auth/users.py +++ /dev/null @@ -1,62 +0,0 @@ -import os -import uuid -from typing import Optional - -from fastapi import Depends, Request -from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin -from fastapi_users.authentication import ( - AuthenticationBackend, - BearerTransport, - JWTStrategy, -) -from fastapi_users.db import SQLAlchemyUserDatabase -from httpx_oauth.clients.google import GoogleOAuth2 - -from .db import User, get_user_db - -SECRET_KEY = os.getenv("SECRET_KEY", os.urandom(24)) - -google_oauth_client = GoogleOAuth2( - os.getenv("GOOGLE_CLIENT_ID", ""), - os.getenv("GOOGLE_CLIENT_SECRET", ""), -) - - -class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): - reset_password_token_secret = SECRET_KEY - verification_token_secret = SECRET_KEY - - async def on_after_register(self, user: User, request: Optional[Request] = None): - print(f"User {user.id} has registered.") - - async def on_after_forgot_password( - self, user: User, token: str, request: Optional[Request] = None - ): - print(f"User {user.id} has forgot their password. Reset token: {token}") - - async def on_after_request_verify( - self, user: User, token: str, request: Optional[Request] = None - ): - print(f"Verification requested for user {user.id}. Verification token: {token}") - - -async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)): - yield UserManager(user_db) - - -bearer_transport = BearerTransport(tokenUrl="auth/jwt/login") - - -def get_jwt_strategy() -> JWTStrategy: - return JWTStrategy(secret=SECRET_KEY, lifetime_seconds=3600) - - -auth_backend = AuthenticationBackend( - name="jwt", - transport=bearer_transport, - get_strategy=get_jwt_strategy, -) - -fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend]) - -current_active_user = fastapi_users.current_user(active=True) \ No newline at end of file diff --git a/src/auth/README.md b/src/db/README.md similarity index 100% rename from src/auth/README.md rename to src/db/README.md diff --git a/src/db/__init__.py b/src/db/__init__.py new file mode 100644 index 0000000..2901713 --- /dev/null +++ b/src/db/__init__.py @@ -0,0 +1 @@ +# https://fastapi.tiangolo.com/tutorial/sql-databases/#review-all-the-files \ No newline at end of file diff --git a/src/db/crud.py b/src/db/crud.py new file mode 100644 index 0000000..359c908 --- /dev/null +++ b/src/db/crud.py @@ -0,0 +1,24 @@ +from sqlalchemy.orm import Session +from loguru import logger + +from . import models, schemas + + +def get_task(db: Session, task_id: str): + return db.query(models.Task).filter(models.Task.id == task_id).first() + + +# def get_user_by_email(db: Session, email: str): +# return db.query(models.User).filter(models.User.email == email).first() + + +def get_tasks(db: Session, skip: int = 0, limit: int = 100): + return db.query(models.Task).offset(skip).limit(limit).all() + + +def create_task(db: Session, task: schemas.TaskCreate): + db_task = models.Task(id=task.id, url=task.url, author=task.author, result=task.result) + db.add(db_task) + db.commit() + db.refresh(db_task) + return db_task diff --git a/src/db/database.py b/src/db/database.py new file mode 100644 index 0000000..416fc5e --- /dev/null +++ b/src/db/database.py @@ -0,0 +1,14 @@ +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +SQLALCHEMY_DATABASE_URL = "sqlite:///./auto-archiver.db" +# SQLALCHEMY_DATABASE_URL = "postgresql://user:password@postgresserver/db" + +engine = create_engine( + SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} +) +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +Base = declarative_base() + diff --git a/src/db/models.py b/src/db/models.py new file mode 100644 index 0000000..22e6653 --- /dev/null +++ b/src/db/models.py @@ -0,0 +1,17 @@ +from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, JSON, TIMESTAMP, DateTime +from sqlalchemy.orm import relationship +from sqlalchemy.sql import func +from .database import Base + + +class Task(Base): + __tablename__ = "tasks" + + id = Column(String, primary_key=True, index=True) + url = Column(String, index=True) + author = Column(String, index=True) + result = Column(JSON, default=None) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + # updated_at = Column(DateTime(timezone=True), onupdate=func.now()) + + # items = relationship("Item", back_populates="owner") diff --git a/src/db/schemas.py b/src/db/schemas.py new file mode 100644 index 0000000..c7cbf93 --- /dev/null +++ b/src/db/schemas.py @@ -0,0 +1,15 @@ +from pydantic import BaseModel +from datetime import datetime + +class TaskCreate(BaseModel): + id: str + url: str + author: str + result: dict + + +class Task(TaskCreate): + created_at: datetime + + class Config: + orm_mode = True \ No newline at end of file diff --git a/src/main.py b/src/main.py index 390490e..78cd87e 100644 --- a/src/main.py +++ b/src/main.py @@ -1,73 +1,56 @@ from celery.result import AsyncResult -from fastapi import Body, FastAPI, Form, Request, Depends, HTTPException +from fastapi import Body, FastAPI, Request, HTTPException, status, Depends from fastapi.encoders import jsonable_encoder -from fastapi.responses import JSONResponse, RedirectResponse +from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates -from pydantic.json import pydantic_encoder +# from pydantic.json import pydantic_encoder from dotenv import load_dotenv -import json -import sys, traceback +import traceback, os, requests from loguru import logger -# import os, requests -# from google.oauth2.credentials import Credentials -# from google_auth_oauthlib.flow import Flow -# from google_auth_oauthlib.flow import InstalledAppFlow -# from oauthlib.oauth2 import WebApplicationClient -from typing import Optional +from worker import create_archive_task, celery -from fastapi import Depends -# from fastapi.security import OAuth2PasswordRequestForm -# from fastapi.security import OAuth2AuthorizationCodeBearer +from db import crud, models, schemas +from db.database import engine, SessionLocal +from sqlalchemy.orm import Session -from auth.db import User, create_db_and_tables -# from app.schemas import UserCreate, UserRead, UserUpdate -from auth.users import ( - SECRET_KEY, - auth_backend, - current_active_user, - fastapi_users, - google_oauth_client, -) - -from worker import create_task, create_archive_task, celery +# models.Base.metadata.create_all(bind=engine) load_dotenv() -app = FastAPI() - - # Configuration -# GOOGLE_CLIENT_ID = os.environ.get("GOOGLE_CLIENT_ID", None) -# GOOGLE_CLIENT_SECRET = os.environ.get("GOOGLE_CLIENT_SECRET", None) -# GOOGLE_DISCOVERY_URL = ("https://accounts.google.com/.well-known/openid-configuration") -# GOOGLE_LOGIN_CALLBACK = os.environ.get("GOOGLE_LOGIN_CALLBACK", "http://localhost:5000/login/callback") -# SECRET_KEY = os.environ.get("SECRET_KEY", os.urandom(24)) +GOOGLE_CHROME_APP_ID = os.environ.get("GOOGLE_CHROME_APP_ID") +assert len(GOOGLE_CHROME_APP_ID)>10, "GOOGLE_CHROME_APP_ID env variable not set" +ALLOWED_EMAILS = set(os.environ.get("ALLOWED_EMAILS", "").split(",")) +assert len(GOOGLE_CHROME_APP_ID)>=1, "at least one ALLOWED_EMAILS is required from the env variable" -# Authentication logic for OAUTH2 -app.include_router( - fastapi_users.get_oauth_router(google_oauth_client, auth_backend, SECRET_KEY), - prefix="/auth/google", - tags=["auth"], -) -@app.get("/authenticated-route") -async def authenticated_route(user: User = Depends(current_active_user)): - return {"message": f"Hello {user.email}!"} +app = FastAPI() +app.mount("/static", StaticFiles(directory="static"), name="static") -# protected version -@app.post("/tasks-auth", status_code=201) -def run_task(payload = Body(...), user: User = Depends(current_active_user)): - logger.info(f"new task for user {user.email}: {payload['url']}") - # task_type = payload["type"] - # task = create_task.delay(int(task_type)) +def get_db(): + session = SessionLocal() + try: yield session + finally: session.close() + + +@app.get("/tasks/search", response_model=list[schemas.Task]) +def search(access_token:str, skip: int = 0, limit: int = 100, db: Session = Depends(get_db)): + validate_user_get_email(access_token) + return crud.get_tasks(db, skip=skip, limit=limit) + +@app.post("/tasks", status_code=201) +def run_task(payload = Body(...)): + email = validate_user_get_email(payload["access_token"]) + logger.info(f"new task for user {email}: {payload['url']}") task = create_archive_task.delay(payload["url"]) return JSONResponse({"task_id": task.id}) -@app.get("/tasks-auth/{task_id}") -def get_status(task_id, user: User = Depends(current_active_user)): - logger.info(f"status check for user {user.email}") +@app.get("/tasks/{task_id}") +def get_status(task_id, access_token:str): + email = validate_user_get_email(access_token) + logger.info(f"status check for user {email}") task_result = AsyncResult(task_id, app=celery) result = { "task_id": task_id, @@ -87,45 +70,45 @@ def get_status(task_id, user: User = Depends(current_active_user)): }) -@app.on_event("startup") -async def on_startup(): - # Not needed if you setup a migration system like Alembic - await create_db_and_tables() -#### -app.mount("/static", StaticFiles(directory="static"), name="static") -templates = Jinja2Templates(directory="templates") @app.get("/") -def home(request: Request): - return templates.TemplateResponse("home.html", context={"request": request}) +def home(): + return JSONResponse({"message": "Hello"}) +@app.on_event("startup") +async def on_startup(): +# # Not needed if you setup a migration system like Alembic +# await create_db_and_tables() + models.Base.metadata.create_all(bind=engine) -#TODO: deprecate -# @app.post("/tasks", status_code=201) -# def run_task(payload = Body(...)): -# # task_type = payload["type"] -# # task = create_task.delay(int(task_type)) -# task = create_archive_task.delay(payload["url"]) -# return JSONResponse({"task_id": task.id}) +#### helper methods +def authenticate_user(access_token): + # https://cloud.google.com/docs/authentication/token-types#access + if type(access_token)!=str or len(access_token)<10: return False, "invalid access_token" + r = requests.get("https://oauth2.googleapis.com/tokeninfo", {"access_token":access_token}) + if r.status_code!=200: return False, "error occurred" + try: + j = r.json() + if j.get("azp") != GOOGLE_CHROME_APP_ID and j.get("aud")!=GOOGLE_CHROME_APP_ID: + return False, f"token does not belong to correct APP_ID" + if j.get("email") not in ALLOWED_EMAILS: + return False, f"email '{j.get('email')}' not in ALLOWED" + if j.get("email_verified") != "true": + return False, f"email '{j.get('email')}' not verified" + if int(j.get("expires_in", -1)) <= 0: + return False, "Token expired" + return True, j.get('email') + except Exception as e: + logger.warning(f"EXCEPTION occurred: {e}") + return False, f"EXCEPTION occurred" -# @app.get("/tasks/{task_id}") -# def get_status(task_id): -# task_result = AsyncResult(task_id, app=celery) -# result = { -# "task_id": task_id, -# "task_status": task_result.status, -# "task_result": task_result.result -# } -# try: -# json_result = jsonable_encoder(result) -# # json_result = jsonable_encoder(result, custom_encoder=pydantic_encoder) # causes error -# return JSONResponse(json_result) -# except Exception as e: -# logger.error(e) -# logger.error(traceback.format_exc()) -# return JSONResponse({ -# "task_id": task_id, -# "task_status": "FAILURE", -# }) +def validate_user_get_email(access_token): + valid_user, info = authenticate_user(access_token) + if valid_user != True: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=info + ) + return info \ No newline at end of file diff --git a/src/worker.py b/src/worker.py index 15e347a..8eafade 100644 --- a/src/worker.py +++ b/src/worker.py @@ -1,31 +1,39 @@ + import os -import time from celery import Celery from dataclasses import asdict from auto_archiver import Config, ArchivingOrchestrator, Metadata +from loguru import logger +from db import crud, models, schemas +from db.database import engine, SessionLocal +from contextlib import contextmanager +import json + +# models.Base.metadata.create_all(bind=engine) celery = Celery(__name__) celery.conf.broker_url = os.environ.get("CELERY_BROKER_URL", "redis://localhost:6379") celery.conf.result_backend = os.environ.get("CELERY_RESULT_BACKEND", "redis://localhost:6379") -@celery.task(name="create_task") -def create_task(task_type): - print("DEV MODE") - time.sleep(int(task_type) * 10) - return True +@contextmanager +def get_db(): + session = SessionLocal() + try: yield session + finally: session.close() - -# from configs.v2config import ConfigV2 -# from auto_archiver import ArchivingOrchestrator config = Config() -config.parse(use_cli=False, yaml_config_filename="orchestration.yaml") +config.parse(use_cli=False, yaml_config_filename="secrets/orchestration.yaml") orchestrator = None -@celery.task(name="create_archive_task") -def create_archive_task(url: str , user_email:str=""): +@celery.task(name="create_archive_task", bind=True) +def create_archive_task(self, url: str , user_email:str=""): global orchestrator if not orchestrator: orchestrator = ArchivingOrchestrator(config) - return orchestrator.feed_item(Metadata().set_url(url)).to_json() - #TODO: associate user with url (?) + result = orchestrator.feed_item(Metadata().set_url(url)).to_json() + # result = orchestrator.feed_item(Metadata().set_url(url)) + with get_db() as session: + db_task = crud.create_task(session, task=schemas.TaskCreate(id=self.request.id, url=url, author=user_email, result=json.loads(result))) + logger.debug(f"Added {db_task.id=} to database on {db_task.created_at}") + return result