From 06ae3d3860c16ada8167ea42b7a8a07f9c1542d4 Mon Sep 17 00:00:00 2001 From: octovimmer Date: Thu, 19 Feb 2026 17:25:10 -0800 Subject: [PATCH 1/3] fix: linting errors --- strix/interface/main.py | 5 +++- strix/llm/dedupe.py | 5 +++- strix/llm/llm.py | 40 +++++++++++++++++++++---------- strix/llm/memory_compressor.py | 8 +++++-- strix/llm/utils.py | 44 ++++++++++++++++++++++++++++++++++ 5 files changed, 86 insertions(+), 16 deletions(-) diff --git a/strix/interface/main.py b/strix/interface/main.py index edd7dd5..58d52a0 100644 --- a/strix/interface/main.py +++ b/strix/interface/main.py @@ -18,6 +18,7 @@ from rich.panel import Panel from rich.text import Text from strix.config import Config, apply_saved_config, save_current_config +from strix.llm.utils import get_litellm_model_name, get_strix_api_base apply_saved_config() @@ -208,6 +209,7 @@ async def warm_up_llm() -> None: or Config.get("openai_api_base") or Config.get("litellm_base_url") or Config.get("ollama_api_base") + or get_strix_api_base(model_name) ) test_messages = [ @@ -217,8 +219,9 @@ async def warm_up_llm() -> None: llm_timeout = int(Config.get("llm_timeout") or "300") + litellm_model = get_litellm_model_name(model_name) or model_name completion_kwargs: dict[str, Any] = { - "model": model_name, + "model": litellm_model, "messages": test_messages, "timeout": llm_timeout, } diff --git a/strix/llm/dedupe.py b/strix/llm/dedupe.py index 9edd6b7..f8cdb08 100644 --- a/strix/llm/dedupe.py +++ b/strix/llm/dedupe.py @@ -6,6 +6,7 @@ from typing import Any import litellm from strix.config import Config +from strix.llm.utils import get_litellm_model_name, get_strix_api_base logger = logging.getLogger(__name__) @@ -162,6 +163,7 @@ def check_duplicate( or Config.get("openai_api_base") or Config.get("litellm_base_url") or Config.get("ollama_api_base") + or get_strix_api_base(model_name) ) messages = [ @@ -176,8 +178,9 @@ def check_duplicate( }, ] + litellm_model = get_litellm_model_name(model_name) or model_name completion_kwargs: dict[str, Any] = { - "model": model_name, + "model": litellm_model, "messages": messages, "timeout": 120, } diff --git a/strix/llm/llm.py b/strix/llm/llm.py index 311de35..d1b6370 100644 --- a/strix/llm/llm.py +++ b/strix/llm/llm.py @@ -14,6 +14,8 @@ from strix.llm.memory_compressor import MemoryCompressor from strix.llm.utils import ( _truncate_to_first_function, fix_incomplete_tool_call, + get_litellm_model_name, + get_strix_api_base, parse_tool_invocations, ) from strix.skills import load_skills @@ -189,12 +191,16 @@ class LLM: return messages + def _get_litellm_model_name(self) -> str: + model = self.config.model_name # Validated non-empty in LLMConfig.__init__ + return get_litellm_model_name(model) or model # type: ignore[return-value] + def _build_completion_args(self, messages: list[dict[str, Any]]) -> dict[str, Any]: if not self._supports_vision(): messages = self._strip_images(messages) args: dict[str, Any] = { - "model": self.config.model_name, + "model": self._get_litellm_model_name(), "messages": messages, "timeout": self.config.timeout, "stream_options": {"include_usage": True}, @@ -202,12 +208,15 @@ class LLM: if api_key := Config.get("llm_api_key"): args["api_key"] = api_key - if api_base := ( + + api_base = ( Config.get("llm_api_base") or Config.get("openai_api_base") or Config.get("litellm_base_url") or Config.get("ollama_api_base") - ): + or get_strix_api_base(self.config.model_name) + ) + if api_base: args["api_base"] = api_base if self._supports_reasoning(): args["reasoning_effort"] = self._reasoning_effort @@ -234,8 +243,8 @@ class LLM: def _update_usage_stats(self, response: Any) -> None: try: if hasattr(response, "usage") and response.usage: - input_tokens = getattr(response.usage, "prompt_tokens", 0) - output_tokens = getattr(response.usage, "completion_tokens", 0) + input_tokens = getattr(response.usage, "prompt_tokens", 0) or 0 + output_tokens = getattr(response.usage, "completion_tokens", 0) or 0 cached_tokens = 0 if hasattr(response.usage, "prompt_tokens_details"): @@ -243,14 +252,11 @@ class LLM: if hasattr(prompt_details, "cached_tokens"): cached_tokens = prompt_details.cached_tokens or 0 + cost = self._extract_cost(response) else: input_tokens = 0 output_tokens = 0 cached_tokens = 0 - - try: - cost = completion_cost(response) or 0.0 - except Exception: # noqa: BLE001 cost = 0.0 self._total_stats.input_tokens += input_tokens @@ -261,6 +267,16 @@ class LLM: except Exception: # noqa: BLE001, S110 # nosec B110 pass + def _extract_cost(self, response: Any) -> float: + if hasattr(response, "usage") and response.usage: + direct_cost = getattr(response.usage, "cost", None) + if direct_cost is not None: + return float(direct_cost) + try: + return completion_cost(response, model=self._get_litellm_model_name()) or 0.0 + except Exception: # noqa: BLE001 + return 0.0 + def _should_retry(self, e: Exception) -> bool: code = getattr(e, "status_code", None) or getattr( getattr(e, "response", None), "status_code", None @@ -280,13 +296,13 @@ class LLM: def _supports_vision(self) -> bool: try: - return bool(supports_vision(model=self.config.model_name)) + return bool(supports_vision(model=self._get_litellm_model_name())) except Exception: # noqa: BLE001 return False def _supports_reasoning(self) -> bool: try: - return bool(supports_reasoning(model=self.config.model_name)) + return bool(supports_reasoning(model=self._get_litellm_model_name())) except Exception: # noqa: BLE001 return False @@ -307,7 +323,7 @@ class LLM: return result def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: - if not messages or not supports_prompt_caching(self.config.model_name): + if not messages or not supports_prompt_caching(self._get_litellm_model_name()): return messages result = list(messages) diff --git a/strix/llm/memory_compressor.py b/strix/llm/memory_compressor.py index ef0b9ab..e46b331 100644 --- a/strix/llm/memory_compressor.py +++ b/strix/llm/memory_compressor.py @@ -4,6 +4,7 @@ from typing import Any import litellm from strix.config import Config +from strix.llm.utils import get_litellm_model_name, get_strix_api_base logger = logging.getLogger(__name__) @@ -45,7 +46,8 @@ keeping the summary concise and to the point.""" def _count_tokens(text: str, model: str) -> int: try: - count = litellm.token_counter(model=model, text=text) + litellm_model = get_litellm_model_name(model) or model + count = litellm.token_counter(model=litellm_model, text=text) return int(count) except Exception: logger.exception("Failed to count tokens") @@ -110,11 +112,13 @@ def _summarize_messages( or Config.get("openai_api_base") or Config.get("litellm_base_url") or Config.get("ollama_api_base") + or get_strix_api_base(model) ) try: + litellm_model = get_litellm_model_name(model) or model completion_args: dict[str, Any] = { - "model": model, + "model": litellm_model, "messages": [{"role": "user", "content": prompt}], "timeout": timeout, } diff --git a/strix/llm/utils.py b/strix/llm/utils.py index 81431f0..8abe6ba 100644 --- a/strix/llm/utils.py +++ b/strix/llm/utils.py @@ -3,6 +3,50 @@ import re from typing import Any +STRIX_API_BASE = "https://models.strix.ai/api/v1" + +STRIX_PROVIDER_PREFIXES: dict[str, str] = { + "claude-": "anthropic", + "gpt-": "openai", + "gemini-": "gemini", +} + + +def is_strix_model(model_name: str | None) -> bool: + """Check if model uses strix/ prefix.""" + return bool(model_name and model_name.startswith("strix/")) + + +def get_strix_api_base(model_name: str | None) -> str | None: + """Return Strix API base URL if using strix/ model, None otherwise.""" + if is_strix_model(model_name): + return STRIX_API_BASE + return None + + +def get_litellm_model_name(model_name: str | None) -> str | None: + """Convert strix/ prefixed model to litellm-compatible provider/model format. + + Maps strix/ models to their corresponding litellm provider: + - strix/claude-* -> anthropic/claude-* + - strix/gpt-* -> openai/gpt-* + - strix/gemini-* -> gemini/gemini-* + - Other models -> openai/ (routed via Strix API) + """ + if not model_name: + return model_name + if not model_name.startswith("strix/"): + return model_name + + base_model = model_name[6:] + + for prefix, provider in STRIX_PROVIDER_PREFIXES.items(): + if base_model.startswith(prefix): + return f"{provider}/{base_model}" + + return f"openai/{base_model}" + + def _truncate_to_first_function(content: str) -> str: if not content: return content From 3b3576b024a8bd4f88ce36877a517cb4c3b6c944 Mon Sep 17 00:00:00 2001 From: 0xallam Date: Fri, 20 Feb 2026 04:40:04 -0800 Subject: [PATCH 2/3] refactor: Centralize strix model resolution with separate API and capability names - Replace fragile prefix matching with explicit STRIX_MODEL_MAP - Add resolve_strix_model() returning (api_model, canonical_model) - api_model (openai/ prefix) for API calls to OpenAI-compatible Strix API - canonical_model (actual provider name) for litellm capability lookups - Centralize resolution in LLMConfig instead of scattered call sites --- strix/interface/main.py | 5 ++-- strix/llm/config.py | 5 ++++ strix/llm/dedupe.py | 5 ++-- strix/llm/llm.py | 17 +++++-------- strix/llm/memory_compressor.py | 9 +++---- strix/llm/utils.py | 46 ++++++++++++++++++---------------- 6 files changed, 45 insertions(+), 42 deletions(-) diff --git a/strix/interface/main.py b/strix/interface/main.py index 5df4ac5..f049bbf 100644 --- a/strix/interface/main.py +++ b/strix/interface/main.py @@ -19,7 +19,7 @@ from rich.text import Text from strix.config import Config, apply_saved_config, save_current_config from strix.config.config import resolve_llm_config -from strix.llm.utils import get_litellm_model_name +from strix.llm.utils import resolve_strix_model apply_saved_config() @@ -210,6 +210,8 @@ async def warm_up_llm() -> None: try: model_name, api_key, api_base = resolve_llm_config() + litellm_model, _ = resolve_strix_model(model_name) + litellm_model = litellm_model or model_name test_messages = [ {"role": "system", "content": "You are a helpful assistant."}, @@ -218,7 +220,6 @@ async def warm_up_llm() -> None: llm_timeout = int(Config.get("llm_timeout") or "300") - litellm_model = get_litellm_model_name(model_name) or model_name completion_kwargs: dict[str, Any] = { "model": litellm_model, "messages": test_messages, diff --git a/strix/llm/config.py b/strix/llm/config.py index 1ee2ddd..a2217bb 100644 --- a/strix/llm/config.py +++ b/strix/llm/config.py @@ -1,5 +1,6 @@ from strix.config import Config from strix.config.config import resolve_llm_config +from strix.llm.utils import resolve_strix_model class LLMConfig: @@ -17,6 +18,10 @@ class LLMConfig: if not self.model_name: raise ValueError("STRIX_LLM environment variable must be set and not empty") + api_model, canonical = resolve_strix_model(self.model_name) + self.litellm_model: str = api_model or self.model_name + self.canonical_model: str = canonical or self.model_name + self.enable_prompt_caching = enable_prompt_caching self.skills = skills or [] diff --git a/strix/llm/dedupe.py b/strix/llm/dedupe.py index 33b3bc9..0ea6088 100644 --- a/strix/llm/dedupe.py +++ b/strix/llm/dedupe.py @@ -6,7 +6,7 @@ from typing import Any import litellm from strix.config.config import resolve_llm_config -from strix.llm.utils import get_litellm_model_name +from strix.llm.utils import resolve_strix_model logger = logging.getLogger(__name__) @@ -157,6 +157,8 @@ def check_duplicate( comparison_data = {"candidate": candidate_cleaned, "existing_reports": existing_cleaned} model_name, api_key, api_base = resolve_llm_config() + litellm_model, _ = resolve_strix_model(model_name) + litellm_model = litellm_model or model_name messages = [ {"role": "system", "content": DEDUPE_SYSTEM_PROMPT}, @@ -170,7 +172,6 @@ def check_duplicate( }, ] - litellm_model = get_litellm_model_name(model_name) or model_name completion_kwargs: dict[str, Any] = { "model": litellm_model, "messages": messages, diff --git a/strix/llm/llm.py b/strix/llm/llm.py index d6373ec..c38bbe1 100644 --- a/strix/llm/llm.py +++ b/strix/llm/llm.py @@ -14,7 +14,6 @@ from strix.llm.memory_compressor import MemoryCompressor from strix.llm.utils import ( _truncate_to_first_function, fix_incomplete_tool_call, - get_litellm_model_name, parse_tool_invocations, ) from strix.skills import load_skills @@ -64,7 +63,7 @@ class LLM: self.agent_name = agent_name self.agent_id: str | None = None self._total_stats = RequestStats() - self.memory_compressor = MemoryCompressor(model_name=config.model_name) + self.memory_compressor = MemoryCompressor(model_name=config.litellm_model) self.system_prompt = self._load_system_prompt(agent_name) reasoning = Config.get("strix_reasoning_effort") @@ -190,16 +189,12 @@ class LLM: return messages - def _get_litellm_model_name(self) -> str: - model = self.config.model_name # Validated non-empty in LLMConfig.__init__ - return get_litellm_model_name(model) or model # type: ignore[return-value] - def _build_completion_args(self, messages: list[dict[str, Any]]) -> dict[str, Any]: if not self._supports_vision(): messages = self._strip_images(messages) args: dict[str, Any] = { - "model": self._get_litellm_model_name(), + "model": self.config.litellm_model, "messages": messages, "timeout": self.config.timeout, "stream_options": {"include_usage": True}, @@ -264,7 +259,7 @@ class LLM: if direct_cost is not None: return float(direct_cost) try: - return completion_cost(response, model=self._get_litellm_model_name()) or 0.0 + return completion_cost(response, model=self.config.canonical_model) or 0.0 except Exception: # noqa: BLE001 return 0.0 @@ -287,13 +282,13 @@ class LLM: def _supports_vision(self) -> bool: try: - return bool(supports_vision(model=self._get_litellm_model_name())) + return bool(supports_vision(model=self.config.canonical_model)) except Exception: # noqa: BLE001 return False def _supports_reasoning(self) -> bool: try: - return bool(supports_reasoning(model=self._get_litellm_model_name())) + return bool(supports_reasoning(model=self.config.canonical_model)) except Exception: # noqa: BLE001 return False @@ -314,7 +309,7 @@ class LLM: return result def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: - if not messages or not supports_prompt_caching(self._get_litellm_model_name()): + if not messages or not supports_prompt_caching(self.config.canonical_model): return messages result = list(messages) diff --git a/strix/llm/memory_compressor.py b/strix/llm/memory_compressor.py index 4590972..28730e8 100644 --- a/strix/llm/memory_compressor.py +++ b/strix/llm/memory_compressor.py @@ -4,7 +4,6 @@ from typing import Any import litellm from strix.config.config import Config, resolve_llm_config -from strix.llm.utils import get_litellm_model_name logger = logging.getLogger(__name__) @@ -46,8 +45,7 @@ keeping the summary concise and to the point.""" def _count_tokens(text: str, model: str) -> int: try: - litellm_model = get_litellm_model_name(model) or model - count = litellm.token_counter(model=litellm_model, text=text) + count = litellm.token_counter(model=model, text=text) return int(count) except Exception: logger.exception("Failed to count tokens") @@ -109,9 +107,8 @@ def _summarize_messages( _, api_key, api_base = resolve_llm_config() try: - litellm_model = get_litellm_model_name(model) or model completion_args: dict[str, Any] = { - "model": litellm_model, + "model": model, "messages": [{"role": "user", "content": prompt}], "timeout": timeout, } @@ -161,7 +158,7 @@ class MemoryCompressor: ): self.max_images = max_images self.model_name = model_name or Config.get("strix_llm") - self.timeout = timeout or int(Config.get("strix_memory_compressor_timeout") or "30") + self.timeout = timeout or int(Config.get("strix_memory_compressor_timeout") or "120") if not self.model_name: raise ValueError("STRIX_LLM environment variable must be set and not empty") diff --git a/strix/llm/utils.py b/strix/llm/utils.py index c7d83a9..bef04ce 100644 --- a/strix/llm/utils.py +++ b/strix/llm/utils.py @@ -3,34 +3,38 @@ import re from typing import Any -STRIX_PROVIDER_PREFIXES: dict[str, str] = { - "claude-": "anthropic", - "gpt-": "openai", - "gemini-": "gemini", +STRIX_MODEL_MAP: dict[str, str] = { + "claude-sonnet-4.6": "anthropic/claude-sonnet-4-6", + "claude-opus-4.6": "anthropic/claude-opus-4-6", + "gpt-5.2": "openai/gpt-5.2", + "gpt-5.1": "openai/gpt-5.1", + "gpt-5": "openai/gpt-5", + "gpt-5.2-codex": "openai/gpt-5.2-codex", + "gpt-5.1-codex-max": "openai/gpt-5.1-codex-max", + "gpt-5.1-codex": "openai/gpt-5.1-codex", + "gpt-5-codex": "openai/gpt-5-codex", + "gemini-3-pro-preview": "gemini/gemini-3-pro-preview", + "gemini-3-flash-preview": "gemini/gemini-3-flash-preview", + "glm-5": "openrouter/z-ai/glm-5", + "glm-4.7": "openrouter/z-ai/glm-4.7", } -def get_litellm_model_name(model_name: str | None) -> str | None: - """Convert strix/ prefixed model to litellm-compatible provider/model format. +def resolve_strix_model(model_name: str | None) -> tuple[str | None, str | None]: + """Resolve a strix/ model into names for API calls and capability lookups. - Maps strix/ models to their corresponding litellm provider: - - strix/claude-* -> anthropic/claude-* - - strix/gpt-* -> openai/gpt-* - - strix/gemini-* -> gemini/gemini-* - - Other models -> openai/ (routed via Strix API) + Returns (api_model, canonical_model): + - api_model: openai/ for API calls (Strix API is OpenAI-compatible) + - canonical_model: actual provider model name for litellm capability lookups + Non-strix models return the same name for both. """ - if not model_name: - return model_name - if not model_name.startswith("strix/"): - return model_name + if not model_name or not model_name.startswith("strix/"): + return model_name, model_name base_model = model_name[6:] - - for prefix, provider in STRIX_PROVIDER_PREFIXES.items(): - if base_model.startswith(prefix): - return f"{provider}/{base_model}" - - return f"openai/{base_model}" + api_model = f"openai/{base_model}" + canonical_model = STRIX_MODEL_MAP.get(base_model, api_model) + return api_model, canonical_model def _truncate_to_first_function(content: str) -> str: From bf8020fafb9d530794b5934375024d3b384884be Mon Sep 17 00:00:00 2001 From: 0xallam Date: Fri, 20 Feb 2026 06:52:27 -0800 Subject: [PATCH 3/3] fix: Strip custom_llm_provider before cost lookup for proxied models --- strix/llm/llm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/strix/llm/llm.py b/strix/llm/llm.py index c38bbe1..50501aa 100644 --- a/strix/llm/llm.py +++ b/strix/llm/llm.py @@ -259,6 +259,8 @@ class LLM: if direct_cost is not None: return float(direct_cost) try: + if hasattr(response, "_hidden_params"): + response._hidden_params.pop("custom_llm_provider", None) return completion_cost(response, model=self.config.canonical_model) or 0.0 except Exception: # noqa: BLE001 return 0.0