Resolve LLM API Base and Models (#317)
This commit is contained in:
@@ -187,6 +187,9 @@ def resolve_llm_config() -> tuple[str | None, str | None, str | None]:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: (model_name, api_key, api_base)
|
tuple: (model_name, api_key, api_base)
|
||||||
|
- model_name: Original model name (strix/ prefix preserved for display)
|
||||||
|
- api_key: LLM API key
|
||||||
|
- api_base: API base URL (auto-set to STRIX_API_BASE for strix/ models)
|
||||||
"""
|
"""
|
||||||
model = Config.get("strix_llm")
|
model = Config.get("strix_llm")
|
||||||
if not model:
|
if not model:
|
||||||
@@ -195,10 +198,8 @@ def resolve_llm_config() -> tuple[str | None, str | None, str | None]:
|
|||||||
api_key = Config.get("llm_api_key")
|
api_key = Config.get("llm_api_key")
|
||||||
|
|
||||||
if model.startswith("strix/"):
|
if model.startswith("strix/"):
|
||||||
model_name = "openai/" + model[6:]
|
|
||||||
api_base: str | None = STRIX_API_BASE
|
api_base: str | None = STRIX_API_BASE
|
||||||
else:
|
else:
|
||||||
model_name = model
|
|
||||||
api_base = (
|
api_base = (
|
||||||
Config.get("llm_api_base")
|
Config.get("llm_api_base")
|
||||||
or Config.get("openai_api_base")
|
or Config.get("openai_api_base")
|
||||||
@@ -206,4 +207,4 @@ def resolve_llm_config() -> tuple[str | None, str | None, str | None]:
|
|||||||
or Config.get("ollama_api_base")
|
or Config.get("ollama_api_base")
|
||||||
)
|
)
|
||||||
|
|
||||||
return model_name, api_key, api_base
|
return model, api_key, api_base
|
||||||
|
|||||||
@@ -18,6 +18,8 @@ from rich.panel import Panel
|
|||||||
from rich.text import Text
|
from rich.text import Text
|
||||||
|
|
||||||
from strix.config import Config, apply_saved_config, save_current_config
|
from strix.config import Config, apply_saved_config, save_current_config
|
||||||
|
from strix.config.config import resolve_llm_config
|
||||||
|
from strix.llm.utils import resolve_strix_model
|
||||||
|
|
||||||
|
|
||||||
apply_saved_config()
|
apply_saved_config()
|
||||||
@@ -204,12 +206,12 @@ def check_docker_installed() -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def warm_up_llm() -> None:
|
async def warm_up_llm() -> None:
|
||||||
from strix.config.config import resolve_llm_config
|
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_name, api_key, api_base = resolve_llm_config()
|
model_name, api_key, api_base = resolve_llm_config()
|
||||||
|
litellm_model, _ = resolve_strix_model(model_name)
|
||||||
|
litellm_model = litellm_model or model_name
|
||||||
|
|
||||||
test_messages = [
|
test_messages = [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
@@ -219,7 +221,7 @@ async def warm_up_llm() -> None:
|
|||||||
llm_timeout = int(Config.get("llm_timeout") or "300")
|
llm_timeout = int(Config.get("llm_timeout") or "300")
|
||||||
|
|
||||||
completion_kwargs: dict[str, Any] = {
|
completion_kwargs: dict[str, Any] = {
|
||||||
"model": model_name,
|
"model": litellm_model,
|
||||||
"messages": test_messages,
|
"messages": test_messages,
|
||||||
"timeout": llm_timeout,
|
"timeout": llm_timeout,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from strix.config import Config
|
from strix.config import Config
|
||||||
from strix.config.config import resolve_llm_config
|
from strix.config.config import resolve_llm_config
|
||||||
|
from strix.llm.utils import resolve_strix_model
|
||||||
|
|
||||||
|
|
||||||
class LLMConfig:
|
class LLMConfig:
|
||||||
@@ -17,6 +18,10 @@ class LLMConfig:
|
|||||||
if not self.model_name:
|
if not self.model_name:
|
||||||
raise ValueError("STRIX_LLM environment variable must be set and not empty")
|
raise ValueError("STRIX_LLM environment variable must be set and not empty")
|
||||||
|
|
||||||
|
api_model, canonical = resolve_strix_model(self.model_name)
|
||||||
|
self.litellm_model: str = api_model or self.model_name
|
||||||
|
self.canonical_model: str = canonical or self.model_name
|
||||||
|
|
||||||
self.enable_prompt_caching = enable_prompt_caching
|
self.enable_prompt_caching = enable_prompt_caching
|
||||||
self.skills = skills or []
|
self.skills = skills or []
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from typing import Any
|
|||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
from strix.config.config import resolve_llm_config
|
from strix.config.config import resolve_llm_config
|
||||||
|
from strix.llm.utils import resolve_strix_model
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -156,6 +157,8 @@ def check_duplicate(
|
|||||||
comparison_data = {"candidate": candidate_cleaned, "existing_reports": existing_cleaned}
|
comparison_data = {"candidate": candidate_cleaned, "existing_reports": existing_cleaned}
|
||||||
|
|
||||||
model_name, api_key, api_base = resolve_llm_config()
|
model_name, api_key, api_base = resolve_llm_config()
|
||||||
|
litellm_model, _ = resolve_strix_model(model_name)
|
||||||
|
litellm_model = litellm_model or model_name
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": DEDUPE_SYSTEM_PROMPT},
|
{"role": "system", "content": DEDUPE_SYSTEM_PROMPT},
|
||||||
@@ -170,7 +173,7 @@ def check_duplicate(
|
|||||||
]
|
]
|
||||||
|
|
||||||
completion_kwargs: dict[str, Any] = {
|
completion_kwargs: dict[str, Any] = {
|
||||||
"model": model_name,
|
"model": litellm_model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"timeout": 120,
|
"timeout": 120,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ class LLM:
|
|||||||
self.agent_name = agent_name
|
self.agent_name = agent_name
|
||||||
self.agent_id: str | None = None
|
self.agent_id: str | None = None
|
||||||
self._total_stats = RequestStats()
|
self._total_stats = RequestStats()
|
||||||
self.memory_compressor = MemoryCompressor(model_name=config.model_name)
|
self.memory_compressor = MemoryCompressor(model_name=config.litellm_model)
|
||||||
self.system_prompt = self._load_system_prompt(agent_name)
|
self.system_prompt = self._load_system_prompt(agent_name)
|
||||||
|
|
||||||
reasoning = Config.get("strix_reasoning_effort")
|
reasoning = Config.get("strix_reasoning_effort")
|
||||||
@@ -194,7 +194,7 @@ class LLM:
|
|||||||
messages = self._strip_images(messages)
|
messages = self._strip_images(messages)
|
||||||
|
|
||||||
args: dict[str, Any] = {
|
args: dict[str, Any] = {
|
||||||
"model": self.config.model_name,
|
"model": self.config.litellm_model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"timeout": self.config.timeout,
|
"timeout": self.config.timeout,
|
||||||
"stream_options": {"include_usage": True},
|
"stream_options": {"include_usage": True},
|
||||||
@@ -229,8 +229,8 @@ class LLM:
|
|||||||
def _update_usage_stats(self, response: Any) -> None:
|
def _update_usage_stats(self, response: Any) -> None:
|
||||||
try:
|
try:
|
||||||
if hasattr(response, "usage") and response.usage:
|
if hasattr(response, "usage") and response.usage:
|
||||||
input_tokens = getattr(response.usage, "prompt_tokens", 0)
|
input_tokens = getattr(response.usage, "prompt_tokens", 0) or 0
|
||||||
output_tokens = getattr(response.usage, "completion_tokens", 0)
|
output_tokens = getattr(response.usage, "completion_tokens", 0) or 0
|
||||||
|
|
||||||
cached_tokens = 0
|
cached_tokens = 0
|
||||||
if hasattr(response.usage, "prompt_tokens_details"):
|
if hasattr(response.usage, "prompt_tokens_details"):
|
||||||
@@ -238,14 +238,11 @@ class LLM:
|
|||||||
if hasattr(prompt_details, "cached_tokens"):
|
if hasattr(prompt_details, "cached_tokens"):
|
||||||
cached_tokens = prompt_details.cached_tokens or 0
|
cached_tokens = prompt_details.cached_tokens or 0
|
||||||
|
|
||||||
|
cost = self._extract_cost(response)
|
||||||
else:
|
else:
|
||||||
input_tokens = 0
|
input_tokens = 0
|
||||||
output_tokens = 0
|
output_tokens = 0
|
||||||
cached_tokens = 0
|
cached_tokens = 0
|
||||||
|
|
||||||
try:
|
|
||||||
cost = completion_cost(response) or 0.0
|
|
||||||
except Exception: # noqa: BLE001
|
|
||||||
cost = 0.0
|
cost = 0.0
|
||||||
|
|
||||||
self._total_stats.input_tokens += input_tokens
|
self._total_stats.input_tokens += input_tokens
|
||||||
@@ -256,6 +253,18 @@ class LLM:
|
|||||||
except Exception: # noqa: BLE001, S110 # nosec B110
|
except Exception: # noqa: BLE001, S110 # nosec B110
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _extract_cost(self, response: Any) -> float:
|
||||||
|
if hasattr(response, "usage") and response.usage:
|
||||||
|
direct_cost = getattr(response.usage, "cost", None)
|
||||||
|
if direct_cost is not None:
|
||||||
|
return float(direct_cost)
|
||||||
|
try:
|
||||||
|
if hasattr(response, "_hidden_params"):
|
||||||
|
response._hidden_params.pop("custom_llm_provider", None)
|
||||||
|
return completion_cost(response, model=self.config.canonical_model) or 0.0
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
return 0.0
|
||||||
|
|
||||||
def _should_retry(self, e: Exception) -> bool:
|
def _should_retry(self, e: Exception) -> bool:
|
||||||
code = getattr(e, "status_code", None) or getattr(
|
code = getattr(e, "status_code", None) or getattr(
|
||||||
getattr(e, "response", None), "status_code", None
|
getattr(e, "response", None), "status_code", None
|
||||||
@@ -275,13 +284,13 @@ class LLM:
|
|||||||
|
|
||||||
def _supports_vision(self) -> bool:
|
def _supports_vision(self) -> bool:
|
||||||
try:
|
try:
|
||||||
return bool(supports_vision(model=self.config.model_name))
|
return bool(supports_vision(model=self.config.canonical_model))
|
||||||
except Exception: # noqa: BLE001
|
except Exception: # noqa: BLE001
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _supports_reasoning(self) -> bool:
|
def _supports_reasoning(self) -> bool:
|
||||||
try:
|
try:
|
||||||
return bool(supports_reasoning(model=self.config.model_name))
|
return bool(supports_reasoning(model=self.config.canonical_model))
|
||||||
except Exception: # noqa: BLE001
|
except Exception: # noqa: BLE001
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -302,7 +311,7 @@ class LLM:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
if not messages or not supports_prompt_caching(self.config.model_name):
|
if not messages or not supports_prompt_caching(self.config.canonical_model):
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
result = list(messages)
|
result = list(messages)
|
||||||
|
|||||||
@@ -158,7 +158,7 @@ class MemoryCompressor:
|
|||||||
):
|
):
|
||||||
self.max_images = max_images
|
self.max_images = max_images
|
||||||
self.model_name = model_name or Config.get("strix_llm")
|
self.model_name = model_name or Config.get("strix_llm")
|
||||||
self.timeout = timeout or int(Config.get("strix_memory_compressor_timeout") or "30")
|
self.timeout = timeout or int(Config.get("strix_memory_compressor_timeout") or "120")
|
||||||
|
|
||||||
if not self.model_name:
|
if not self.model_name:
|
||||||
raise ValueError("STRIX_LLM environment variable must be set and not empty")
|
raise ValueError("STRIX_LLM environment variable must be set and not empty")
|
||||||
|
|||||||
@@ -3,6 +3,40 @@ import re
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
STRIX_MODEL_MAP: dict[str, str] = {
|
||||||
|
"claude-sonnet-4.6": "anthropic/claude-sonnet-4-6",
|
||||||
|
"claude-opus-4.6": "anthropic/claude-opus-4-6",
|
||||||
|
"gpt-5.2": "openai/gpt-5.2",
|
||||||
|
"gpt-5.1": "openai/gpt-5.1",
|
||||||
|
"gpt-5": "openai/gpt-5",
|
||||||
|
"gpt-5.2-codex": "openai/gpt-5.2-codex",
|
||||||
|
"gpt-5.1-codex-max": "openai/gpt-5.1-codex-max",
|
||||||
|
"gpt-5.1-codex": "openai/gpt-5.1-codex",
|
||||||
|
"gpt-5-codex": "openai/gpt-5-codex",
|
||||||
|
"gemini-3-pro-preview": "gemini/gemini-3-pro-preview",
|
||||||
|
"gemini-3-flash-preview": "gemini/gemini-3-flash-preview",
|
||||||
|
"glm-5": "openrouter/z-ai/glm-5",
|
||||||
|
"glm-4.7": "openrouter/z-ai/glm-4.7",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_strix_model(model_name: str | None) -> tuple[str | None, str | None]:
|
||||||
|
"""Resolve a strix/ model into names for API calls and capability lookups.
|
||||||
|
|
||||||
|
Returns (api_model, canonical_model):
|
||||||
|
- api_model: openai/<base> for API calls (Strix API is OpenAI-compatible)
|
||||||
|
- canonical_model: actual provider model name for litellm capability lookups
|
||||||
|
Non-strix models return the same name for both.
|
||||||
|
"""
|
||||||
|
if not model_name or not model_name.startswith("strix/"):
|
||||||
|
return model_name, model_name
|
||||||
|
|
||||||
|
base_model = model_name[6:]
|
||||||
|
api_model = f"openai/{base_model}"
|
||||||
|
canonical_model = STRIX_MODEL_MAP.get(base_model, api_model)
|
||||||
|
return api_model, canonical_model
|
||||||
|
|
||||||
|
|
||||||
def _truncate_to_first_function(content: str) -> str:
|
def _truncate_to_first_function(content: str) -> str:
|
||||||
if not content:
|
if not content:
|
||||||
return content
|
return content
|
||||||
|
|||||||
Reference in New Issue
Block a user