fix: linting errors

This commit is contained in:
octovimmer
2026-02-19 17:25:10 -08:00
parent 40cb705494
commit 06ae3d3860
5 changed files with 86 additions and 16 deletions

View File

@@ -14,6 +14,8 @@ from strix.llm.memory_compressor import MemoryCompressor
from strix.llm.utils import (
_truncate_to_first_function,
fix_incomplete_tool_call,
get_litellm_model_name,
get_strix_api_base,
parse_tool_invocations,
)
from strix.skills import load_skills
@@ -189,12 +191,16 @@ class LLM:
return messages
def _get_litellm_model_name(self) -> str:
model = self.config.model_name # Validated non-empty in LLMConfig.__init__
return get_litellm_model_name(model) or model # type: ignore[return-value]
def _build_completion_args(self, messages: list[dict[str, Any]]) -> dict[str, Any]:
if not self._supports_vision():
messages = self._strip_images(messages)
args: dict[str, Any] = {
"model": self.config.model_name,
"model": self._get_litellm_model_name(),
"messages": messages,
"timeout": self.config.timeout,
"stream_options": {"include_usage": True},
@@ -202,12 +208,15 @@ class LLM:
if api_key := Config.get("llm_api_key"):
args["api_key"] = api_key
if api_base := (
api_base = (
Config.get("llm_api_base")
or Config.get("openai_api_base")
or Config.get("litellm_base_url")
or Config.get("ollama_api_base")
):
or get_strix_api_base(self.config.model_name)
)
if api_base:
args["api_base"] = api_base
if self._supports_reasoning():
args["reasoning_effort"] = self._reasoning_effort
@@ -234,8 +243,8 @@ class LLM:
def _update_usage_stats(self, response: Any) -> None:
try:
if hasattr(response, "usage") and response.usage:
input_tokens = getattr(response.usage, "prompt_tokens", 0)
output_tokens = getattr(response.usage, "completion_tokens", 0)
input_tokens = getattr(response.usage, "prompt_tokens", 0) or 0
output_tokens = getattr(response.usage, "completion_tokens", 0) or 0
cached_tokens = 0
if hasattr(response.usage, "prompt_tokens_details"):
@@ -243,14 +252,11 @@ class LLM:
if hasattr(prompt_details, "cached_tokens"):
cached_tokens = prompt_details.cached_tokens or 0
cost = self._extract_cost(response)
else:
input_tokens = 0
output_tokens = 0
cached_tokens = 0
try:
cost = completion_cost(response) or 0.0
except Exception: # noqa: BLE001
cost = 0.0
self._total_stats.input_tokens += input_tokens
@@ -261,6 +267,16 @@ class LLM:
except Exception: # noqa: BLE001, S110 # nosec B110
pass
def _extract_cost(self, response: Any) -> float:
if hasattr(response, "usage") and response.usage:
direct_cost = getattr(response.usage, "cost", None)
if direct_cost is not None:
return float(direct_cost)
try:
return completion_cost(response, model=self._get_litellm_model_name()) or 0.0
except Exception: # noqa: BLE001
return 0.0
def _should_retry(self, e: Exception) -> bool:
code = getattr(e, "status_code", None) or getattr(
getattr(e, "response", None), "status_code", None
@@ -280,13 +296,13 @@ class LLM:
def _supports_vision(self) -> bool:
try:
return bool(supports_vision(model=self.config.model_name))
return bool(supports_vision(model=self._get_litellm_model_name()))
except Exception: # noqa: BLE001
return False
def _supports_reasoning(self) -> bool:
try:
return bool(supports_reasoning(model=self.config.model_name))
return bool(supports_reasoning(model=self._get_litellm_model_name()))
except Exception: # noqa: BLE001
return False
@@ -307,7 +323,7 @@ class LLM:
return result
def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
if not messages or not supports_prompt_caching(self.config.model_name):
if not messages or not supports_prompt_caching(self._get_litellm_model_name()):
return messages
result = list(messages)