refactor: Centralize strix model resolution with separate API and capability names
- Replace fragile prefix matching with explicit STRIX_MODEL_MAP - Add resolve_strix_model() returning (api_model, canonical_model) - api_model (openai/ prefix) for API calls to OpenAI-compatible Strix API - canonical_model (actual provider name) for litellm capability lookups - Centralize resolution in LLMConfig instead of scattered call sites
This commit is contained in:
@@ -19,7 +19,7 @@ from rich.text import Text
|
||||
|
||||
from strix.config import Config, apply_saved_config, save_current_config
|
||||
from strix.config.config import resolve_llm_config
|
||||
from strix.llm.utils import get_litellm_model_name
|
||||
from strix.llm.utils import resolve_strix_model
|
||||
|
||||
|
||||
apply_saved_config()
|
||||
@@ -210,6 +210,8 @@ async def warm_up_llm() -> None:
|
||||
|
||||
try:
|
||||
model_name, api_key, api_base = resolve_llm_config()
|
||||
litellm_model, _ = resolve_strix_model(model_name)
|
||||
litellm_model = litellm_model or model_name
|
||||
|
||||
test_messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
@@ -218,7 +220,6 @@ async def warm_up_llm() -> None:
|
||||
|
||||
llm_timeout = int(Config.get("llm_timeout") or "300")
|
||||
|
||||
litellm_model = get_litellm_model_name(model_name) or model_name
|
||||
completion_kwargs: dict[str, Any] = {
|
||||
"model": litellm_model,
|
||||
"messages": test_messages,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from strix.config import Config
|
||||
from strix.config.config import resolve_llm_config
|
||||
from strix.llm.utils import resolve_strix_model
|
||||
|
||||
|
||||
class LLMConfig:
|
||||
@@ -17,6 +18,10 @@ class LLMConfig:
|
||||
if not self.model_name:
|
||||
raise ValueError("STRIX_LLM environment variable must be set and not empty")
|
||||
|
||||
api_model, canonical = resolve_strix_model(self.model_name)
|
||||
self.litellm_model: str = api_model or self.model_name
|
||||
self.canonical_model: str = canonical or self.model_name
|
||||
|
||||
self.enable_prompt_caching = enable_prompt_caching
|
||||
self.skills = skills or []
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Any
|
||||
import litellm
|
||||
|
||||
from strix.config.config import resolve_llm_config
|
||||
from strix.llm.utils import get_litellm_model_name
|
||||
from strix.llm.utils import resolve_strix_model
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -157,6 +157,8 @@ def check_duplicate(
|
||||
comparison_data = {"candidate": candidate_cleaned, "existing_reports": existing_cleaned}
|
||||
|
||||
model_name, api_key, api_base = resolve_llm_config()
|
||||
litellm_model, _ = resolve_strix_model(model_name)
|
||||
litellm_model = litellm_model or model_name
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": DEDUPE_SYSTEM_PROMPT},
|
||||
@@ -170,7 +172,6 @@ def check_duplicate(
|
||||
},
|
||||
]
|
||||
|
||||
litellm_model = get_litellm_model_name(model_name) or model_name
|
||||
completion_kwargs: dict[str, Any] = {
|
||||
"model": litellm_model,
|
||||
"messages": messages,
|
||||
|
||||
@@ -14,7 +14,6 @@ from strix.llm.memory_compressor import MemoryCompressor
|
||||
from strix.llm.utils import (
|
||||
_truncate_to_first_function,
|
||||
fix_incomplete_tool_call,
|
||||
get_litellm_model_name,
|
||||
parse_tool_invocations,
|
||||
)
|
||||
from strix.skills import load_skills
|
||||
@@ -64,7 +63,7 @@ class LLM:
|
||||
self.agent_name = agent_name
|
||||
self.agent_id: str | None = None
|
||||
self._total_stats = RequestStats()
|
||||
self.memory_compressor = MemoryCompressor(model_name=config.model_name)
|
||||
self.memory_compressor = MemoryCompressor(model_name=config.litellm_model)
|
||||
self.system_prompt = self._load_system_prompt(agent_name)
|
||||
|
||||
reasoning = Config.get("strix_reasoning_effort")
|
||||
@@ -190,16 +189,12 @@ class LLM:
|
||||
|
||||
return messages
|
||||
|
||||
def _get_litellm_model_name(self) -> str:
|
||||
model = self.config.model_name # Validated non-empty in LLMConfig.__init__
|
||||
return get_litellm_model_name(model) or model # type: ignore[return-value]
|
||||
|
||||
def _build_completion_args(self, messages: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
if not self._supports_vision():
|
||||
messages = self._strip_images(messages)
|
||||
|
||||
args: dict[str, Any] = {
|
||||
"model": self._get_litellm_model_name(),
|
||||
"model": self.config.litellm_model,
|
||||
"messages": messages,
|
||||
"timeout": self.config.timeout,
|
||||
"stream_options": {"include_usage": True},
|
||||
@@ -264,7 +259,7 @@ class LLM:
|
||||
if direct_cost is not None:
|
||||
return float(direct_cost)
|
||||
try:
|
||||
return completion_cost(response, model=self._get_litellm_model_name()) or 0.0
|
||||
return completion_cost(response, model=self.config.canonical_model) or 0.0
|
||||
except Exception: # noqa: BLE001
|
||||
return 0.0
|
||||
|
||||
@@ -287,13 +282,13 @@ class LLM:
|
||||
|
||||
def _supports_vision(self) -> bool:
|
||||
try:
|
||||
return bool(supports_vision(model=self._get_litellm_model_name()))
|
||||
return bool(supports_vision(model=self.config.canonical_model))
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
def _supports_reasoning(self) -> bool:
|
||||
try:
|
||||
return bool(supports_reasoning(model=self._get_litellm_model_name()))
|
||||
return bool(supports_reasoning(model=self.config.canonical_model))
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
@@ -314,7 +309,7 @@ class LLM:
|
||||
return result
|
||||
|
||||
def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
if not messages or not supports_prompt_caching(self._get_litellm_model_name()):
|
||||
if not messages or not supports_prompt_caching(self.config.canonical_model):
|
||||
return messages
|
||||
|
||||
result = list(messages)
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Any
|
||||
import litellm
|
||||
|
||||
from strix.config.config import Config, resolve_llm_config
|
||||
from strix.llm.utils import get_litellm_model_name
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -46,8 +45,7 @@ keeping the summary concise and to the point."""
|
||||
|
||||
def _count_tokens(text: str, model: str) -> int:
|
||||
try:
|
||||
litellm_model = get_litellm_model_name(model) or model
|
||||
count = litellm.token_counter(model=litellm_model, text=text)
|
||||
count = litellm.token_counter(model=model, text=text)
|
||||
return int(count)
|
||||
except Exception:
|
||||
logger.exception("Failed to count tokens")
|
||||
@@ -109,9 +107,8 @@ def _summarize_messages(
|
||||
_, api_key, api_base = resolve_llm_config()
|
||||
|
||||
try:
|
||||
litellm_model = get_litellm_model_name(model) or model
|
||||
completion_args: dict[str, Any] = {
|
||||
"model": litellm_model,
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"timeout": timeout,
|
||||
}
|
||||
@@ -161,7 +158,7 @@ class MemoryCompressor:
|
||||
):
|
||||
self.max_images = max_images
|
||||
self.model_name = model_name or Config.get("strix_llm")
|
||||
self.timeout = timeout or int(Config.get("strix_memory_compressor_timeout") or "30")
|
||||
self.timeout = timeout or int(Config.get("strix_memory_compressor_timeout") or "120")
|
||||
|
||||
if not self.model_name:
|
||||
raise ValueError("STRIX_LLM environment variable must be set and not empty")
|
||||
|
||||
@@ -3,34 +3,38 @@ import re
|
||||
from typing import Any
|
||||
|
||||
|
||||
STRIX_PROVIDER_PREFIXES: dict[str, str] = {
|
||||
"claude-": "anthropic",
|
||||
"gpt-": "openai",
|
||||
"gemini-": "gemini",
|
||||
STRIX_MODEL_MAP: dict[str, str] = {
|
||||
"claude-sonnet-4.6": "anthropic/claude-sonnet-4-6",
|
||||
"claude-opus-4.6": "anthropic/claude-opus-4-6",
|
||||
"gpt-5.2": "openai/gpt-5.2",
|
||||
"gpt-5.1": "openai/gpt-5.1",
|
||||
"gpt-5": "openai/gpt-5",
|
||||
"gpt-5.2-codex": "openai/gpt-5.2-codex",
|
||||
"gpt-5.1-codex-max": "openai/gpt-5.1-codex-max",
|
||||
"gpt-5.1-codex": "openai/gpt-5.1-codex",
|
||||
"gpt-5-codex": "openai/gpt-5-codex",
|
||||
"gemini-3-pro-preview": "gemini/gemini-3-pro-preview",
|
||||
"gemini-3-flash-preview": "gemini/gemini-3-flash-preview",
|
||||
"glm-5": "openrouter/z-ai/glm-5",
|
||||
"glm-4.7": "openrouter/z-ai/glm-4.7",
|
||||
}
|
||||
|
||||
|
||||
def get_litellm_model_name(model_name: str | None) -> str | None:
|
||||
"""Convert strix/ prefixed model to litellm-compatible provider/model format.
|
||||
def resolve_strix_model(model_name: str | None) -> tuple[str | None, str | None]:
|
||||
"""Resolve a strix/ model into names for API calls and capability lookups.
|
||||
|
||||
Maps strix/ models to their corresponding litellm provider:
|
||||
- strix/claude-* -> anthropic/claude-*
|
||||
- strix/gpt-* -> openai/gpt-*
|
||||
- strix/gemini-* -> gemini/gemini-*
|
||||
- Other models -> openai/<model> (routed via Strix API)
|
||||
Returns (api_model, canonical_model):
|
||||
- api_model: openai/<base> for API calls (Strix API is OpenAI-compatible)
|
||||
- canonical_model: actual provider model name for litellm capability lookups
|
||||
Non-strix models return the same name for both.
|
||||
"""
|
||||
if not model_name:
|
||||
return model_name
|
||||
if not model_name.startswith("strix/"):
|
||||
return model_name
|
||||
if not model_name or not model_name.startswith("strix/"):
|
||||
return model_name, model_name
|
||||
|
||||
base_model = model_name[6:]
|
||||
|
||||
for prefix, provider in STRIX_PROVIDER_PREFIXES.items():
|
||||
if base_model.startswith(prefix):
|
||||
return f"{provider}/{base_model}"
|
||||
|
||||
return f"openai/{base_model}"
|
||||
api_model = f"openai/{base_model}"
|
||||
canonical_model = STRIX_MODEL_MAP.get(base_model, api_model)
|
||||
return api_model, canonical_model
|
||||
|
||||
|
||||
def _truncate_to_first_function(content: str) -> str:
|
||||
|
||||
Reference in New Issue
Block a user