refactor(llm): streamline reasoning effort handling and remove unused patterns

This commit is contained in:
0xallam
2026-01-06 10:49:04 -08:00
committed by Ahmed Allam
parent 45bb0ae8d8
commit 2777ae3fe8

View File

@@ -4,7 +4,6 @@ import os
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from fnmatch import fnmatch
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@@ -14,7 +13,7 @@ from jinja2 import (
FileSystemLoader, FileSystemLoader,
select_autoescape, select_autoescape,
) )
from litellm import completion_cost, stream_chunk_builder from litellm import completion_cost, stream_chunk_builder, supports_reasoning
from litellm.utils import supports_prompt_caching, supports_vision from litellm.utils import supports_prompt_caching, supports_vision
from strix.llm.config import LLMConfig from strix.llm.config import LLMConfig
@@ -63,57 +62,6 @@ class LLMRequestFailedError(Exception):
self.details = details self.details = details
SUPPORTS_STOP_WORDS_FALSE_PATTERNS: list[str] = [
"o1*",
"grok-4-0709",
"grok-code-fast-1",
"deepseek-r1-0528*",
]
REASONING_EFFORT_PATTERNS: list[str] = [
"o1-2024-12-17",
"o1",
"o3",
"o3-2025-04-16",
"o3-mini-2025-01-31",
"o3-mini",
"o4-mini",
"o4-mini-2025-04-16",
"gemini-2.5-flash",
"gemini-2.5-pro",
"gpt-5*",
"deepseek-r1-0528*",
"claude-sonnet-4-5*",
"claude-haiku-4-5*",
]
def normalize_model_name(model: str) -> str:
raw = (model or "").strip().lower()
if "/" in raw:
name = raw.split("/")[-1]
if ":" in name:
name = name.split(":", 1)[0]
else:
name = raw
if name.endswith("-gguf"):
name = name[: -len("-gguf")]
return name
def model_matches(model: str, patterns: list[str]) -> bool:
raw = (model or "").strip().lower()
name = normalize_model_name(model)
for pat in patterns:
pat_l = pat.lower()
if "/" in pat_l:
if fnmatch(raw, pat_l):
return True
elif fnmatch(name, pat_l):
return True
return False
class StepRole(str, Enum): class StepRole(str, Enum):
AGENT = "agent" AGENT = "agent"
USER = "user" USER = "user"
@@ -421,17 +369,13 @@ class LLM:
"supported": supports_prompt_caching(self.config.model_name), "supported": supports_prompt_caching(self.config.model_name),
} }
def _should_include_stop_param(self) -> bool:
if not self.config.model_name:
return True
return not model_matches(self.config.model_name, SUPPORTS_STOP_WORDS_FALSE_PATTERNS)
def _should_include_reasoning_effort(self) -> bool: def _should_include_reasoning_effort(self) -> bool:
if not self.config.model_name: if not self.config.model_name:
return False return False
try:
return model_matches(self.config.model_name, REASONING_EFFORT_PATTERNS) return bool(supports_reasoning(model=self.config.model_name))
except Exception: # noqa: BLE001
return False
def _model_supports_vision(self) -> bool: def _model_supports_vision(self) -> bool:
if not self.config.model_name: if not self.config.model_name:
@@ -499,8 +443,7 @@ class LLM:
if _LLM_API_BASE: if _LLM_API_BASE:
completion_args["api_base"] = _LLM_API_BASE completion_args["api_base"] = _LLM_API_BASE
if self._should_include_stop_param(): completion_args["stop"] = ["</function>"]
completion_args["stop"] = ["</function>"]
if self._should_include_reasoning_effort(): if self._should_include_reasoning_effort():
completion_args["reasoning_effort"] = "high" completion_args["reasoning_effort"] = "high"