From 4297c8f6e4848462f0e54e411aaaac89ebae2f28 Mon Sep 17 00:00:00 2001 From: 0xallam Date: Sat, 6 Dec 2025 23:22:32 +0200 Subject: [PATCH] fix: pass api_key directly to litellm completion calls --- strix/interface/main.py | 24 ++++++++++++------------ strix/llm/__init__.py | 2 -- strix/llm/llm.py | 16 +++++++++------- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/strix/interface/main.py b/strix/interface/main.py index 7f8370c..6ed3405 100644 --- a/strix/interface/main.py +++ b/strix/interface/main.py @@ -10,6 +10,7 @@ import os import shutil import sys from pathlib import Path +from typing import Any import litellm from docker.errors import DockerException @@ -189,19 +190,12 @@ async def warm_up_llm() -> None: try: model_name = os.getenv("STRIX_LLM", "openai/gpt-5") api_key = os.getenv("LLM_API_KEY") - - if api_key: - os.environ.setdefault("LITELLM_API_KEY", api_key) - litellm.api_key = api_key - api_base = ( os.getenv("LLM_API_BASE") or os.getenv("OPENAI_API_BASE") or os.getenv("LITELLM_BASE_URL") or os.getenv("OLLAMA_API_BASE") ) - if api_base: - litellm.api_base = api_base test_messages = [ {"role": "system", "content": "You are a helpful assistant."}, @@ -210,11 +204,17 @@ async def warm_up_llm() -> None: llm_timeout = int(os.getenv("LLM_TIMEOUT", "600")) - response = litellm.completion( - model=model_name, - messages=test_messages, - timeout=llm_timeout, - ) + completion_kwargs: dict[str, Any] = { + "model": model_name, + "messages": test_messages, + "timeout": llm_timeout, + } + if api_key: + completion_kwargs["api_key"] = api_key + if api_base: + completion_kwargs["api_base"] = api_base + + response = litellm.completion(**completion_kwargs) validate_llm_response(response) diff --git a/strix/llm/__init__.py b/strix/llm/__init__.py index 6dde525..f3f8b67 100644 --- a/strix/llm/__init__.py +++ b/strix/llm/__init__.py @@ -11,5 +11,3 @@ __all__ = [ ] litellm._logging._disable_debugging() - -litellm.drop_params = True diff --git a/strix/llm/llm.py b/strix/llm/llm.py index 98103b2..8044130 100644 --- a/strix/llm/llm.py +++ b/strix/llm/llm.py @@ -25,19 +25,16 @@ from strix.tools import get_tools_prompt logger = logging.getLogger(__name__) -api_key = os.getenv("LLM_API_KEY") -if api_key: - os.environ.setdefault("LITELLM_API_KEY", api_key) - litellm.api_key = api_key +litellm.drop_params = True +litellm.modify_params = True -api_base = ( +_LLM_API_KEY = os.getenv("LLM_API_KEY") +_LLM_API_BASE = ( os.getenv("LLM_API_BASE") or os.getenv("OPENAI_API_BASE") or os.getenv("LITELLM_BASE_URL") or os.getenv("OLLAMA_API_BASE") ) -if api_base: - litellm.api_base = api_base class LLMRequestFailedError(Exception): @@ -401,6 +398,11 @@ class LLM: "timeout": self.config.timeout, } + if _LLM_API_KEY: + completion_args["api_key"] = _LLM_API_KEY + if _LLM_API_BASE: + completion_args["api_base"] = _LLM_API_BASE + if self._should_include_stop_param(): completion_args["stop"] = [""]