From e30ef9aec87172c975a0619e0a74f92bfd7e7a6a Mon Sep 17 00:00:00 2001 From: 0xallam Date: Mon, 19 Jan 2026 11:35:05 -0800 Subject: [PATCH] perf: optimize TUI streaming rendering performance - Pre-compile regex patterns in streaming_parser.py - Move hot-path imports to module level in tui.py - Add streaming content caching to avoid re-rendering unchanged content - Track streaming length to skip unnecessary re-renders - Reduce UI update interval from 250ms to 350ms --- strix/interface/streaming_parser.py | 20 ++++++----- strix/interface/tui.py | 53 +++++++++++++++++------------ 2 files changed, 43 insertions(+), 30 deletions(-) diff --git a/strix/interface/streaming_parser.py b/strix/interface/streaming_parser.py index 8adbc9b..95e9523 100644 --- a/strix/interface/streaming_parser.py +++ b/strix/interface/streaming_parser.py @@ -6,6 +6,11 @@ from typing import Literal _FUNCTION_TAG_PREFIX = "]+)>") +_FUNC_END_PATTERN = re.compile(r"") +_COMPLETE_PARAM_PATTERN = re.compile(r"]+)>(.*?)", re.DOTALL) +_INCOMPLETE_PARAM_PATTERN = re.compile(r"]+)>(.*)$", re.DOTALL) + def _get_safe_content(content: str) -> tuple[str, str]: if not content: @@ -39,8 +44,7 @@ def parse_streaming_content(content: str) -> list[StreamSegment]: segments: list[StreamSegment] = [] - func_pattern = r"]+)>" - func_matches = list(re.finditer(func_pattern, content)) + func_matches = list(_FUNC_PATTERN.finditer(content)) if not func_matches: safe_content, _ = _get_safe_content(content) @@ -59,12 +63,12 @@ def parse_streaming_content(content: str) -> list[StreamSegment]: tool_name = match.group(1) func_start = match.end() - func_end_match = re.search(r"", content[func_start:]) + func_end_match = _FUNC_END_PATTERN.search(content, func_start) if func_end_match: - func_body = content[func_start : func_start + func_end_match.start()] + func_body = content[func_start : func_end_match.start()] is_complete = True - end_pos = func_start + func_end_match.end() + end_pos = func_end_match.end() else: if i + 1 < len(func_matches): next_func_start = func_matches[i + 1].start() @@ -98,8 +102,7 @@ def parse_streaming_content(content: str) -> list[StreamSegment]: def _parse_streaming_params(func_body: str) -> dict[str, str]: args: dict[str, str] = {} - complete_pattern = r"]+)>(.*?)" - complete_matches = list(re.finditer(complete_pattern, func_body, re.DOTALL)) + complete_matches = list(_COMPLETE_PARAM_PATTERN.finditer(func_body)) complete_end_pos = 0 for match in complete_matches: @@ -109,8 +112,7 @@ def _parse_streaming_params(func_body: str) -> dict[str, str]: complete_end_pos = max(complete_end_pos, match.end()) remaining = func_body[complete_end_pos:] - incomplete_pattern = r"]+)>(.*)$" - incomplete_match = re.search(incomplete_pattern, remaining, re.DOTALL) + incomplete_match = _INCOMPLETE_PARAM_PATTERN.search(remaining) if incomplete_match: param_name = incomplete_match.group(1) param_value = html.unescape(incomplete_match.group(2).strip()) diff --git a/strix/interface/tui.py b/strix/interface/tui.py index e21baae..6181b32 100644 --- a/strix/interface/tui.py +++ b/strix/interface/tui.py @@ -29,6 +29,10 @@ from textual.widgets import Button, Label, Static, TextArea, Tree from textual.widgets.tree import TreeNode from strix.agents.StrixAgent import StrixAgent +from strix.interface.streaming_parser import parse_streaming_content +from strix.interface.tool_components.agent_message_renderer import AgentMessageRenderer +from strix.interface.tool_components.registry import get_tool_renderer +from strix.interface.tool_components.user_message_renderer import UserMessageRenderer from strix.interface.utils import build_tui_stats_text from strix.llm.config import LLMConfig from strix.telemetry.tracer import Tracer, set_global_tracer @@ -691,6 +695,9 @@ class StrixTUIApp(App): # type: ignore[misc] self._displayed_agents: set[str] = set() self._displayed_events: list[str] = [] + self._streaming_render_cache: dict[str, tuple[int, Any]] = {} + self._last_streaming_len: dict[str, int] = {} + self._scan_thread: threading.Thread | None = None self._scan_stop_event = threading.Event() self._scan_completed = threading.Event() @@ -853,7 +860,7 @@ class StrixTUIApp(App): # type: ignore[misc] self._start_scan_thread() - self.set_interval(0.25, self._update_ui_from_tracer) + self.set_interval(0.35, self._update_ui_from_tracer) def _update_ui_from_tracer(self) -> None: if self.show_splash: @@ -946,11 +953,17 @@ class StrixTUIApp(App): # type: ignore[misc] ) current_event_ids = [e["id"] for e in events] + current_streaming_len = len(streaming) if streaming else 0 + last_streaming_len = self._last_streaming_len.get(self.selected_agent_id, 0) - if not streaming and current_event_ids == self._displayed_events: + if ( + current_event_ids == self._displayed_events + and current_streaming_len == last_streaming_len + ): return None, None self._displayed_events = current_event_ids + self._last_streaming_len[self.selected_agent_id] = current_streaming_len return self._get_rendered_events_content(events), "chat-content" def _update_chat_view(self) -> None: @@ -1025,18 +1038,20 @@ class StrixTUIApp(App): # type: ignore[misc] return Group(*renderables) - def _render_streaming_content(self, content: str) -> Any: - from strix.interface.streaming_parser import parse_streaming_content + def _render_streaming_content(self, content: str, agent_id: str | None = None) -> Any: + cache_key = agent_id or self.selected_agent_id or "" + content_len = len(content) + + if cache_key in self._streaming_render_cache: + cached_len, cached_output = self._streaming_render_cache[cache_key] + if cached_len == content_len: + return cached_output renderables: list[Any] = [] segments = parse_streaming_content(content) for segment in segments: if segment.type == "text": - from strix.interface.tool_components.agent_message_renderer import ( - AgentMessageRenderer, - ) - text_content = AgentMessageRenderer.render_simple(segment.content) if renderables: renderables.append(Text("")) @@ -1053,18 +1068,18 @@ class StrixTUIApp(App): # type: ignore[misc] renderables.append(tool_renderable) if not renderables: - return Text() + result = Text() + elif len(renderables) == 1: + result = renderables[0] + else: + result = Group(*renderables) - if len(renderables) == 1: - return renderables[0] - - return Group(*renderables) + self._streaming_render_cache[cache_key] = (content_len, result) + return result def _render_streaming_tool( self, tool_name: str, args: dict[str, str], is_complete: bool ) -> Any: - from strix.interface.tool_components.registry import get_tool_renderer - tool_data = { "tool_name": tool_name, "args": args, @@ -1395,6 +1410,8 @@ class StrixTUIApp(App): # type: ignore[misc] return self._displayed_events.clear() + self._streaming_render_cache.clear() + self._last_streaming_len.clear() self.call_later(self._update_chat_view) self._update_agent_status_display() @@ -1589,8 +1606,6 @@ class StrixTUIApp(App): # type: ignore[misc] return None if role == "user": - from strix.interface.tool_components.user_message_renderer import UserMessageRenderer - return UserMessageRenderer.render_simple(content) if metadata.get("interrupted"): @@ -1601,8 +1616,6 @@ class StrixTUIApp(App): # type: ignore[misc] interrupted_text.append("Interrupted by user", style="yellow dim") return Group(streaming_result, interrupted_text) - from strix.interface.tool_components.agent_message_renderer import AgentMessageRenderer - return AgentMessageRenderer.render_simple(content) def _render_tool_content_simple(self, tool_data: dict[str, Any]) -> Any: @@ -1611,8 +1624,6 @@ class StrixTUIApp(App): # type: ignore[misc] status = tool_data.get("status", "unknown") result = tool_data.get("result") - from strix.interface.tool_components.registry import get_tool_renderer - renderer = get_tool_renderer(tool_name) if renderer: