fix: linting errors
This commit is contained in:
@@ -18,6 +18,7 @@ from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
from strix.config import Config, apply_saved_config, save_current_config
|
||||
from strix.llm.utils import get_litellm_model_name, get_strix_api_base
|
||||
|
||||
|
||||
apply_saved_config()
|
||||
@@ -208,6 +209,7 @@ async def warm_up_llm() -> None:
|
||||
or Config.get("openai_api_base")
|
||||
or Config.get("litellm_base_url")
|
||||
or Config.get("ollama_api_base")
|
||||
or get_strix_api_base(model_name)
|
||||
)
|
||||
|
||||
test_messages = [
|
||||
@@ -217,8 +219,9 @@ async def warm_up_llm() -> None:
|
||||
|
||||
llm_timeout = int(Config.get("llm_timeout") or "300")
|
||||
|
||||
litellm_model = get_litellm_model_name(model_name) or model_name
|
||||
completion_kwargs: dict[str, Any] = {
|
||||
"model": model_name,
|
||||
"model": litellm_model,
|
||||
"messages": test_messages,
|
||||
"timeout": llm_timeout,
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any
|
||||
import litellm
|
||||
|
||||
from strix.config import Config
|
||||
from strix.llm.utils import get_litellm_model_name, get_strix_api_base
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -162,6 +163,7 @@ def check_duplicate(
|
||||
or Config.get("openai_api_base")
|
||||
or Config.get("litellm_base_url")
|
||||
or Config.get("ollama_api_base")
|
||||
or get_strix_api_base(model_name)
|
||||
)
|
||||
|
||||
messages = [
|
||||
@@ -176,8 +178,9 @@ def check_duplicate(
|
||||
},
|
||||
]
|
||||
|
||||
litellm_model = get_litellm_model_name(model_name) or model_name
|
||||
completion_kwargs: dict[str, Any] = {
|
||||
"model": model_name,
|
||||
"model": litellm_model,
|
||||
"messages": messages,
|
||||
"timeout": 120,
|
||||
}
|
||||
|
||||
@@ -14,6 +14,8 @@ from strix.llm.memory_compressor import MemoryCompressor
|
||||
from strix.llm.utils import (
|
||||
_truncate_to_first_function,
|
||||
fix_incomplete_tool_call,
|
||||
get_litellm_model_name,
|
||||
get_strix_api_base,
|
||||
parse_tool_invocations,
|
||||
)
|
||||
from strix.skills import load_skills
|
||||
@@ -189,12 +191,16 @@ class LLM:
|
||||
|
||||
return messages
|
||||
|
||||
def _get_litellm_model_name(self) -> str:
|
||||
model = self.config.model_name # Validated non-empty in LLMConfig.__init__
|
||||
return get_litellm_model_name(model) or model # type: ignore[return-value]
|
||||
|
||||
def _build_completion_args(self, messages: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
if not self._supports_vision():
|
||||
messages = self._strip_images(messages)
|
||||
|
||||
args: dict[str, Any] = {
|
||||
"model": self.config.model_name,
|
||||
"model": self._get_litellm_model_name(),
|
||||
"messages": messages,
|
||||
"timeout": self.config.timeout,
|
||||
"stream_options": {"include_usage": True},
|
||||
@@ -202,12 +208,15 @@ class LLM:
|
||||
|
||||
if api_key := Config.get("llm_api_key"):
|
||||
args["api_key"] = api_key
|
||||
if api_base := (
|
||||
|
||||
api_base = (
|
||||
Config.get("llm_api_base")
|
||||
or Config.get("openai_api_base")
|
||||
or Config.get("litellm_base_url")
|
||||
or Config.get("ollama_api_base")
|
||||
):
|
||||
or get_strix_api_base(self.config.model_name)
|
||||
)
|
||||
if api_base:
|
||||
args["api_base"] = api_base
|
||||
if self._supports_reasoning():
|
||||
args["reasoning_effort"] = self._reasoning_effort
|
||||
@@ -234,8 +243,8 @@ class LLM:
|
||||
def _update_usage_stats(self, response: Any) -> None:
|
||||
try:
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
input_tokens = getattr(response.usage, "prompt_tokens", 0)
|
||||
output_tokens = getattr(response.usage, "completion_tokens", 0)
|
||||
input_tokens = getattr(response.usage, "prompt_tokens", 0) or 0
|
||||
output_tokens = getattr(response.usage, "completion_tokens", 0) or 0
|
||||
|
||||
cached_tokens = 0
|
||||
if hasattr(response.usage, "prompt_tokens_details"):
|
||||
@@ -243,14 +252,11 @@ class LLM:
|
||||
if hasattr(prompt_details, "cached_tokens"):
|
||||
cached_tokens = prompt_details.cached_tokens or 0
|
||||
|
||||
cost = self._extract_cost(response)
|
||||
else:
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
cached_tokens = 0
|
||||
|
||||
try:
|
||||
cost = completion_cost(response) or 0.0
|
||||
except Exception: # noqa: BLE001
|
||||
cost = 0.0
|
||||
|
||||
self._total_stats.input_tokens += input_tokens
|
||||
@@ -261,6 +267,16 @@ class LLM:
|
||||
except Exception: # noqa: BLE001, S110 # nosec B110
|
||||
pass
|
||||
|
||||
def _extract_cost(self, response: Any) -> float:
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
direct_cost = getattr(response.usage, "cost", None)
|
||||
if direct_cost is not None:
|
||||
return float(direct_cost)
|
||||
try:
|
||||
return completion_cost(response, model=self._get_litellm_model_name()) or 0.0
|
||||
except Exception: # noqa: BLE001
|
||||
return 0.0
|
||||
|
||||
def _should_retry(self, e: Exception) -> bool:
|
||||
code = getattr(e, "status_code", None) or getattr(
|
||||
getattr(e, "response", None), "status_code", None
|
||||
@@ -280,13 +296,13 @@ class LLM:
|
||||
|
||||
def _supports_vision(self) -> bool:
|
||||
try:
|
||||
return bool(supports_vision(model=self.config.model_name))
|
||||
return bool(supports_vision(model=self._get_litellm_model_name()))
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
def _supports_reasoning(self) -> bool:
|
||||
try:
|
||||
return bool(supports_reasoning(model=self.config.model_name))
|
||||
return bool(supports_reasoning(model=self._get_litellm_model_name()))
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
@@ -307,7 +323,7 @@ class LLM:
|
||||
return result
|
||||
|
||||
def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
if not messages or not supports_prompt_caching(self.config.model_name):
|
||||
if not messages or not supports_prompt_caching(self._get_litellm_model_name()):
|
||||
return messages
|
||||
|
||||
result = list(messages)
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any
|
||||
import litellm
|
||||
|
||||
from strix.config import Config
|
||||
from strix.llm.utils import get_litellm_model_name, get_strix_api_base
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -45,7 +46,8 @@ keeping the summary concise and to the point."""
|
||||
|
||||
def _count_tokens(text: str, model: str) -> int:
|
||||
try:
|
||||
count = litellm.token_counter(model=model, text=text)
|
||||
litellm_model = get_litellm_model_name(model) or model
|
||||
count = litellm.token_counter(model=litellm_model, text=text)
|
||||
return int(count)
|
||||
except Exception:
|
||||
logger.exception("Failed to count tokens")
|
||||
@@ -110,11 +112,13 @@ def _summarize_messages(
|
||||
or Config.get("openai_api_base")
|
||||
or Config.get("litellm_base_url")
|
||||
or Config.get("ollama_api_base")
|
||||
or get_strix_api_base(model)
|
||||
)
|
||||
|
||||
try:
|
||||
litellm_model = get_litellm_model_name(model) or model
|
||||
completion_args: dict[str, Any] = {
|
||||
"model": model,
|
||||
"model": litellm_model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"timeout": timeout,
|
||||
}
|
||||
|
||||
@@ -3,6 +3,50 @@ import re
|
||||
from typing import Any
|
||||
|
||||
|
||||
STRIX_API_BASE = "https://models.strix.ai/api/v1"
|
||||
|
||||
STRIX_PROVIDER_PREFIXES: dict[str, str] = {
|
||||
"claude-": "anthropic",
|
||||
"gpt-": "openai",
|
||||
"gemini-": "gemini",
|
||||
}
|
||||
|
||||
|
||||
def is_strix_model(model_name: str | None) -> bool:
|
||||
"""Check if model uses strix/ prefix."""
|
||||
return bool(model_name and model_name.startswith("strix/"))
|
||||
|
||||
|
||||
def get_strix_api_base(model_name: str | None) -> str | None:
|
||||
"""Return Strix API base URL if using strix/ model, None otherwise."""
|
||||
if is_strix_model(model_name):
|
||||
return STRIX_API_BASE
|
||||
return None
|
||||
|
||||
|
||||
def get_litellm_model_name(model_name: str | None) -> str | None:
|
||||
"""Convert strix/ prefixed model to litellm-compatible provider/model format.
|
||||
|
||||
Maps strix/ models to their corresponding litellm provider:
|
||||
- strix/claude-* -> anthropic/claude-*
|
||||
- strix/gpt-* -> openai/gpt-*
|
||||
- strix/gemini-* -> gemini/gemini-*
|
||||
- Other models -> openai/<model> (routed via Strix API)
|
||||
"""
|
||||
if not model_name:
|
||||
return model_name
|
||||
if not model_name.startswith("strix/"):
|
||||
return model_name
|
||||
|
||||
base_model = model_name[6:]
|
||||
|
||||
for prefix, provider in STRIX_PROVIDER_PREFIXES.items():
|
||||
if base_model.startswith(prefix):
|
||||
return f"{provider}/{base_model}"
|
||||
|
||||
return f"openai/{base_model}"
|
||||
|
||||
|
||||
def _truncate_to_first_function(content: str) -> str:
|
||||
if not content:
|
||||
return content
|
||||
|
||||
Reference in New Issue
Block a user