feat(tui): add real-time streaming LLM output with full content display
- Convert LiteLLM requests to streaming mode with stream_request() - Add streaming parser to handle live LLM output segments - Update TUI for real-time streaming content rendering - Add tracer methods for streaming content tracking - Clean function tags from streamed content to prevent display - Remove all truncation from tool renderers for full content visibility
This commit is contained in:
@@ -351,9 +351,16 @@ class BaseAgent(metaclass=AgentMeta):
|
||||
self.state.add_message("user", task)
|
||||
|
||||
async def _process_iteration(self, tracer: Optional["Tracer"]) -> bool:
|
||||
response = await self.llm.generate(self.state.get_conversation_history())
|
||||
final_response = None
|
||||
async for response in self.llm.generate(self.state.get_conversation_history()):
|
||||
final_response = response
|
||||
if tracer and response.content:
|
||||
tracer.update_streaming_content(self.state.agent_id, response.content)
|
||||
|
||||
content_stripped = (response.content or "").strip()
|
||||
if final_response is None:
|
||||
return False
|
||||
|
||||
content_stripped = (final_response.content or "").strip()
|
||||
|
||||
if not content_stripped:
|
||||
corrective_message = (
|
||||
@@ -369,17 +376,18 @@ class BaseAgent(metaclass=AgentMeta):
|
||||
self.state.add_message("user", corrective_message)
|
||||
return False
|
||||
|
||||
self.state.add_message("assistant", response.content)
|
||||
self.state.add_message("assistant", final_response.content)
|
||||
if tracer:
|
||||
tracer.clear_streaming_content(self.state.agent_id)
|
||||
tracer.log_chat_message(
|
||||
content=clean_content(response.content),
|
||||
content=clean_content(final_response.content),
|
||||
role="assistant",
|
||||
agent_id=self.state.agent_id,
|
||||
)
|
||||
|
||||
actions = (
|
||||
response.tool_invocations
|
||||
if hasattr(response, "tool_invocations") and response.tool_invocations
|
||||
final_response.tool_invocations
|
||||
if hasattr(final_response, "tool_invocations") and final_response.tool_invocations
|
||||
else []
|
||||
)
|
||||
|
||||
|
||||
119
strix/interface/streaming_parser.py
Normal file
119
strix/interface/streaming_parser.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import html
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
|
||||
_FUNCTION_TAG_PREFIX = "<function="
|
||||
|
||||
|
||||
def _get_safe_content(content: str) -> tuple[str, str]:
|
||||
if not content:
|
||||
return "", ""
|
||||
|
||||
last_lt = content.rfind("<")
|
||||
if last_lt == -1:
|
||||
return content, ""
|
||||
|
||||
suffix = content[last_lt:]
|
||||
target = _FUNCTION_TAG_PREFIX # "<function="
|
||||
|
||||
if target.startswith(suffix):
|
||||
return content[:last_lt], suffix
|
||||
|
||||
return content, ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamSegment:
|
||||
type: Literal["text", "tool"]
|
||||
content: str
|
||||
tool_name: str | None = None
|
||||
args: dict[str, str] | None = None
|
||||
is_complete: bool = False
|
||||
|
||||
|
||||
def parse_streaming_content(content: str) -> list[StreamSegment]:
|
||||
if not content:
|
||||
return []
|
||||
|
||||
segments: list[StreamSegment] = []
|
||||
|
||||
func_pattern = r"<function=([^>]+)>"
|
||||
func_matches = list(re.finditer(func_pattern, content))
|
||||
|
||||
if not func_matches:
|
||||
safe_content, _ = _get_safe_content(content)
|
||||
text = safe_content.strip()
|
||||
if text:
|
||||
segments.append(StreamSegment(type="text", content=text))
|
||||
return segments
|
||||
|
||||
first_func_start = func_matches[0].start()
|
||||
if first_func_start > 0:
|
||||
text_before = content[:first_func_start].strip()
|
||||
if text_before:
|
||||
segments.append(StreamSegment(type="text", content=text_before))
|
||||
|
||||
for i, match in enumerate(func_matches):
|
||||
tool_name = match.group(1)
|
||||
func_start = match.end()
|
||||
|
||||
func_end_match = re.search(r"</function>", content[func_start:])
|
||||
|
||||
if func_end_match:
|
||||
func_body = content[func_start : func_start + func_end_match.start()]
|
||||
is_complete = True
|
||||
end_pos = func_start + func_end_match.end()
|
||||
else:
|
||||
if i + 1 < len(func_matches):
|
||||
next_func_start = func_matches[i + 1].start()
|
||||
func_body = content[func_start:next_func_start]
|
||||
else:
|
||||
func_body = content[func_start:]
|
||||
is_complete = False
|
||||
end_pos = len(content)
|
||||
|
||||
args = _parse_streaming_params(func_body)
|
||||
|
||||
segments.append(
|
||||
StreamSegment(
|
||||
type="tool",
|
||||
content=func_body,
|
||||
tool_name=tool_name,
|
||||
args=args,
|
||||
is_complete=is_complete,
|
||||
)
|
||||
)
|
||||
|
||||
if is_complete and i + 1 < len(func_matches):
|
||||
next_start = func_matches[i + 1].start()
|
||||
text_between = content[end_pos:next_start].strip()
|
||||
if text_between:
|
||||
segments.append(StreamSegment(type="text", content=text_between))
|
||||
|
||||
return segments
|
||||
|
||||
|
||||
def _parse_streaming_params(func_body: str) -> dict[str, str]:
|
||||
args: dict[str, str] = {}
|
||||
|
||||
complete_pattern = r"<parameter=([^>]+)>(.*?)</parameter>"
|
||||
complete_matches = list(re.finditer(complete_pattern, func_body, re.DOTALL))
|
||||
complete_end_pos = 0
|
||||
|
||||
for match in complete_matches:
|
||||
param_name = match.group(1)
|
||||
param_value = html.unescape(match.group(2).strip())
|
||||
args[param_name] = param_value
|
||||
complete_end_pos = max(complete_end_pos, match.end())
|
||||
|
||||
remaining = func_body[complete_end_pos:]
|
||||
incomplete_pattern = r"<parameter=([^>]+)>(.*)$"
|
||||
incomplete_match = re.search(incomplete_pattern, remaining, re.DOTALL)
|
||||
if incomplete_match:
|
||||
param_name = incomplete_match.group(1)
|
||||
param_value = html.unescape(incomplete_match.group(2).strip())
|
||||
args[param_name] = param_value
|
||||
|
||||
return args
|
||||
@@ -181,4 +181,10 @@ class AgentMessageRenderer(BaseToolRenderer):
|
||||
if not content:
|
||||
return Text()
|
||||
|
||||
return _apply_markdown_styles(content)
|
||||
from strix.llm.utils import clean_content
|
||||
|
||||
cleaned = clean_content(content)
|
||||
if not cleaned:
|
||||
return Text()
|
||||
|
||||
return _apply_markdown_styles(cleaned)
|
||||
|
||||
@@ -40,7 +40,7 @@ class CreateAgentRenderer(BaseToolRenderer):
|
||||
|
||||
if task:
|
||||
text.append("\n ")
|
||||
text.append(cls.truncate(task, 400), style="dim")
|
||||
text.append(task, style="dim")
|
||||
else:
|
||||
text.append("\n ")
|
||||
text.append("Spawning agent...", style="dim")
|
||||
@@ -66,7 +66,7 @@ class SendMessageToAgentRenderer(BaseToolRenderer):
|
||||
|
||||
if message:
|
||||
text.append("\n ")
|
||||
text.append(cls.truncate(message, 400), style="dim")
|
||||
text.append(message, style="dim")
|
||||
else:
|
||||
text.append("\n ")
|
||||
text.append("Sending...", style="dim")
|
||||
@@ -129,7 +129,7 @@ class WaitForMessageRenderer(BaseToolRenderer):
|
||||
|
||||
if reason:
|
||||
text.append("\n ")
|
||||
text.append(cls.truncate(reason, 400), style="dim")
|
||||
text.append(reason, style="dim")
|
||||
else:
|
||||
text.append("\n ")
|
||||
text.append("Agent paused until message received...", style="dim")
|
||||
|
||||
@@ -23,12 +23,6 @@ class BaseToolRenderer(ABC):
|
||||
css_classes = cls.get_css_classes(status)
|
||||
return Static(content, classes=css_classes)
|
||||
|
||||
@classmethod
|
||||
def truncate(cls, text: str, max_length: int = 500) -> str:
|
||||
if len(text) <= max_length:
|
||||
return text
|
||||
return text[: max_length - 3] + "..."
|
||||
|
||||
@classmethod
|
||||
def status_icon(cls, status: str) -> tuple[str, str]:
|
||||
icons = {
|
||||
|
||||
@@ -74,7 +74,7 @@ class BrowserRenderer(BaseToolRenderer):
|
||||
def _build_url_action(cls, text: Text, label: str, url: str | None, suffix: str = "") -> None:
|
||||
text.append(label, style="#06b6d4")
|
||||
if url:
|
||||
text.append(cls.truncate(url, 300), style="#06b6d4")
|
||||
text.append(url, style="#06b6d4")
|
||||
if suffix:
|
||||
text.append(suffix, style="#06b6d4")
|
||||
|
||||
@@ -120,7 +120,7 @@ class BrowserRenderer(BaseToolRenderer):
|
||||
label, value = handlers[action]
|
||||
text.append(label, style="#06b6d4")
|
||||
if value:
|
||||
text.append(cls.truncate(str(value), 200), style="#06b6d4")
|
||||
text.append(str(value), style="#06b6d4")
|
||||
return text
|
||||
|
||||
if action == "execute_js":
|
||||
@@ -128,7 +128,7 @@ class BrowserRenderer(BaseToolRenderer):
|
||||
js_code = args.get("js_code")
|
||||
if js_code:
|
||||
text.append("\n")
|
||||
text.append_text(cls._highlight_js(cls.truncate(js_code, 2000)))
|
||||
text.append_text(cls._highlight_js(js_code))
|
||||
return text
|
||||
|
||||
text.append(action, style="#06b6d4")
|
||||
|
||||
@@ -83,8 +83,7 @@ class StrReplaceEditorRenderer(BaseToolRenderer):
|
||||
|
||||
if command == "str_replace" and (old_str or new_str):
|
||||
if old_str:
|
||||
old_display = cls.truncate(old_str, 1000)
|
||||
highlighted_old = cls._highlight_code(old_display, path)
|
||||
highlighted_old = cls._highlight_code(old_str, path)
|
||||
for line in highlighted_old.plain.split("\n"):
|
||||
text.append("\n")
|
||||
text.append("-", style="#ef4444")
|
||||
@@ -92,8 +91,7 @@ class StrReplaceEditorRenderer(BaseToolRenderer):
|
||||
text.append(line)
|
||||
|
||||
if new_str:
|
||||
new_display = cls.truncate(new_str, 1000)
|
||||
highlighted_new = cls._highlight_code(new_display, path)
|
||||
highlighted_new = cls._highlight_code(new_str, path)
|
||||
for line in highlighted_new.plain.split("\n"):
|
||||
text.append("\n")
|
||||
text.append("+", style="#22c55e")
|
||||
@@ -101,13 +99,11 @@ class StrReplaceEditorRenderer(BaseToolRenderer):
|
||||
text.append(line)
|
||||
|
||||
elif command == "create" and file_text:
|
||||
text_display = cls.truncate(file_text, 1500)
|
||||
text.append("\n")
|
||||
text.append_text(cls._highlight_code(text_display, path))
|
||||
text.append_text(cls._highlight_code(file_text, path))
|
||||
|
||||
elif command == "insert" and new_str:
|
||||
new_display = cls.truncate(new_str, 1000)
|
||||
highlighted_new = cls._highlight_code(new_display, path)
|
||||
highlighted_new = cls._highlight_code(new_str, path)
|
||||
for line in highlighted_new.plain.split("\n"):
|
||||
text.append("\n")
|
||||
text.append("+", style="#22c55e")
|
||||
@@ -164,19 +160,15 @@ class SearchFilesRenderer(BaseToolRenderer):
|
||||
text.append(" ")
|
||||
|
||||
if path and regex:
|
||||
path_display = path[-30:] if len(path) > 30 else path
|
||||
regex_display = regex[:30] if len(regex) > 30 else regex
|
||||
text.append(path_display, style="dim")
|
||||
text.append(path, style="dim")
|
||||
text.append(" for '", style="dim")
|
||||
text.append(regex_display, style="dim")
|
||||
text.append(regex, style="dim")
|
||||
text.append("'", style="dim")
|
||||
elif path:
|
||||
path_display = path[-60:] if len(path) > 60 else path
|
||||
text.append(path_display, style="dim")
|
||||
text.append(path, style="dim")
|
||||
elif regex:
|
||||
regex_display = regex[:60] if len(regex) > 60 else regex
|
||||
text.append("'", style="dim")
|
||||
text.append(regex_display, style="dim")
|
||||
text.append(regex, style="dim")
|
||||
text.append("'", style="dim")
|
||||
else:
|
||||
text.append("Searching...", style="dim")
|
||||
|
||||
@@ -28,11 +28,11 @@ class CreateNoteRenderer(BaseToolRenderer):
|
||||
|
||||
if title:
|
||||
text.append("\n ")
|
||||
text.append(cls.truncate(title.strip(), 300))
|
||||
text.append(title.strip())
|
||||
|
||||
if content:
|
||||
text.append("\n ")
|
||||
text.append(cls.truncate(content.strip(), 800), style="dim")
|
||||
text.append(content.strip(), style="dim")
|
||||
|
||||
if not title and not content:
|
||||
text.append("\n ")
|
||||
@@ -75,11 +75,11 @@ class UpdateNoteRenderer(BaseToolRenderer):
|
||||
|
||||
if title:
|
||||
text.append("\n ")
|
||||
text.append(cls.truncate(title, 300))
|
||||
text.append(title)
|
||||
|
||||
if content:
|
||||
text.append("\n ")
|
||||
text.append(cls.truncate(content.strip(), 800), style="dim")
|
||||
text.append(content.strip(), style="dim")
|
||||
|
||||
if not title and not content:
|
||||
text.append("\n ")
|
||||
@@ -110,23 +110,18 @@ class ListNotesRenderer(BaseToolRenderer):
|
||||
text.append("\n ")
|
||||
text.append("No notes", style="dim")
|
||||
else:
|
||||
for note in notes[:5]:
|
||||
for note in notes:
|
||||
title = note.get("title", "").strip() or "(untitled)"
|
||||
category = note.get("category", "general")
|
||||
note_content = note.get("content", "").strip()
|
||||
|
||||
text.append("\n - ")
|
||||
text.append(cls.truncate(title, 300))
|
||||
text.append(title)
|
||||
text.append(f" ({category})", style="dim")
|
||||
|
||||
if note_content:
|
||||
text.append("\n ")
|
||||
text.append(cls.truncate(note_content, 400), style="dim")
|
||||
|
||||
remaining = max(count - 5, 0)
|
||||
if remaining:
|
||||
text.append("\n ")
|
||||
text.append(f"... +{remaining} more", style="dim")
|
||||
text.append(note_content, style="dim")
|
||||
else:
|
||||
text.append("\n ")
|
||||
text.append("Loading...", style="dim")
|
||||
|
||||
@@ -26,7 +26,7 @@ class ListRequestsRenderer(BaseToolRenderer):
|
||||
if result and isinstance(result, dict) and "requests" in result:
|
||||
requests = result["requests"]
|
||||
if isinstance(requests, list) and requests:
|
||||
for req in requests[:3]:
|
||||
for req in requests:
|
||||
if isinstance(req, dict):
|
||||
method = req.get("method", "?")
|
||||
path = req.get("path", "?")
|
||||
@@ -34,16 +34,12 @@ class ListRequestsRenderer(BaseToolRenderer):
|
||||
status = response.get("statusCode", "?")
|
||||
text.append("\n ")
|
||||
text.append(f"{method} {path} → {status}", style="dim")
|
||||
|
||||
if len(requests) > 3:
|
||||
text.append("\n ")
|
||||
text.append(f"... +{len(requests) - 3} more", style="dim")
|
||||
else:
|
||||
text.append("\n ")
|
||||
text.append("No requests found", style="dim")
|
||||
elif httpql_filter:
|
||||
text.append("\n ")
|
||||
text.append(cls.truncate(httpql_filter, 300), style="dim")
|
||||
text.append(httpql_filter, style="dim")
|
||||
else:
|
||||
text.append("\n ")
|
||||
text.append("All requests", style="dim")
|
||||
@@ -72,17 +68,15 @@ class ViewRequestRenderer(BaseToolRenderer):
|
||||
if "content" in result:
|
||||
content = result["content"]
|
||||
text.append("\n ")
|
||||
text.append(cls.truncate(content, 500), style="dim")
|
||||
text.append(content, style="dim")
|
||||
elif "matches" in result:
|
||||
matches = result["matches"]
|
||||
if isinstance(matches, list) and matches:
|
||||
for match in matches[:3]:
|
||||
for match in matches:
|
||||
if isinstance(match, dict) and "match" in match:
|
||||
text.append("\n ")
|
||||
text.append(match["match"], style="dim")
|
||||
if len(matches) > 3:
|
||||
text.append("\n ")
|
||||
text.append(f"... +{len(matches) - 3} more matches", style="dim")
|
||||
|
||||
else:
|
||||
text.append("\n ")
|
||||
text.append("No matches found", style="dim")
|
||||
@@ -123,13 +117,13 @@ class SendRequestRenderer(BaseToolRenderer):
|
||||
text.append(f"Status: {status_code}", style="dim")
|
||||
if response_body:
|
||||
text.append("\n ")
|
||||
text.append(cls.truncate(response_body, 300), style="dim")
|
||||
text.append(response_body, style="dim")
|
||||
else:
|
||||
text.append("\n ")
|
||||
text.append("Response received", style="dim")
|
||||
elif url:
|
||||
text.append("\n ")
|
||||
text.append(cls.truncate(url, 400), style="dim")
|
||||
text.append(url, style="dim")
|
||||
else:
|
||||
text.append("\n ")
|
||||
text.append("Sending...", style="dim")
|
||||
@@ -163,13 +157,13 @@ class RepeatRequestRenderer(BaseToolRenderer):
|
||||
text.append(f"Status: {status_code}", style="dim")
|
||||
if response_body:
|
||||
text.append("\n ")
|
||||
text.append(cls.truncate(response_body, 300), style="dim")
|
||||
text.append(response_body, style="dim")
|
||||
else:
|
||||
text.append("\n ")
|
||||
text.append("Response received", style="dim")
|
||||
elif modifications:
|
||||
text.append("\n ")
|
||||
text.append(cls.truncate(str(modifications), 400), style="dim")
|
||||
text.append(str(modifications), style="dim")
|
||||
else:
|
||||
text.append("\n ")
|
||||
text.append("No modifications", style="dim")
|
||||
@@ -211,16 +205,13 @@ class ListSitemapRenderer(BaseToolRenderer):
|
||||
if result and isinstance(result, dict) and "entries" in result:
|
||||
entries = result["entries"]
|
||||
if isinstance(entries, list) and entries:
|
||||
for entry in entries[:4]:
|
||||
for entry in entries:
|
||||
if isinstance(entry, dict):
|
||||
label = entry.get("label", "?")
|
||||
kind = entry.get("kind", "?")
|
||||
text.append("\n ")
|
||||
text.append(f"{kind}: {label}", style="dim")
|
||||
|
||||
if len(entries) > 4:
|
||||
text.append("\n ")
|
||||
text.append(f"... +{len(entries) - 4} more", style="dim")
|
||||
else:
|
||||
text.append("\n ")
|
||||
text.append("No entries found", style="dim")
|
||||
|
||||
@@ -56,8 +56,7 @@ class PythonRenderer(BaseToolRenderer):
|
||||
text.append("\n")
|
||||
|
||||
if code and action in ["new_session", "execute"]:
|
||||
code_display = cls.truncate(code, 2000)
|
||||
text.append_text(cls._highlight_python(code_display))
|
||||
text.append_text(cls._highlight_python(code))
|
||||
elif action == "close":
|
||||
text.append(" ")
|
||||
text.append("Closing session...", style="dim")
|
||||
|
||||
@@ -59,10 +59,8 @@ def _render_default_tool_widget(tool_data: dict[str, Any]) -> Static:
|
||||
text.append(tool_name, style="bold blue")
|
||||
text.append("\n")
|
||||
|
||||
for k, v in list(args.items())[:2]:
|
||||
for k, v in list(args.items()):
|
||||
str_v = str(v)
|
||||
if len(str_v) > 80:
|
||||
str_v = str_v[:77] + "..."
|
||||
text.append(" ")
|
||||
text.append(k, style="dim")
|
||||
text.append(": ")
|
||||
@@ -71,8 +69,6 @@ def _render_default_tool_widget(tool_data: dict[str, Any]) -> Static:
|
||||
|
||||
if status in ["completed", "failed", "error"] and result is not None:
|
||||
result_str = str(result)
|
||||
if len(result_str) > 150:
|
||||
result_str = result_str[:147] + "..."
|
||||
text.append("Result: ", style="bold")
|
||||
text.append(result_str)
|
||||
else:
|
||||
|
||||
@@ -154,6 +154,4 @@ class TerminalRenderer(BaseToolRenderer):
|
||||
|
||||
@classmethod
|
||||
def _format_command(cls, command: str) -> Text:
|
||||
if len(command) > 2000:
|
||||
command = command[:2000] + "..."
|
||||
return cls._highlight_bash(command)
|
||||
|
||||
@@ -23,8 +23,7 @@ class ThinkRenderer(BaseToolRenderer):
|
||||
text.append("\n ")
|
||||
|
||||
if thought:
|
||||
thought_display = cls.truncate(thought, 600)
|
||||
text.append(thought_display, style="italic dim")
|
||||
text.append(thought, style="italic dim")
|
||||
else:
|
||||
text.append("Thinking...", style="italic dim")
|
||||
|
||||
|
||||
@@ -14,29 +14,18 @@ STATUS_MARKERS: dict[str, str] = {
|
||||
}
|
||||
|
||||
|
||||
def _format_todo_lines(text: Text, result: dict[str, Any], limit: int = 25) -> None:
|
||||
def _format_todo_lines(text: Text, result: dict[str, Any]) -> None:
|
||||
todos = result.get("todos")
|
||||
if not isinstance(todos, list) or not todos:
|
||||
text.append("\n ")
|
||||
text.append("No todos", style="dim")
|
||||
return
|
||||
|
||||
total = len(todos)
|
||||
|
||||
for index, todo in enumerate(todos):
|
||||
if index >= limit:
|
||||
remaining = total - limit
|
||||
if remaining > 0:
|
||||
text.append("\n ")
|
||||
text.append(f"... +{remaining} more", style="dim")
|
||||
break
|
||||
|
||||
for todo in todos:
|
||||
status = todo.get("status", "pending")
|
||||
marker = STATUS_MARKERS.get(status, STATUS_MARKERS["pending"])
|
||||
|
||||
title = todo.get("title", "").strip() or "(untitled)"
|
||||
if len(title) > 90:
|
||||
title = title[:87] + "..."
|
||||
|
||||
text.append("\n ")
|
||||
text.append(marker)
|
||||
|
||||
@@ -34,9 +34,6 @@ class UserMessageRenderer(BaseToolRenderer):
|
||||
def _format_user_message(cls, content: str) -> Text:
|
||||
text = Text()
|
||||
|
||||
if len(content) > 300:
|
||||
content = content[:297] + "..."
|
||||
|
||||
text.append("▍", style="#3b82f6")
|
||||
text.append(" ")
|
||||
text.append("You:", style="bold")
|
||||
|
||||
@@ -23,7 +23,7 @@ class WebSearchRenderer(BaseToolRenderer):
|
||||
|
||||
if query:
|
||||
text.append("\n ")
|
||||
text.append(cls.truncate(query, 100), style="dim")
|
||||
text.append(query, style="dim")
|
||||
|
||||
css_classes = cls.get_css_classes("completed")
|
||||
return Static(text, classes=css_classes)
|
||||
|
||||
@@ -491,7 +491,7 @@ class StrixTUIApp(App): # type: ignore[misc]
|
||||
|
||||
self._start_scan_thread()
|
||||
|
||||
self.set_interval(0.5, self._update_ui_from_tracer)
|
||||
self.set_interval(0.1, self._update_ui_from_tracer)
|
||||
|
||||
def _update_ui_from_tracer(self) -> None:
|
||||
if self.show_splash:
|
||||
@@ -596,13 +596,14 @@ class StrixTUIApp(App): # type: ignore[misc]
|
||||
)
|
||||
else:
|
||||
events = self._gather_agent_events(self.selected_agent_id)
|
||||
if not events:
|
||||
streaming = self.tracer.get_streaming_content(self.selected_agent_id)
|
||||
if not events and not streaming:
|
||||
content, css_class = self._get_chat_placeholder_content(
|
||||
"Starting agent...", "placeholder-no-activity"
|
||||
)
|
||||
else:
|
||||
current_event_ids = [e["id"] for e in events]
|
||||
if current_event_ids == self._displayed_events:
|
||||
if current_event_ids == self._displayed_events and not streaming:
|
||||
return
|
||||
content = self._get_rendered_events_content(events)
|
||||
css_class = "chat-content"
|
||||
@@ -644,8 +645,92 @@ class StrixTUIApp(App): # type: ignore[misc]
|
||||
result.append_text(content)
|
||||
first = False
|
||||
|
||||
if self.selected_agent_id:
|
||||
streaming = self.tracer.get_streaming_content(self.selected_agent_id)
|
||||
if streaming:
|
||||
streaming_text = self._render_streaming_content(streaming)
|
||||
if streaming_text:
|
||||
if not first:
|
||||
result.append("\n\n")
|
||||
result.append_text(streaming_text)
|
||||
|
||||
return result
|
||||
|
||||
def _render_streaming_content(self, content: str) -> Text:
|
||||
from strix.interface.streaming_parser import parse_streaming_content
|
||||
|
||||
result = Text()
|
||||
segments = parse_streaming_content(content)
|
||||
|
||||
for i, segment in enumerate(segments):
|
||||
if i > 0:
|
||||
result.append("\n\n")
|
||||
|
||||
if segment.type == "text":
|
||||
from strix.interface.tool_components.agent_message_renderer import (
|
||||
AgentMessageRenderer,
|
||||
)
|
||||
|
||||
text_content = AgentMessageRenderer.render_simple(segment.content)
|
||||
result.append_text(text_content)
|
||||
|
||||
elif segment.type == "tool":
|
||||
tool_text = self._render_streaming_tool(
|
||||
segment.tool_name or "unknown",
|
||||
segment.args or {},
|
||||
segment.is_complete,
|
||||
)
|
||||
result.append_text(tool_text)
|
||||
|
||||
return result
|
||||
|
||||
def _render_streaming_tool(
|
||||
self, tool_name: str, args: dict[str, str], is_complete: bool
|
||||
) -> Text:
|
||||
from strix.interface.tool_components.registry import get_tool_renderer
|
||||
|
||||
tool_data = {
|
||||
"tool_name": tool_name,
|
||||
"args": args,
|
||||
"status": "completed" if is_complete else "running",
|
||||
"result": None,
|
||||
}
|
||||
|
||||
renderer = get_tool_renderer(tool_name)
|
||||
if renderer:
|
||||
widget = renderer.render(tool_data)
|
||||
renderable = widget.renderable
|
||||
if isinstance(renderable, Text):
|
||||
return renderable
|
||||
text = Text()
|
||||
text.append(str(renderable))
|
||||
return text
|
||||
|
||||
return self._render_default_streaming_tool(tool_name, args, is_complete)
|
||||
|
||||
def _render_default_streaming_tool(
|
||||
self, tool_name: str, args: dict[str, str], is_complete: bool
|
||||
) -> Text:
|
||||
text = Text()
|
||||
|
||||
if is_complete:
|
||||
text.append("✓ ", style="green")
|
||||
else:
|
||||
text.append("● ", style="yellow")
|
||||
|
||||
text.append("Using tool ", style="dim")
|
||||
text.append(tool_name, style="bold blue")
|
||||
|
||||
if args:
|
||||
for key, value in list(args.items())[:3]:
|
||||
text.append("\n ")
|
||||
text.append(key, style="dim")
|
||||
text.append(": ")
|
||||
display_value = value if len(value) <= 100 else value[:97] + "..."
|
||||
text.append(display_value, style="italic" if not is_complete else None)
|
||||
|
||||
return text
|
||||
|
||||
def _get_status_display_content(
|
||||
self, agent_id: str, agent_data: dict[str, Any]
|
||||
) -> tuple[Text | None, Text, bool]:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from fnmatch import fnmatch
|
||||
@@ -12,7 +13,7 @@ from jinja2 import (
|
||||
FileSystemLoader,
|
||||
select_autoescape,
|
||||
)
|
||||
from litellm import ModelResponse, completion_cost
|
||||
from litellm import completion_cost, stream_chunk_builder
|
||||
from litellm.utils import supports_prompt_caching, supports_vision
|
||||
|
||||
from strix.llm.config import LLMConfig
|
||||
@@ -276,7 +277,7 @@ class LLM:
|
||||
conversation_history: list[dict[str, Any]],
|
||||
scan_id: str | None = None,
|
||||
step_number: int = 1,
|
||||
) -> LLMResponse:
|
||||
) -> AsyncIterator[LLMResponse]:
|
||||
messages = [{"role": "system", "content": self.system_prompt}]
|
||||
|
||||
identity_message = self._build_identity_message()
|
||||
@@ -292,30 +293,43 @@ class LLM:
|
||||
cached_messages = self._prepare_cached_messages(messages)
|
||||
|
||||
try:
|
||||
response = await self._make_request(cached_messages)
|
||||
self._update_usage_stats(response)
|
||||
accumulated_content = ""
|
||||
chunks: list[Any] = []
|
||||
|
||||
content = ""
|
||||
if (
|
||||
response.choices
|
||||
and hasattr(response.choices[0], "message")
|
||||
and response.choices[0].message
|
||||
):
|
||||
content = getattr(response.choices[0].message, "content", "") or ""
|
||||
async for chunk in self._stream_request(cached_messages):
|
||||
chunks.append(chunk)
|
||||
delta = self._extract_chunk_delta(chunk)
|
||||
if delta:
|
||||
accumulated_content += delta
|
||||
|
||||
content = _truncate_to_first_function(content)
|
||||
if "</function>" in accumulated_content:
|
||||
function_end = accumulated_content.find("</function>") + len("</function>")
|
||||
accumulated_content = accumulated_content[:function_end]
|
||||
|
||||
if "</function>" in content:
|
||||
function_end_index = content.find("</function>") + len("</function>")
|
||||
content = content[:function_end_index]
|
||||
yield LLMResponse(
|
||||
scan_id=scan_id,
|
||||
step_number=step_number,
|
||||
role=StepRole.AGENT,
|
||||
content=accumulated_content,
|
||||
tool_invocations=None,
|
||||
)
|
||||
|
||||
tool_invocations = parse_tool_invocations(content)
|
||||
if chunks:
|
||||
complete_response = stream_chunk_builder(chunks)
|
||||
self._update_usage_stats(complete_response)
|
||||
|
||||
return LLMResponse(
|
||||
accumulated_content = _truncate_to_first_function(accumulated_content)
|
||||
if "</function>" in accumulated_content:
|
||||
function_end = accumulated_content.find("</function>") + len("</function>")
|
||||
accumulated_content = accumulated_content[:function_end]
|
||||
|
||||
tool_invocations = parse_tool_invocations(accumulated_content)
|
||||
|
||||
yield LLMResponse(
|
||||
scan_id=scan_id,
|
||||
step_number=step_number,
|
||||
role=StepRole.AGENT,
|
||||
content=content,
|
||||
content=accumulated_content,
|
||||
tool_invocations=tool_invocations if tool_invocations else None,
|
||||
)
|
||||
|
||||
@@ -364,6 +378,12 @@ class LLM:
|
||||
except Exception as e:
|
||||
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
|
||||
|
||||
def _extract_chunk_delta(self, chunk: Any) -> str:
|
||||
if chunk.choices and hasattr(chunk.choices[0], "delta"):
|
||||
delta = chunk.choices[0].delta
|
||||
return getattr(delta, "content", "") or ""
|
||||
return ""
|
||||
|
||||
@property
|
||||
def usage_stats(self) -> dict[str, dict[str, int | float]]:
|
||||
return {
|
||||
@@ -436,10 +456,10 @@ class LLM:
|
||||
filtered_messages.append(updated_msg)
|
||||
return filtered_messages
|
||||
|
||||
async def _make_request(
|
||||
async def _stream_request(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
) -> ModelResponse:
|
||||
) -> AsyncIterator[Any]:
|
||||
if not self._model_supports_vision():
|
||||
messages = self._filter_images_from_messages(messages)
|
||||
|
||||
@@ -447,6 +467,7 @@ class LLM:
|
||||
"model": self.config.model_name,
|
||||
"messages": messages,
|
||||
"timeout": self.config.timeout,
|
||||
"stream_options": {"include_usage": True},
|
||||
}
|
||||
|
||||
if _LLM_API_KEY:
|
||||
@@ -461,14 +482,13 @@ class LLM:
|
||||
completion_args["reasoning_effort"] = "high"
|
||||
|
||||
queue = get_global_queue()
|
||||
response = await queue.make_request(completion_args)
|
||||
|
||||
self._total_stats.requests += 1
|
||||
self._last_request_stats = RequestStats(requests=1)
|
||||
|
||||
return response
|
||||
async for chunk in queue.stream_request(completion_args):
|
||||
yield chunk
|
||||
|
||||
def _update_usage_stats(self, response: ModelResponse) -> None:
|
||||
def _update_usage_stats(self, response: Any) -> None:
|
||||
try:
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
input_tokens = getattr(response.usage, "prompt_tokens", 0)
|
||||
|
||||
@@ -3,10 +3,12 @@ import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
from litellm import ModelResponse, completion
|
||||
from litellm import completion
|
||||
from litellm.types.utils import ModelResponseStream
|
||||
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
|
||||
|
||||
|
||||
@@ -42,7 +44,9 @@ class LLMRequestQueue:
|
||||
self._last_request_time = 0.0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
async def make_request(self, completion_args: dict[str, Any]) -> ModelResponse:
|
||||
async def stream_request(
|
||||
self, completion_args: dict[str, Any]
|
||||
) -> AsyncIterator[ModelResponseStream]:
|
||||
try:
|
||||
while not self._semaphore.acquire(timeout=0.2):
|
||||
await asyncio.sleep(0.1)
|
||||
@@ -56,7 +60,8 @@ class LLMRequestQueue:
|
||||
if sleep_needed > 0:
|
||||
await asyncio.sleep(sleep_needed)
|
||||
|
||||
return await self._reliable_request(completion_args)
|
||||
async for chunk in self._reliable_stream_request(completion_args):
|
||||
yield chunk
|
||||
finally:
|
||||
self._semaphore.release()
|
||||
|
||||
@@ -66,15 +71,12 @@ class LLMRequestQueue:
|
||||
retry=retry_if_exception(should_retry_exception),
|
||||
reraise=True,
|
||||
)
|
||||
async def _reliable_request(self, completion_args: dict[str, Any]) -> ModelResponse:
|
||||
response = completion(**completion_args, stream=False)
|
||||
if isinstance(response, ModelResponse):
|
||||
return response
|
||||
self._raise_unexpected_response()
|
||||
raise RuntimeError("Unreachable code")
|
||||
|
||||
def _raise_unexpected_response(self) -> None:
|
||||
raise RuntimeError("Unexpected response type")
|
||||
async def _reliable_stream_request(
|
||||
self, completion_args: dict[str, Any]
|
||||
) -> AsyncIterator[ModelResponseStream]:
|
||||
response = await asyncio.to_thread(completion, **completion_args, stream=True)
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
|
||||
|
||||
_global_queue: LLMRequestQueue | None = None
|
||||
|
||||
@@ -47,10 +47,14 @@ def parse_tool_invocations(content: str) -> list[dict[str, Any]] | None:
|
||||
|
||||
|
||||
def _fix_stopword(content: str) -> str:
|
||||
if "<function=" in content and content.count("<function=") == 1:
|
||||
if (
|
||||
"<function=" in content
|
||||
and content.count("<function=") == 1
|
||||
and "</function>" not in content
|
||||
):
|
||||
if content.endswith("</"):
|
||||
content = content.rstrip() + "function>"
|
||||
elif not content.rstrip().endswith("</function>"):
|
||||
else:
|
||||
content = content + "\n</function>"
|
||||
return content
|
||||
|
||||
@@ -75,6 +79,12 @@ def clean_content(content: str) -> str:
|
||||
tool_pattern = r"<function=[^>]+>.*?</function>"
|
||||
cleaned = re.sub(tool_pattern, "", content, flags=re.DOTALL)
|
||||
|
||||
incomplete_tool_pattern = r"<function=[^>]+>.*$"
|
||||
cleaned = re.sub(incomplete_tool_pattern, "", cleaned, flags=re.DOTALL)
|
||||
|
||||
partial_tag_pattern = r"<f(?:u(?:n(?:c(?:t(?:i(?:o(?:n(?:=(?:[^>]*)?)?)?)?)?)?)?)?)?$"
|
||||
cleaned = re.sub(partial_tag_pattern, "", cleaned)
|
||||
|
||||
hidden_xml_patterns = [
|
||||
r"<inter_agent_message>.*?</inter_agent_message>",
|
||||
r"<agent_completion_report>.*?</agent_completion_report>",
|
||||
|
||||
@@ -33,6 +33,7 @@ class Tracer:
|
||||
self.agents: dict[str, dict[str, Any]] = {}
|
||||
self.tool_executions: dict[int, dict[str, Any]] = {}
|
||||
self.chat_messages: list[dict[str, Any]] = []
|
||||
self.streaming_content: dict[str, str] = {}
|
||||
|
||||
self.vulnerability_reports: list[dict[str, Any]] = []
|
||||
self.final_scan_result: str | None = None
|
||||
@@ -333,5 +334,14 @@ class Tracer:
|
||||
"total_tokens": total_stats["input_tokens"] + total_stats["output_tokens"],
|
||||
}
|
||||
|
||||
def update_streaming_content(self, agent_id: str, content: str) -> None:
|
||||
self.streaming_content[agent_id] = content
|
||||
|
||||
def clear_streaming_content(self, agent_id: str) -> None:
|
||||
self.streaming_content.pop(agent_id, None)
|
||||
|
||||
def get_streaming_content(self, agent_id: str) -> str | None:
|
||||
return self.streaming_content.get(agent_id)
|
||||
|
||||
def cleanup(self) -> None:
|
||||
self.save_run_data(mark_complete=True)
|
||||
|
||||
Reference in New Issue
Block a user