From 06ae3d3860c16ada8167ea42b7a8a07f9c1542d4 Mon Sep 17 00:00:00 2001 From: octovimmer Date: Thu, 19 Feb 2026 17:25:10 -0800 Subject: [PATCH] fix: linting errors --- strix/interface/main.py | 5 +++- strix/llm/dedupe.py | 5 +++- strix/llm/llm.py | 40 +++++++++++++++++++++---------- strix/llm/memory_compressor.py | 8 +++++-- strix/llm/utils.py | 44 ++++++++++++++++++++++++++++++++++ 5 files changed, 86 insertions(+), 16 deletions(-) diff --git a/strix/interface/main.py b/strix/interface/main.py index edd7dd5..58d52a0 100644 --- a/strix/interface/main.py +++ b/strix/interface/main.py @@ -18,6 +18,7 @@ from rich.panel import Panel from rich.text import Text from strix.config import Config, apply_saved_config, save_current_config +from strix.llm.utils import get_litellm_model_name, get_strix_api_base apply_saved_config() @@ -208,6 +209,7 @@ async def warm_up_llm() -> None: or Config.get("openai_api_base") or Config.get("litellm_base_url") or Config.get("ollama_api_base") + or get_strix_api_base(model_name) ) test_messages = [ @@ -217,8 +219,9 @@ async def warm_up_llm() -> None: llm_timeout = int(Config.get("llm_timeout") or "300") + litellm_model = get_litellm_model_name(model_name) or model_name completion_kwargs: dict[str, Any] = { - "model": model_name, + "model": litellm_model, "messages": test_messages, "timeout": llm_timeout, } diff --git a/strix/llm/dedupe.py b/strix/llm/dedupe.py index 9edd6b7..f8cdb08 100644 --- a/strix/llm/dedupe.py +++ b/strix/llm/dedupe.py @@ -6,6 +6,7 @@ from typing import Any import litellm from strix.config import Config +from strix.llm.utils import get_litellm_model_name, get_strix_api_base logger = logging.getLogger(__name__) @@ -162,6 +163,7 @@ def check_duplicate( or Config.get("openai_api_base") or Config.get("litellm_base_url") or Config.get("ollama_api_base") + or get_strix_api_base(model_name) ) messages = [ @@ -176,8 +178,9 @@ def check_duplicate( }, ] + litellm_model = get_litellm_model_name(model_name) or model_name completion_kwargs: dict[str, Any] = { - "model": model_name, + "model": litellm_model, "messages": messages, "timeout": 120, } diff --git a/strix/llm/llm.py b/strix/llm/llm.py index 311de35..d1b6370 100644 --- a/strix/llm/llm.py +++ b/strix/llm/llm.py @@ -14,6 +14,8 @@ from strix.llm.memory_compressor import MemoryCompressor from strix.llm.utils import ( _truncate_to_first_function, fix_incomplete_tool_call, + get_litellm_model_name, + get_strix_api_base, parse_tool_invocations, ) from strix.skills import load_skills @@ -189,12 +191,16 @@ class LLM: return messages + def _get_litellm_model_name(self) -> str: + model = self.config.model_name # Validated non-empty in LLMConfig.__init__ + return get_litellm_model_name(model) or model # type: ignore[return-value] + def _build_completion_args(self, messages: list[dict[str, Any]]) -> dict[str, Any]: if not self._supports_vision(): messages = self._strip_images(messages) args: dict[str, Any] = { - "model": self.config.model_name, + "model": self._get_litellm_model_name(), "messages": messages, "timeout": self.config.timeout, "stream_options": {"include_usage": True}, @@ -202,12 +208,15 @@ class LLM: if api_key := Config.get("llm_api_key"): args["api_key"] = api_key - if api_base := ( + + api_base = ( Config.get("llm_api_base") or Config.get("openai_api_base") or Config.get("litellm_base_url") or Config.get("ollama_api_base") - ): + or get_strix_api_base(self.config.model_name) + ) + if api_base: args["api_base"] = api_base if self._supports_reasoning(): args["reasoning_effort"] = self._reasoning_effort @@ -234,8 +243,8 @@ class LLM: def _update_usage_stats(self, response: Any) -> None: try: if hasattr(response, "usage") and response.usage: - input_tokens = getattr(response.usage, "prompt_tokens", 0) - output_tokens = getattr(response.usage, "completion_tokens", 0) + input_tokens = getattr(response.usage, "prompt_tokens", 0) or 0 + output_tokens = getattr(response.usage, "completion_tokens", 0) or 0 cached_tokens = 0 if hasattr(response.usage, "prompt_tokens_details"): @@ -243,14 +252,11 @@ class LLM: if hasattr(prompt_details, "cached_tokens"): cached_tokens = prompt_details.cached_tokens or 0 + cost = self._extract_cost(response) else: input_tokens = 0 output_tokens = 0 cached_tokens = 0 - - try: - cost = completion_cost(response) or 0.0 - except Exception: # noqa: BLE001 cost = 0.0 self._total_stats.input_tokens += input_tokens @@ -261,6 +267,16 @@ class LLM: except Exception: # noqa: BLE001, S110 # nosec B110 pass + def _extract_cost(self, response: Any) -> float: + if hasattr(response, "usage") and response.usage: + direct_cost = getattr(response.usage, "cost", None) + if direct_cost is not None: + return float(direct_cost) + try: + return completion_cost(response, model=self._get_litellm_model_name()) or 0.0 + except Exception: # noqa: BLE001 + return 0.0 + def _should_retry(self, e: Exception) -> bool: code = getattr(e, "status_code", None) or getattr( getattr(e, "response", None), "status_code", None @@ -280,13 +296,13 @@ class LLM: def _supports_vision(self) -> bool: try: - return bool(supports_vision(model=self.config.model_name)) + return bool(supports_vision(model=self._get_litellm_model_name())) except Exception: # noqa: BLE001 return False def _supports_reasoning(self) -> bool: try: - return bool(supports_reasoning(model=self.config.model_name)) + return bool(supports_reasoning(model=self._get_litellm_model_name())) except Exception: # noqa: BLE001 return False @@ -307,7 +323,7 @@ class LLM: return result def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: - if not messages or not supports_prompt_caching(self.config.model_name): + if not messages or not supports_prompt_caching(self._get_litellm_model_name()): return messages result = list(messages) diff --git a/strix/llm/memory_compressor.py b/strix/llm/memory_compressor.py index ef0b9ab..e46b331 100644 --- a/strix/llm/memory_compressor.py +++ b/strix/llm/memory_compressor.py @@ -4,6 +4,7 @@ from typing import Any import litellm from strix.config import Config +from strix.llm.utils import get_litellm_model_name, get_strix_api_base logger = logging.getLogger(__name__) @@ -45,7 +46,8 @@ keeping the summary concise and to the point.""" def _count_tokens(text: str, model: str) -> int: try: - count = litellm.token_counter(model=model, text=text) + litellm_model = get_litellm_model_name(model) or model + count = litellm.token_counter(model=litellm_model, text=text) return int(count) except Exception: logger.exception("Failed to count tokens") @@ -110,11 +112,13 @@ def _summarize_messages( or Config.get("openai_api_base") or Config.get("litellm_base_url") or Config.get("ollama_api_base") + or get_strix_api_base(model) ) try: + litellm_model = get_litellm_model_name(model) or model completion_args: dict[str, Any] = { - "model": model, + "model": litellm_model, "messages": [{"role": "user", "content": prompt}], "timeout": timeout, } diff --git a/strix/llm/utils.py b/strix/llm/utils.py index 81431f0..8abe6ba 100644 --- a/strix/llm/utils.py +++ b/strix/llm/utils.py @@ -3,6 +3,50 @@ import re from typing import Any +STRIX_API_BASE = "https://models.strix.ai/api/v1" + +STRIX_PROVIDER_PREFIXES: dict[str, str] = { + "claude-": "anthropic", + "gpt-": "openai", + "gemini-": "gemini", +} + + +def is_strix_model(model_name: str | None) -> bool: + """Check if model uses strix/ prefix.""" + return bool(model_name and model_name.startswith("strix/")) + + +def get_strix_api_base(model_name: str | None) -> str | None: + """Return Strix API base URL if using strix/ model, None otherwise.""" + if is_strix_model(model_name): + return STRIX_API_BASE + return None + + +def get_litellm_model_name(model_name: str | None) -> str | None: + """Convert strix/ prefixed model to litellm-compatible provider/model format. + + Maps strix/ models to their corresponding litellm provider: + - strix/claude-* -> anthropic/claude-* + - strix/gpt-* -> openai/gpt-* + - strix/gemini-* -> gemini/gemini-* + - Other models -> openai/ (routed via Strix API) + """ + if not model_name: + return model_name + if not model_name.startswith("strix/"): + return model_name + + base_model = model_name[6:] + + for prefix, provider in STRIX_PROVIDER_PREFIXES.items(): + if base_model.startswith(prefix): + return f"{provider}/{base_model}" + + return f"openai/{base_model}" + + def _truncate_to_first_function(content: str) -> str: if not content: return content