feat: add ENABLE_SHARING setting

This commit is contained in:
Felix Spöttel
2023-06-29 13:11:23 +02:00
parent 238a694f72
commit f01ea48f57
5 changed files with 29 additions and 20 deletions

View File

@@ -4,11 +4,18 @@ API_SECRET="change_me"
# see https://github.com/openai/whisper#available-models-and-languages
WHISPER_MODEL="small"
# If enabled, GET requests to routes `/job/:id` and `/job/:id/artifacts` will be unauthenticated.
ENABLE_SHARING="false"
# the domain you want to access the service from. Its A records need to point to the host IP.
TRAEFIK_DOMAIN="whisperbox-transcribe.localhost"
# an email which is used to verify domain ownership before a TLS certificate is issued.
TRAEFIK_SSLEMAIL=""
# you probably do not need to change this.
# ---
# below settings match the default docker-compose configuration.
BROKER_URL="redis://redis:6379/0"
DATABASE_URI="sqlite:////etc/whisperbox-transcribe/data/whisperbox-transcribe.sqlite"
ENVIRONMENT="production"

View File

@@ -32,12 +32,7 @@ This project is intended to be run via [docker compose](https://docs.docker.com/
### 2. Configure service
2. Create an `.env` file from `.env.example` and configure it:
- `API_SECRET`: the API key used to authenticate against the API.
- `WHISPER_MODEL`: the whisper model size you want to use.
- `TRAEFIK_DOMAIN`: the domain you want to access the service from. Its A records need to point to the host IP.
- `TRAEFIK_SSLEMAIL`: an email which is used to verify domain ownership before a TLS certificate is issued.
2. Create an `.env` file from `.env.example` and configure it. Refer to comments for available envs and their usage.
### 3. Run service

View File

@@ -5,14 +5,14 @@ from pydantic import BaseSettings
class Settings(BaseSettings):
API_SECRET: str
BROKER_URL: str
DATABASE_URI: str
ENVIRONMENT: str
TASK_SOFT_TIME_LIMIT: int = 3 * 60 * 60
TASK_HARD_TIME_LIMIT: int = 4 * 60 * 60
# derived settings
BROKER_URL: str
ENABLE_SHARING: bool = False
if "pytest" in sys.modules:

View File

@@ -6,20 +6,20 @@ client = TestClient(app)
def test_authorization_header_missing() -> None:
res = client.get("/api/v1")
res = client.get("/api/v1/jobs")
assert res.status_code == 401
def test_authorization_header_malformed() -> None:
res = client.get("/api/v1", headers={"Authorization": "Bearer"})
res = client.get("/api/v1/jobs", headers={"Authorization": "Bearer"})
assert res.status_code == 401
def test_incorrect_api_key() -> None:
res = client.get("/api/v1", headers={"Authorization": "Bearer incorrect"})
res = client.get("/api/v1/jobs", headers={"Authorization": "Bearer incorrect"})
assert res.status_code == 401
def test_existing_api_key(auth_headers: dict[str, str]) -> None:
res = client.get("/api/v1", headers=auth_headers)
assert res.status_code == 204
res = client.get("/api/v1/jobs", headers=auth_headers)
assert res.status_code == 200

View File

@@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
import app.shared.db.models as models
import app.web.dtos as dtos
from app.shared.db.base import SessionLocal, get_session
from app.shared.settings import settings
from app.web.security import authenticate_api_key
from app.web.task_queue import task_queue
@@ -29,10 +30,7 @@ app = FastAPI(
)
api_router = APIRouter(
prefix="/api/v1",
dependencies=[Depends(authenticate_api_key)],
)
api_router = APIRouter(prefix="/api/v1")
@api_router.get("/", response_model=None, status_code=204)
@@ -41,7 +39,10 @@ def api_root() -> None:
@api_router.get(
"/jobs", response_model=list[dtos.Job], summary="Get metadata for all jobs"
"/jobs",
dependencies=[Depends(authenticate_api_key)],
response_model=list[dtos.Job],
summary="Get metadata for all jobs",
)
def get_transcripts(
session: DatabaseSession,
@@ -58,6 +59,7 @@ def get_transcripts(
@api_router.get(
"/jobs/{id}",
dependencies=[] if settings.ENABLE_SHARING else [Depends(authenticate_api_key)],
response_model=dtos.Job,
summary="Get metadata for one job",
)
@@ -78,6 +80,7 @@ def get_transcript(
@api_router.get(
"/jobs/{id}/artifacts",
dependencies=[] if settings.ENABLE_SHARING else [Depends(authenticate_api_key)],
response_model=list[dtos.Artifact],
summary="Get all artifacts for one job",
)
@@ -98,7 +101,10 @@ def get_artifacts_for_job(
@api_router.delete(
"/jobs/{id}", status_code=204, summary="Delete a job with all artifacts"
"/jobs/{id}",
dependencies=[Depends(authenticate_api_key)],
status_code=204,
summary="Delete a job with all artifacts",
)
def delete_transcript(
session: DatabaseSession,
@@ -133,6 +139,7 @@ class PostJobPayload(BaseModel):
@api_router.post(
"/jobs",
dependencies=[Depends(authenticate_api_key)],
response_model=dtos.Job,
status_code=201,
summary="Enqueue a new job",