fix: linting errors
This commit is contained in:
@@ -18,6 +18,7 @@ from rich.panel import Panel
|
|||||||
from rich.text import Text
|
from rich.text import Text
|
||||||
|
|
||||||
from strix.config import Config, apply_saved_config, save_current_config
|
from strix.config import Config, apply_saved_config, save_current_config
|
||||||
|
from strix.llm.utils import get_litellm_model_name, get_strix_api_base
|
||||||
|
|
||||||
|
|
||||||
apply_saved_config()
|
apply_saved_config()
|
||||||
@@ -208,6 +209,7 @@ async def warm_up_llm() -> None:
|
|||||||
or Config.get("openai_api_base")
|
or Config.get("openai_api_base")
|
||||||
or Config.get("litellm_base_url")
|
or Config.get("litellm_base_url")
|
||||||
or Config.get("ollama_api_base")
|
or Config.get("ollama_api_base")
|
||||||
|
or get_strix_api_base(model_name)
|
||||||
)
|
)
|
||||||
|
|
||||||
test_messages = [
|
test_messages = [
|
||||||
@@ -217,8 +219,9 @@ async def warm_up_llm() -> None:
|
|||||||
|
|
||||||
llm_timeout = int(Config.get("llm_timeout") or "300")
|
llm_timeout = int(Config.get("llm_timeout") or "300")
|
||||||
|
|
||||||
|
litellm_model = get_litellm_model_name(model_name) or model_name
|
||||||
completion_kwargs: dict[str, Any] = {
|
completion_kwargs: dict[str, Any] = {
|
||||||
"model": model_name,
|
"model": litellm_model,
|
||||||
"messages": test_messages,
|
"messages": test_messages,
|
||||||
"timeout": llm_timeout,
|
"timeout": llm_timeout,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from typing import Any
|
|||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
from strix.config import Config
|
from strix.config import Config
|
||||||
|
from strix.llm.utils import get_litellm_model_name, get_strix_api_base
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -162,6 +163,7 @@ def check_duplicate(
|
|||||||
or Config.get("openai_api_base")
|
or Config.get("openai_api_base")
|
||||||
or Config.get("litellm_base_url")
|
or Config.get("litellm_base_url")
|
||||||
or Config.get("ollama_api_base")
|
or Config.get("ollama_api_base")
|
||||||
|
or get_strix_api_base(model_name)
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
@@ -176,8 +178,9 @@ def check_duplicate(
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
litellm_model = get_litellm_model_name(model_name) or model_name
|
||||||
completion_kwargs: dict[str, Any] = {
|
completion_kwargs: dict[str, Any] = {
|
||||||
"model": model_name,
|
"model": litellm_model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"timeout": 120,
|
"timeout": 120,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ from strix.llm.memory_compressor import MemoryCompressor
|
|||||||
from strix.llm.utils import (
|
from strix.llm.utils import (
|
||||||
_truncate_to_first_function,
|
_truncate_to_first_function,
|
||||||
fix_incomplete_tool_call,
|
fix_incomplete_tool_call,
|
||||||
|
get_litellm_model_name,
|
||||||
|
get_strix_api_base,
|
||||||
parse_tool_invocations,
|
parse_tool_invocations,
|
||||||
)
|
)
|
||||||
from strix.skills import load_skills
|
from strix.skills import load_skills
|
||||||
@@ -189,12 +191,16 @@ class LLM:
|
|||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
def _get_litellm_model_name(self) -> str:
|
||||||
|
model = self.config.model_name # Validated non-empty in LLMConfig.__init__
|
||||||
|
return get_litellm_model_name(model) or model # type: ignore[return-value]
|
||||||
|
|
||||||
def _build_completion_args(self, messages: list[dict[str, Any]]) -> dict[str, Any]:
|
def _build_completion_args(self, messages: list[dict[str, Any]]) -> dict[str, Any]:
|
||||||
if not self._supports_vision():
|
if not self._supports_vision():
|
||||||
messages = self._strip_images(messages)
|
messages = self._strip_images(messages)
|
||||||
|
|
||||||
args: dict[str, Any] = {
|
args: dict[str, Any] = {
|
||||||
"model": self.config.model_name,
|
"model": self._get_litellm_model_name(),
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"timeout": self.config.timeout,
|
"timeout": self.config.timeout,
|
||||||
"stream_options": {"include_usage": True},
|
"stream_options": {"include_usage": True},
|
||||||
@@ -202,12 +208,15 @@ class LLM:
|
|||||||
|
|
||||||
if api_key := Config.get("llm_api_key"):
|
if api_key := Config.get("llm_api_key"):
|
||||||
args["api_key"] = api_key
|
args["api_key"] = api_key
|
||||||
if api_base := (
|
|
||||||
|
api_base = (
|
||||||
Config.get("llm_api_base")
|
Config.get("llm_api_base")
|
||||||
or Config.get("openai_api_base")
|
or Config.get("openai_api_base")
|
||||||
or Config.get("litellm_base_url")
|
or Config.get("litellm_base_url")
|
||||||
or Config.get("ollama_api_base")
|
or Config.get("ollama_api_base")
|
||||||
):
|
or get_strix_api_base(self.config.model_name)
|
||||||
|
)
|
||||||
|
if api_base:
|
||||||
args["api_base"] = api_base
|
args["api_base"] = api_base
|
||||||
if self._supports_reasoning():
|
if self._supports_reasoning():
|
||||||
args["reasoning_effort"] = self._reasoning_effort
|
args["reasoning_effort"] = self._reasoning_effort
|
||||||
@@ -234,8 +243,8 @@ class LLM:
|
|||||||
def _update_usage_stats(self, response: Any) -> None:
|
def _update_usage_stats(self, response: Any) -> None:
|
||||||
try:
|
try:
|
||||||
if hasattr(response, "usage") and response.usage:
|
if hasattr(response, "usage") and response.usage:
|
||||||
input_tokens = getattr(response.usage, "prompt_tokens", 0)
|
input_tokens = getattr(response.usage, "prompt_tokens", 0) or 0
|
||||||
output_tokens = getattr(response.usage, "completion_tokens", 0)
|
output_tokens = getattr(response.usage, "completion_tokens", 0) or 0
|
||||||
|
|
||||||
cached_tokens = 0
|
cached_tokens = 0
|
||||||
if hasattr(response.usage, "prompt_tokens_details"):
|
if hasattr(response.usage, "prompt_tokens_details"):
|
||||||
@@ -243,14 +252,11 @@ class LLM:
|
|||||||
if hasattr(prompt_details, "cached_tokens"):
|
if hasattr(prompt_details, "cached_tokens"):
|
||||||
cached_tokens = prompt_details.cached_tokens or 0
|
cached_tokens = prompt_details.cached_tokens or 0
|
||||||
|
|
||||||
|
cost = self._extract_cost(response)
|
||||||
else:
|
else:
|
||||||
input_tokens = 0
|
input_tokens = 0
|
||||||
output_tokens = 0
|
output_tokens = 0
|
||||||
cached_tokens = 0
|
cached_tokens = 0
|
||||||
|
|
||||||
try:
|
|
||||||
cost = completion_cost(response) or 0.0
|
|
||||||
except Exception: # noqa: BLE001
|
|
||||||
cost = 0.0
|
cost = 0.0
|
||||||
|
|
||||||
self._total_stats.input_tokens += input_tokens
|
self._total_stats.input_tokens += input_tokens
|
||||||
@@ -261,6 +267,16 @@ class LLM:
|
|||||||
except Exception: # noqa: BLE001, S110 # nosec B110
|
except Exception: # noqa: BLE001, S110 # nosec B110
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _extract_cost(self, response: Any) -> float:
|
||||||
|
if hasattr(response, "usage") and response.usage:
|
||||||
|
direct_cost = getattr(response.usage, "cost", None)
|
||||||
|
if direct_cost is not None:
|
||||||
|
return float(direct_cost)
|
||||||
|
try:
|
||||||
|
return completion_cost(response, model=self._get_litellm_model_name()) or 0.0
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
return 0.0
|
||||||
|
|
||||||
def _should_retry(self, e: Exception) -> bool:
|
def _should_retry(self, e: Exception) -> bool:
|
||||||
code = getattr(e, "status_code", None) or getattr(
|
code = getattr(e, "status_code", None) or getattr(
|
||||||
getattr(e, "response", None), "status_code", None
|
getattr(e, "response", None), "status_code", None
|
||||||
@@ -280,13 +296,13 @@ class LLM:
|
|||||||
|
|
||||||
def _supports_vision(self) -> bool:
|
def _supports_vision(self) -> bool:
|
||||||
try:
|
try:
|
||||||
return bool(supports_vision(model=self.config.model_name))
|
return bool(supports_vision(model=self._get_litellm_model_name()))
|
||||||
except Exception: # noqa: BLE001
|
except Exception: # noqa: BLE001
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _supports_reasoning(self) -> bool:
|
def _supports_reasoning(self) -> bool:
|
||||||
try:
|
try:
|
||||||
return bool(supports_reasoning(model=self.config.model_name))
|
return bool(supports_reasoning(model=self._get_litellm_model_name()))
|
||||||
except Exception: # noqa: BLE001
|
except Exception: # noqa: BLE001
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -307,7 +323,7 @@ class LLM:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
if not messages or not supports_prompt_caching(self.config.model_name):
|
if not messages or not supports_prompt_caching(self._get_litellm_model_name()):
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
result = list(messages)
|
result = list(messages)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from typing import Any
|
|||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
from strix.config import Config
|
from strix.config import Config
|
||||||
|
from strix.llm.utils import get_litellm_model_name, get_strix_api_base
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -45,7 +46,8 @@ keeping the summary concise and to the point."""
|
|||||||
|
|
||||||
def _count_tokens(text: str, model: str) -> int:
|
def _count_tokens(text: str, model: str) -> int:
|
||||||
try:
|
try:
|
||||||
count = litellm.token_counter(model=model, text=text)
|
litellm_model = get_litellm_model_name(model) or model
|
||||||
|
count = litellm.token_counter(model=litellm_model, text=text)
|
||||||
return int(count)
|
return int(count)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to count tokens")
|
logger.exception("Failed to count tokens")
|
||||||
@@ -110,11 +112,13 @@ def _summarize_messages(
|
|||||||
or Config.get("openai_api_base")
|
or Config.get("openai_api_base")
|
||||||
or Config.get("litellm_base_url")
|
or Config.get("litellm_base_url")
|
||||||
or Config.get("ollama_api_base")
|
or Config.get("ollama_api_base")
|
||||||
|
or get_strix_api_base(model)
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
litellm_model = get_litellm_model_name(model) or model
|
||||||
completion_args: dict[str, Any] = {
|
completion_args: dict[str, Any] = {
|
||||||
"model": model,
|
"model": litellm_model,
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
"timeout": timeout,
|
"timeout": timeout,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,50 @@ import re
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
STRIX_API_BASE = "https://models.strix.ai/api/v1"
|
||||||
|
|
||||||
|
STRIX_PROVIDER_PREFIXES: dict[str, str] = {
|
||||||
|
"claude-": "anthropic",
|
||||||
|
"gpt-": "openai",
|
||||||
|
"gemini-": "gemini",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def is_strix_model(model_name: str | None) -> bool:
|
||||||
|
"""Check if model uses strix/ prefix."""
|
||||||
|
return bool(model_name and model_name.startswith("strix/"))
|
||||||
|
|
||||||
|
|
||||||
|
def get_strix_api_base(model_name: str | None) -> str | None:
|
||||||
|
"""Return Strix API base URL if using strix/ model, None otherwise."""
|
||||||
|
if is_strix_model(model_name):
|
||||||
|
return STRIX_API_BASE
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_litellm_model_name(model_name: str | None) -> str | None:
|
||||||
|
"""Convert strix/ prefixed model to litellm-compatible provider/model format.
|
||||||
|
|
||||||
|
Maps strix/ models to their corresponding litellm provider:
|
||||||
|
- strix/claude-* -> anthropic/claude-*
|
||||||
|
- strix/gpt-* -> openai/gpt-*
|
||||||
|
- strix/gemini-* -> gemini/gemini-*
|
||||||
|
- Other models -> openai/<model> (routed via Strix API)
|
||||||
|
"""
|
||||||
|
if not model_name:
|
||||||
|
return model_name
|
||||||
|
if not model_name.startswith("strix/"):
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
base_model = model_name[6:]
|
||||||
|
|
||||||
|
for prefix, provider in STRIX_PROVIDER_PREFIXES.items():
|
||||||
|
if base_model.startswith(prefix):
|
||||||
|
return f"{provider}/{base_model}"
|
||||||
|
|
||||||
|
return f"openai/{base_model}"
|
||||||
|
|
||||||
|
|
||||||
def _truncate_to_first_function(content: str) -> str:
|
def _truncate_to_first_function(content: str) -> str:
|
||||||
if not content:
|
if not content:
|
||||||
return content
|
return content
|
||||||
|
|||||||
Reference in New Issue
Block a user