fix: pass api_key directly to litellm completion calls

This commit is contained in:
0xallam
2025-12-06 23:22:32 +02:00
committed by Ahmed Allam
parent 286d53384a
commit 4297c8f6e4
3 changed files with 21 additions and 21 deletions

View File

@@ -10,6 +10,7 @@ import os
import shutil
import sys
from pathlib import Path
from typing import Any
import litellm
from docker.errors import DockerException
@@ -189,19 +190,12 @@ async def warm_up_llm() -> None:
try:
model_name = os.getenv("STRIX_LLM", "openai/gpt-5")
api_key = os.getenv("LLM_API_KEY")
if api_key:
os.environ.setdefault("LITELLM_API_KEY", api_key)
litellm.api_key = api_key
api_base = (
os.getenv("LLM_API_BASE")
or os.getenv("OPENAI_API_BASE")
or os.getenv("LITELLM_BASE_URL")
or os.getenv("OLLAMA_API_BASE")
)
if api_base:
litellm.api_base = api_base
test_messages = [
{"role": "system", "content": "You are a helpful assistant."},
@@ -210,11 +204,17 @@ async def warm_up_llm() -> None:
llm_timeout = int(os.getenv("LLM_TIMEOUT", "600"))
response = litellm.completion(
model=model_name,
messages=test_messages,
timeout=llm_timeout,
)
completion_kwargs: dict[str, Any] = {
"model": model_name,
"messages": test_messages,
"timeout": llm_timeout,
}
if api_key:
completion_kwargs["api_key"] = api_key
if api_base:
completion_kwargs["api_base"] = api_base
response = litellm.completion(**completion_kwargs)
validate_llm_response(response)

View File

@@ -11,5 +11,3 @@ __all__ = [
]
litellm._logging._disable_debugging()
litellm.drop_params = True

View File

@@ -25,19 +25,16 @@ from strix.tools import get_tools_prompt
logger = logging.getLogger(__name__)
api_key = os.getenv("LLM_API_KEY")
if api_key:
os.environ.setdefault("LITELLM_API_KEY", api_key)
litellm.api_key = api_key
litellm.drop_params = True
litellm.modify_params = True
api_base = (
_LLM_API_KEY = os.getenv("LLM_API_KEY")
_LLM_API_BASE = (
os.getenv("LLM_API_BASE")
or os.getenv("OPENAI_API_BASE")
or os.getenv("LITELLM_BASE_URL")
or os.getenv("OLLAMA_API_BASE")
)
if api_base:
litellm.api_base = api_base
class LLMRequestFailedError(Exception):
@@ -401,6 +398,11 @@ class LLM:
"timeout": self.config.timeout,
}
if _LLM_API_KEY:
completion_args["api_key"] = _LLM_API_KEY
if _LLM_API_BASE:
completion_args["api_base"] = _LLM_API_BASE
if self._should_include_stop_param():
completion_args["stop"] = ["</function>"]