diff --git a/strix/agents/base_agent.py b/strix/agents/base_agent.py
index 67aeb38..56ae6b1 100644
--- a/strix/agents/base_agent.py
+++ b/strix/agents/base_agent.py
@@ -351,9 +351,16 @@ class BaseAgent(metaclass=AgentMeta):
self.state.add_message("user", task)
async def _process_iteration(self, tracer: Optional["Tracer"]) -> bool:
- response = await self.llm.generate(self.state.get_conversation_history())
+ final_response = None
+ async for response in self.llm.generate(self.state.get_conversation_history()):
+ final_response = response
+ if tracer and response.content:
+ tracer.update_streaming_content(self.state.agent_id, response.content)
- content_stripped = (response.content or "").strip()
+ if final_response is None:
+ return False
+
+ content_stripped = (final_response.content or "").strip()
if not content_stripped:
corrective_message = (
@@ -369,17 +376,18 @@ class BaseAgent(metaclass=AgentMeta):
self.state.add_message("user", corrective_message)
return False
- self.state.add_message("assistant", response.content)
+ self.state.add_message("assistant", final_response.content)
if tracer:
+ tracer.clear_streaming_content(self.state.agent_id)
tracer.log_chat_message(
- content=clean_content(response.content),
+ content=clean_content(final_response.content),
role="assistant",
agent_id=self.state.agent_id,
)
actions = (
- response.tool_invocations
- if hasattr(response, "tool_invocations") and response.tool_invocations
+ final_response.tool_invocations
+ if hasattr(final_response, "tool_invocations") and final_response.tool_invocations
else []
)
diff --git a/strix/interface/streaming_parser.py b/strix/interface/streaming_parser.py
new file mode 100644
index 0000000..8adbc9b
--- /dev/null
+++ b/strix/interface/streaming_parser.py
@@ -0,0 +1,119 @@
+import html
+import re
+from dataclasses import dataclass
+from typing import Literal
+
+
+_FUNCTION_TAG_PREFIX = " tuple[str, str]:
+ if not content:
+ return "", ""
+
+ last_lt = content.rfind("<")
+ if last_lt == -1:
+ return content, ""
+
+ suffix = content[last_lt:]
+ target = _FUNCTION_TAG_PREFIX # " list[StreamSegment]:
+ if not content:
+ return []
+
+ segments: list[StreamSegment] = []
+
+ func_pattern = r"]+)>"
+ func_matches = list(re.finditer(func_pattern, content))
+
+ if not func_matches:
+ safe_content, _ = _get_safe_content(content)
+ text = safe_content.strip()
+ if text:
+ segments.append(StreamSegment(type="text", content=text))
+ return segments
+
+ first_func_start = func_matches[0].start()
+ if first_func_start > 0:
+ text_before = content[:first_func_start].strip()
+ if text_before:
+ segments.append(StreamSegment(type="text", content=text_before))
+
+ for i, match in enumerate(func_matches):
+ tool_name = match.group(1)
+ func_start = match.end()
+
+ func_end_match = re.search(r"", content[func_start:])
+
+ if func_end_match:
+ func_body = content[func_start : func_start + func_end_match.start()]
+ is_complete = True
+ end_pos = func_start + func_end_match.end()
+ else:
+ if i + 1 < len(func_matches):
+ next_func_start = func_matches[i + 1].start()
+ func_body = content[func_start:next_func_start]
+ else:
+ func_body = content[func_start:]
+ is_complete = False
+ end_pos = len(content)
+
+ args = _parse_streaming_params(func_body)
+
+ segments.append(
+ StreamSegment(
+ type="tool",
+ content=func_body,
+ tool_name=tool_name,
+ args=args,
+ is_complete=is_complete,
+ )
+ )
+
+ if is_complete and i + 1 < len(func_matches):
+ next_start = func_matches[i + 1].start()
+ text_between = content[end_pos:next_start].strip()
+ if text_between:
+ segments.append(StreamSegment(type="text", content=text_between))
+
+ return segments
+
+
+def _parse_streaming_params(func_body: str) -> dict[str, str]:
+ args: dict[str, str] = {}
+
+ complete_pattern = r"]+)>(.*?)"
+ complete_matches = list(re.finditer(complete_pattern, func_body, re.DOTALL))
+ complete_end_pos = 0
+
+ for match in complete_matches:
+ param_name = match.group(1)
+ param_value = html.unescape(match.group(2).strip())
+ args[param_name] = param_value
+ complete_end_pos = max(complete_end_pos, match.end())
+
+ remaining = func_body[complete_end_pos:]
+ incomplete_pattern = r"]+)>(.*)$"
+ incomplete_match = re.search(incomplete_pattern, remaining, re.DOTALL)
+ if incomplete_match:
+ param_name = incomplete_match.group(1)
+ param_value = html.unescape(incomplete_match.group(2).strip())
+ args[param_name] = param_value
+
+ return args
diff --git a/strix/interface/tool_components/agent_message_renderer.py b/strix/interface/tool_components/agent_message_renderer.py
index ea16653..a51ea2a 100644
--- a/strix/interface/tool_components/agent_message_renderer.py
+++ b/strix/interface/tool_components/agent_message_renderer.py
@@ -181,4 +181,10 @@ class AgentMessageRenderer(BaseToolRenderer):
if not content:
return Text()
- return _apply_markdown_styles(content)
+ from strix.llm.utils import clean_content
+
+ cleaned = clean_content(content)
+ if not cleaned:
+ return Text()
+
+ return _apply_markdown_styles(cleaned)
diff --git a/strix/interface/tool_components/agents_graph_renderer.py b/strix/interface/tool_components/agents_graph_renderer.py
index d414984..356bdcb 100644
--- a/strix/interface/tool_components/agents_graph_renderer.py
+++ b/strix/interface/tool_components/agents_graph_renderer.py
@@ -40,7 +40,7 @@ class CreateAgentRenderer(BaseToolRenderer):
if task:
text.append("\n ")
- text.append(cls.truncate(task, 400), style="dim")
+ text.append(task, style="dim")
else:
text.append("\n ")
text.append("Spawning agent...", style="dim")
@@ -66,7 +66,7 @@ class SendMessageToAgentRenderer(BaseToolRenderer):
if message:
text.append("\n ")
- text.append(cls.truncate(message, 400), style="dim")
+ text.append(message, style="dim")
else:
text.append("\n ")
text.append("Sending...", style="dim")
@@ -129,7 +129,7 @@ class WaitForMessageRenderer(BaseToolRenderer):
if reason:
text.append("\n ")
- text.append(cls.truncate(reason, 400), style="dim")
+ text.append(reason, style="dim")
else:
text.append("\n ")
text.append("Agent paused until message received...", style="dim")
diff --git a/strix/interface/tool_components/base_renderer.py b/strix/interface/tool_components/base_renderer.py
index ef25c37..11e8458 100644
--- a/strix/interface/tool_components/base_renderer.py
+++ b/strix/interface/tool_components/base_renderer.py
@@ -23,12 +23,6 @@ class BaseToolRenderer(ABC):
css_classes = cls.get_css_classes(status)
return Static(content, classes=css_classes)
- @classmethod
- def truncate(cls, text: str, max_length: int = 500) -> str:
- if len(text) <= max_length:
- return text
- return text[: max_length - 3] + "..."
-
@classmethod
def status_icon(cls, status: str) -> tuple[str, str]:
icons = {
diff --git a/strix/interface/tool_components/browser_renderer.py b/strix/interface/tool_components/browser_renderer.py
index 0d22079..dc7b868 100644
--- a/strix/interface/tool_components/browser_renderer.py
+++ b/strix/interface/tool_components/browser_renderer.py
@@ -74,7 +74,7 @@ class BrowserRenderer(BaseToolRenderer):
def _build_url_action(cls, text: Text, label: str, url: str | None, suffix: str = "") -> None:
text.append(label, style="#06b6d4")
if url:
- text.append(cls.truncate(url, 300), style="#06b6d4")
+ text.append(url, style="#06b6d4")
if suffix:
text.append(suffix, style="#06b6d4")
@@ -120,7 +120,7 @@ class BrowserRenderer(BaseToolRenderer):
label, value = handlers[action]
text.append(label, style="#06b6d4")
if value:
- text.append(cls.truncate(str(value), 200), style="#06b6d4")
+ text.append(str(value), style="#06b6d4")
return text
if action == "execute_js":
@@ -128,7 +128,7 @@ class BrowserRenderer(BaseToolRenderer):
js_code = args.get("js_code")
if js_code:
text.append("\n")
- text.append_text(cls._highlight_js(cls.truncate(js_code, 2000)))
+ text.append_text(cls._highlight_js(js_code))
return text
text.append(action, style="#06b6d4")
diff --git a/strix/interface/tool_components/file_edit_renderer.py b/strix/interface/tool_components/file_edit_renderer.py
index b7fee27..8b3ca40 100644
--- a/strix/interface/tool_components/file_edit_renderer.py
+++ b/strix/interface/tool_components/file_edit_renderer.py
@@ -83,8 +83,7 @@ class StrReplaceEditorRenderer(BaseToolRenderer):
if command == "str_replace" and (old_str or new_str):
if old_str:
- old_display = cls.truncate(old_str, 1000)
- highlighted_old = cls._highlight_code(old_display, path)
+ highlighted_old = cls._highlight_code(old_str, path)
for line in highlighted_old.plain.split("\n"):
text.append("\n")
text.append("-", style="#ef4444")
@@ -92,8 +91,7 @@ class StrReplaceEditorRenderer(BaseToolRenderer):
text.append(line)
if new_str:
- new_display = cls.truncate(new_str, 1000)
- highlighted_new = cls._highlight_code(new_display, path)
+ highlighted_new = cls._highlight_code(new_str, path)
for line in highlighted_new.plain.split("\n"):
text.append("\n")
text.append("+", style="#22c55e")
@@ -101,13 +99,11 @@ class StrReplaceEditorRenderer(BaseToolRenderer):
text.append(line)
elif command == "create" and file_text:
- text_display = cls.truncate(file_text, 1500)
text.append("\n")
- text.append_text(cls._highlight_code(text_display, path))
+ text.append_text(cls._highlight_code(file_text, path))
elif command == "insert" and new_str:
- new_display = cls.truncate(new_str, 1000)
- highlighted_new = cls._highlight_code(new_display, path)
+ highlighted_new = cls._highlight_code(new_str, path)
for line in highlighted_new.plain.split("\n"):
text.append("\n")
text.append("+", style="#22c55e")
@@ -164,19 +160,15 @@ class SearchFilesRenderer(BaseToolRenderer):
text.append(" ")
if path and regex:
- path_display = path[-30:] if len(path) > 30 else path
- regex_display = regex[:30] if len(regex) > 30 else regex
- text.append(path_display, style="dim")
+ text.append(path, style="dim")
text.append(" for '", style="dim")
- text.append(regex_display, style="dim")
+ text.append(regex, style="dim")
text.append("'", style="dim")
elif path:
- path_display = path[-60:] if len(path) > 60 else path
- text.append(path_display, style="dim")
+ text.append(path, style="dim")
elif regex:
- regex_display = regex[:60] if len(regex) > 60 else regex
text.append("'", style="dim")
- text.append(regex_display, style="dim")
+ text.append(regex, style="dim")
text.append("'", style="dim")
else:
text.append("Searching...", style="dim")
diff --git a/strix/interface/tool_components/notes_renderer.py b/strix/interface/tool_components/notes_renderer.py
index a44542d..4660a27 100644
--- a/strix/interface/tool_components/notes_renderer.py
+++ b/strix/interface/tool_components/notes_renderer.py
@@ -28,11 +28,11 @@ class CreateNoteRenderer(BaseToolRenderer):
if title:
text.append("\n ")
- text.append(cls.truncate(title.strip(), 300))
+ text.append(title.strip())
if content:
text.append("\n ")
- text.append(cls.truncate(content.strip(), 800), style="dim")
+ text.append(content.strip(), style="dim")
if not title and not content:
text.append("\n ")
@@ -75,11 +75,11 @@ class UpdateNoteRenderer(BaseToolRenderer):
if title:
text.append("\n ")
- text.append(cls.truncate(title, 300))
+ text.append(title)
if content:
text.append("\n ")
- text.append(cls.truncate(content.strip(), 800), style="dim")
+ text.append(content.strip(), style="dim")
if not title and not content:
text.append("\n ")
@@ -110,23 +110,18 @@ class ListNotesRenderer(BaseToolRenderer):
text.append("\n ")
text.append("No notes", style="dim")
else:
- for note in notes[:5]:
+ for note in notes:
title = note.get("title", "").strip() or "(untitled)"
category = note.get("category", "general")
note_content = note.get("content", "").strip()
text.append("\n - ")
- text.append(cls.truncate(title, 300))
+ text.append(title)
text.append(f" ({category})", style="dim")
if note_content:
text.append("\n ")
- text.append(cls.truncate(note_content, 400), style="dim")
-
- remaining = max(count - 5, 0)
- if remaining:
- text.append("\n ")
- text.append(f"... +{remaining} more", style="dim")
+ text.append(note_content, style="dim")
else:
text.append("\n ")
text.append("Loading...", style="dim")
diff --git a/strix/interface/tool_components/proxy_renderer.py b/strix/interface/tool_components/proxy_renderer.py
index f42e9c2..4d8c658 100644
--- a/strix/interface/tool_components/proxy_renderer.py
+++ b/strix/interface/tool_components/proxy_renderer.py
@@ -26,7 +26,7 @@ class ListRequestsRenderer(BaseToolRenderer):
if result and isinstance(result, dict) and "requests" in result:
requests = result["requests"]
if isinstance(requests, list) and requests:
- for req in requests[:3]:
+ for req in requests:
if isinstance(req, dict):
method = req.get("method", "?")
path = req.get("path", "?")
@@ -34,16 +34,12 @@ class ListRequestsRenderer(BaseToolRenderer):
status = response.get("statusCode", "?")
text.append("\n ")
text.append(f"{method} {path} → {status}", style="dim")
-
- if len(requests) > 3:
- text.append("\n ")
- text.append(f"... +{len(requests) - 3} more", style="dim")
else:
text.append("\n ")
text.append("No requests found", style="dim")
elif httpql_filter:
text.append("\n ")
- text.append(cls.truncate(httpql_filter, 300), style="dim")
+ text.append(httpql_filter, style="dim")
else:
text.append("\n ")
text.append("All requests", style="dim")
@@ -72,17 +68,15 @@ class ViewRequestRenderer(BaseToolRenderer):
if "content" in result:
content = result["content"]
text.append("\n ")
- text.append(cls.truncate(content, 500), style="dim")
+ text.append(content, style="dim")
elif "matches" in result:
matches = result["matches"]
if isinstance(matches, list) and matches:
- for match in matches[:3]:
+ for match in matches:
if isinstance(match, dict) and "match" in match:
text.append("\n ")
text.append(match["match"], style="dim")
- if len(matches) > 3:
- text.append("\n ")
- text.append(f"... +{len(matches) - 3} more matches", style="dim")
+
else:
text.append("\n ")
text.append("No matches found", style="dim")
@@ -123,13 +117,13 @@ class SendRequestRenderer(BaseToolRenderer):
text.append(f"Status: {status_code}", style="dim")
if response_body:
text.append("\n ")
- text.append(cls.truncate(response_body, 300), style="dim")
+ text.append(response_body, style="dim")
else:
text.append("\n ")
text.append("Response received", style="dim")
elif url:
text.append("\n ")
- text.append(cls.truncate(url, 400), style="dim")
+ text.append(url, style="dim")
else:
text.append("\n ")
text.append("Sending...", style="dim")
@@ -163,13 +157,13 @@ class RepeatRequestRenderer(BaseToolRenderer):
text.append(f"Status: {status_code}", style="dim")
if response_body:
text.append("\n ")
- text.append(cls.truncate(response_body, 300), style="dim")
+ text.append(response_body, style="dim")
else:
text.append("\n ")
text.append("Response received", style="dim")
elif modifications:
text.append("\n ")
- text.append(cls.truncate(str(modifications), 400), style="dim")
+ text.append(str(modifications), style="dim")
else:
text.append("\n ")
text.append("No modifications", style="dim")
@@ -211,16 +205,13 @@ class ListSitemapRenderer(BaseToolRenderer):
if result and isinstance(result, dict) and "entries" in result:
entries = result["entries"]
if isinstance(entries, list) and entries:
- for entry in entries[:4]:
+ for entry in entries:
if isinstance(entry, dict):
label = entry.get("label", "?")
kind = entry.get("kind", "?")
text.append("\n ")
text.append(f"{kind}: {label}", style="dim")
- if len(entries) > 4:
- text.append("\n ")
- text.append(f"... +{len(entries) - 4} more", style="dim")
else:
text.append("\n ")
text.append("No entries found", style="dim")
diff --git a/strix/interface/tool_components/python_renderer.py b/strix/interface/tool_components/python_renderer.py
index 5e1bef3..8fb54ac 100644
--- a/strix/interface/tool_components/python_renderer.py
+++ b/strix/interface/tool_components/python_renderer.py
@@ -56,8 +56,7 @@ class PythonRenderer(BaseToolRenderer):
text.append("\n")
if code and action in ["new_session", "execute"]:
- code_display = cls.truncate(code, 2000)
- text.append_text(cls._highlight_python(code_display))
+ text.append_text(cls._highlight_python(code))
elif action == "close":
text.append(" ")
text.append("Closing session...", style="dim")
diff --git a/strix/interface/tool_components/registry.py b/strix/interface/tool_components/registry.py
index f9567ae..25267d9 100644
--- a/strix/interface/tool_components/registry.py
+++ b/strix/interface/tool_components/registry.py
@@ -59,10 +59,8 @@ def _render_default_tool_widget(tool_data: dict[str, Any]) -> Static:
text.append(tool_name, style="bold blue")
text.append("\n")
- for k, v in list(args.items())[:2]:
+ for k, v in list(args.items()):
str_v = str(v)
- if len(str_v) > 80:
- str_v = str_v[:77] + "..."
text.append(" ")
text.append(k, style="dim")
text.append(": ")
@@ -71,8 +69,6 @@ def _render_default_tool_widget(tool_data: dict[str, Any]) -> Static:
if status in ["completed", "failed", "error"] and result is not None:
result_str = str(result)
- if len(result_str) > 150:
- result_str = result_str[:147] + "..."
text.append("Result: ", style="bold")
text.append(result_str)
else:
diff --git a/strix/interface/tool_components/terminal_renderer.py b/strix/interface/tool_components/terminal_renderer.py
index 56decec..f707a00 100644
--- a/strix/interface/tool_components/terminal_renderer.py
+++ b/strix/interface/tool_components/terminal_renderer.py
@@ -154,6 +154,4 @@ class TerminalRenderer(BaseToolRenderer):
@classmethod
def _format_command(cls, command: str) -> Text:
- if len(command) > 2000:
- command = command[:2000] + "..."
return cls._highlight_bash(command)
diff --git a/strix/interface/tool_components/thinking_renderer.py b/strix/interface/tool_components/thinking_renderer.py
index dd7ff09..598bdf3 100644
--- a/strix/interface/tool_components/thinking_renderer.py
+++ b/strix/interface/tool_components/thinking_renderer.py
@@ -23,8 +23,7 @@ class ThinkRenderer(BaseToolRenderer):
text.append("\n ")
if thought:
- thought_display = cls.truncate(thought, 600)
- text.append(thought_display, style="italic dim")
+ text.append(thought, style="italic dim")
else:
text.append("Thinking...", style="italic dim")
diff --git a/strix/interface/tool_components/todo_renderer.py b/strix/interface/tool_components/todo_renderer.py
index 2ba58ef..6224f9f 100644
--- a/strix/interface/tool_components/todo_renderer.py
+++ b/strix/interface/tool_components/todo_renderer.py
@@ -14,29 +14,18 @@ STATUS_MARKERS: dict[str, str] = {
}
-def _format_todo_lines(text: Text, result: dict[str, Any], limit: int = 25) -> None:
+def _format_todo_lines(text: Text, result: dict[str, Any]) -> None:
todos = result.get("todos")
if not isinstance(todos, list) or not todos:
text.append("\n ")
text.append("No todos", style="dim")
return
- total = len(todos)
-
- for index, todo in enumerate(todos):
- if index >= limit:
- remaining = total - limit
- if remaining > 0:
- text.append("\n ")
- text.append(f"... +{remaining} more", style="dim")
- break
-
+ for todo in todos:
status = todo.get("status", "pending")
marker = STATUS_MARKERS.get(status, STATUS_MARKERS["pending"])
title = todo.get("title", "").strip() or "(untitled)"
- if len(title) > 90:
- title = title[:87] + "..."
text.append("\n ")
text.append(marker)
diff --git a/strix/interface/tool_components/user_message_renderer.py b/strix/interface/tool_components/user_message_renderer.py
index ad80924..b1081e8 100644
--- a/strix/interface/tool_components/user_message_renderer.py
+++ b/strix/interface/tool_components/user_message_renderer.py
@@ -34,9 +34,6 @@ class UserMessageRenderer(BaseToolRenderer):
def _format_user_message(cls, content: str) -> Text:
text = Text()
- if len(content) > 300:
- content = content[:297] + "..."
-
text.append("▍", style="#3b82f6")
text.append(" ")
text.append("You:", style="bold")
diff --git a/strix/interface/tool_components/web_search_renderer.py b/strix/interface/tool_components/web_search_renderer.py
index ef07972..4bd20f7 100644
--- a/strix/interface/tool_components/web_search_renderer.py
+++ b/strix/interface/tool_components/web_search_renderer.py
@@ -23,7 +23,7 @@ class WebSearchRenderer(BaseToolRenderer):
if query:
text.append("\n ")
- text.append(cls.truncate(query, 100), style="dim")
+ text.append(query, style="dim")
css_classes = cls.get_css_classes("completed")
return Static(text, classes=css_classes)
diff --git a/strix/interface/tui.py b/strix/interface/tui.py
index deae4b7..2a0cccc 100644
--- a/strix/interface/tui.py
+++ b/strix/interface/tui.py
@@ -491,7 +491,7 @@ class StrixTUIApp(App): # type: ignore[misc]
self._start_scan_thread()
- self.set_interval(0.5, self._update_ui_from_tracer)
+ self.set_interval(0.1, self._update_ui_from_tracer)
def _update_ui_from_tracer(self) -> None:
if self.show_splash:
@@ -596,13 +596,14 @@ class StrixTUIApp(App): # type: ignore[misc]
)
else:
events = self._gather_agent_events(self.selected_agent_id)
- if not events:
+ streaming = self.tracer.get_streaming_content(self.selected_agent_id)
+ if not events and not streaming:
content, css_class = self._get_chat_placeholder_content(
"Starting agent...", "placeholder-no-activity"
)
else:
current_event_ids = [e["id"] for e in events]
- if current_event_ids == self._displayed_events:
+ if current_event_ids == self._displayed_events and not streaming:
return
content = self._get_rendered_events_content(events)
css_class = "chat-content"
@@ -644,8 +645,92 @@ class StrixTUIApp(App): # type: ignore[misc]
result.append_text(content)
first = False
+ if self.selected_agent_id:
+ streaming = self.tracer.get_streaming_content(self.selected_agent_id)
+ if streaming:
+ streaming_text = self._render_streaming_content(streaming)
+ if streaming_text:
+ if not first:
+ result.append("\n\n")
+ result.append_text(streaming_text)
+
return result
+ def _render_streaming_content(self, content: str) -> Text:
+ from strix.interface.streaming_parser import parse_streaming_content
+
+ result = Text()
+ segments = parse_streaming_content(content)
+
+ for i, segment in enumerate(segments):
+ if i > 0:
+ result.append("\n\n")
+
+ if segment.type == "text":
+ from strix.interface.tool_components.agent_message_renderer import (
+ AgentMessageRenderer,
+ )
+
+ text_content = AgentMessageRenderer.render_simple(segment.content)
+ result.append_text(text_content)
+
+ elif segment.type == "tool":
+ tool_text = self._render_streaming_tool(
+ segment.tool_name or "unknown",
+ segment.args or {},
+ segment.is_complete,
+ )
+ result.append_text(tool_text)
+
+ return result
+
+ def _render_streaming_tool(
+ self, tool_name: str, args: dict[str, str], is_complete: bool
+ ) -> Text:
+ from strix.interface.tool_components.registry import get_tool_renderer
+
+ tool_data = {
+ "tool_name": tool_name,
+ "args": args,
+ "status": "completed" if is_complete else "running",
+ "result": None,
+ }
+
+ renderer = get_tool_renderer(tool_name)
+ if renderer:
+ widget = renderer.render(tool_data)
+ renderable = widget.renderable
+ if isinstance(renderable, Text):
+ return renderable
+ text = Text()
+ text.append(str(renderable))
+ return text
+
+ return self._render_default_streaming_tool(tool_name, args, is_complete)
+
+ def _render_default_streaming_tool(
+ self, tool_name: str, args: dict[str, str], is_complete: bool
+ ) -> Text:
+ text = Text()
+
+ if is_complete:
+ text.append("✓ ", style="green")
+ else:
+ text.append("● ", style="yellow")
+
+ text.append("Using tool ", style="dim")
+ text.append(tool_name, style="bold blue")
+
+ if args:
+ for key, value in list(args.items())[:3]:
+ text.append("\n ")
+ text.append(key, style="dim")
+ text.append(": ")
+ display_value = value if len(value) <= 100 else value[:97] + "..."
+ text.append(display_value, style="italic" if not is_complete else None)
+
+ return text
+
def _get_status_display_content(
self, agent_id: str, agent_data: dict[str, Any]
) -> tuple[Text | None, Text, bool]:
diff --git a/strix/llm/llm.py b/strix/llm/llm.py
index e3df248..558a0c5 100644
--- a/strix/llm/llm.py
+++ b/strix/llm/llm.py
@@ -1,5 +1,6 @@
import logging
import os
+from collections.abc import AsyncIterator
from dataclasses import dataclass
from enum import Enum
from fnmatch import fnmatch
@@ -12,7 +13,7 @@ from jinja2 import (
FileSystemLoader,
select_autoescape,
)
-from litellm import ModelResponse, completion_cost
+from litellm import completion_cost, stream_chunk_builder
from litellm.utils import supports_prompt_caching, supports_vision
from strix.llm.config import LLMConfig
@@ -276,7 +277,7 @@ class LLM:
conversation_history: list[dict[str, Any]],
scan_id: str | None = None,
step_number: int = 1,
- ) -> LLMResponse:
+ ) -> AsyncIterator[LLMResponse]:
messages = [{"role": "system", "content": self.system_prompt}]
identity_message = self._build_identity_message()
@@ -292,30 +293,43 @@ class LLM:
cached_messages = self._prepare_cached_messages(messages)
try:
- response = await self._make_request(cached_messages)
- self._update_usage_stats(response)
+ accumulated_content = ""
+ chunks: list[Any] = []
- content = ""
- if (
- response.choices
- and hasattr(response.choices[0], "message")
- and response.choices[0].message
- ):
- content = getattr(response.choices[0].message, "content", "") or ""
+ async for chunk in self._stream_request(cached_messages):
+ chunks.append(chunk)
+ delta = self._extract_chunk_delta(chunk)
+ if delta:
+ accumulated_content += delta
- content = _truncate_to_first_function(content)
+ if "" in accumulated_content:
+ function_end = accumulated_content.find("") + len("")
+ accumulated_content = accumulated_content[:function_end]
- if "" in content:
- function_end_index = content.find("") + len("")
- content = content[:function_end_index]
+ yield LLMResponse(
+ scan_id=scan_id,
+ step_number=step_number,
+ role=StepRole.AGENT,
+ content=accumulated_content,
+ tool_invocations=None,
+ )
- tool_invocations = parse_tool_invocations(content)
+ if chunks:
+ complete_response = stream_chunk_builder(chunks)
+ self._update_usage_stats(complete_response)
- return LLMResponse(
+ accumulated_content = _truncate_to_first_function(accumulated_content)
+ if "" in accumulated_content:
+ function_end = accumulated_content.find("") + len("")
+ accumulated_content = accumulated_content[:function_end]
+
+ tool_invocations = parse_tool_invocations(accumulated_content)
+
+ yield LLMResponse(
scan_id=scan_id,
step_number=step_number,
role=StepRole.AGENT,
- content=content,
+ content=accumulated_content,
tool_invocations=tool_invocations if tool_invocations else None,
)
@@ -364,6 +378,12 @@ class LLM:
except Exception as e:
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
+ def _extract_chunk_delta(self, chunk: Any) -> str:
+ if chunk.choices and hasattr(chunk.choices[0], "delta"):
+ delta = chunk.choices[0].delta
+ return getattr(delta, "content", "") or ""
+ return ""
+
@property
def usage_stats(self) -> dict[str, dict[str, int | float]]:
return {
@@ -436,10 +456,10 @@ class LLM:
filtered_messages.append(updated_msg)
return filtered_messages
- async def _make_request(
+ async def _stream_request(
self,
messages: list[dict[str, Any]],
- ) -> ModelResponse:
+ ) -> AsyncIterator[Any]:
if not self._model_supports_vision():
messages = self._filter_images_from_messages(messages)
@@ -447,6 +467,7 @@ class LLM:
"model": self.config.model_name,
"messages": messages,
"timeout": self.config.timeout,
+ "stream_options": {"include_usage": True},
}
if _LLM_API_KEY:
@@ -461,14 +482,13 @@ class LLM:
completion_args["reasoning_effort"] = "high"
queue = get_global_queue()
- response = await queue.make_request(completion_args)
-
self._total_stats.requests += 1
self._last_request_stats = RequestStats(requests=1)
- return response
+ async for chunk in queue.stream_request(completion_args):
+ yield chunk
- def _update_usage_stats(self, response: ModelResponse) -> None:
+ def _update_usage_stats(self, response: Any) -> None:
try:
if hasattr(response, "usage") and response.usage:
input_tokens = getattr(response.usage, "prompt_tokens", 0)
diff --git a/strix/llm/request_queue.py b/strix/llm/request_queue.py
index 3c6a00f..4760196 100644
--- a/strix/llm/request_queue.py
+++ b/strix/llm/request_queue.py
@@ -3,10 +3,12 @@ import logging
import os
import threading
import time
+from collections.abc import AsyncIterator
from typing import Any
import litellm
-from litellm import ModelResponse, completion
+from litellm import completion
+from litellm.types.utils import ModelResponseStream
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
@@ -42,7 +44,9 @@ class LLMRequestQueue:
self._last_request_time = 0.0
self._lock = threading.Lock()
- async def make_request(self, completion_args: dict[str, Any]) -> ModelResponse:
+ async def stream_request(
+ self, completion_args: dict[str, Any]
+ ) -> AsyncIterator[ModelResponseStream]:
try:
while not self._semaphore.acquire(timeout=0.2):
await asyncio.sleep(0.1)
@@ -56,7 +60,8 @@ class LLMRequestQueue:
if sleep_needed > 0:
await asyncio.sleep(sleep_needed)
- return await self._reliable_request(completion_args)
+ async for chunk in self._reliable_stream_request(completion_args):
+ yield chunk
finally:
self._semaphore.release()
@@ -66,15 +71,12 @@ class LLMRequestQueue:
retry=retry_if_exception(should_retry_exception),
reraise=True,
)
- async def _reliable_request(self, completion_args: dict[str, Any]) -> ModelResponse:
- response = completion(**completion_args, stream=False)
- if isinstance(response, ModelResponse):
- return response
- self._raise_unexpected_response()
- raise RuntimeError("Unreachable code")
-
- def _raise_unexpected_response(self) -> None:
- raise RuntimeError("Unexpected response type")
+ async def _reliable_stream_request(
+ self, completion_args: dict[str, Any]
+ ) -> AsyncIterator[ModelResponseStream]:
+ response = await asyncio.to_thread(completion, **completion_args, stream=True)
+ for chunk in response:
+ yield chunk
_global_queue: LLMRequestQueue | None = None
diff --git a/strix/llm/utils.py b/strix/llm/utils.py
index 8c141c6..e775cff 100644
--- a/strix/llm/utils.py
+++ b/strix/llm/utils.py
@@ -47,10 +47,14 @@ def parse_tool_invocations(content: str) -> list[dict[str, Any]] | None:
def _fix_stopword(content: str) -> str:
- if "" not in content
+ ):
if content.endswith(""):
content = content.rstrip() + "function>"
- elif not content.rstrip().endswith(""):
+ else:
content = content + "\n"
return content
@@ -75,6 +79,12 @@ def clean_content(content: str) -> str:
tool_pattern = r"]+>.*?"
cleaned = re.sub(tool_pattern, "", content, flags=re.DOTALL)
+ incomplete_tool_pattern = r"]+>.*$"
+ cleaned = re.sub(incomplete_tool_pattern, "", cleaned, flags=re.DOTALL)
+
+ partial_tag_pattern = r"]*)?)?)?)?)?)?)?)?)?$"
+ cleaned = re.sub(partial_tag_pattern, "", cleaned)
+
hidden_xml_patterns = [
r".*?",
r".*?",
diff --git a/strix/telemetry/tracer.py b/strix/telemetry/tracer.py
index 63781bc..59423f2 100644
--- a/strix/telemetry/tracer.py
+++ b/strix/telemetry/tracer.py
@@ -33,6 +33,7 @@ class Tracer:
self.agents: dict[str, dict[str, Any]] = {}
self.tool_executions: dict[int, dict[str, Any]] = {}
self.chat_messages: list[dict[str, Any]] = []
+ self.streaming_content: dict[str, str] = {}
self.vulnerability_reports: list[dict[str, Any]] = []
self.final_scan_result: str | None = None
@@ -333,5 +334,14 @@ class Tracer:
"total_tokens": total_stats["input_tokens"] + total_stats["output_tokens"],
}
+ def update_streaming_content(self, agent_id: str, content: str) -> None:
+ self.streaming_content[agent_id] = content
+
+ def clear_streaming_content(self, agent_id: str) -> None:
+ self.streaming_content.pop(agent_id, None)
+
+ def get_streaming_content(self, agent_id: str) -> str | None:
+ return self.streaming_content.get(agent_id)
+
def cleanup(self) -> None:
self.save_run_data(mark_complete=True)