feat(llm): enhance model features handling with pattern matching

This commit is contained in:
Ahmed Allam
2025-11-15 12:43:43 +04:00
parent 821929cd3e
commit d1f7741965

View File

@@ -2,6 +2,7 @@ import logging
import os
from dataclasses import dataclass
from enum import Enum
from fnmatch import fnmatch
from pathlib import Path
from typing import Any
@@ -45,27 +46,14 @@ class LLMRequestFailedError(Exception):
self.details = details
MODELS_WITHOUT_STOP_WORDS = [
"gpt-5",
"gpt-5-mini",
"gpt-5-nano",
"o1-mini",
"o1-preview",
"o1",
"o1-2024-12-17",
"o3",
"o3-2025-04-16",
"o3-mini-2025-01-31",
"o3-mini",
"o4-mini",
"o4-mini-2025-04-16",
SUPPORTS_STOP_WORDS_FALSE_PATTERNS: list[str] = [
"o1*",
"grok-4-0709",
"grok-code-fast-1",
"deepseek-r1-0528*",
]
REASONING_EFFORT_SUPPORTED_MODELS = [
"gpt-5",
"gpt-5-mini",
"gpt-5-nano",
REASONING_EFFORT_PATTERNS: list[str] = [
"o1-2024-12-17",
"o1",
"o3",
@@ -76,9 +64,39 @@ REASONING_EFFORT_SUPPORTED_MODELS = [
"o4-mini-2025-04-16",
"gemini-2.5-flash",
"gemini-2.5-pro",
"gpt-5*",
"deepseek-r1-0528*",
"claude-sonnet-4-5*",
"claude-haiku-4-5*",
]
def normalize_model_name(model: str) -> str:
raw = (model or "").strip().lower()
if "/" in raw:
name = raw.split("/")[-1]
if ":" in name:
name = name.split(":", 1)[0]
else:
name = raw
if name.endswith("-gguf"):
name = name[: -len("-gguf")]
return name
def model_matches(model: str, patterns: list[str]) -> bool:
raw = (model or "").strip().lower()
name = normalize_model_name(model)
for pat in patterns:
pat_l = pat.lower()
if "/" in pat_l:
if fnmatch(raw, pat_l):
return True
elif fnmatch(name, pat_l):
return True
return False
class StepRole(str, Enum):
AGENT = "agent"
USER = "user"
@@ -332,27 +350,13 @@ class LLM:
if not self.config.model_name:
return True
actual_model_name = self.config.model_name.split("/")[-1].lower()
model_name_lower = self.config.model_name.lower()
return not any(
actual_model_name == unsupported_model.lower()
or model_name_lower == unsupported_model.lower()
for unsupported_model in MODELS_WITHOUT_STOP_WORDS
)
return not model_matches(self.config.model_name, SUPPORTS_STOP_WORDS_FALSE_PATTERNS)
def _should_include_reasoning_effort(self) -> bool:
if not self.config.model_name:
return False
actual_model_name = self.config.model_name.split("/")[-1].lower()
model_name_lower = self.config.model_name.lower()
return any(
actual_model_name == supported_model.lower()
or model_name_lower == supported_model.lower()
for supported_model in REASONING_EFFORT_SUPPORTED_MODELS
)
return model_matches(self.config.model_name, REASONING_EFFORT_PATTERNS)
async def _make_request(
self,
@@ -361,7 +365,6 @@ class LLM:
completion_args: dict[str, Any] = {
"model": self.config.model_name,
"messages": messages,
"temperature": self.config.temperature,
"timeout": self.config.timeout,
}