From 3ef47873133dbb15ee13919eaef9b04001b4ec9f Mon Sep 17 00:00:00 2001 From: msramalho <19508417+msramalho@users.noreply.github.com> Date: Sun, 26 Feb 2023 20:22:20 +0100 Subject: [PATCH] API updates --- src/db/crud.py | 20 ++++++++++++-------- src/db/models.py | 3 +-- src/main.py | 30 ++++++++++++++++++++---------- src/worker.py | 6 ++++-- 4 files changed, 37 insertions(+), 22 deletions(-) diff --git a/src/db/crud.py b/src/db/crud.py index 359c908..cf7c7d4 100644 --- a/src/db/crud.py +++ b/src/db/crud.py @@ -1,20 +1,24 @@ -from sqlalchemy.orm import Session -from loguru import logger +from sqlalchemy.orm import Session, load_only from . import models, schemas def get_task(db: Session, task_id: str): - return db.query(models.Task).filter(models.Task.id == task_id).first() - - -# def get_user_by_email(db: Session, email: str): -# return db.query(models.User).filter(models.User.email == email).first() + return base_query(db).filter(models.Task.id == task_id).first() def get_tasks(db: Session, skip: int = 0, limit: int = 100): - return db.query(models.Task).offset(skip).limit(limit).all() + return base_query(db).offset(skip).limit(limit).all() +def search_tasks_by_url(db: Session, url:str, skip: int = 0, limit: int = 100): + return base_query(db).filter(models.Task.url.like(f'%{url}%')).offset(skip).limit(limit).all() + +def search_tasks_by_email(db: Session, email:str, skip: int = 0, limit: int = 100): + return base_query(db).filter(models.Task.author==email).offset(skip).limit(limit).all() + +def base_query(db:Session): + # allow only some fields to be returned, for example author should remain hidden + return db.query(models.Task).options(load_only(models.Task.id, models.Task.created_at, models.Task.url, models.Task.result)) def create_task(db: Session, task: schemas.TaskCreate): db_task = models.Task(id=task.id, url=task.url, author=task.author, result=task.result) diff --git a/src/db/models.py b/src/db/models.py index 22e6653..78d9df9 100644 --- a/src/db/models.py +++ b/src/db/models.py @@ -1,5 +1,4 @@ -from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, JSON, TIMESTAMP, DateTime -from sqlalchemy.orm import relationship +from sqlalchemy import Column, String, JSON, DateTime from sqlalchemy.sql import func from .database import Base diff --git a/src/main.py b/src/main.py index f6a98dd..c4499f3 100644 --- a/src/main.py +++ b/src/main.py @@ -26,7 +26,7 @@ assert len(GOOGLE_CHROME_APP_ID)>10, "GOOGLE_CHROME_APP_ID env variable not set" ALLOWED_EMAILS = set(os.environ.get("ALLOWED_EMAILS", "").split(",")) assert len(GOOGLE_CHROME_APP_ID)>=1, "at least one ALLOWED_EMAILS is required from the env variable" ALLOWED_ORIGINS = os.environ.get("ALLOWED_ORIGINS", "chrome-extension://ondkcheoicfckabcnkdgbepofpjmjcmb,chrome-extension://ojcimmjndnlmmlgnjaeojoebaceokpdp").split(",") -VERSION = "0.1.1" +VERSION = "0.1.3" app = FastAPI() app.add_middleware( @@ -44,17 +44,27 @@ def get_db(): finally: session.close() +@app.get("/tasks/search-url", response_model=list[schemas.Task]) +def search(access_token:str, url:str, skip: int = 0, limit: int = 100, db: Session = Depends(get_db)): + validate_user_get_email(access_token) + return crud.search_tasks_by_url(db, url, skip=skip, limit=limit) + @app.get("/tasks/search", response_model=list[schemas.Task]) def search(access_token:str, skip: int = 0, limit: int = 100, db: Session = Depends(get_db)): validate_user_get_email(access_token) return crud.get_tasks(db, skip=skip, limit=limit) + +@app.get("/tasks/sync", response_model=list[schemas.Task]) +def search(access_token:str, skip: int = 0, limit: int = 100, db: Session = Depends(get_db)): + email = validate_user_get_email(access_token) + return crud.search_tasks_by_email(db, email, skip=skip, limit=limit) @app.post("/tasks", status_code=201) def run_task(payload = Body(...)): - email = validate_user_get_email(payload["access_token"]) - logger.info(f"new task for user {email}: {payload['url']}") - task = create_archive_task.delay(payload["url"]) - return JSONResponse({"task_id": task.id}) + email = validate_user_get_email(payload.get("access_token")) + logger.info(f"new task for user {email}: {payload.get('url')}") + task = create_archive_task.delay(url=payload.get('url'), email=email) + return JSONResponse({"id": task.id}) @app.get("/tasks/{task_id}") def get_status(task_id, access_token:str): @@ -62,9 +72,9 @@ def get_status(task_id, access_token:str): logger.info(f"status check for user {email}") task_result = AsyncResult(task_id, app=celery) result = { - "task_id": task_id, - "task_status": task_result.status, - "task_result": task_result.result + "id": task_id, + "status": task_result.status, + "result": task_result.result } try: json_result = jsonable_encoder(result, exclude_unset=True) @@ -74,8 +84,8 @@ def get_status(task_id, access_token:str): logger.error(e) logger.error(traceback.format_exc()) return JSONResponse({ - "task_id": task_id, - "task_status": "FAILURE", + "id": task_id, + "status": "FAILURE", }) diff --git a/src/worker.py b/src/worker.py index 4615524..24ee914 100644 --- a/src/worker.py +++ b/src/worker.py @@ -4,6 +4,7 @@ import os from celery import Celery from dataclasses import asdict from auto_archiver import Config, ArchivingOrchestrator, Metadata +from auto_archiver.enrichers import ScreenshotEnricher from loguru import logger from db import crud, models, schemas @@ -28,11 +29,12 @@ config.parse(use_cli=False, yaml_config_filename="secrets/orchestration.yaml") orchestrator = None @celery.task(name="create_archive_task", bind=True) -def create_archive_task(self, url: str , user_email:str=""): +def create_archive_task(self, url: str, email:str=""): + assert type(url)==str and len(url)>5, f"Invalid URL received: {url}" global orchestrator if not orchestrator: orchestrator = ArchivingOrchestrator(config) result = orchestrator.feed_item(Metadata().set_url(url)).to_json() with get_db() as session: - db_task = crud.create_task(session, task=schemas.TaskCreate(id=self.request.id, url=url, author=user_email, result=json.loads(result))) + db_task = crud.create_task(session, task=schemas.TaskCreate(id=self.request.id, url=url, author=email, result=json.loads(result))) logger.debug(f"Added {db_task.id=} to database on {db_task.created_at}") return result