fix: pass api_key directly to litellm completion calls
This commit is contained in:
@@ -10,6 +10,7 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from docker.errors import DockerException
|
from docker.errors import DockerException
|
||||||
@@ -189,19 +190,12 @@ async def warm_up_llm() -> None:
|
|||||||
try:
|
try:
|
||||||
model_name = os.getenv("STRIX_LLM", "openai/gpt-5")
|
model_name = os.getenv("STRIX_LLM", "openai/gpt-5")
|
||||||
api_key = os.getenv("LLM_API_KEY")
|
api_key = os.getenv("LLM_API_KEY")
|
||||||
|
|
||||||
if api_key:
|
|
||||||
os.environ.setdefault("LITELLM_API_KEY", api_key)
|
|
||||||
litellm.api_key = api_key
|
|
||||||
|
|
||||||
api_base = (
|
api_base = (
|
||||||
os.getenv("LLM_API_BASE")
|
os.getenv("LLM_API_BASE")
|
||||||
or os.getenv("OPENAI_API_BASE")
|
or os.getenv("OPENAI_API_BASE")
|
||||||
or os.getenv("LITELLM_BASE_URL")
|
or os.getenv("LITELLM_BASE_URL")
|
||||||
or os.getenv("OLLAMA_API_BASE")
|
or os.getenv("OLLAMA_API_BASE")
|
||||||
)
|
)
|
||||||
if api_base:
|
|
||||||
litellm.api_base = api_base
|
|
||||||
|
|
||||||
test_messages = [
|
test_messages = [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
@@ -210,11 +204,17 @@ async def warm_up_llm() -> None:
|
|||||||
|
|
||||||
llm_timeout = int(os.getenv("LLM_TIMEOUT", "600"))
|
llm_timeout = int(os.getenv("LLM_TIMEOUT", "600"))
|
||||||
|
|
||||||
response = litellm.completion(
|
completion_kwargs: dict[str, Any] = {
|
||||||
model=model_name,
|
"model": model_name,
|
||||||
messages=test_messages,
|
"messages": test_messages,
|
||||||
timeout=llm_timeout,
|
"timeout": llm_timeout,
|
||||||
)
|
}
|
||||||
|
if api_key:
|
||||||
|
completion_kwargs["api_key"] = api_key
|
||||||
|
if api_base:
|
||||||
|
completion_kwargs["api_base"] = api_base
|
||||||
|
|
||||||
|
response = litellm.completion(**completion_kwargs)
|
||||||
|
|
||||||
validate_llm_response(response)
|
validate_llm_response(response)
|
||||||
|
|
||||||
|
|||||||
@@ -11,5 +11,3 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
litellm._logging._disable_debugging()
|
litellm._logging._disable_debugging()
|
||||||
|
|
||||||
litellm.drop_params = True
|
|
||||||
|
|||||||
@@ -25,19 +25,16 @@ from strix.tools import get_tools_prompt
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
api_key = os.getenv("LLM_API_KEY")
|
litellm.drop_params = True
|
||||||
if api_key:
|
litellm.modify_params = True
|
||||||
os.environ.setdefault("LITELLM_API_KEY", api_key)
|
|
||||||
litellm.api_key = api_key
|
|
||||||
|
|
||||||
api_base = (
|
_LLM_API_KEY = os.getenv("LLM_API_KEY")
|
||||||
|
_LLM_API_BASE = (
|
||||||
os.getenv("LLM_API_BASE")
|
os.getenv("LLM_API_BASE")
|
||||||
or os.getenv("OPENAI_API_BASE")
|
or os.getenv("OPENAI_API_BASE")
|
||||||
or os.getenv("LITELLM_BASE_URL")
|
or os.getenv("LITELLM_BASE_URL")
|
||||||
or os.getenv("OLLAMA_API_BASE")
|
or os.getenv("OLLAMA_API_BASE")
|
||||||
)
|
)
|
||||||
if api_base:
|
|
||||||
litellm.api_base = api_base
|
|
||||||
|
|
||||||
|
|
||||||
class LLMRequestFailedError(Exception):
|
class LLMRequestFailedError(Exception):
|
||||||
@@ -401,6 +398,11 @@ class LLM:
|
|||||||
"timeout": self.config.timeout,
|
"timeout": self.config.timeout,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if _LLM_API_KEY:
|
||||||
|
completion_args["api_key"] = _LLM_API_KEY
|
||||||
|
if _LLM_API_BASE:
|
||||||
|
completion_args["api_base"] = _LLM_API_BASE
|
||||||
|
|
||||||
if self._should_include_stop_param():
|
if self._should_include_stop_param():
|
||||||
completion_args["stop"] = ["</function>"]
|
completion_args["stop"] = ["</function>"]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user