From a6dcb7756ef82a08553ebf1a3d2fe9c39709dfe3 Mon Sep 17 00:00:00 2001 From: 0xallam Date: Mon, 5 Jan 2026 09:52:05 -0800 Subject: [PATCH] feat(tui): add real-time streaming LLM output with full content display - Convert LiteLLM requests to streaming mode with stream_request() - Add streaming parser to handle live LLM output segments - Update TUI for real-time streaming content rendering - Add tracer methods for streaming content tracking - Clean function tags from streamed content to prevent display - Remove all truncation from tool renderers for full content visibility --- strix/agents/base_agent.py | 20 ++- strix/interface/streaming_parser.py | 119 ++++++++++++++++++ .../tool_components/agent_message_renderer.py | 8 +- .../tool_components/agents_graph_renderer.py | 6 +- .../tool_components/base_renderer.py | 6 - .../tool_components/browser_renderer.py | 6 +- .../tool_components/file_edit_renderer.py | 24 ++-- .../tool_components/notes_renderer.py | 19 ++- .../tool_components/proxy_renderer.py | 29 ++--- .../tool_components/python_renderer.py | 3 +- strix/interface/tool_components/registry.py | 6 +- .../tool_components/terminal_renderer.py | 2 - .../tool_components/thinking_renderer.py | 3 +- .../tool_components/todo_renderer.py | 15 +-- .../tool_components/user_message_renderer.py | 3 - .../tool_components/web_search_renderer.py | 2 +- strix/interface/tui.py | 91 +++++++++++++- strix/llm/llm.py | 68 ++++++---- strix/llm/request_queue.py | 26 ++-- strix/llm/utils.py | 14 ++- strix/telemetry/tracer.py | 10 ++ 21 files changed, 345 insertions(+), 135 deletions(-) create mode 100644 strix/interface/streaming_parser.py diff --git a/strix/agents/base_agent.py b/strix/agents/base_agent.py index 67aeb38..56ae6b1 100644 --- a/strix/agents/base_agent.py +++ b/strix/agents/base_agent.py @@ -351,9 +351,16 @@ class BaseAgent(metaclass=AgentMeta): self.state.add_message("user", task) async def _process_iteration(self, tracer: Optional["Tracer"]) -> bool: - response = await self.llm.generate(self.state.get_conversation_history()) + final_response = None + async for response in self.llm.generate(self.state.get_conversation_history()): + final_response = response + if tracer and response.content: + tracer.update_streaming_content(self.state.agent_id, response.content) - content_stripped = (response.content or "").strip() + if final_response is None: + return False + + content_stripped = (final_response.content or "").strip() if not content_stripped: corrective_message = ( @@ -369,17 +376,18 @@ class BaseAgent(metaclass=AgentMeta): self.state.add_message("user", corrective_message) return False - self.state.add_message("assistant", response.content) + self.state.add_message("assistant", final_response.content) if tracer: + tracer.clear_streaming_content(self.state.agent_id) tracer.log_chat_message( - content=clean_content(response.content), + content=clean_content(final_response.content), role="assistant", agent_id=self.state.agent_id, ) actions = ( - response.tool_invocations - if hasattr(response, "tool_invocations") and response.tool_invocations + final_response.tool_invocations + if hasattr(final_response, "tool_invocations") and final_response.tool_invocations else [] ) diff --git a/strix/interface/streaming_parser.py b/strix/interface/streaming_parser.py new file mode 100644 index 0000000..8adbc9b --- /dev/null +++ b/strix/interface/streaming_parser.py @@ -0,0 +1,119 @@ +import html +import re +from dataclasses import dataclass +from typing import Literal + + +_FUNCTION_TAG_PREFIX = " tuple[str, str]: + if not content: + return "", "" + + last_lt = content.rfind("<") + if last_lt == -1: + return content, "" + + suffix = content[last_lt:] + target = _FUNCTION_TAG_PREFIX # " list[StreamSegment]: + if not content: + return [] + + segments: list[StreamSegment] = [] + + func_pattern = r"]+)>" + func_matches = list(re.finditer(func_pattern, content)) + + if not func_matches: + safe_content, _ = _get_safe_content(content) + text = safe_content.strip() + if text: + segments.append(StreamSegment(type="text", content=text)) + return segments + + first_func_start = func_matches[0].start() + if first_func_start > 0: + text_before = content[:first_func_start].strip() + if text_before: + segments.append(StreamSegment(type="text", content=text_before)) + + for i, match in enumerate(func_matches): + tool_name = match.group(1) + func_start = match.end() + + func_end_match = re.search(r"", content[func_start:]) + + if func_end_match: + func_body = content[func_start : func_start + func_end_match.start()] + is_complete = True + end_pos = func_start + func_end_match.end() + else: + if i + 1 < len(func_matches): + next_func_start = func_matches[i + 1].start() + func_body = content[func_start:next_func_start] + else: + func_body = content[func_start:] + is_complete = False + end_pos = len(content) + + args = _parse_streaming_params(func_body) + + segments.append( + StreamSegment( + type="tool", + content=func_body, + tool_name=tool_name, + args=args, + is_complete=is_complete, + ) + ) + + if is_complete and i + 1 < len(func_matches): + next_start = func_matches[i + 1].start() + text_between = content[end_pos:next_start].strip() + if text_between: + segments.append(StreamSegment(type="text", content=text_between)) + + return segments + + +def _parse_streaming_params(func_body: str) -> dict[str, str]: + args: dict[str, str] = {} + + complete_pattern = r"]+)>(.*?)" + complete_matches = list(re.finditer(complete_pattern, func_body, re.DOTALL)) + complete_end_pos = 0 + + for match in complete_matches: + param_name = match.group(1) + param_value = html.unescape(match.group(2).strip()) + args[param_name] = param_value + complete_end_pos = max(complete_end_pos, match.end()) + + remaining = func_body[complete_end_pos:] + incomplete_pattern = r"]+)>(.*)$" + incomplete_match = re.search(incomplete_pattern, remaining, re.DOTALL) + if incomplete_match: + param_name = incomplete_match.group(1) + param_value = html.unescape(incomplete_match.group(2).strip()) + args[param_name] = param_value + + return args diff --git a/strix/interface/tool_components/agent_message_renderer.py b/strix/interface/tool_components/agent_message_renderer.py index ea16653..a51ea2a 100644 --- a/strix/interface/tool_components/agent_message_renderer.py +++ b/strix/interface/tool_components/agent_message_renderer.py @@ -181,4 +181,10 @@ class AgentMessageRenderer(BaseToolRenderer): if not content: return Text() - return _apply_markdown_styles(content) + from strix.llm.utils import clean_content + + cleaned = clean_content(content) + if not cleaned: + return Text() + + return _apply_markdown_styles(cleaned) diff --git a/strix/interface/tool_components/agents_graph_renderer.py b/strix/interface/tool_components/agents_graph_renderer.py index d414984..356bdcb 100644 --- a/strix/interface/tool_components/agents_graph_renderer.py +++ b/strix/interface/tool_components/agents_graph_renderer.py @@ -40,7 +40,7 @@ class CreateAgentRenderer(BaseToolRenderer): if task: text.append("\n ") - text.append(cls.truncate(task, 400), style="dim") + text.append(task, style="dim") else: text.append("\n ") text.append("Spawning agent...", style="dim") @@ -66,7 +66,7 @@ class SendMessageToAgentRenderer(BaseToolRenderer): if message: text.append("\n ") - text.append(cls.truncate(message, 400), style="dim") + text.append(message, style="dim") else: text.append("\n ") text.append("Sending...", style="dim") @@ -129,7 +129,7 @@ class WaitForMessageRenderer(BaseToolRenderer): if reason: text.append("\n ") - text.append(cls.truncate(reason, 400), style="dim") + text.append(reason, style="dim") else: text.append("\n ") text.append("Agent paused until message received...", style="dim") diff --git a/strix/interface/tool_components/base_renderer.py b/strix/interface/tool_components/base_renderer.py index ef25c37..11e8458 100644 --- a/strix/interface/tool_components/base_renderer.py +++ b/strix/interface/tool_components/base_renderer.py @@ -23,12 +23,6 @@ class BaseToolRenderer(ABC): css_classes = cls.get_css_classes(status) return Static(content, classes=css_classes) - @classmethod - def truncate(cls, text: str, max_length: int = 500) -> str: - if len(text) <= max_length: - return text - return text[: max_length - 3] + "..." - @classmethod def status_icon(cls, status: str) -> tuple[str, str]: icons = { diff --git a/strix/interface/tool_components/browser_renderer.py b/strix/interface/tool_components/browser_renderer.py index 0d22079..dc7b868 100644 --- a/strix/interface/tool_components/browser_renderer.py +++ b/strix/interface/tool_components/browser_renderer.py @@ -74,7 +74,7 @@ class BrowserRenderer(BaseToolRenderer): def _build_url_action(cls, text: Text, label: str, url: str | None, suffix: str = "") -> None: text.append(label, style="#06b6d4") if url: - text.append(cls.truncate(url, 300), style="#06b6d4") + text.append(url, style="#06b6d4") if suffix: text.append(suffix, style="#06b6d4") @@ -120,7 +120,7 @@ class BrowserRenderer(BaseToolRenderer): label, value = handlers[action] text.append(label, style="#06b6d4") if value: - text.append(cls.truncate(str(value), 200), style="#06b6d4") + text.append(str(value), style="#06b6d4") return text if action == "execute_js": @@ -128,7 +128,7 @@ class BrowserRenderer(BaseToolRenderer): js_code = args.get("js_code") if js_code: text.append("\n") - text.append_text(cls._highlight_js(cls.truncate(js_code, 2000))) + text.append_text(cls._highlight_js(js_code)) return text text.append(action, style="#06b6d4") diff --git a/strix/interface/tool_components/file_edit_renderer.py b/strix/interface/tool_components/file_edit_renderer.py index b7fee27..8b3ca40 100644 --- a/strix/interface/tool_components/file_edit_renderer.py +++ b/strix/interface/tool_components/file_edit_renderer.py @@ -83,8 +83,7 @@ class StrReplaceEditorRenderer(BaseToolRenderer): if command == "str_replace" and (old_str or new_str): if old_str: - old_display = cls.truncate(old_str, 1000) - highlighted_old = cls._highlight_code(old_display, path) + highlighted_old = cls._highlight_code(old_str, path) for line in highlighted_old.plain.split("\n"): text.append("\n") text.append("-", style="#ef4444") @@ -92,8 +91,7 @@ class StrReplaceEditorRenderer(BaseToolRenderer): text.append(line) if new_str: - new_display = cls.truncate(new_str, 1000) - highlighted_new = cls._highlight_code(new_display, path) + highlighted_new = cls._highlight_code(new_str, path) for line in highlighted_new.plain.split("\n"): text.append("\n") text.append("+", style="#22c55e") @@ -101,13 +99,11 @@ class StrReplaceEditorRenderer(BaseToolRenderer): text.append(line) elif command == "create" and file_text: - text_display = cls.truncate(file_text, 1500) text.append("\n") - text.append_text(cls._highlight_code(text_display, path)) + text.append_text(cls._highlight_code(file_text, path)) elif command == "insert" and new_str: - new_display = cls.truncate(new_str, 1000) - highlighted_new = cls._highlight_code(new_display, path) + highlighted_new = cls._highlight_code(new_str, path) for line in highlighted_new.plain.split("\n"): text.append("\n") text.append("+", style="#22c55e") @@ -164,19 +160,15 @@ class SearchFilesRenderer(BaseToolRenderer): text.append(" ") if path and regex: - path_display = path[-30:] if len(path) > 30 else path - regex_display = regex[:30] if len(regex) > 30 else regex - text.append(path_display, style="dim") + text.append(path, style="dim") text.append(" for '", style="dim") - text.append(regex_display, style="dim") + text.append(regex, style="dim") text.append("'", style="dim") elif path: - path_display = path[-60:] if len(path) > 60 else path - text.append(path_display, style="dim") + text.append(path, style="dim") elif regex: - regex_display = regex[:60] if len(regex) > 60 else regex text.append("'", style="dim") - text.append(regex_display, style="dim") + text.append(regex, style="dim") text.append("'", style="dim") else: text.append("Searching...", style="dim") diff --git a/strix/interface/tool_components/notes_renderer.py b/strix/interface/tool_components/notes_renderer.py index a44542d..4660a27 100644 --- a/strix/interface/tool_components/notes_renderer.py +++ b/strix/interface/tool_components/notes_renderer.py @@ -28,11 +28,11 @@ class CreateNoteRenderer(BaseToolRenderer): if title: text.append("\n ") - text.append(cls.truncate(title.strip(), 300)) + text.append(title.strip()) if content: text.append("\n ") - text.append(cls.truncate(content.strip(), 800), style="dim") + text.append(content.strip(), style="dim") if not title and not content: text.append("\n ") @@ -75,11 +75,11 @@ class UpdateNoteRenderer(BaseToolRenderer): if title: text.append("\n ") - text.append(cls.truncate(title, 300)) + text.append(title) if content: text.append("\n ") - text.append(cls.truncate(content.strip(), 800), style="dim") + text.append(content.strip(), style="dim") if not title and not content: text.append("\n ") @@ -110,23 +110,18 @@ class ListNotesRenderer(BaseToolRenderer): text.append("\n ") text.append("No notes", style="dim") else: - for note in notes[:5]: + for note in notes: title = note.get("title", "").strip() or "(untitled)" category = note.get("category", "general") note_content = note.get("content", "").strip() text.append("\n - ") - text.append(cls.truncate(title, 300)) + text.append(title) text.append(f" ({category})", style="dim") if note_content: text.append("\n ") - text.append(cls.truncate(note_content, 400), style="dim") - - remaining = max(count - 5, 0) - if remaining: - text.append("\n ") - text.append(f"... +{remaining} more", style="dim") + text.append(note_content, style="dim") else: text.append("\n ") text.append("Loading...", style="dim") diff --git a/strix/interface/tool_components/proxy_renderer.py b/strix/interface/tool_components/proxy_renderer.py index f42e9c2..4d8c658 100644 --- a/strix/interface/tool_components/proxy_renderer.py +++ b/strix/interface/tool_components/proxy_renderer.py @@ -26,7 +26,7 @@ class ListRequestsRenderer(BaseToolRenderer): if result and isinstance(result, dict) and "requests" in result: requests = result["requests"] if isinstance(requests, list) and requests: - for req in requests[:3]: + for req in requests: if isinstance(req, dict): method = req.get("method", "?") path = req.get("path", "?") @@ -34,16 +34,12 @@ class ListRequestsRenderer(BaseToolRenderer): status = response.get("statusCode", "?") text.append("\n ") text.append(f"{method} {path} → {status}", style="dim") - - if len(requests) > 3: - text.append("\n ") - text.append(f"... +{len(requests) - 3} more", style="dim") else: text.append("\n ") text.append("No requests found", style="dim") elif httpql_filter: text.append("\n ") - text.append(cls.truncate(httpql_filter, 300), style="dim") + text.append(httpql_filter, style="dim") else: text.append("\n ") text.append("All requests", style="dim") @@ -72,17 +68,15 @@ class ViewRequestRenderer(BaseToolRenderer): if "content" in result: content = result["content"] text.append("\n ") - text.append(cls.truncate(content, 500), style="dim") + text.append(content, style="dim") elif "matches" in result: matches = result["matches"] if isinstance(matches, list) and matches: - for match in matches[:3]: + for match in matches: if isinstance(match, dict) and "match" in match: text.append("\n ") text.append(match["match"], style="dim") - if len(matches) > 3: - text.append("\n ") - text.append(f"... +{len(matches) - 3} more matches", style="dim") + else: text.append("\n ") text.append("No matches found", style="dim") @@ -123,13 +117,13 @@ class SendRequestRenderer(BaseToolRenderer): text.append(f"Status: {status_code}", style="dim") if response_body: text.append("\n ") - text.append(cls.truncate(response_body, 300), style="dim") + text.append(response_body, style="dim") else: text.append("\n ") text.append("Response received", style="dim") elif url: text.append("\n ") - text.append(cls.truncate(url, 400), style="dim") + text.append(url, style="dim") else: text.append("\n ") text.append("Sending...", style="dim") @@ -163,13 +157,13 @@ class RepeatRequestRenderer(BaseToolRenderer): text.append(f"Status: {status_code}", style="dim") if response_body: text.append("\n ") - text.append(cls.truncate(response_body, 300), style="dim") + text.append(response_body, style="dim") else: text.append("\n ") text.append("Response received", style="dim") elif modifications: text.append("\n ") - text.append(cls.truncate(str(modifications), 400), style="dim") + text.append(str(modifications), style="dim") else: text.append("\n ") text.append("No modifications", style="dim") @@ -211,16 +205,13 @@ class ListSitemapRenderer(BaseToolRenderer): if result and isinstance(result, dict) and "entries" in result: entries = result["entries"] if isinstance(entries, list) and entries: - for entry in entries[:4]: + for entry in entries: if isinstance(entry, dict): label = entry.get("label", "?") kind = entry.get("kind", "?") text.append("\n ") text.append(f"{kind}: {label}", style="dim") - if len(entries) > 4: - text.append("\n ") - text.append(f"... +{len(entries) - 4} more", style="dim") else: text.append("\n ") text.append("No entries found", style="dim") diff --git a/strix/interface/tool_components/python_renderer.py b/strix/interface/tool_components/python_renderer.py index 5e1bef3..8fb54ac 100644 --- a/strix/interface/tool_components/python_renderer.py +++ b/strix/interface/tool_components/python_renderer.py @@ -56,8 +56,7 @@ class PythonRenderer(BaseToolRenderer): text.append("\n") if code and action in ["new_session", "execute"]: - code_display = cls.truncate(code, 2000) - text.append_text(cls._highlight_python(code_display)) + text.append_text(cls._highlight_python(code)) elif action == "close": text.append(" ") text.append("Closing session...", style="dim") diff --git a/strix/interface/tool_components/registry.py b/strix/interface/tool_components/registry.py index f9567ae..25267d9 100644 --- a/strix/interface/tool_components/registry.py +++ b/strix/interface/tool_components/registry.py @@ -59,10 +59,8 @@ def _render_default_tool_widget(tool_data: dict[str, Any]) -> Static: text.append(tool_name, style="bold blue") text.append("\n") - for k, v in list(args.items())[:2]: + for k, v in list(args.items()): str_v = str(v) - if len(str_v) > 80: - str_v = str_v[:77] + "..." text.append(" ") text.append(k, style="dim") text.append(": ") @@ -71,8 +69,6 @@ def _render_default_tool_widget(tool_data: dict[str, Any]) -> Static: if status in ["completed", "failed", "error"] and result is not None: result_str = str(result) - if len(result_str) > 150: - result_str = result_str[:147] + "..." text.append("Result: ", style="bold") text.append(result_str) else: diff --git a/strix/interface/tool_components/terminal_renderer.py b/strix/interface/tool_components/terminal_renderer.py index 56decec..f707a00 100644 --- a/strix/interface/tool_components/terminal_renderer.py +++ b/strix/interface/tool_components/terminal_renderer.py @@ -154,6 +154,4 @@ class TerminalRenderer(BaseToolRenderer): @classmethod def _format_command(cls, command: str) -> Text: - if len(command) > 2000: - command = command[:2000] + "..." return cls._highlight_bash(command) diff --git a/strix/interface/tool_components/thinking_renderer.py b/strix/interface/tool_components/thinking_renderer.py index dd7ff09..598bdf3 100644 --- a/strix/interface/tool_components/thinking_renderer.py +++ b/strix/interface/tool_components/thinking_renderer.py @@ -23,8 +23,7 @@ class ThinkRenderer(BaseToolRenderer): text.append("\n ") if thought: - thought_display = cls.truncate(thought, 600) - text.append(thought_display, style="italic dim") + text.append(thought, style="italic dim") else: text.append("Thinking...", style="italic dim") diff --git a/strix/interface/tool_components/todo_renderer.py b/strix/interface/tool_components/todo_renderer.py index 2ba58ef..6224f9f 100644 --- a/strix/interface/tool_components/todo_renderer.py +++ b/strix/interface/tool_components/todo_renderer.py @@ -14,29 +14,18 @@ STATUS_MARKERS: dict[str, str] = { } -def _format_todo_lines(text: Text, result: dict[str, Any], limit: int = 25) -> None: +def _format_todo_lines(text: Text, result: dict[str, Any]) -> None: todos = result.get("todos") if not isinstance(todos, list) or not todos: text.append("\n ") text.append("No todos", style="dim") return - total = len(todos) - - for index, todo in enumerate(todos): - if index >= limit: - remaining = total - limit - if remaining > 0: - text.append("\n ") - text.append(f"... +{remaining} more", style="dim") - break - + for todo in todos: status = todo.get("status", "pending") marker = STATUS_MARKERS.get(status, STATUS_MARKERS["pending"]) title = todo.get("title", "").strip() or "(untitled)" - if len(title) > 90: - title = title[:87] + "..." text.append("\n ") text.append(marker) diff --git a/strix/interface/tool_components/user_message_renderer.py b/strix/interface/tool_components/user_message_renderer.py index ad80924..b1081e8 100644 --- a/strix/interface/tool_components/user_message_renderer.py +++ b/strix/interface/tool_components/user_message_renderer.py @@ -34,9 +34,6 @@ class UserMessageRenderer(BaseToolRenderer): def _format_user_message(cls, content: str) -> Text: text = Text() - if len(content) > 300: - content = content[:297] + "..." - text.append("▍", style="#3b82f6") text.append(" ") text.append("You:", style="bold") diff --git a/strix/interface/tool_components/web_search_renderer.py b/strix/interface/tool_components/web_search_renderer.py index ef07972..4bd20f7 100644 --- a/strix/interface/tool_components/web_search_renderer.py +++ b/strix/interface/tool_components/web_search_renderer.py @@ -23,7 +23,7 @@ class WebSearchRenderer(BaseToolRenderer): if query: text.append("\n ") - text.append(cls.truncate(query, 100), style="dim") + text.append(query, style="dim") css_classes = cls.get_css_classes("completed") return Static(text, classes=css_classes) diff --git a/strix/interface/tui.py b/strix/interface/tui.py index deae4b7..2a0cccc 100644 --- a/strix/interface/tui.py +++ b/strix/interface/tui.py @@ -491,7 +491,7 @@ class StrixTUIApp(App): # type: ignore[misc] self._start_scan_thread() - self.set_interval(0.5, self._update_ui_from_tracer) + self.set_interval(0.1, self._update_ui_from_tracer) def _update_ui_from_tracer(self) -> None: if self.show_splash: @@ -596,13 +596,14 @@ class StrixTUIApp(App): # type: ignore[misc] ) else: events = self._gather_agent_events(self.selected_agent_id) - if not events: + streaming = self.tracer.get_streaming_content(self.selected_agent_id) + if not events and not streaming: content, css_class = self._get_chat_placeholder_content( "Starting agent...", "placeholder-no-activity" ) else: current_event_ids = [e["id"] for e in events] - if current_event_ids == self._displayed_events: + if current_event_ids == self._displayed_events and not streaming: return content = self._get_rendered_events_content(events) css_class = "chat-content" @@ -644,8 +645,92 @@ class StrixTUIApp(App): # type: ignore[misc] result.append_text(content) first = False + if self.selected_agent_id: + streaming = self.tracer.get_streaming_content(self.selected_agent_id) + if streaming: + streaming_text = self._render_streaming_content(streaming) + if streaming_text: + if not first: + result.append("\n\n") + result.append_text(streaming_text) + return result + def _render_streaming_content(self, content: str) -> Text: + from strix.interface.streaming_parser import parse_streaming_content + + result = Text() + segments = parse_streaming_content(content) + + for i, segment in enumerate(segments): + if i > 0: + result.append("\n\n") + + if segment.type == "text": + from strix.interface.tool_components.agent_message_renderer import ( + AgentMessageRenderer, + ) + + text_content = AgentMessageRenderer.render_simple(segment.content) + result.append_text(text_content) + + elif segment.type == "tool": + tool_text = self._render_streaming_tool( + segment.tool_name or "unknown", + segment.args or {}, + segment.is_complete, + ) + result.append_text(tool_text) + + return result + + def _render_streaming_tool( + self, tool_name: str, args: dict[str, str], is_complete: bool + ) -> Text: + from strix.interface.tool_components.registry import get_tool_renderer + + tool_data = { + "tool_name": tool_name, + "args": args, + "status": "completed" if is_complete else "running", + "result": None, + } + + renderer = get_tool_renderer(tool_name) + if renderer: + widget = renderer.render(tool_data) + renderable = widget.renderable + if isinstance(renderable, Text): + return renderable + text = Text() + text.append(str(renderable)) + return text + + return self._render_default_streaming_tool(tool_name, args, is_complete) + + def _render_default_streaming_tool( + self, tool_name: str, args: dict[str, str], is_complete: bool + ) -> Text: + text = Text() + + if is_complete: + text.append("✓ ", style="green") + else: + text.append("● ", style="yellow") + + text.append("Using tool ", style="dim") + text.append(tool_name, style="bold blue") + + if args: + for key, value in list(args.items())[:3]: + text.append("\n ") + text.append(key, style="dim") + text.append(": ") + display_value = value if len(value) <= 100 else value[:97] + "..." + text.append(display_value, style="italic" if not is_complete else None) + + return text + def _get_status_display_content( self, agent_id: str, agent_data: dict[str, Any] ) -> tuple[Text | None, Text, bool]: diff --git a/strix/llm/llm.py b/strix/llm/llm.py index e3df248..558a0c5 100644 --- a/strix/llm/llm.py +++ b/strix/llm/llm.py @@ -1,5 +1,6 @@ import logging import os +from collections.abc import AsyncIterator from dataclasses import dataclass from enum import Enum from fnmatch import fnmatch @@ -12,7 +13,7 @@ from jinja2 import ( FileSystemLoader, select_autoescape, ) -from litellm import ModelResponse, completion_cost +from litellm import completion_cost, stream_chunk_builder from litellm.utils import supports_prompt_caching, supports_vision from strix.llm.config import LLMConfig @@ -276,7 +277,7 @@ class LLM: conversation_history: list[dict[str, Any]], scan_id: str | None = None, step_number: int = 1, - ) -> LLMResponse: + ) -> AsyncIterator[LLMResponse]: messages = [{"role": "system", "content": self.system_prompt}] identity_message = self._build_identity_message() @@ -292,30 +293,43 @@ class LLM: cached_messages = self._prepare_cached_messages(messages) try: - response = await self._make_request(cached_messages) - self._update_usage_stats(response) + accumulated_content = "" + chunks: list[Any] = [] - content = "" - if ( - response.choices - and hasattr(response.choices[0], "message") - and response.choices[0].message - ): - content = getattr(response.choices[0].message, "content", "") or "" + async for chunk in self._stream_request(cached_messages): + chunks.append(chunk) + delta = self._extract_chunk_delta(chunk) + if delta: + accumulated_content += delta - content = _truncate_to_first_function(content) + if "" in accumulated_content: + function_end = accumulated_content.find("") + len("") + accumulated_content = accumulated_content[:function_end] - if "" in content: - function_end_index = content.find("") + len("") - content = content[:function_end_index] + yield LLMResponse( + scan_id=scan_id, + step_number=step_number, + role=StepRole.AGENT, + content=accumulated_content, + tool_invocations=None, + ) - tool_invocations = parse_tool_invocations(content) + if chunks: + complete_response = stream_chunk_builder(chunks) + self._update_usage_stats(complete_response) - return LLMResponse( + accumulated_content = _truncate_to_first_function(accumulated_content) + if "" in accumulated_content: + function_end = accumulated_content.find("") + len("") + accumulated_content = accumulated_content[:function_end] + + tool_invocations = parse_tool_invocations(accumulated_content) + + yield LLMResponse( scan_id=scan_id, step_number=step_number, role=StepRole.AGENT, - content=content, + content=accumulated_content, tool_invocations=tool_invocations if tool_invocations else None, ) @@ -364,6 +378,12 @@ class LLM: except Exception as e: raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e + def _extract_chunk_delta(self, chunk: Any) -> str: + if chunk.choices and hasattr(chunk.choices[0], "delta"): + delta = chunk.choices[0].delta + return getattr(delta, "content", "") or "" + return "" + @property def usage_stats(self) -> dict[str, dict[str, int | float]]: return { @@ -436,10 +456,10 @@ class LLM: filtered_messages.append(updated_msg) return filtered_messages - async def _make_request( + async def _stream_request( self, messages: list[dict[str, Any]], - ) -> ModelResponse: + ) -> AsyncIterator[Any]: if not self._model_supports_vision(): messages = self._filter_images_from_messages(messages) @@ -447,6 +467,7 @@ class LLM: "model": self.config.model_name, "messages": messages, "timeout": self.config.timeout, + "stream_options": {"include_usage": True}, } if _LLM_API_KEY: @@ -461,14 +482,13 @@ class LLM: completion_args["reasoning_effort"] = "high" queue = get_global_queue() - response = await queue.make_request(completion_args) - self._total_stats.requests += 1 self._last_request_stats = RequestStats(requests=1) - return response + async for chunk in queue.stream_request(completion_args): + yield chunk - def _update_usage_stats(self, response: ModelResponse) -> None: + def _update_usage_stats(self, response: Any) -> None: try: if hasattr(response, "usage") and response.usage: input_tokens = getattr(response.usage, "prompt_tokens", 0) diff --git a/strix/llm/request_queue.py b/strix/llm/request_queue.py index 3c6a00f..4760196 100644 --- a/strix/llm/request_queue.py +++ b/strix/llm/request_queue.py @@ -3,10 +3,12 @@ import logging import os import threading import time +from collections.abc import AsyncIterator from typing import Any import litellm -from litellm import ModelResponse, completion +from litellm import completion +from litellm.types.utils import ModelResponseStream from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential @@ -42,7 +44,9 @@ class LLMRequestQueue: self._last_request_time = 0.0 self._lock = threading.Lock() - async def make_request(self, completion_args: dict[str, Any]) -> ModelResponse: + async def stream_request( + self, completion_args: dict[str, Any] + ) -> AsyncIterator[ModelResponseStream]: try: while not self._semaphore.acquire(timeout=0.2): await asyncio.sleep(0.1) @@ -56,7 +60,8 @@ class LLMRequestQueue: if sleep_needed > 0: await asyncio.sleep(sleep_needed) - return await self._reliable_request(completion_args) + async for chunk in self._reliable_stream_request(completion_args): + yield chunk finally: self._semaphore.release() @@ -66,15 +71,12 @@ class LLMRequestQueue: retry=retry_if_exception(should_retry_exception), reraise=True, ) - async def _reliable_request(self, completion_args: dict[str, Any]) -> ModelResponse: - response = completion(**completion_args, stream=False) - if isinstance(response, ModelResponse): - return response - self._raise_unexpected_response() - raise RuntimeError("Unreachable code") - - def _raise_unexpected_response(self) -> None: - raise RuntimeError("Unexpected response type") + async def _reliable_stream_request( + self, completion_args: dict[str, Any] + ) -> AsyncIterator[ModelResponseStream]: + response = await asyncio.to_thread(completion, **completion_args, stream=True) + for chunk in response: + yield chunk _global_queue: LLMRequestQueue | None = None diff --git a/strix/llm/utils.py b/strix/llm/utils.py index 8c141c6..e775cff 100644 --- a/strix/llm/utils.py +++ b/strix/llm/utils.py @@ -47,10 +47,14 @@ def parse_tool_invocations(content: str) -> list[dict[str, Any]] | None: def _fix_stopword(content: str) -> str: - if "" not in content + ): if content.endswith("" - elif not content.rstrip().endswith(""): + else: content = content + "\n" return content @@ -75,6 +79,12 @@ def clean_content(content: str) -> str: tool_pattern = r"]+>.*?" cleaned = re.sub(tool_pattern, "", content, flags=re.DOTALL) + incomplete_tool_pattern = r"]+>.*$" + cleaned = re.sub(incomplete_tool_pattern, "", cleaned, flags=re.DOTALL) + + partial_tag_pattern = r"]*)?)?)?)?)?)?)?)?)?$" + cleaned = re.sub(partial_tag_pattern, "", cleaned) + hidden_xml_patterns = [ r".*?", r".*?", diff --git a/strix/telemetry/tracer.py b/strix/telemetry/tracer.py index 63781bc..59423f2 100644 --- a/strix/telemetry/tracer.py +++ b/strix/telemetry/tracer.py @@ -33,6 +33,7 @@ class Tracer: self.agents: dict[str, dict[str, Any]] = {} self.tool_executions: dict[int, dict[str, Any]] = {} self.chat_messages: list[dict[str, Any]] = [] + self.streaming_content: dict[str, str] = {} self.vulnerability_reports: list[dict[str, Any]] = [] self.final_scan_result: str | None = None @@ -333,5 +334,14 @@ class Tracer: "total_tokens": total_stats["input_tokens"] + total_stats["output_tokens"], } + def update_streaming_content(self, agent_id: str, content: str) -> None: + self.streaming_content[agent_id] = content + + def clear_streaming_content(self, agent_id: str) -> None: + self.streaming_content.pop(agent_id, None) + + def get_streaming_content(self, agent_id: str) -> str | None: + return self.streaming_content.get(agent_id) + def cleanup(self) -> None: self.save_run_data(mark_complete=True)