feat(llm): enhance model features handling with pattern matching

This commit is contained in:
Ahmed Allam
2025-11-15 12:43:43 +04:00
parent 821929cd3e
commit d1f7741965

View File

@@ -2,6 +2,7 @@ import logging
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from fnmatch import fnmatch
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@@ -45,27 +46,14 @@ class LLMRequestFailedError(Exception):
self.details = details self.details = details
MODELS_WITHOUT_STOP_WORDS = [ SUPPORTS_STOP_WORDS_FALSE_PATTERNS: list[str] = [
"gpt-5", "o1*",
"gpt-5-mini",
"gpt-5-nano",
"o1-mini",
"o1-preview",
"o1",
"o1-2024-12-17",
"o3",
"o3-2025-04-16",
"o3-mini-2025-01-31",
"o3-mini",
"o4-mini",
"o4-mini-2025-04-16",
"grok-4-0709", "grok-4-0709",
"grok-code-fast-1",
"deepseek-r1-0528*",
] ]
REASONING_EFFORT_SUPPORTED_MODELS = [ REASONING_EFFORT_PATTERNS: list[str] = [
"gpt-5",
"gpt-5-mini",
"gpt-5-nano",
"o1-2024-12-17", "o1-2024-12-17",
"o1", "o1",
"o3", "o3",
@@ -76,9 +64,39 @@ REASONING_EFFORT_SUPPORTED_MODELS = [
"o4-mini-2025-04-16", "o4-mini-2025-04-16",
"gemini-2.5-flash", "gemini-2.5-flash",
"gemini-2.5-pro", "gemini-2.5-pro",
"gpt-5*",
"deepseek-r1-0528*",
"claude-sonnet-4-5*",
"claude-haiku-4-5*",
] ]
def normalize_model_name(model: str) -> str:
raw = (model or "").strip().lower()
if "/" in raw:
name = raw.split("/")[-1]
if ":" in name:
name = name.split(":", 1)[0]
else:
name = raw
if name.endswith("-gguf"):
name = name[: -len("-gguf")]
return name
def model_matches(model: str, patterns: list[str]) -> bool:
raw = (model or "").strip().lower()
name = normalize_model_name(model)
for pat in patterns:
pat_l = pat.lower()
if "/" in pat_l:
if fnmatch(raw, pat_l):
return True
elif fnmatch(name, pat_l):
return True
return False
class StepRole(str, Enum): class StepRole(str, Enum):
AGENT = "agent" AGENT = "agent"
USER = "user" USER = "user"
@@ -332,27 +350,13 @@ class LLM:
if not self.config.model_name: if not self.config.model_name:
return True return True
actual_model_name = self.config.model_name.split("/")[-1].lower() return not model_matches(self.config.model_name, SUPPORTS_STOP_WORDS_FALSE_PATTERNS)
model_name_lower = self.config.model_name.lower()
return not any(
actual_model_name == unsupported_model.lower()
or model_name_lower == unsupported_model.lower()
for unsupported_model in MODELS_WITHOUT_STOP_WORDS
)
def _should_include_reasoning_effort(self) -> bool: def _should_include_reasoning_effort(self) -> bool:
if not self.config.model_name: if not self.config.model_name:
return False return False
actual_model_name = self.config.model_name.split("/")[-1].lower() return model_matches(self.config.model_name, REASONING_EFFORT_PATTERNS)
model_name_lower = self.config.model_name.lower()
return any(
actual_model_name == supported_model.lower()
or model_name_lower == supported_model.lower()
for supported_model in REASONING_EFFORT_SUPPORTED_MODELS
)
async def _make_request( async def _make_request(
self, self,
@@ -361,7 +365,6 @@ class LLM:
completion_args: dict[str, Any] = { completion_args: dict[str, Any] = {
"model": self.config.model_name, "model": self.config.model_name,
"messages": messages, "messages": messages,
"temperature": self.config.temperature,
"timeout": self.config.timeout, "timeout": self.config.timeout,
} }