fix: filter out image_url content for non-vision models
This commit is contained in:
@@ -13,7 +13,7 @@ from jinja2 import (
|
|||||||
select_autoescape,
|
select_autoescape,
|
||||||
)
|
)
|
||||||
from litellm import ModelResponse, completion_cost
|
from litellm import ModelResponse, completion_cost
|
||||||
from litellm.utils import supports_prompt_caching
|
from litellm.utils import supports_prompt_caching, supports_vision
|
||||||
|
|
||||||
from strix.llm.config import LLMConfig
|
from strix.llm.config import LLMConfig
|
||||||
from strix.llm.memory_compressor import MemoryCompressor
|
from strix.llm.memory_compressor import MemoryCompressor
|
||||||
@@ -388,10 +388,55 @@ class LLM:
|
|||||||
|
|
||||||
return model_matches(self.config.model_name, REASONING_EFFORT_PATTERNS)
|
return model_matches(self.config.model_name, REASONING_EFFORT_PATTERNS)
|
||||||
|
|
||||||
|
def _model_supports_vision(self) -> bool:
|
||||||
|
if not self.config.model_name:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
return supports_vision(model=self.config.model_name)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _filter_images_from_messages(
|
||||||
|
self, messages: list[dict[str, Any]]
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
filtered_messages = []
|
||||||
|
for msg in messages:
|
||||||
|
content = msg.get("content")
|
||||||
|
if isinstance(content, list):
|
||||||
|
filtered_content = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
if item.get("type") == "image_url":
|
||||||
|
filtered_content.append({
|
||||||
|
"type": "text",
|
||||||
|
"text": "[Screenshot removed - model does not support vision. "
|
||||||
|
"Use view_source or execute_js to interact with the page instead.]",
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
filtered_content.append(item)
|
||||||
|
else:
|
||||||
|
filtered_content.append(item)
|
||||||
|
if filtered_content:
|
||||||
|
text_parts = [
|
||||||
|
item.get("text", "") if isinstance(item, dict) else str(item)
|
||||||
|
for item in filtered_content
|
||||||
|
]
|
||||||
|
if all(isinstance(item, dict) and item.get("type") == "text" for item in filtered_content):
|
||||||
|
msg = {**msg, "content": "\n".join(text_parts)}
|
||||||
|
else:
|
||||||
|
msg = {**msg, "content": filtered_content}
|
||||||
|
else:
|
||||||
|
msg = {**msg, "content": ""}
|
||||||
|
filtered_messages.append(msg)
|
||||||
|
return filtered_messages
|
||||||
|
|
||||||
async def _make_request(
|
async def _make_request(
|
||||||
self,
|
self,
|
||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
|
if not self._model_supports_vision():
|
||||||
|
messages = self._filter_images_from_messages(messages)
|
||||||
|
|
||||||
completion_args: dict[str, Any] = {
|
completion_args: dict[str, Any] = {
|
||||||
"model": self.config.model_name,
|
"model": self.config.model_name,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
|
|||||||
Reference in New Issue
Block a user