refactor: Centralize strix model resolution with separate API and capability names
- Replace fragile prefix matching with explicit STRIX_MODEL_MAP - Add resolve_strix_model() returning (api_model, canonical_model) - api_model (openai/ prefix) for API calls to OpenAI-compatible Strix API - canonical_model (actual provider name) for litellm capability lookups - Centralize resolution in LLMConfig instead of scattered call sites
This commit is contained in:
@@ -14,7 +14,6 @@ from strix.llm.memory_compressor import MemoryCompressor
|
||||
from strix.llm.utils import (
|
||||
_truncate_to_first_function,
|
||||
fix_incomplete_tool_call,
|
||||
get_litellm_model_name,
|
||||
parse_tool_invocations,
|
||||
)
|
||||
from strix.skills import load_skills
|
||||
@@ -64,7 +63,7 @@ class LLM:
|
||||
self.agent_name = agent_name
|
||||
self.agent_id: str | None = None
|
||||
self._total_stats = RequestStats()
|
||||
self.memory_compressor = MemoryCompressor(model_name=config.model_name)
|
||||
self.memory_compressor = MemoryCompressor(model_name=config.litellm_model)
|
||||
self.system_prompt = self._load_system_prompt(agent_name)
|
||||
|
||||
reasoning = Config.get("strix_reasoning_effort")
|
||||
@@ -190,16 +189,12 @@ class LLM:
|
||||
|
||||
return messages
|
||||
|
||||
def _get_litellm_model_name(self) -> str:
|
||||
model = self.config.model_name # Validated non-empty in LLMConfig.__init__
|
||||
return get_litellm_model_name(model) or model # type: ignore[return-value]
|
||||
|
||||
def _build_completion_args(self, messages: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
if not self._supports_vision():
|
||||
messages = self._strip_images(messages)
|
||||
|
||||
args: dict[str, Any] = {
|
||||
"model": self._get_litellm_model_name(),
|
||||
"model": self.config.litellm_model,
|
||||
"messages": messages,
|
||||
"timeout": self.config.timeout,
|
||||
"stream_options": {"include_usage": True},
|
||||
@@ -264,7 +259,7 @@ class LLM:
|
||||
if direct_cost is not None:
|
||||
return float(direct_cost)
|
||||
try:
|
||||
return completion_cost(response, model=self._get_litellm_model_name()) or 0.0
|
||||
return completion_cost(response, model=self.config.canonical_model) or 0.0
|
||||
except Exception: # noqa: BLE001
|
||||
return 0.0
|
||||
|
||||
@@ -287,13 +282,13 @@ class LLM:
|
||||
|
||||
def _supports_vision(self) -> bool:
|
||||
try:
|
||||
return bool(supports_vision(model=self._get_litellm_model_name()))
|
||||
return bool(supports_vision(model=self.config.canonical_model))
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
def _supports_reasoning(self) -> bool:
|
||||
try:
|
||||
return bool(supports_reasoning(model=self._get_litellm_model_name()))
|
||||
return bool(supports_reasoning(model=self.config.canonical_model))
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
@@ -314,7 +309,7 @@ class LLM:
|
||||
return result
|
||||
|
||||
def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
if not messages or not supports_prompt_caching(self._get_litellm_model_name()):
|
||||
if not messages or not supports_prompt_caching(self.config.canonical_model):
|
||||
return messages
|
||||
|
||||
result = list(messages)
|
||||
|
||||
Reference in New Issue
Block a user