refactor(llm): streamline reasoning effort handling and remove unused patterns
This commit is contained in:
@@ -4,7 +4,6 @@ import os
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from fnmatch import fnmatch
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@@ -14,7 +13,7 @@ from jinja2 import (
|
||||
FileSystemLoader,
|
||||
select_autoescape,
|
||||
)
|
||||
from litellm import completion_cost, stream_chunk_builder
|
||||
from litellm import completion_cost, stream_chunk_builder, supports_reasoning
|
||||
from litellm.utils import supports_prompt_caching, supports_vision
|
||||
|
||||
from strix.llm.config import LLMConfig
|
||||
@@ -63,57 +62,6 @@ class LLMRequestFailedError(Exception):
|
||||
self.details = details
|
||||
|
||||
|
||||
SUPPORTS_STOP_WORDS_FALSE_PATTERNS: list[str] = [
|
||||
"o1*",
|
||||
"grok-4-0709",
|
||||
"grok-code-fast-1",
|
||||
"deepseek-r1-0528*",
|
||||
]
|
||||
|
||||
REASONING_EFFORT_PATTERNS: list[str] = [
|
||||
"o1-2024-12-17",
|
||||
"o1",
|
||||
"o3",
|
||||
"o3-2025-04-16",
|
||||
"o3-mini-2025-01-31",
|
||||
"o3-mini",
|
||||
"o4-mini",
|
||||
"o4-mini-2025-04-16",
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-pro",
|
||||
"gpt-5*",
|
||||
"deepseek-r1-0528*",
|
||||
"claude-sonnet-4-5*",
|
||||
"claude-haiku-4-5*",
|
||||
]
|
||||
|
||||
|
||||
def normalize_model_name(model: str) -> str:
|
||||
raw = (model or "").strip().lower()
|
||||
if "/" in raw:
|
||||
name = raw.split("/")[-1]
|
||||
if ":" in name:
|
||||
name = name.split(":", 1)[0]
|
||||
else:
|
||||
name = raw
|
||||
if name.endswith("-gguf"):
|
||||
name = name[: -len("-gguf")]
|
||||
return name
|
||||
|
||||
|
||||
def model_matches(model: str, patterns: list[str]) -> bool:
|
||||
raw = (model or "").strip().lower()
|
||||
name = normalize_model_name(model)
|
||||
for pat in patterns:
|
||||
pat_l = pat.lower()
|
||||
if "/" in pat_l:
|
||||
if fnmatch(raw, pat_l):
|
||||
return True
|
||||
elif fnmatch(name, pat_l):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class StepRole(str, Enum):
|
||||
AGENT = "agent"
|
||||
USER = "user"
|
||||
@@ -421,17 +369,13 @@ class LLM:
|
||||
"supported": supports_prompt_caching(self.config.model_name),
|
||||
}
|
||||
|
||||
def _should_include_stop_param(self) -> bool:
|
||||
if not self.config.model_name:
|
||||
return True
|
||||
|
||||
return not model_matches(self.config.model_name, SUPPORTS_STOP_WORDS_FALSE_PATTERNS)
|
||||
|
||||
def _should_include_reasoning_effort(self) -> bool:
|
||||
if not self.config.model_name:
|
||||
return False
|
||||
|
||||
return model_matches(self.config.model_name, REASONING_EFFORT_PATTERNS)
|
||||
try:
|
||||
return bool(supports_reasoning(model=self.config.model_name))
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
def _model_supports_vision(self) -> bool:
|
||||
if not self.config.model_name:
|
||||
@@ -499,8 +443,7 @@ class LLM:
|
||||
if _LLM_API_BASE:
|
||||
completion_args["api_base"] = _LLM_API_BASE
|
||||
|
||||
if self._should_include_stop_param():
|
||||
completion_args["stop"] = ["</function>"]
|
||||
completion_args["stop"] = ["</function>"]
|
||||
|
||||
if self._should_include_reasoning_effort():
|
||||
completion_args["reasoning_effort"] = "high"
|
||||
|
||||
Reference in New Issue
Block a user