diff --git a/strix/interface/main.py b/strix/interface/main.py index 5df4ac5..f049bbf 100644 --- a/strix/interface/main.py +++ b/strix/interface/main.py @@ -19,7 +19,7 @@ from rich.text import Text from strix.config import Config, apply_saved_config, save_current_config from strix.config.config import resolve_llm_config -from strix.llm.utils import get_litellm_model_name +from strix.llm.utils import resolve_strix_model apply_saved_config() @@ -210,6 +210,8 @@ async def warm_up_llm() -> None: try: model_name, api_key, api_base = resolve_llm_config() + litellm_model, _ = resolve_strix_model(model_name) + litellm_model = litellm_model or model_name test_messages = [ {"role": "system", "content": "You are a helpful assistant."}, @@ -218,7 +220,6 @@ async def warm_up_llm() -> None: llm_timeout = int(Config.get("llm_timeout") or "300") - litellm_model = get_litellm_model_name(model_name) or model_name completion_kwargs: dict[str, Any] = { "model": litellm_model, "messages": test_messages, diff --git a/strix/llm/config.py b/strix/llm/config.py index 1ee2ddd..a2217bb 100644 --- a/strix/llm/config.py +++ b/strix/llm/config.py @@ -1,5 +1,6 @@ from strix.config import Config from strix.config.config import resolve_llm_config +from strix.llm.utils import resolve_strix_model class LLMConfig: @@ -17,6 +18,10 @@ class LLMConfig: if not self.model_name: raise ValueError("STRIX_LLM environment variable must be set and not empty") + api_model, canonical = resolve_strix_model(self.model_name) + self.litellm_model: str = api_model or self.model_name + self.canonical_model: str = canonical or self.model_name + self.enable_prompt_caching = enable_prompt_caching self.skills = skills or [] diff --git a/strix/llm/dedupe.py b/strix/llm/dedupe.py index 33b3bc9..0ea6088 100644 --- a/strix/llm/dedupe.py +++ b/strix/llm/dedupe.py @@ -6,7 +6,7 @@ from typing import Any import litellm from strix.config.config import resolve_llm_config -from strix.llm.utils import get_litellm_model_name +from strix.llm.utils import resolve_strix_model logger = logging.getLogger(__name__) @@ -157,6 +157,8 @@ def check_duplicate( comparison_data = {"candidate": candidate_cleaned, "existing_reports": existing_cleaned} model_name, api_key, api_base = resolve_llm_config() + litellm_model, _ = resolve_strix_model(model_name) + litellm_model = litellm_model or model_name messages = [ {"role": "system", "content": DEDUPE_SYSTEM_PROMPT}, @@ -170,7 +172,6 @@ def check_duplicate( }, ] - litellm_model = get_litellm_model_name(model_name) or model_name completion_kwargs: dict[str, Any] = { "model": litellm_model, "messages": messages, diff --git a/strix/llm/llm.py b/strix/llm/llm.py index d6373ec..c38bbe1 100644 --- a/strix/llm/llm.py +++ b/strix/llm/llm.py @@ -14,7 +14,6 @@ from strix.llm.memory_compressor import MemoryCompressor from strix.llm.utils import ( _truncate_to_first_function, fix_incomplete_tool_call, - get_litellm_model_name, parse_tool_invocations, ) from strix.skills import load_skills @@ -64,7 +63,7 @@ class LLM: self.agent_name = agent_name self.agent_id: str | None = None self._total_stats = RequestStats() - self.memory_compressor = MemoryCompressor(model_name=config.model_name) + self.memory_compressor = MemoryCompressor(model_name=config.litellm_model) self.system_prompt = self._load_system_prompt(agent_name) reasoning = Config.get("strix_reasoning_effort") @@ -190,16 +189,12 @@ class LLM: return messages - def _get_litellm_model_name(self) -> str: - model = self.config.model_name # Validated non-empty in LLMConfig.__init__ - return get_litellm_model_name(model) or model # type: ignore[return-value] - def _build_completion_args(self, messages: list[dict[str, Any]]) -> dict[str, Any]: if not self._supports_vision(): messages = self._strip_images(messages) args: dict[str, Any] = { - "model": self._get_litellm_model_name(), + "model": self.config.litellm_model, "messages": messages, "timeout": self.config.timeout, "stream_options": {"include_usage": True}, @@ -264,7 +259,7 @@ class LLM: if direct_cost is not None: return float(direct_cost) try: - return completion_cost(response, model=self._get_litellm_model_name()) or 0.0 + return completion_cost(response, model=self.config.canonical_model) or 0.0 except Exception: # noqa: BLE001 return 0.0 @@ -287,13 +282,13 @@ class LLM: def _supports_vision(self) -> bool: try: - return bool(supports_vision(model=self._get_litellm_model_name())) + return bool(supports_vision(model=self.config.canonical_model)) except Exception: # noqa: BLE001 return False def _supports_reasoning(self) -> bool: try: - return bool(supports_reasoning(model=self._get_litellm_model_name())) + return bool(supports_reasoning(model=self.config.canonical_model)) except Exception: # noqa: BLE001 return False @@ -314,7 +309,7 @@ class LLM: return result def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: - if not messages or not supports_prompt_caching(self._get_litellm_model_name()): + if not messages or not supports_prompt_caching(self.config.canonical_model): return messages result = list(messages) diff --git a/strix/llm/memory_compressor.py b/strix/llm/memory_compressor.py index 4590972..28730e8 100644 --- a/strix/llm/memory_compressor.py +++ b/strix/llm/memory_compressor.py @@ -4,7 +4,6 @@ from typing import Any import litellm from strix.config.config import Config, resolve_llm_config -from strix.llm.utils import get_litellm_model_name logger = logging.getLogger(__name__) @@ -46,8 +45,7 @@ keeping the summary concise and to the point.""" def _count_tokens(text: str, model: str) -> int: try: - litellm_model = get_litellm_model_name(model) or model - count = litellm.token_counter(model=litellm_model, text=text) + count = litellm.token_counter(model=model, text=text) return int(count) except Exception: logger.exception("Failed to count tokens") @@ -109,9 +107,8 @@ def _summarize_messages( _, api_key, api_base = resolve_llm_config() try: - litellm_model = get_litellm_model_name(model) or model completion_args: dict[str, Any] = { - "model": litellm_model, + "model": model, "messages": [{"role": "user", "content": prompt}], "timeout": timeout, } @@ -161,7 +158,7 @@ class MemoryCompressor: ): self.max_images = max_images self.model_name = model_name or Config.get("strix_llm") - self.timeout = timeout or int(Config.get("strix_memory_compressor_timeout") or "30") + self.timeout = timeout or int(Config.get("strix_memory_compressor_timeout") or "120") if not self.model_name: raise ValueError("STRIX_LLM environment variable must be set and not empty") diff --git a/strix/llm/utils.py b/strix/llm/utils.py index c7d83a9..bef04ce 100644 --- a/strix/llm/utils.py +++ b/strix/llm/utils.py @@ -3,34 +3,38 @@ import re from typing import Any -STRIX_PROVIDER_PREFIXES: dict[str, str] = { - "claude-": "anthropic", - "gpt-": "openai", - "gemini-": "gemini", +STRIX_MODEL_MAP: dict[str, str] = { + "claude-sonnet-4.6": "anthropic/claude-sonnet-4-6", + "claude-opus-4.6": "anthropic/claude-opus-4-6", + "gpt-5.2": "openai/gpt-5.2", + "gpt-5.1": "openai/gpt-5.1", + "gpt-5": "openai/gpt-5", + "gpt-5.2-codex": "openai/gpt-5.2-codex", + "gpt-5.1-codex-max": "openai/gpt-5.1-codex-max", + "gpt-5.1-codex": "openai/gpt-5.1-codex", + "gpt-5-codex": "openai/gpt-5-codex", + "gemini-3-pro-preview": "gemini/gemini-3-pro-preview", + "gemini-3-flash-preview": "gemini/gemini-3-flash-preview", + "glm-5": "openrouter/z-ai/glm-5", + "glm-4.7": "openrouter/z-ai/glm-4.7", } -def get_litellm_model_name(model_name: str | None) -> str | None: - """Convert strix/ prefixed model to litellm-compatible provider/model format. +def resolve_strix_model(model_name: str | None) -> tuple[str | None, str | None]: + """Resolve a strix/ model into names for API calls and capability lookups. - Maps strix/ models to their corresponding litellm provider: - - strix/claude-* -> anthropic/claude-* - - strix/gpt-* -> openai/gpt-* - - strix/gemini-* -> gemini/gemini-* - - Other models -> openai/ (routed via Strix API) + Returns (api_model, canonical_model): + - api_model: openai/ for API calls (Strix API is OpenAI-compatible) + - canonical_model: actual provider model name for litellm capability lookups + Non-strix models return the same name for both. """ - if not model_name: - return model_name - if not model_name.startswith("strix/"): - return model_name + if not model_name or not model_name.startswith("strix/"): + return model_name, model_name base_model = model_name[6:] - - for prefix, provider in STRIX_PROVIDER_PREFIXES.items(): - if base_model.startswith(prefix): - return f"{provider}/{base_model}" - - return f"openai/{base_model}" + api_model = f"openai/{base_model}" + canonical_model = STRIX_MODEL_MAP.get(base_model, api_model) + return api_model, canonical_model def _truncate_to_first_function(content: str) -> str: