refactor: Centralize strix model resolution with separate API and capability names

- Replace fragile prefix matching with explicit STRIX_MODEL_MAP
- Add resolve_strix_model() returning (api_model, canonical_model)
- api_model (openai/ prefix) for API calls to OpenAI-compatible Strix API
- canonical_model (actual provider name) for litellm capability lookups
- Centralize resolution in LLMConfig instead of scattered call sites
This commit is contained in:
0xallam
2026-02-20 04:40:04 -08:00
parent d2c99ea4df
commit 3b3576b024
6 changed files with 45 additions and 42 deletions

View File

@@ -14,7 +14,6 @@ from strix.llm.memory_compressor import MemoryCompressor
from strix.llm.utils import (
_truncate_to_first_function,
fix_incomplete_tool_call,
get_litellm_model_name,
parse_tool_invocations,
)
from strix.skills import load_skills
@@ -64,7 +63,7 @@ class LLM:
self.agent_name = agent_name
self.agent_id: str | None = None
self._total_stats = RequestStats()
self.memory_compressor = MemoryCompressor(model_name=config.model_name)
self.memory_compressor = MemoryCompressor(model_name=config.litellm_model)
self.system_prompt = self._load_system_prompt(agent_name)
reasoning = Config.get("strix_reasoning_effort")
@@ -190,16 +189,12 @@ class LLM:
return messages
def _get_litellm_model_name(self) -> str:
model = self.config.model_name # Validated non-empty in LLMConfig.__init__
return get_litellm_model_name(model) or model # type: ignore[return-value]
def _build_completion_args(self, messages: list[dict[str, Any]]) -> dict[str, Any]:
if not self._supports_vision():
messages = self._strip_images(messages)
args: dict[str, Any] = {
"model": self._get_litellm_model_name(),
"model": self.config.litellm_model,
"messages": messages,
"timeout": self.config.timeout,
"stream_options": {"include_usage": True},
@@ -264,7 +259,7 @@ class LLM:
if direct_cost is not None:
return float(direct_cost)
try:
return completion_cost(response, model=self._get_litellm_model_name()) or 0.0
return completion_cost(response, model=self.config.canonical_model) or 0.0
except Exception: # noqa: BLE001
return 0.0
@@ -287,13 +282,13 @@ class LLM:
def _supports_vision(self) -> bool:
try:
return bool(supports_vision(model=self._get_litellm_model_name()))
return bool(supports_vision(model=self.config.canonical_model))
except Exception: # noqa: BLE001
return False
def _supports_reasoning(self) -> bool:
try:
return bool(supports_reasoning(model=self._get_litellm_model_name()))
return bool(supports_reasoning(model=self.config.canonical_model))
except Exception: # noqa: BLE001
return False
@@ -314,7 +309,7 @@ class LLM:
return result
def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
if not messages or not supports_prompt_caching(self._get_litellm_model_name()):
if not messages or not supports_prompt_caching(self.config.canonical_model):
return messages
result = list(messages)