diff --git a/strix/config/config.py b/strix/config/config.py index 53a3726..f8836b2 100644 --- a/strix/config/config.py +++ b/strix/config/config.py @@ -187,6 +187,9 @@ def resolve_llm_config() -> tuple[str | None, str | None, str | None]: Returns: tuple: (model_name, api_key, api_base) + - model_name: Original model name (strix/ prefix preserved for display) + - api_key: LLM API key + - api_base: API base URL (auto-set to STRIX_API_BASE for strix/ models) """ model = Config.get("strix_llm") if not model: @@ -195,10 +198,8 @@ def resolve_llm_config() -> tuple[str | None, str | None, str | None]: api_key = Config.get("llm_api_key") if model.startswith("strix/"): - model_name = "openai/" + model[6:] api_base: str | None = STRIX_API_BASE else: - model_name = model api_base = ( Config.get("llm_api_base") or Config.get("openai_api_base") @@ -206,4 +207,4 @@ def resolve_llm_config() -> tuple[str | None, str | None, str | None]: or Config.get("ollama_api_base") ) - return model_name, api_key, api_base + return model, api_key, api_base diff --git a/strix/interface/main.py b/strix/interface/main.py index e7ab6c0..f049bbf 100644 --- a/strix/interface/main.py +++ b/strix/interface/main.py @@ -18,6 +18,8 @@ from rich.panel import Panel from rich.text import Text from strix.config import Config, apply_saved_config, save_current_config +from strix.config.config import resolve_llm_config +from strix.llm.utils import resolve_strix_model apply_saved_config() @@ -204,12 +206,12 @@ def check_docker_installed() -> None: async def warm_up_llm() -> None: - from strix.config.config import resolve_llm_config - console = Console() try: model_name, api_key, api_base = resolve_llm_config() + litellm_model, _ = resolve_strix_model(model_name) + litellm_model = litellm_model or model_name test_messages = [ {"role": "system", "content": "You are a helpful assistant."}, @@ -219,7 +221,7 @@ async def warm_up_llm() -> None: llm_timeout = int(Config.get("llm_timeout") or "300") completion_kwargs: dict[str, Any] = { - "model": model_name, + "model": litellm_model, "messages": test_messages, "timeout": llm_timeout, } diff --git a/strix/llm/config.py b/strix/llm/config.py index 1ee2ddd..a2217bb 100644 --- a/strix/llm/config.py +++ b/strix/llm/config.py @@ -1,5 +1,6 @@ from strix.config import Config from strix.config.config import resolve_llm_config +from strix.llm.utils import resolve_strix_model class LLMConfig: @@ -17,6 +18,10 @@ class LLMConfig: if not self.model_name: raise ValueError("STRIX_LLM environment variable must be set and not empty") + api_model, canonical = resolve_strix_model(self.model_name) + self.litellm_model: str = api_model or self.model_name + self.canonical_model: str = canonical or self.model_name + self.enable_prompt_caching = enable_prompt_caching self.skills = skills or [] diff --git a/strix/llm/dedupe.py b/strix/llm/dedupe.py index ec15192..0ea6088 100644 --- a/strix/llm/dedupe.py +++ b/strix/llm/dedupe.py @@ -6,6 +6,7 @@ from typing import Any import litellm from strix.config.config import resolve_llm_config +from strix.llm.utils import resolve_strix_model logger = logging.getLogger(__name__) @@ -156,6 +157,8 @@ def check_duplicate( comparison_data = {"candidate": candidate_cleaned, "existing_reports": existing_cleaned} model_name, api_key, api_base = resolve_llm_config() + litellm_model, _ = resolve_strix_model(model_name) + litellm_model = litellm_model or model_name messages = [ {"role": "system", "content": DEDUPE_SYSTEM_PROMPT}, @@ -170,7 +173,7 @@ def check_duplicate( ] completion_kwargs: dict[str, Any] = { - "model": model_name, + "model": litellm_model, "messages": messages, "timeout": 120, } diff --git a/strix/llm/llm.py b/strix/llm/llm.py index 8133fe3..50501aa 100644 --- a/strix/llm/llm.py +++ b/strix/llm/llm.py @@ -63,7 +63,7 @@ class LLM: self.agent_name = agent_name self.agent_id: str | None = None self._total_stats = RequestStats() - self.memory_compressor = MemoryCompressor(model_name=config.model_name) + self.memory_compressor = MemoryCompressor(model_name=config.litellm_model) self.system_prompt = self._load_system_prompt(agent_name) reasoning = Config.get("strix_reasoning_effort") @@ -194,7 +194,7 @@ class LLM: messages = self._strip_images(messages) args: dict[str, Any] = { - "model": self.config.model_name, + "model": self.config.litellm_model, "messages": messages, "timeout": self.config.timeout, "stream_options": {"include_usage": True}, @@ -229,8 +229,8 @@ class LLM: def _update_usage_stats(self, response: Any) -> None: try: if hasattr(response, "usage") and response.usage: - input_tokens = getattr(response.usage, "prompt_tokens", 0) - output_tokens = getattr(response.usage, "completion_tokens", 0) + input_tokens = getattr(response.usage, "prompt_tokens", 0) or 0 + output_tokens = getattr(response.usage, "completion_tokens", 0) or 0 cached_tokens = 0 if hasattr(response.usage, "prompt_tokens_details"): @@ -238,14 +238,11 @@ class LLM: if hasattr(prompt_details, "cached_tokens"): cached_tokens = prompt_details.cached_tokens or 0 + cost = self._extract_cost(response) else: input_tokens = 0 output_tokens = 0 cached_tokens = 0 - - try: - cost = completion_cost(response) or 0.0 - except Exception: # noqa: BLE001 cost = 0.0 self._total_stats.input_tokens += input_tokens @@ -256,6 +253,18 @@ class LLM: except Exception: # noqa: BLE001, S110 # nosec B110 pass + def _extract_cost(self, response: Any) -> float: + if hasattr(response, "usage") and response.usage: + direct_cost = getattr(response.usage, "cost", None) + if direct_cost is not None: + return float(direct_cost) + try: + if hasattr(response, "_hidden_params"): + response._hidden_params.pop("custom_llm_provider", None) + return completion_cost(response, model=self.config.canonical_model) or 0.0 + except Exception: # noqa: BLE001 + return 0.0 + def _should_retry(self, e: Exception) -> bool: code = getattr(e, "status_code", None) or getattr( getattr(e, "response", None), "status_code", None @@ -275,13 +284,13 @@ class LLM: def _supports_vision(self) -> bool: try: - return bool(supports_vision(model=self.config.model_name)) + return bool(supports_vision(model=self.config.canonical_model)) except Exception: # noqa: BLE001 return False def _supports_reasoning(self) -> bool: try: - return bool(supports_reasoning(model=self.config.model_name)) + return bool(supports_reasoning(model=self.config.canonical_model)) except Exception: # noqa: BLE001 return False @@ -302,7 +311,7 @@ class LLM: return result def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: - if not messages or not supports_prompt_caching(self.config.model_name): + if not messages or not supports_prompt_caching(self.config.canonical_model): return messages result = list(messages) diff --git a/strix/llm/memory_compressor.py b/strix/llm/memory_compressor.py index f5981f6..28730e8 100644 --- a/strix/llm/memory_compressor.py +++ b/strix/llm/memory_compressor.py @@ -158,7 +158,7 @@ class MemoryCompressor: ): self.max_images = max_images self.model_name = model_name or Config.get("strix_llm") - self.timeout = timeout or int(Config.get("strix_memory_compressor_timeout") or "30") + self.timeout = timeout or int(Config.get("strix_memory_compressor_timeout") or "120") if not self.model_name: raise ValueError("STRIX_LLM environment variable must be set and not empty") diff --git a/strix/llm/utils.py b/strix/llm/utils.py index 81431f0..bef04ce 100644 --- a/strix/llm/utils.py +++ b/strix/llm/utils.py @@ -3,6 +3,40 @@ import re from typing import Any +STRIX_MODEL_MAP: dict[str, str] = { + "claude-sonnet-4.6": "anthropic/claude-sonnet-4-6", + "claude-opus-4.6": "anthropic/claude-opus-4-6", + "gpt-5.2": "openai/gpt-5.2", + "gpt-5.1": "openai/gpt-5.1", + "gpt-5": "openai/gpt-5", + "gpt-5.2-codex": "openai/gpt-5.2-codex", + "gpt-5.1-codex-max": "openai/gpt-5.1-codex-max", + "gpt-5.1-codex": "openai/gpt-5.1-codex", + "gpt-5-codex": "openai/gpt-5-codex", + "gemini-3-pro-preview": "gemini/gemini-3-pro-preview", + "gemini-3-flash-preview": "gemini/gemini-3-flash-preview", + "glm-5": "openrouter/z-ai/glm-5", + "glm-4.7": "openrouter/z-ai/glm-4.7", +} + + +def resolve_strix_model(model_name: str | None) -> tuple[str | None, str | None]: + """Resolve a strix/ model into names for API calls and capability lookups. + + Returns (api_model, canonical_model): + - api_model: openai/ for API calls (Strix API is OpenAI-compatible) + - canonical_model: actual provider model name for litellm capability lookups + Non-strix models return the same name for both. + """ + if not model_name or not model_name.startswith("strix/"): + return model_name, model_name + + base_model = model_name[6:] + api_model = f"openai/{base_model}" + canonical_model = STRIX_MODEL_MAP.get(base_model, api_model) + return api_model, canonical_model + + def _truncate_to_first_function(content: str) -> str: if not content: return content