From f01ea48f577a45f8c9b1c81dcb53b7520688cb59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20Sp=C3=B6ttel?= <1682504+fspoettel@users.noreply.github.com> Date: Thu, 29 Jun 2023 13:11:23 +0200 Subject: [PATCH] feat: add `ENABLE_SHARING` setting --- .env.example | 9 ++++++++- README.md | 7 +------ app/shared/settings.py | 4 ++-- app/tests/test_auth.py | 10 +++++----- app/web/main.py | 19 +++++++++++++------ 5 files changed, 29 insertions(+), 20 deletions(-) diff --git a/.env.example b/.env.example index 63e271e..d0f0e0f 100644 --- a/.env.example +++ b/.env.example @@ -4,11 +4,18 @@ API_SECRET="change_me" # see https://github.com/openai/whisper#available-models-and-languages WHISPER_MODEL="small" +# If enabled, GET requests to routes `/job/:id` and `/job/:id/artifacts` will be unauthenticated. +ENABLE_SHARING="false" + +# the domain you want to access the service from. Its A records need to point to the host IP. TRAEFIK_DOMAIN="whisperbox-transcribe.localhost" +# an email which is used to verify domain ownership before a TLS certificate is issued. TRAEFIK_SSLEMAIL="" -# you probably do not need to change this. +# --- +# below settings match the default docker-compose configuration. + BROKER_URL="redis://redis:6379/0" DATABASE_URI="sqlite:////etc/whisperbox-transcribe/data/whisperbox-transcribe.sqlite" ENVIRONMENT="production" diff --git a/README.md b/README.md index e7fcf4f..cf3b60c 100644 --- a/README.md +++ b/README.md @@ -32,12 +32,7 @@ This project is intended to be run via [docker compose](https://docs.docker.com/ ### 2. Configure service -2. Create an `.env` file from `.env.example` and configure it: - - - `API_SECRET`: the API key used to authenticate against the API. - - `WHISPER_MODEL`: the whisper model size you want to use. - - `TRAEFIK_DOMAIN`: the domain you want to access the service from. Its A records need to point to the host IP. - - `TRAEFIK_SSLEMAIL`: an email which is used to verify domain ownership before a TLS certificate is issued. +2. Create an `.env` file from `.env.example` and configure it. Refer to comments for available envs and their usage. ### 3. Run service diff --git a/app/shared/settings.py b/app/shared/settings.py index 3c8ee15..39e350f 100644 --- a/app/shared/settings.py +++ b/app/shared/settings.py @@ -5,14 +5,14 @@ from pydantic import BaseSettings class Settings(BaseSettings): API_SECRET: str + BROKER_URL: str DATABASE_URI: str ENVIRONMENT: str TASK_SOFT_TIME_LIMIT: int = 3 * 60 * 60 TASK_HARD_TIME_LIMIT: int = 4 * 60 * 60 - # derived settings - BROKER_URL: str + ENABLE_SHARING: bool = False if "pytest" in sys.modules: diff --git a/app/tests/test_auth.py b/app/tests/test_auth.py index 4161b6e..fadc2e0 100644 --- a/app/tests/test_auth.py +++ b/app/tests/test_auth.py @@ -6,20 +6,20 @@ client = TestClient(app) def test_authorization_header_missing() -> None: - res = client.get("/api/v1") + res = client.get("/api/v1/jobs") assert res.status_code == 401 def test_authorization_header_malformed() -> None: - res = client.get("/api/v1", headers={"Authorization": "Bearer"}) + res = client.get("/api/v1/jobs", headers={"Authorization": "Bearer"}) assert res.status_code == 401 def test_incorrect_api_key() -> None: - res = client.get("/api/v1", headers={"Authorization": "Bearer incorrect"}) + res = client.get("/api/v1/jobs", headers={"Authorization": "Bearer incorrect"}) assert res.status_code == 401 def test_existing_api_key(auth_headers: dict[str, str]) -> None: - res = client.get("/api/v1", headers=auth_headers) - assert res.status_code == 204 + res = client.get("/api/v1/jobs", headers=auth_headers) + assert res.status_code == 200 diff --git a/app/web/main.py b/app/web/main.py index 7aba541..775c315 100644 --- a/app/web/main.py +++ b/app/web/main.py @@ -9,6 +9,7 @@ from sqlalchemy.orm import Session import app.shared.db.models as models import app.web.dtos as dtos from app.shared.db.base import SessionLocal, get_session +from app.shared.settings import settings from app.web.security import authenticate_api_key from app.web.task_queue import task_queue @@ -29,10 +30,7 @@ app = FastAPI( ) -api_router = APIRouter( - prefix="/api/v1", - dependencies=[Depends(authenticate_api_key)], -) +api_router = APIRouter(prefix="/api/v1") @api_router.get("/", response_model=None, status_code=204) @@ -41,7 +39,10 @@ def api_root() -> None: @api_router.get( - "/jobs", response_model=list[dtos.Job], summary="Get metadata for all jobs" + "/jobs", + dependencies=[Depends(authenticate_api_key)], + response_model=list[dtos.Job], + summary="Get metadata for all jobs", ) def get_transcripts( session: DatabaseSession, @@ -58,6 +59,7 @@ def get_transcripts( @api_router.get( "/jobs/{id}", + dependencies=[] if settings.ENABLE_SHARING else [Depends(authenticate_api_key)], response_model=dtos.Job, summary="Get metadata for one job", ) @@ -78,6 +80,7 @@ def get_transcript( @api_router.get( "/jobs/{id}/artifacts", + dependencies=[] if settings.ENABLE_SHARING else [Depends(authenticate_api_key)], response_model=list[dtos.Artifact], summary="Get all artifacts for one job", ) @@ -98,7 +101,10 @@ def get_artifacts_for_job( @api_router.delete( - "/jobs/{id}", status_code=204, summary="Delete a job with all artifacts" + "/jobs/{id}", + dependencies=[Depends(authenticate_api_key)], + status_code=204, + summary="Delete a job with all artifacts", ) def delete_transcript( session: DatabaseSession, @@ -133,6 +139,7 @@ class PostJobPayload(BaseModel): @api_router.post( "/jobs", + dependencies=[Depends(authenticate_api_key)], response_model=dtos.Job, status_code=201, summary="Enqueue a new job",