fix: linting errors
This commit is contained in:
@@ -14,6 +14,8 @@ from strix.llm.memory_compressor import MemoryCompressor
|
||||
from strix.llm.utils import (
|
||||
_truncate_to_first_function,
|
||||
fix_incomplete_tool_call,
|
||||
get_litellm_model_name,
|
||||
get_strix_api_base,
|
||||
parse_tool_invocations,
|
||||
)
|
||||
from strix.skills import load_skills
|
||||
@@ -189,12 +191,16 @@ class LLM:
|
||||
|
||||
return messages
|
||||
|
||||
def _get_litellm_model_name(self) -> str:
|
||||
model = self.config.model_name # Validated non-empty in LLMConfig.__init__
|
||||
return get_litellm_model_name(model) or model # type: ignore[return-value]
|
||||
|
||||
def _build_completion_args(self, messages: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
if not self._supports_vision():
|
||||
messages = self._strip_images(messages)
|
||||
|
||||
args: dict[str, Any] = {
|
||||
"model": self.config.model_name,
|
||||
"model": self._get_litellm_model_name(),
|
||||
"messages": messages,
|
||||
"timeout": self.config.timeout,
|
||||
"stream_options": {"include_usage": True},
|
||||
@@ -202,12 +208,15 @@ class LLM:
|
||||
|
||||
if api_key := Config.get("llm_api_key"):
|
||||
args["api_key"] = api_key
|
||||
if api_base := (
|
||||
|
||||
api_base = (
|
||||
Config.get("llm_api_base")
|
||||
or Config.get("openai_api_base")
|
||||
or Config.get("litellm_base_url")
|
||||
or Config.get("ollama_api_base")
|
||||
):
|
||||
or get_strix_api_base(self.config.model_name)
|
||||
)
|
||||
if api_base:
|
||||
args["api_base"] = api_base
|
||||
if self._supports_reasoning():
|
||||
args["reasoning_effort"] = self._reasoning_effort
|
||||
@@ -234,8 +243,8 @@ class LLM:
|
||||
def _update_usage_stats(self, response: Any) -> None:
|
||||
try:
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
input_tokens = getattr(response.usage, "prompt_tokens", 0)
|
||||
output_tokens = getattr(response.usage, "completion_tokens", 0)
|
||||
input_tokens = getattr(response.usage, "prompt_tokens", 0) or 0
|
||||
output_tokens = getattr(response.usage, "completion_tokens", 0) or 0
|
||||
|
||||
cached_tokens = 0
|
||||
if hasattr(response.usage, "prompt_tokens_details"):
|
||||
@@ -243,14 +252,11 @@ class LLM:
|
||||
if hasattr(prompt_details, "cached_tokens"):
|
||||
cached_tokens = prompt_details.cached_tokens or 0
|
||||
|
||||
cost = self._extract_cost(response)
|
||||
else:
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
cached_tokens = 0
|
||||
|
||||
try:
|
||||
cost = completion_cost(response) or 0.0
|
||||
except Exception: # noqa: BLE001
|
||||
cost = 0.0
|
||||
|
||||
self._total_stats.input_tokens += input_tokens
|
||||
@@ -261,6 +267,16 @@ class LLM:
|
||||
except Exception: # noqa: BLE001, S110 # nosec B110
|
||||
pass
|
||||
|
||||
def _extract_cost(self, response: Any) -> float:
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
direct_cost = getattr(response.usage, "cost", None)
|
||||
if direct_cost is not None:
|
||||
return float(direct_cost)
|
||||
try:
|
||||
return completion_cost(response, model=self._get_litellm_model_name()) or 0.0
|
||||
except Exception: # noqa: BLE001
|
||||
return 0.0
|
||||
|
||||
def _should_retry(self, e: Exception) -> bool:
|
||||
code = getattr(e, "status_code", None) or getattr(
|
||||
getattr(e, "response", None), "status_code", None
|
||||
@@ -280,13 +296,13 @@ class LLM:
|
||||
|
||||
def _supports_vision(self) -> bool:
|
||||
try:
|
||||
return bool(supports_vision(model=self.config.model_name))
|
||||
return bool(supports_vision(model=self._get_litellm_model_name()))
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
def _supports_reasoning(self) -> bool:
|
||||
try:
|
||||
return bool(supports_reasoning(model=self.config.model_name))
|
||||
return bool(supports_reasoning(model=self._get_litellm_model_name()))
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
@@ -307,7 +323,7 @@ class LLM:
|
||||
return result
|
||||
|
||||
def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
if not messages or not supports_prompt_caching(self.config.model_name):
|
||||
if not messages or not supports_prompt_caching(self._get_litellm_model_name()):
|
||||
return messages
|
||||
|
||||
result = list(messages)
|
||||
|
||||
Reference in New Issue
Block a user