fix: pass api_key directly to litellm completion calls

This commit is contained in:
0xallam
2025-12-06 23:22:32 +02:00
committed by Ahmed Allam
parent 286d53384a
commit 4297c8f6e4
3 changed files with 21 additions and 21 deletions

View File

@@ -10,6 +10,7 @@ import os
import shutil import shutil
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Any
import litellm import litellm
from docker.errors import DockerException from docker.errors import DockerException
@@ -189,19 +190,12 @@ async def warm_up_llm() -> None:
try: try:
model_name = os.getenv("STRIX_LLM", "openai/gpt-5") model_name = os.getenv("STRIX_LLM", "openai/gpt-5")
api_key = os.getenv("LLM_API_KEY") api_key = os.getenv("LLM_API_KEY")
if api_key:
os.environ.setdefault("LITELLM_API_KEY", api_key)
litellm.api_key = api_key
api_base = ( api_base = (
os.getenv("LLM_API_BASE") os.getenv("LLM_API_BASE")
or os.getenv("OPENAI_API_BASE") or os.getenv("OPENAI_API_BASE")
or os.getenv("LITELLM_BASE_URL") or os.getenv("LITELLM_BASE_URL")
or os.getenv("OLLAMA_API_BASE") or os.getenv("OLLAMA_API_BASE")
) )
if api_base:
litellm.api_base = api_base
test_messages = [ test_messages = [
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
@@ -210,11 +204,17 @@ async def warm_up_llm() -> None:
llm_timeout = int(os.getenv("LLM_TIMEOUT", "600")) llm_timeout = int(os.getenv("LLM_TIMEOUT", "600"))
response = litellm.completion( completion_kwargs: dict[str, Any] = {
model=model_name, "model": model_name,
messages=test_messages, "messages": test_messages,
timeout=llm_timeout, "timeout": llm_timeout,
) }
if api_key:
completion_kwargs["api_key"] = api_key
if api_base:
completion_kwargs["api_base"] = api_base
response = litellm.completion(**completion_kwargs)
validate_llm_response(response) validate_llm_response(response)

View File

@@ -11,5 +11,3 @@ __all__ = [
] ]
litellm._logging._disable_debugging() litellm._logging._disable_debugging()
litellm.drop_params = True

View File

@@ -25,19 +25,16 @@ from strix.tools import get_tools_prompt
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
api_key = os.getenv("LLM_API_KEY") litellm.drop_params = True
if api_key: litellm.modify_params = True
os.environ.setdefault("LITELLM_API_KEY", api_key)
litellm.api_key = api_key
api_base = ( _LLM_API_KEY = os.getenv("LLM_API_KEY")
_LLM_API_BASE = (
os.getenv("LLM_API_BASE") os.getenv("LLM_API_BASE")
or os.getenv("OPENAI_API_BASE") or os.getenv("OPENAI_API_BASE")
or os.getenv("LITELLM_BASE_URL") or os.getenv("LITELLM_BASE_URL")
or os.getenv("OLLAMA_API_BASE") or os.getenv("OLLAMA_API_BASE")
) )
if api_base:
litellm.api_base = api_base
class LLMRequestFailedError(Exception): class LLMRequestFailedError(Exception):
@@ -401,6 +398,11 @@ class LLM:
"timeout": self.config.timeout, "timeout": self.config.timeout,
} }
if _LLM_API_KEY:
completion_args["api_key"] = _LLM_API_KEY
if _LLM_API_BASE:
completion_args["api_base"] = _LLM_API_BASE
if self._should_include_stop_param(): if self._should_include_stop_param():
completion_args["stop"] = ["</function>"] completion_args["stop"] = ["</function>"]