feat(tui): add real-time streaming LLM output with full content display
- Convert LiteLLM requests to streaming mode with stream_request() - Add streaming parser to handle live LLM output segments - Update TUI for real-time streaming content rendering - Add tracer methods for streaming content tracking - Clean function tags from streamed content to prevent display - Remove all truncation from tool renderers for full content visibility
This commit is contained in:
@@ -351,9 +351,16 @@ class BaseAgent(metaclass=AgentMeta):
|
|||||||
self.state.add_message("user", task)
|
self.state.add_message("user", task)
|
||||||
|
|
||||||
async def _process_iteration(self, tracer: Optional["Tracer"]) -> bool:
|
async def _process_iteration(self, tracer: Optional["Tracer"]) -> bool:
|
||||||
response = await self.llm.generate(self.state.get_conversation_history())
|
final_response = None
|
||||||
|
async for response in self.llm.generate(self.state.get_conversation_history()):
|
||||||
|
final_response = response
|
||||||
|
if tracer and response.content:
|
||||||
|
tracer.update_streaming_content(self.state.agent_id, response.content)
|
||||||
|
|
||||||
content_stripped = (response.content or "").strip()
|
if final_response is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
content_stripped = (final_response.content or "").strip()
|
||||||
|
|
||||||
if not content_stripped:
|
if not content_stripped:
|
||||||
corrective_message = (
|
corrective_message = (
|
||||||
@@ -369,17 +376,18 @@ class BaseAgent(metaclass=AgentMeta):
|
|||||||
self.state.add_message("user", corrective_message)
|
self.state.add_message("user", corrective_message)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self.state.add_message("assistant", response.content)
|
self.state.add_message("assistant", final_response.content)
|
||||||
if tracer:
|
if tracer:
|
||||||
|
tracer.clear_streaming_content(self.state.agent_id)
|
||||||
tracer.log_chat_message(
|
tracer.log_chat_message(
|
||||||
content=clean_content(response.content),
|
content=clean_content(final_response.content),
|
||||||
role="assistant",
|
role="assistant",
|
||||||
agent_id=self.state.agent_id,
|
agent_id=self.state.agent_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
actions = (
|
actions = (
|
||||||
response.tool_invocations
|
final_response.tool_invocations
|
||||||
if hasattr(response, "tool_invocations") and response.tool_invocations
|
if hasattr(final_response, "tool_invocations") and final_response.tool_invocations
|
||||||
else []
|
else []
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
119
strix/interface/streaming_parser.py
Normal file
119
strix/interface/streaming_parser.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
import html
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
|
||||||
|
_FUNCTION_TAG_PREFIX = "<function="
|
||||||
|
|
||||||
|
|
||||||
|
def _get_safe_content(content: str) -> tuple[str, str]:
|
||||||
|
if not content:
|
||||||
|
return "", ""
|
||||||
|
|
||||||
|
last_lt = content.rfind("<")
|
||||||
|
if last_lt == -1:
|
||||||
|
return content, ""
|
||||||
|
|
||||||
|
suffix = content[last_lt:]
|
||||||
|
target = _FUNCTION_TAG_PREFIX # "<function="
|
||||||
|
|
||||||
|
if target.startswith(suffix):
|
||||||
|
return content[:last_lt], suffix
|
||||||
|
|
||||||
|
return content, ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StreamSegment:
|
||||||
|
type: Literal["text", "tool"]
|
||||||
|
content: str
|
||||||
|
tool_name: str | None = None
|
||||||
|
args: dict[str, str] | None = None
|
||||||
|
is_complete: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def parse_streaming_content(content: str) -> list[StreamSegment]:
|
||||||
|
if not content:
|
||||||
|
return []
|
||||||
|
|
||||||
|
segments: list[StreamSegment] = []
|
||||||
|
|
||||||
|
func_pattern = r"<function=([^>]+)>"
|
||||||
|
func_matches = list(re.finditer(func_pattern, content))
|
||||||
|
|
||||||
|
if not func_matches:
|
||||||
|
safe_content, _ = _get_safe_content(content)
|
||||||
|
text = safe_content.strip()
|
||||||
|
if text:
|
||||||
|
segments.append(StreamSegment(type="text", content=text))
|
||||||
|
return segments
|
||||||
|
|
||||||
|
first_func_start = func_matches[0].start()
|
||||||
|
if first_func_start > 0:
|
||||||
|
text_before = content[:first_func_start].strip()
|
||||||
|
if text_before:
|
||||||
|
segments.append(StreamSegment(type="text", content=text_before))
|
||||||
|
|
||||||
|
for i, match in enumerate(func_matches):
|
||||||
|
tool_name = match.group(1)
|
||||||
|
func_start = match.end()
|
||||||
|
|
||||||
|
func_end_match = re.search(r"</function>", content[func_start:])
|
||||||
|
|
||||||
|
if func_end_match:
|
||||||
|
func_body = content[func_start : func_start + func_end_match.start()]
|
||||||
|
is_complete = True
|
||||||
|
end_pos = func_start + func_end_match.end()
|
||||||
|
else:
|
||||||
|
if i + 1 < len(func_matches):
|
||||||
|
next_func_start = func_matches[i + 1].start()
|
||||||
|
func_body = content[func_start:next_func_start]
|
||||||
|
else:
|
||||||
|
func_body = content[func_start:]
|
||||||
|
is_complete = False
|
||||||
|
end_pos = len(content)
|
||||||
|
|
||||||
|
args = _parse_streaming_params(func_body)
|
||||||
|
|
||||||
|
segments.append(
|
||||||
|
StreamSegment(
|
||||||
|
type="tool",
|
||||||
|
content=func_body,
|
||||||
|
tool_name=tool_name,
|
||||||
|
args=args,
|
||||||
|
is_complete=is_complete,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_complete and i + 1 < len(func_matches):
|
||||||
|
next_start = func_matches[i + 1].start()
|
||||||
|
text_between = content[end_pos:next_start].strip()
|
||||||
|
if text_between:
|
||||||
|
segments.append(StreamSegment(type="text", content=text_between))
|
||||||
|
|
||||||
|
return segments
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_streaming_params(func_body: str) -> dict[str, str]:
|
||||||
|
args: dict[str, str] = {}
|
||||||
|
|
||||||
|
complete_pattern = r"<parameter=([^>]+)>(.*?)</parameter>"
|
||||||
|
complete_matches = list(re.finditer(complete_pattern, func_body, re.DOTALL))
|
||||||
|
complete_end_pos = 0
|
||||||
|
|
||||||
|
for match in complete_matches:
|
||||||
|
param_name = match.group(1)
|
||||||
|
param_value = html.unescape(match.group(2).strip())
|
||||||
|
args[param_name] = param_value
|
||||||
|
complete_end_pos = max(complete_end_pos, match.end())
|
||||||
|
|
||||||
|
remaining = func_body[complete_end_pos:]
|
||||||
|
incomplete_pattern = r"<parameter=([^>]+)>(.*)$"
|
||||||
|
incomplete_match = re.search(incomplete_pattern, remaining, re.DOTALL)
|
||||||
|
if incomplete_match:
|
||||||
|
param_name = incomplete_match.group(1)
|
||||||
|
param_value = html.unescape(incomplete_match.group(2).strip())
|
||||||
|
args[param_name] = param_value
|
||||||
|
|
||||||
|
return args
|
||||||
@@ -181,4 +181,10 @@ class AgentMessageRenderer(BaseToolRenderer):
|
|||||||
if not content:
|
if not content:
|
||||||
return Text()
|
return Text()
|
||||||
|
|
||||||
return _apply_markdown_styles(content)
|
from strix.llm.utils import clean_content
|
||||||
|
|
||||||
|
cleaned = clean_content(content)
|
||||||
|
if not cleaned:
|
||||||
|
return Text()
|
||||||
|
|
||||||
|
return _apply_markdown_styles(cleaned)
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ class CreateAgentRenderer(BaseToolRenderer):
|
|||||||
|
|
||||||
if task:
|
if task:
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append(cls.truncate(task, 400), style="dim")
|
text.append(task, style="dim")
|
||||||
else:
|
else:
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append("Spawning agent...", style="dim")
|
text.append("Spawning agent...", style="dim")
|
||||||
@@ -66,7 +66,7 @@ class SendMessageToAgentRenderer(BaseToolRenderer):
|
|||||||
|
|
||||||
if message:
|
if message:
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append(cls.truncate(message, 400), style="dim")
|
text.append(message, style="dim")
|
||||||
else:
|
else:
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append("Sending...", style="dim")
|
text.append("Sending...", style="dim")
|
||||||
@@ -129,7 +129,7 @@ class WaitForMessageRenderer(BaseToolRenderer):
|
|||||||
|
|
||||||
if reason:
|
if reason:
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append(cls.truncate(reason, 400), style="dim")
|
text.append(reason, style="dim")
|
||||||
else:
|
else:
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append("Agent paused until message received...", style="dim")
|
text.append("Agent paused until message received...", style="dim")
|
||||||
|
|||||||
@@ -23,12 +23,6 @@ class BaseToolRenderer(ABC):
|
|||||||
css_classes = cls.get_css_classes(status)
|
css_classes = cls.get_css_classes(status)
|
||||||
return Static(content, classes=css_classes)
|
return Static(content, classes=css_classes)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def truncate(cls, text: str, max_length: int = 500) -> str:
|
|
||||||
if len(text) <= max_length:
|
|
||||||
return text
|
|
||||||
return text[: max_length - 3] + "..."
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def status_icon(cls, status: str) -> tuple[str, str]:
|
def status_icon(cls, status: str) -> tuple[str, str]:
|
||||||
icons = {
|
icons = {
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ class BrowserRenderer(BaseToolRenderer):
|
|||||||
def _build_url_action(cls, text: Text, label: str, url: str | None, suffix: str = "") -> None:
|
def _build_url_action(cls, text: Text, label: str, url: str | None, suffix: str = "") -> None:
|
||||||
text.append(label, style="#06b6d4")
|
text.append(label, style="#06b6d4")
|
||||||
if url:
|
if url:
|
||||||
text.append(cls.truncate(url, 300), style="#06b6d4")
|
text.append(url, style="#06b6d4")
|
||||||
if suffix:
|
if suffix:
|
||||||
text.append(suffix, style="#06b6d4")
|
text.append(suffix, style="#06b6d4")
|
||||||
|
|
||||||
@@ -120,7 +120,7 @@ class BrowserRenderer(BaseToolRenderer):
|
|||||||
label, value = handlers[action]
|
label, value = handlers[action]
|
||||||
text.append(label, style="#06b6d4")
|
text.append(label, style="#06b6d4")
|
||||||
if value:
|
if value:
|
||||||
text.append(cls.truncate(str(value), 200), style="#06b6d4")
|
text.append(str(value), style="#06b6d4")
|
||||||
return text
|
return text
|
||||||
|
|
||||||
if action == "execute_js":
|
if action == "execute_js":
|
||||||
@@ -128,7 +128,7 @@ class BrowserRenderer(BaseToolRenderer):
|
|||||||
js_code = args.get("js_code")
|
js_code = args.get("js_code")
|
||||||
if js_code:
|
if js_code:
|
||||||
text.append("\n")
|
text.append("\n")
|
||||||
text.append_text(cls._highlight_js(cls.truncate(js_code, 2000)))
|
text.append_text(cls._highlight_js(js_code))
|
||||||
return text
|
return text
|
||||||
|
|
||||||
text.append(action, style="#06b6d4")
|
text.append(action, style="#06b6d4")
|
||||||
|
|||||||
@@ -83,8 +83,7 @@ class StrReplaceEditorRenderer(BaseToolRenderer):
|
|||||||
|
|
||||||
if command == "str_replace" and (old_str or new_str):
|
if command == "str_replace" and (old_str or new_str):
|
||||||
if old_str:
|
if old_str:
|
||||||
old_display = cls.truncate(old_str, 1000)
|
highlighted_old = cls._highlight_code(old_str, path)
|
||||||
highlighted_old = cls._highlight_code(old_display, path)
|
|
||||||
for line in highlighted_old.plain.split("\n"):
|
for line in highlighted_old.plain.split("\n"):
|
||||||
text.append("\n")
|
text.append("\n")
|
||||||
text.append("-", style="#ef4444")
|
text.append("-", style="#ef4444")
|
||||||
@@ -92,8 +91,7 @@ class StrReplaceEditorRenderer(BaseToolRenderer):
|
|||||||
text.append(line)
|
text.append(line)
|
||||||
|
|
||||||
if new_str:
|
if new_str:
|
||||||
new_display = cls.truncate(new_str, 1000)
|
highlighted_new = cls._highlight_code(new_str, path)
|
||||||
highlighted_new = cls._highlight_code(new_display, path)
|
|
||||||
for line in highlighted_new.plain.split("\n"):
|
for line in highlighted_new.plain.split("\n"):
|
||||||
text.append("\n")
|
text.append("\n")
|
||||||
text.append("+", style="#22c55e")
|
text.append("+", style="#22c55e")
|
||||||
@@ -101,13 +99,11 @@ class StrReplaceEditorRenderer(BaseToolRenderer):
|
|||||||
text.append(line)
|
text.append(line)
|
||||||
|
|
||||||
elif command == "create" and file_text:
|
elif command == "create" and file_text:
|
||||||
text_display = cls.truncate(file_text, 1500)
|
|
||||||
text.append("\n")
|
text.append("\n")
|
||||||
text.append_text(cls._highlight_code(text_display, path))
|
text.append_text(cls._highlight_code(file_text, path))
|
||||||
|
|
||||||
elif command == "insert" and new_str:
|
elif command == "insert" and new_str:
|
||||||
new_display = cls.truncate(new_str, 1000)
|
highlighted_new = cls._highlight_code(new_str, path)
|
||||||
highlighted_new = cls._highlight_code(new_display, path)
|
|
||||||
for line in highlighted_new.plain.split("\n"):
|
for line in highlighted_new.plain.split("\n"):
|
||||||
text.append("\n")
|
text.append("\n")
|
||||||
text.append("+", style="#22c55e")
|
text.append("+", style="#22c55e")
|
||||||
@@ -164,19 +160,15 @@ class SearchFilesRenderer(BaseToolRenderer):
|
|||||||
text.append(" ")
|
text.append(" ")
|
||||||
|
|
||||||
if path and regex:
|
if path and regex:
|
||||||
path_display = path[-30:] if len(path) > 30 else path
|
text.append(path, style="dim")
|
||||||
regex_display = regex[:30] if len(regex) > 30 else regex
|
|
||||||
text.append(path_display, style="dim")
|
|
||||||
text.append(" for '", style="dim")
|
text.append(" for '", style="dim")
|
||||||
text.append(regex_display, style="dim")
|
text.append(regex, style="dim")
|
||||||
text.append("'", style="dim")
|
text.append("'", style="dim")
|
||||||
elif path:
|
elif path:
|
||||||
path_display = path[-60:] if len(path) > 60 else path
|
text.append(path, style="dim")
|
||||||
text.append(path_display, style="dim")
|
|
||||||
elif regex:
|
elif regex:
|
||||||
regex_display = regex[:60] if len(regex) > 60 else regex
|
|
||||||
text.append("'", style="dim")
|
text.append("'", style="dim")
|
||||||
text.append(regex_display, style="dim")
|
text.append(regex, style="dim")
|
||||||
text.append("'", style="dim")
|
text.append("'", style="dim")
|
||||||
else:
|
else:
|
||||||
text.append("Searching...", style="dim")
|
text.append("Searching...", style="dim")
|
||||||
|
|||||||
@@ -28,11 +28,11 @@ class CreateNoteRenderer(BaseToolRenderer):
|
|||||||
|
|
||||||
if title:
|
if title:
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append(cls.truncate(title.strip(), 300))
|
text.append(title.strip())
|
||||||
|
|
||||||
if content:
|
if content:
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append(cls.truncate(content.strip(), 800), style="dim")
|
text.append(content.strip(), style="dim")
|
||||||
|
|
||||||
if not title and not content:
|
if not title and not content:
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
@@ -75,11 +75,11 @@ class UpdateNoteRenderer(BaseToolRenderer):
|
|||||||
|
|
||||||
if title:
|
if title:
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append(cls.truncate(title, 300))
|
text.append(title)
|
||||||
|
|
||||||
if content:
|
if content:
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append(cls.truncate(content.strip(), 800), style="dim")
|
text.append(content.strip(), style="dim")
|
||||||
|
|
||||||
if not title and not content:
|
if not title and not content:
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
@@ -110,23 +110,18 @@ class ListNotesRenderer(BaseToolRenderer):
|
|||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append("No notes", style="dim")
|
text.append("No notes", style="dim")
|
||||||
else:
|
else:
|
||||||
for note in notes[:5]:
|
for note in notes:
|
||||||
title = note.get("title", "").strip() or "(untitled)"
|
title = note.get("title", "").strip() or "(untitled)"
|
||||||
category = note.get("category", "general")
|
category = note.get("category", "general")
|
||||||
note_content = note.get("content", "").strip()
|
note_content = note.get("content", "").strip()
|
||||||
|
|
||||||
text.append("\n - ")
|
text.append("\n - ")
|
||||||
text.append(cls.truncate(title, 300))
|
text.append(title)
|
||||||
text.append(f" ({category})", style="dim")
|
text.append(f" ({category})", style="dim")
|
||||||
|
|
||||||
if note_content:
|
if note_content:
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append(cls.truncate(note_content, 400), style="dim")
|
text.append(note_content, style="dim")
|
||||||
|
|
||||||
remaining = max(count - 5, 0)
|
|
||||||
if remaining:
|
|
||||||
text.append("\n ")
|
|
||||||
text.append(f"... +{remaining} more", style="dim")
|
|
||||||
else:
|
else:
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append("Loading...", style="dim")
|
text.append("Loading...", style="dim")
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ class ListRequestsRenderer(BaseToolRenderer):
|
|||||||
if result and isinstance(result, dict) and "requests" in result:
|
if result and isinstance(result, dict) and "requests" in result:
|
||||||
requests = result["requests"]
|
requests = result["requests"]
|
||||||
if isinstance(requests, list) and requests:
|
if isinstance(requests, list) and requests:
|
||||||
for req in requests[:3]:
|
for req in requests:
|
||||||
if isinstance(req, dict):
|
if isinstance(req, dict):
|
||||||
method = req.get("method", "?")
|
method = req.get("method", "?")
|
||||||
path = req.get("path", "?")
|
path = req.get("path", "?")
|
||||||
@@ -34,16 +34,12 @@ class ListRequestsRenderer(BaseToolRenderer):
|
|||||||
status = response.get("statusCode", "?")
|
status = response.get("statusCode", "?")
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append(f"{method} {path} → {status}", style="dim")
|
text.append(f"{method} {path} → {status}", style="dim")
|
||||||
|
|
||||||
if len(requests) > 3:
|
|
||||||
text.append("\n ")
|
|
||||||
text.append(f"... +{len(requests) - 3} more", style="dim")
|
|
||||||
else:
|
else:
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append("No requests found", style="dim")
|
text.append("No requests found", style="dim")
|
||||||
elif httpql_filter:
|
elif httpql_filter:
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append(cls.truncate(httpql_filter, 300), style="dim")
|
text.append(httpql_filter, style="dim")
|
||||||
else:
|
else:
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append("All requests", style="dim")
|
text.append("All requests", style="dim")
|
||||||
@@ -72,17 +68,15 @@ class ViewRequestRenderer(BaseToolRenderer):
|
|||||||
if "content" in result:
|
if "content" in result:
|
||||||
content = result["content"]
|
content = result["content"]
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append(cls.truncate(content, 500), style="dim")
|
text.append(content, style="dim")
|
||||||
elif "matches" in result:
|
elif "matches" in result:
|
||||||
matches = result["matches"]
|
matches = result["matches"]
|
||||||
if isinstance(matches, list) and matches:
|
if isinstance(matches, list) and matches:
|
||||||
for match in matches[:3]:
|
for match in matches:
|
||||||
if isinstance(match, dict) and "match" in match:
|
if isinstance(match, dict) and "match" in match:
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append(match["match"], style="dim")
|
text.append(match["match"], style="dim")
|
||||||
if len(matches) > 3:
|
|
||||||
text.append("\n ")
|
|
||||||
text.append(f"... +{len(matches) - 3} more matches", style="dim")
|
|
||||||
else:
|
else:
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append("No matches found", style="dim")
|
text.append("No matches found", style="dim")
|
||||||
@@ -123,13 +117,13 @@ class SendRequestRenderer(BaseToolRenderer):
|
|||||||
text.append(f"Status: {status_code}", style="dim")
|
text.append(f"Status: {status_code}", style="dim")
|
||||||
if response_body:
|
if response_body:
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append(cls.truncate(response_body, 300), style="dim")
|
text.append(response_body, style="dim")
|
||||||
else:
|
else:
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append("Response received", style="dim")
|
text.append("Response received", style="dim")
|
||||||
elif url:
|
elif url:
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append(cls.truncate(url, 400), style="dim")
|
text.append(url, style="dim")
|
||||||
else:
|
else:
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append("Sending...", style="dim")
|
text.append("Sending...", style="dim")
|
||||||
@@ -163,13 +157,13 @@ class RepeatRequestRenderer(BaseToolRenderer):
|
|||||||
text.append(f"Status: {status_code}", style="dim")
|
text.append(f"Status: {status_code}", style="dim")
|
||||||
if response_body:
|
if response_body:
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append(cls.truncate(response_body, 300), style="dim")
|
text.append(response_body, style="dim")
|
||||||
else:
|
else:
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append("Response received", style="dim")
|
text.append("Response received", style="dim")
|
||||||
elif modifications:
|
elif modifications:
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append(cls.truncate(str(modifications), 400), style="dim")
|
text.append(str(modifications), style="dim")
|
||||||
else:
|
else:
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append("No modifications", style="dim")
|
text.append("No modifications", style="dim")
|
||||||
@@ -211,16 +205,13 @@ class ListSitemapRenderer(BaseToolRenderer):
|
|||||||
if result and isinstance(result, dict) and "entries" in result:
|
if result and isinstance(result, dict) and "entries" in result:
|
||||||
entries = result["entries"]
|
entries = result["entries"]
|
||||||
if isinstance(entries, list) and entries:
|
if isinstance(entries, list) and entries:
|
||||||
for entry in entries[:4]:
|
for entry in entries:
|
||||||
if isinstance(entry, dict):
|
if isinstance(entry, dict):
|
||||||
label = entry.get("label", "?")
|
label = entry.get("label", "?")
|
||||||
kind = entry.get("kind", "?")
|
kind = entry.get("kind", "?")
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append(f"{kind}: {label}", style="dim")
|
text.append(f"{kind}: {label}", style="dim")
|
||||||
|
|
||||||
if len(entries) > 4:
|
|
||||||
text.append("\n ")
|
|
||||||
text.append(f"... +{len(entries) - 4} more", style="dim")
|
|
||||||
else:
|
else:
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append("No entries found", style="dim")
|
text.append("No entries found", style="dim")
|
||||||
|
|||||||
@@ -56,8 +56,7 @@ class PythonRenderer(BaseToolRenderer):
|
|||||||
text.append("\n")
|
text.append("\n")
|
||||||
|
|
||||||
if code and action in ["new_session", "execute"]:
|
if code and action in ["new_session", "execute"]:
|
||||||
code_display = cls.truncate(code, 2000)
|
text.append_text(cls._highlight_python(code))
|
||||||
text.append_text(cls._highlight_python(code_display))
|
|
||||||
elif action == "close":
|
elif action == "close":
|
||||||
text.append(" ")
|
text.append(" ")
|
||||||
text.append("Closing session...", style="dim")
|
text.append("Closing session...", style="dim")
|
||||||
|
|||||||
@@ -59,10 +59,8 @@ def _render_default_tool_widget(tool_data: dict[str, Any]) -> Static:
|
|||||||
text.append(tool_name, style="bold blue")
|
text.append(tool_name, style="bold blue")
|
||||||
text.append("\n")
|
text.append("\n")
|
||||||
|
|
||||||
for k, v in list(args.items())[:2]:
|
for k, v in list(args.items()):
|
||||||
str_v = str(v)
|
str_v = str(v)
|
||||||
if len(str_v) > 80:
|
|
||||||
str_v = str_v[:77] + "..."
|
|
||||||
text.append(" ")
|
text.append(" ")
|
||||||
text.append(k, style="dim")
|
text.append(k, style="dim")
|
||||||
text.append(": ")
|
text.append(": ")
|
||||||
@@ -71,8 +69,6 @@ def _render_default_tool_widget(tool_data: dict[str, Any]) -> Static:
|
|||||||
|
|
||||||
if status in ["completed", "failed", "error"] and result is not None:
|
if status in ["completed", "failed", "error"] and result is not None:
|
||||||
result_str = str(result)
|
result_str = str(result)
|
||||||
if len(result_str) > 150:
|
|
||||||
result_str = result_str[:147] + "..."
|
|
||||||
text.append("Result: ", style="bold")
|
text.append("Result: ", style="bold")
|
||||||
text.append(result_str)
|
text.append(result_str)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -154,6 +154,4 @@ class TerminalRenderer(BaseToolRenderer):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _format_command(cls, command: str) -> Text:
|
def _format_command(cls, command: str) -> Text:
|
||||||
if len(command) > 2000:
|
|
||||||
command = command[:2000] + "..."
|
|
||||||
return cls._highlight_bash(command)
|
return cls._highlight_bash(command)
|
||||||
|
|||||||
@@ -23,8 +23,7 @@ class ThinkRenderer(BaseToolRenderer):
|
|||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
|
|
||||||
if thought:
|
if thought:
|
||||||
thought_display = cls.truncate(thought, 600)
|
text.append(thought, style="italic dim")
|
||||||
text.append(thought_display, style="italic dim")
|
|
||||||
else:
|
else:
|
||||||
text.append("Thinking...", style="italic dim")
|
text.append("Thinking...", style="italic dim")
|
||||||
|
|
||||||
|
|||||||
@@ -14,29 +14,18 @@ STATUS_MARKERS: dict[str, str] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _format_todo_lines(text: Text, result: dict[str, Any], limit: int = 25) -> None:
|
def _format_todo_lines(text: Text, result: dict[str, Any]) -> None:
|
||||||
todos = result.get("todos")
|
todos = result.get("todos")
|
||||||
if not isinstance(todos, list) or not todos:
|
if not isinstance(todos, list) or not todos:
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append("No todos", style="dim")
|
text.append("No todos", style="dim")
|
||||||
return
|
return
|
||||||
|
|
||||||
total = len(todos)
|
for todo in todos:
|
||||||
|
|
||||||
for index, todo in enumerate(todos):
|
|
||||||
if index >= limit:
|
|
||||||
remaining = total - limit
|
|
||||||
if remaining > 0:
|
|
||||||
text.append("\n ")
|
|
||||||
text.append(f"... +{remaining} more", style="dim")
|
|
||||||
break
|
|
||||||
|
|
||||||
status = todo.get("status", "pending")
|
status = todo.get("status", "pending")
|
||||||
marker = STATUS_MARKERS.get(status, STATUS_MARKERS["pending"])
|
marker = STATUS_MARKERS.get(status, STATUS_MARKERS["pending"])
|
||||||
|
|
||||||
title = todo.get("title", "").strip() or "(untitled)"
|
title = todo.get("title", "").strip() or "(untitled)"
|
||||||
if len(title) > 90:
|
|
||||||
title = title[:87] + "..."
|
|
||||||
|
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append(marker)
|
text.append(marker)
|
||||||
|
|||||||
@@ -34,9 +34,6 @@ class UserMessageRenderer(BaseToolRenderer):
|
|||||||
def _format_user_message(cls, content: str) -> Text:
|
def _format_user_message(cls, content: str) -> Text:
|
||||||
text = Text()
|
text = Text()
|
||||||
|
|
||||||
if len(content) > 300:
|
|
||||||
content = content[:297] + "..."
|
|
||||||
|
|
||||||
text.append("▍", style="#3b82f6")
|
text.append("▍", style="#3b82f6")
|
||||||
text.append(" ")
|
text.append(" ")
|
||||||
text.append("You:", style="bold")
|
text.append("You:", style="bold")
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ class WebSearchRenderer(BaseToolRenderer):
|
|||||||
|
|
||||||
if query:
|
if query:
|
||||||
text.append("\n ")
|
text.append("\n ")
|
||||||
text.append(cls.truncate(query, 100), style="dim")
|
text.append(query, style="dim")
|
||||||
|
|
||||||
css_classes = cls.get_css_classes("completed")
|
css_classes = cls.get_css_classes("completed")
|
||||||
return Static(text, classes=css_classes)
|
return Static(text, classes=css_classes)
|
||||||
|
|||||||
@@ -491,7 +491,7 @@ class StrixTUIApp(App): # type: ignore[misc]
|
|||||||
|
|
||||||
self._start_scan_thread()
|
self._start_scan_thread()
|
||||||
|
|
||||||
self.set_interval(0.5, self._update_ui_from_tracer)
|
self.set_interval(0.1, self._update_ui_from_tracer)
|
||||||
|
|
||||||
def _update_ui_from_tracer(self) -> None:
|
def _update_ui_from_tracer(self) -> None:
|
||||||
if self.show_splash:
|
if self.show_splash:
|
||||||
@@ -596,13 +596,14 @@ class StrixTUIApp(App): # type: ignore[misc]
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
events = self._gather_agent_events(self.selected_agent_id)
|
events = self._gather_agent_events(self.selected_agent_id)
|
||||||
if not events:
|
streaming = self.tracer.get_streaming_content(self.selected_agent_id)
|
||||||
|
if not events and not streaming:
|
||||||
content, css_class = self._get_chat_placeholder_content(
|
content, css_class = self._get_chat_placeholder_content(
|
||||||
"Starting agent...", "placeholder-no-activity"
|
"Starting agent...", "placeholder-no-activity"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
current_event_ids = [e["id"] for e in events]
|
current_event_ids = [e["id"] for e in events]
|
||||||
if current_event_ids == self._displayed_events:
|
if current_event_ids == self._displayed_events and not streaming:
|
||||||
return
|
return
|
||||||
content = self._get_rendered_events_content(events)
|
content = self._get_rendered_events_content(events)
|
||||||
css_class = "chat-content"
|
css_class = "chat-content"
|
||||||
@@ -644,8 +645,92 @@ class StrixTUIApp(App): # type: ignore[misc]
|
|||||||
result.append_text(content)
|
result.append_text(content)
|
||||||
first = False
|
first = False
|
||||||
|
|
||||||
|
if self.selected_agent_id:
|
||||||
|
streaming = self.tracer.get_streaming_content(self.selected_agent_id)
|
||||||
|
if streaming:
|
||||||
|
streaming_text = self._render_streaming_content(streaming)
|
||||||
|
if streaming_text:
|
||||||
|
if not first:
|
||||||
|
result.append("\n\n")
|
||||||
|
result.append_text(streaming_text)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _render_streaming_content(self, content: str) -> Text:
|
||||||
|
from strix.interface.streaming_parser import parse_streaming_content
|
||||||
|
|
||||||
|
result = Text()
|
||||||
|
segments = parse_streaming_content(content)
|
||||||
|
|
||||||
|
for i, segment in enumerate(segments):
|
||||||
|
if i > 0:
|
||||||
|
result.append("\n\n")
|
||||||
|
|
||||||
|
if segment.type == "text":
|
||||||
|
from strix.interface.tool_components.agent_message_renderer import (
|
||||||
|
AgentMessageRenderer,
|
||||||
|
)
|
||||||
|
|
||||||
|
text_content = AgentMessageRenderer.render_simple(segment.content)
|
||||||
|
result.append_text(text_content)
|
||||||
|
|
||||||
|
elif segment.type == "tool":
|
||||||
|
tool_text = self._render_streaming_tool(
|
||||||
|
segment.tool_name or "unknown",
|
||||||
|
segment.args or {},
|
||||||
|
segment.is_complete,
|
||||||
|
)
|
||||||
|
result.append_text(tool_text)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _render_streaming_tool(
|
||||||
|
self, tool_name: str, args: dict[str, str], is_complete: bool
|
||||||
|
) -> Text:
|
||||||
|
from strix.interface.tool_components.registry import get_tool_renderer
|
||||||
|
|
||||||
|
tool_data = {
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"args": args,
|
||||||
|
"status": "completed" if is_complete else "running",
|
||||||
|
"result": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
renderer = get_tool_renderer(tool_name)
|
||||||
|
if renderer:
|
||||||
|
widget = renderer.render(tool_data)
|
||||||
|
renderable = widget.renderable
|
||||||
|
if isinstance(renderable, Text):
|
||||||
|
return renderable
|
||||||
|
text = Text()
|
||||||
|
text.append(str(renderable))
|
||||||
|
return text
|
||||||
|
|
||||||
|
return self._render_default_streaming_tool(tool_name, args, is_complete)
|
||||||
|
|
||||||
|
def _render_default_streaming_tool(
|
||||||
|
self, tool_name: str, args: dict[str, str], is_complete: bool
|
||||||
|
) -> Text:
|
||||||
|
text = Text()
|
||||||
|
|
||||||
|
if is_complete:
|
||||||
|
text.append("✓ ", style="green")
|
||||||
|
else:
|
||||||
|
text.append("● ", style="yellow")
|
||||||
|
|
||||||
|
text.append("Using tool ", style="dim")
|
||||||
|
text.append(tool_name, style="bold blue")
|
||||||
|
|
||||||
|
if args:
|
||||||
|
for key, value in list(args.items())[:3]:
|
||||||
|
text.append("\n ")
|
||||||
|
text.append(key, style="dim")
|
||||||
|
text.append(": ")
|
||||||
|
display_value = value if len(value) <= 100 else value[:97] + "..."
|
||||||
|
text.append(display_value, style="italic" if not is_complete else None)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
def _get_status_display_content(
|
def _get_status_display_content(
|
||||||
self, agent_id: str, agent_data: dict[str, Any]
|
self, agent_id: str, agent_data: dict[str, Any]
|
||||||
) -> tuple[Text | None, Text, bool]:
|
) -> tuple[Text | None, Text, bool]:
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from fnmatch import fnmatch
|
from fnmatch import fnmatch
|
||||||
@@ -12,7 +13,7 @@ from jinja2 import (
|
|||||||
FileSystemLoader,
|
FileSystemLoader,
|
||||||
select_autoescape,
|
select_autoescape,
|
||||||
)
|
)
|
||||||
from litellm import ModelResponse, completion_cost
|
from litellm import completion_cost, stream_chunk_builder
|
||||||
from litellm.utils import supports_prompt_caching, supports_vision
|
from litellm.utils import supports_prompt_caching, supports_vision
|
||||||
|
|
||||||
from strix.llm.config import LLMConfig
|
from strix.llm.config import LLMConfig
|
||||||
@@ -276,7 +277,7 @@ class LLM:
|
|||||||
conversation_history: list[dict[str, Any]],
|
conversation_history: list[dict[str, Any]],
|
||||||
scan_id: str | None = None,
|
scan_id: str | None = None,
|
||||||
step_number: int = 1,
|
step_number: int = 1,
|
||||||
) -> LLMResponse:
|
) -> AsyncIterator[LLMResponse]:
|
||||||
messages = [{"role": "system", "content": self.system_prompt}]
|
messages = [{"role": "system", "content": self.system_prompt}]
|
||||||
|
|
||||||
identity_message = self._build_identity_message()
|
identity_message = self._build_identity_message()
|
||||||
@@ -292,30 +293,43 @@ class LLM:
|
|||||||
cached_messages = self._prepare_cached_messages(messages)
|
cached_messages = self._prepare_cached_messages(messages)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self._make_request(cached_messages)
|
accumulated_content = ""
|
||||||
self._update_usage_stats(response)
|
chunks: list[Any] = []
|
||||||
|
|
||||||
content = ""
|
async for chunk in self._stream_request(cached_messages):
|
||||||
if (
|
chunks.append(chunk)
|
||||||
response.choices
|
delta = self._extract_chunk_delta(chunk)
|
||||||
and hasattr(response.choices[0], "message")
|
if delta:
|
||||||
and response.choices[0].message
|
accumulated_content += delta
|
||||||
):
|
|
||||||
content = getattr(response.choices[0].message, "content", "") or ""
|
|
||||||
|
|
||||||
content = _truncate_to_first_function(content)
|
if "</function>" in accumulated_content:
|
||||||
|
function_end = accumulated_content.find("</function>") + len("</function>")
|
||||||
|
accumulated_content = accumulated_content[:function_end]
|
||||||
|
|
||||||
if "</function>" in content:
|
yield LLMResponse(
|
||||||
function_end_index = content.find("</function>") + len("</function>")
|
|
||||||
content = content[:function_end_index]
|
|
||||||
|
|
||||||
tool_invocations = parse_tool_invocations(content)
|
|
||||||
|
|
||||||
return LLMResponse(
|
|
||||||
scan_id=scan_id,
|
scan_id=scan_id,
|
||||||
step_number=step_number,
|
step_number=step_number,
|
||||||
role=StepRole.AGENT,
|
role=StepRole.AGENT,
|
||||||
content=content,
|
content=accumulated_content,
|
||||||
|
tool_invocations=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if chunks:
|
||||||
|
complete_response = stream_chunk_builder(chunks)
|
||||||
|
self._update_usage_stats(complete_response)
|
||||||
|
|
||||||
|
accumulated_content = _truncate_to_first_function(accumulated_content)
|
||||||
|
if "</function>" in accumulated_content:
|
||||||
|
function_end = accumulated_content.find("</function>") + len("</function>")
|
||||||
|
accumulated_content = accumulated_content[:function_end]
|
||||||
|
|
||||||
|
tool_invocations = parse_tool_invocations(accumulated_content)
|
||||||
|
|
||||||
|
yield LLMResponse(
|
||||||
|
scan_id=scan_id,
|
||||||
|
step_number=step_number,
|
||||||
|
role=StepRole.AGENT,
|
||||||
|
content=accumulated_content,
|
||||||
tool_invocations=tool_invocations if tool_invocations else None,
|
tool_invocations=tool_invocations if tool_invocations else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -364,6 +378,12 @@ class LLM:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
|
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
|
||||||
|
|
||||||
|
def _extract_chunk_delta(self, chunk: Any) -> str:
|
||||||
|
if chunk.choices and hasattr(chunk.choices[0], "delta"):
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
return getattr(delta, "content", "") or ""
|
||||||
|
return ""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def usage_stats(self) -> dict[str, dict[str, int | float]]:
|
def usage_stats(self) -> dict[str, dict[str, int | float]]:
|
||||||
return {
|
return {
|
||||||
@@ -436,10 +456,10 @@ class LLM:
|
|||||||
filtered_messages.append(updated_msg)
|
filtered_messages.append(updated_msg)
|
||||||
return filtered_messages
|
return filtered_messages
|
||||||
|
|
||||||
async def _make_request(
|
async def _stream_request(
|
||||||
self,
|
self,
|
||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
) -> ModelResponse:
|
) -> AsyncIterator[Any]:
|
||||||
if not self._model_supports_vision():
|
if not self._model_supports_vision():
|
||||||
messages = self._filter_images_from_messages(messages)
|
messages = self._filter_images_from_messages(messages)
|
||||||
|
|
||||||
@@ -447,6 +467,7 @@ class LLM:
|
|||||||
"model": self.config.model_name,
|
"model": self.config.model_name,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"timeout": self.config.timeout,
|
"timeout": self.config.timeout,
|
||||||
|
"stream_options": {"include_usage": True},
|
||||||
}
|
}
|
||||||
|
|
||||||
if _LLM_API_KEY:
|
if _LLM_API_KEY:
|
||||||
@@ -461,14 +482,13 @@ class LLM:
|
|||||||
completion_args["reasoning_effort"] = "high"
|
completion_args["reasoning_effort"] = "high"
|
||||||
|
|
||||||
queue = get_global_queue()
|
queue = get_global_queue()
|
||||||
response = await queue.make_request(completion_args)
|
|
||||||
|
|
||||||
self._total_stats.requests += 1
|
self._total_stats.requests += 1
|
||||||
self._last_request_stats = RequestStats(requests=1)
|
self._last_request_stats = RequestStats(requests=1)
|
||||||
|
|
||||||
return response
|
async for chunk in queue.stream_request(completion_args):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
def _update_usage_stats(self, response: ModelResponse) -> None:
|
def _update_usage_stats(self, response: Any) -> None:
|
||||||
try:
|
try:
|
||||||
if hasattr(response, "usage") and response.usage:
|
if hasattr(response, "usage") and response.usage:
|
||||||
input_tokens = getattr(response.usage, "prompt_tokens", 0)
|
input_tokens = getattr(response.usage, "prompt_tokens", 0)
|
||||||
|
|||||||
@@ -3,10 +3,12 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import ModelResponse, completion
|
from litellm import completion
|
||||||
|
from litellm.types.utils import ModelResponseStream
|
||||||
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
|
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
|
||||||
|
|
||||||
|
|
||||||
@@ -42,7 +44,9 @@ class LLMRequestQueue:
|
|||||||
self._last_request_time = 0.0
|
self._last_request_time = 0.0
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
async def make_request(self, completion_args: dict[str, Any]) -> ModelResponse:
|
async def stream_request(
|
||||||
|
self, completion_args: dict[str, Any]
|
||||||
|
) -> AsyncIterator[ModelResponseStream]:
|
||||||
try:
|
try:
|
||||||
while not self._semaphore.acquire(timeout=0.2):
|
while not self._semaphore.acquire(timeout=0.2):
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
@@ -56,7 +60,8 @@ class LLMRequestQueue:
|
|||||||
if sleep_needed > 0:
|
if sleep_needed > 0:
|
||||||
await asyncio.sleep(sleep_needed)
|
await asyncio.sleep(sleep_needed)
|
||||||
|
|
||||||
return await self._reliable_request(completion_args)
|
async for chunk in self._reliable_stream_request(completion_args):
|
||||||
|
yield chunk
|
||||||
finally:
|
finally:
|
||||||
self._semaphore.release()
|
self._semaphore.release()
|
||||||
|
|
||||||
@@ -66,15 +71,12 @@ class LLMRequestQueue:
|
|||||||
retry=retry_if_exception(should_retry_exception),
|
retry=retry_if_exception(should_retry_exception),
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
async def _reliable_request(self, completion_args: dict[str, Any]) -> ModelResponse:
|
async def _reliable_stream_request(
|
||||||
response = completion(**completion_args, stream=False)
|
self, completion_args: dict[str, Any]
|
||||||
if isinstance(response, ModelResponse):
|
) -> AsyncIterator[ModelResponseStream]:
|
||||||
return response
|
response = await asyncio.to_thread(completion, **completion_args, stream=True)
|
||||||
self._raise_unexpected_response()
|
for chunk in response:
|
||||||
raise RuntimeError("Unreachable code")
|
yield chunk
|
||||||
|
|
||||||
def _raise_unexpected_response(self) -> None:
|
|
||||||
raise RuntimeError("Unexpected response type")
|
|
||||||
|
|
||||||
|
|
||||||
_global_queue: LLMRequestQueue | None = None
|
_global_queue: LLMRequestQueue | None = None
|
||||||
|
|||||||
@@ -47,10 +47,14 @@ def parse_tool_invocations(content: str) -> list[dict[str, Any]] | None:
|
|||||||
|
|
||||||
|
|
||||||
def _fix_stopword(content: str) -> str:
|
def _fix_stopword(content: str) -> str:
|
||||||
if "<function=" in content and content.count("<function=") == 1:
|
if (
|
||||||
|
"<function=" in content
|
||||||
|
and content.count("<function=") == 1
|
||||||
|
and "</function>" not in content
|
||||||
|
):
|
||||||
if content.endswith("</"):
|
if content.endswith("</"):
|
||||||
content = content.rstrip() + "function>"
|
content = content.rstrip() + "function>"
|
||||||
elif not content.rstrip().endswith("</function>"):
|
else:
|
||||||
content = content + "\n</function>"
|
content = content + "\n</function>"
|
||||||
return content
|
return content
|
||||||
|
|
||||||
@@ -75,6 +79,12 @@ def clean_content(content: str) -> str:
|
|||||||
tool_pattern = r"<function=[^>]+>.*?</function>"
|
tool_pattern = r"<function=[^>]+>.*?</function>"
|
||||||
cleaned = re.sub(tool_pattern, "", content, flags=re.DOTALL)
|
cleaned = re.sub(tool_pattern, "", content, flags=re.DOTALL)
|
||||||
|
|
||||||
|
incomplete_tool_pattern = r"<function=[^>]+>.*$"
|
||||||
|
cleaned = re.sub(incomplete_tool_pattern, "", cleaned, flags=re.DOTALL)
|
||||||
|
|
||||||
|
partial_tag_pattern = r"<f(?:u(?:n(?:c(?:t(?:i(?:o(?:n(?:=(?:[^>]*)?)?)?)?)?)?)?)?)?$"
|
||||||
|
cleaned = re.sub(partial_tag_pattern, "", cleaned)
|
||||||
|
|
||||||
hidden_xml_patterns = [
|
hidden_xml_patterns = [
|
||||||
r"<inter_agent_message>.*?</inter_agent_message>",
|
r"<inter_agent_message>.*?</inter_agent_message>",
|
||||||
r"<agent_completion_report>.*?</agent_completion_report>",
|
r"<agent_completion_report>.*?</agent_completion_report>",
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ class Tracer:
|
|||||||
self.agents: dict[str, dict[str, Any]] = {}
|
self.agents: dict[str, dict[str, Any]] = {}
|
||||||
self.tool_executions: dict[int, dict[str, Any]] = {}
|
self.tool_executions: dict[int, dict[str, Any]] = {}
|
||||||
self.chat_messages: list[dict[str, Any]] = []
|
self.chat_messages: list[dict[str, Any]] = []
|
||||||
|
self.streaming_content: dict[str, str] = {}
|
||||||
|
|
||||||
self.vulnerability_reports: list[dict[str, Any]] = []
|
self.vulnerability_reports: list[dict[str, Any]] = []
|
||||||
self.final_scan_result: str | None = None
|
self.final_scan_result: str | None = None
|
||||||
@@ -333,5 +334,14 @@ class Tracer:
|
|||||||
"total_tokens": total_stats["input_tokens"] + total_stats["output_tokens"],
|
"total_tokens": total_stats["input_tokens"] + total_stats["output_tokens"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def update_streaming_content(self, agent_id: str, content: str) -> None:
|
||||||
|
self.streaming_content[agent_id] = content
|
||||||
|
|
||||||
|
def clear_streaming_content(self, agent_id: str) -> None:
|
||||||
|
self.streaming_content.pop(agent_id, None)
|
||||||
|
|
||||||
|
def get_streaming_content(self, agent_id: str) -> str | None:
|
||||||
|
return self.streaming_content.get(agent_id)
|
||||||
|
|
||||||
def cleanup(self) -> None:
|
def cleanup(self) -> None:
|
||||||
self.save_run_data(mark_complete=True)
|
self.save_run_data(mark_complete=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user