fix: linting errors

This commit is contained in:
octovimmer
2026-02-19 17:25:10 -08:00
parent 40cb705494
commit 06ae3d3860
5 changed files with 86 additions and 16 deletions

View File

@@ -18,6 +18,7 @@ from rich.panel import Panel
from rich.text import Text
from strix.config import Config, apply_saved_config, save_current_config
from strix.llm.utils import get_litellm_model_name, get_strix_api_base
apply_saved_config()
@@ -208,6 +209,7 @@ async def warm_up_llm() -> None:
or Config.get("openai_api_base")
or Config.get("litellm_base_url")
or Config.get("ollama_api_base")
or get_strix_api_base(model_name)
)
test_messages = [
@@ -217,8 +219,9 @@ async def warm_up_llm() -> None:
llm_timeout = int(Config.get("llm_timeout") or "300")
litellm_model = get_litellm_model_name(model_name) or model_name
completion_kwargs: dict[str, Any] = {
"model": model_name,
"model": litellm_model,
"messages": test_messages,
"timeout": llm_timeout,
}

View File

@@ -6,6 +6,7 @@ from typing import Any
import litellm
from strix.config import Config
from strix.llm.utils import get_litellm_model_name, get_strix_api_base
logger = logging.getLogger(__name__)
@@ -162,6 +163,7 @@ def check_duplicate(
or Config.get("openai_api_base")
or Config.get("litellm_base_url")
or Config.get("ollama_api_base")
or get_strix_api_base(model_name)
)
messages = [
@@ -176,8 +178,9 @@ def check_duplicate(
},
]
litellm_model = get_litellm_model_name(model_name) or model_name
completion_kwargs: dict[str, Any] = {
"model": model_name,
"model": litellm_model,
"messages": messages,
"timeout": 120,
}

View File

@@ -14,6 +14,8 @@ from strix.llm.memory_compressor import MemoryCompressor
from strix.llm.utils import (
_truncate_to_first_function,
fix_incomplete_tool_call,
get_litellm_model_name,
get_strix_api_base,
parse_tool_invocations,
)
from strix.skills import load_skills
@@ -189,12 +191,16 @@ class LLM:
return messages
def _get_litellm_model_name(self) -> str:
model = self.config.model_name # Validated non-empty in LLMConfig.__init__
return get_litellm_model_name(model) or model # type: ignore[return-value]
def _build_completion_args(self, messages: list[dict[str, Any]]) -> dict[str, Any]:
if not self._supports_vision():
messages = self._strip_images(messages)
args: dict[str, Any] = {
"model": self.config.model_name,
"model": self._get_litellm_model_name(),
"messages": messages,
"timeout": self.config.timeout,
"stream_options": {"include_usage": True},
@@ -202,12 +208,15 @@ class LLM:
if api_key := Config.get("llm_api_key"):
args["api_key"] = api_key
if api_base := (
api_base = (
Config.get("llm_api_base")
or Config.get("openai_api_base")
or Config.get("litellm_base_url")
or Config.get("ollama_api_base")
):
or get_strix_api_base(self.config.model_name)
)
if api_base:
args["api_base"] = api_base
if self._supports_reasoning():
args["reasoning_effort"] = self._reasoning_effort
@@ -234,8 +243,8 @@ class LLM:
def _update_usage_stats(self, response: Any) -> None:
try:
if hasattr(response, "usage") and response.usage:
input_tokens = getattr(response.usage, "prompt_tokens", 0)
output_tokens = getattr(response.usage, "completion_tokens", 0)
input_tokens = getattr(response.usage, "prompt_tokens", 0) or 0
output_tokens = getattr(response.usage, "completion_tokens", 0) or 0
cached_tokens = 0
if hasattr(response.usage, "prompt_tokens_details"):
@@ -243,14 +252,11 @@ class LLM:
if hasattr(prompt_details, "cached_tokens"):
cached_tokens = prompt_details.cached_tokens or 0
cost = self._extract_cost(response)
else:
input_tokens = 0
output_tokens = 0
cached_tokens = 0
try:
cost = completion_cost(response) or 0.0
except Exception: # noqa: BLE001
cost = 0.0
self._total_stats.input_tokens += input_tokens
@@ -261,6 +267,16 @@ class LLM:
except Exception: # noqa: BLE001, S110 # nosec B110
pass
def _extract_cost(self, response: Any) -> float:
if hasattr(response, "usage") and response.usage:
direct_cost = getattr(response.usage, "cost", None)
if direct_cost is not None:
return float(direct_cost)
try:
return completion_cost(response, model=self._get_litellm_model_name()) or 0.0
except Exception: # noqa: BLE001
return 0.0
def _should_retry(self, e: Exception) -> bool:
code = getattr(e, "status_code", None) or getattr(
getattr(e, "response", None), "status_code", None
@@ -280,13 +296,13 @@ class LLM:
def _supports_vision(self) -> bool:
try:
return bool(supports_vision(model=self.config.model_name))
return bool(supports_vision(model=self._get_litellm_model_name()))
except Exception: # noqa: BLE001
return False
def _supports_reasoning(self) -> bool:
try:
return bool(supports_reasoning(model=self.config.model_name))
return bool(supports_reasoning(model=self._get_litellm_model_name()))
except Exception: # noqa: BLE001
return False
@@ -307,7 +323,7 @@ class LLM:
return result
def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
if not messages or not supports_prompt_caching(self.config.model_name):
if not messages or not supports_prompt_caching(self._get_litellm_model_name()):
return messages
result = list(messages)

View File

@@ -4,6 +4,7 @@ from typing import Any
import litellm
from strix.config import Config
from strix.llm.utils import get_litellm_model_name, get_strix_api_base
logger = logging.getLogger(__name__)
@@ -45,7 +46,8 @@ keeping the summary concise and to the point."""
def _count_tokens(text: str, model: str) -> int:
try:
count = litellm.token_counter(model=model, text=text)
litellm_model = get_litellm_model_name(model) or model
count = litellm.token_counter(model=litellm_model, text=text)
return int(count)
except Exception:
logger.exception("Failed to count tokens")
@@ -110,11 +112,13 @@ def _summarize_messages(
or Config.get("openai_api_base")
or Config.get("litellm_base_url")
or Config.get("ollama_api_base")
or get_strix_api_base(model)
)
try:
litellm_model = get_litellm_model_name(model) or model
completion_args: dict[str, Any] = {
"model": model,
"model": litellm_model,
"messages": [{"role": "user", "content": prompt}],
"timeout": timeout,
}

View File

@@ -3,6 +3,50 @@ import re
from typing import Any
STRIX_API_BASE = "https://models.strix.ai/api/v1"
STRIX_PROVIDER_PREFIXES: dict[str, str] = {
"claude-": "anthropic",
"gpt-": "openai",
"gemini-": "gemini",
}
def is_strix_model(model_name: str | None) -> bool:
"""Check if model uses strix/ prefix."""
return bool(model_name and model_name.startswith("strix/"))
def get_strix_api_base(model_name: str | None) -> str | None:
"""Return Strix API base URL if using strix/ model, None otherwise."""
if is_strix_model(model_name):
return STRIX_API_BASE
return None
def get_litellm_model_name(model_name: str | None) -> str | None:
"""Convert strix/ prefixed model to litellm-compatible provider/model format.
Maps strix/ models to their corresponding litellm provider:
- strix/claude-* -> anthropic/claude-*
- strix/gpt-* -> openai/gpt-*
- strix/gemini-* -> gemini/gemini-*
- Other models -> openai/<model> (routed via Strix API)
"""
if not model_name:
return model_name
if not model_name.startswith("strix/"):
return model_name
base_model = model_name[6:]
for prefix, provider in STRIX_PROVIDER_PREFIXES.items():
if base_model.startswith(prefix):
return f"{provider}/{base_model}"
return f"openai/{base_model}"
def _truncate_to_first_function(content: str) -> str:
if not content:
return content