perf: optimize TUI streaming rendering performance
- Pre-compile regex patterns in streaming_parser.py - Move hot-path imports to module level in tui.py - Add streaming content caching to avoid re-rendering unchanged content - Track streaming length to skip unnecessary re-renders - Reduce UI update interval from 250ms to 350ms
This commit is contained in:
@@ -6,6 +6,11 @@ from typing import Literal
|
|||||||
|
|
||||||
_FUNCTION_TAG_PREFIX = "<function="
|
_FUNCTION_TAG_PREFIX = "<function="
|
||||||
|
|
||||||
|
_FUNC_PATTERN = re.compile(r"<function=([^>]+)>")
|
||||||
|
_FUNC_END_PATTERN = re.compile(r"</function>")
|
||||||
|
_COMPLETE_PARAM_PATTERN = re.compile(r"<parameter=([^>]+)>(.*?)</parameter>", re.DOTALL)
|
||||||
|
_INCOMPLETE_PARAM_PATTERN = re.compile(r"<parameter=([^>]+)>(.*)$", re.DOTALL)
|
||||||
|
|
||||||
|
|
||||||
def _get_safe_content(content: str) -> tuple[str, str]:
|
def _get_safe_content(content: str) -> tuple[str, str]:
|
||||||
if not content:
|
if not content:
|
||||||
@@ -39,8 +44,7 @@ def parse_streaming_content(content: str) -> list[StreamSegment]:
|
|||||||
|
|
||||||
segments: list[StreamSegment] = []
|
segments: list[StreamSegment] = []
|
||||||
|
|
||||||
func_pattern = r"<function=([^>]+)>"
|
func_matches = list(_FUNC_PATTERN.finditer(content))
|
||||||
func_matches = list(re.finditer(func_pattern, content))
|
|
||||||
|
|
||||||
if not func_matches:
|
if not func_matches:
|
||||||
safe_content, _ = _get_safe_content(content)
|
safe_content, _ = _get_safe_content(content)
|
||||||
@@ -59,12 +63,12 @@ def parse_streaming_content(content: str) -> list[StreamSegment]:
|
|||||||
tool_name = match.group(1)
|
tool_name = match.group(1)
|
||||||
func_start = match.end()
|
func_start = match.end()
|
||||||
|
|
||||||
func_end_match = re.search(r"</function>", content[func_start:])
|
func_end_match = _FUNC_END_PATTERN.search(content, func_start)
|
||||||
|
|
||||||
if func_end_match:
|
if func_end_match:
|
||||||
func_body = content[func_start : func_start + func_end_match.start()]
|
func_body = content[func_start : func_end_match.start()]
|
||||||
is_complete = True
|
is_complete = True
|
||||||
end_pos = func_start + func_end_match.end()
|
end_pos = func_end_match.end()
|
||||||
else:
|
else:
|
||||||
if i + 1 < len(func_matches):
|
if i + 1 < len(func_matches):
|
||||||
next_func_start = func_matches[i + 1].start()
|
next_func_start = func_matches[i + 1].start()
|
||||||
@@ -98,8 +102,7 @@ def parse_streaming_content(content: str) -> list[StreamSegment]:
|
|||||||
def _parse_streaming_params(func_body: str) -> dict[str, str]:
|
def _parse_streaming_params(func_body: str) -> dict[str, str]:
|
||||||
args: dict[str, str] = {}
|
args: dict[str, str] = {}
|
||||||
|
|
||||||
complete_pattern = r"<parameter=([^>]+)>(.*?)</parameter>"
|
complete_matches = list(_COMPLETE_PARAM_PATTERN.finditer(func_body))
|
||||||
complete_matches = list(re.finditer(complete_pattern, func_body, re.DOTALL))
|
|
||||||
complete_end_pos = 0
|
complete_end_pos = 0
|
||||||
|
|
||||||
for match in complete_matches:
|
for match in complete_matches:
|
||||||
@@ -109,8 +112,7 @@ def _parse_streaming_params(func_body: str) -> dict[str, str]:
|
|||||||
complete_end_pos = max(complete_end_pos, match.end())
|
complete_end_pos = max(complete_end_pos, match.end())
|
||||||
|
|
||||||
remaining = func_body[complete_end_pos:]
|
remaining = func_body[complete_end_pos:]
|
||||||
incomplete_pattern = r"<parameter=([^>]+)>(.*)$"
|
incomplete_match = _INCOMPLETE_PARAM_PATTERN.search(remaining)
|
||||||
incomplete_match = re.search(incomplete_pattern, remaining, re.DOTALL)
|
|
||||||
if incomplete_match:
|
if incomplete_match:
|
||||||
param_name = incomplete_match.group(1)
|
param_name = incomplete_match.group(1)
|
||||||
param_value = html.unescape(incomplete_match.group(2).strip())
|
param_value = html.unescape(incomplete_match.group(2).strip())
|
||||||
|
|||||||
@@ -29,6 +29,10 @@ from textual.widgets import Button, Label, Static, TextArea, Tree
|
|||||||
from textual.widgets.tree import TreeNode
|
from textual.widgets.tree import TreeNode
|
||||||
|
|
||||||
from strix.agents.StrixAgent import StrixAgent
|
from strix.agents.StrixAgent import StrixAgent
|
||||||
|
from strix.interface.streaming_parser import parse_streaming_content
|
||||||
|
from strix.interface.tool_components.agent_message_renderer import AgentMessageRenderer
|
||||||
|
from strix.interface.tool_components.registry import get_tool_renderer
|
||||||
|
from strix.interface.tool_components.user_message_renderer import UserMessageRenderer
|
||||||
from strix.interface.utils import build_tui_stats_text
|
from strix.interface.utils import build_tui_stats_text
|
||||||
from strix.llm.config import LLMConfig
|
from strix.llm.config import LLMConfig
|
||||||
from strix.telemetry.tracer import Tracer, set_global_tracer
|
from strix.telemetry.tracer import Tracer, set_global_tracer
|
||||||
@@ -691,6 +695,9 @@ class StrixTUIApp(App): # type: ignore[misc]
|
|||||||
self._displayed_agents: set[str] = set()
|
self._displayed_agents: set[str] = set()
|
||||||
self._displayed_events: list[str] = []
|
self._displayed_events: list[str] = []
|
||||||
|
|
||||||
|
self._streaming_render_cache: dict[str, tuple[int, Any]] = {}
|
||||||
|
self._last_streaming_len: dict[str, int] = {}
|
||||||
|
|
||||||
self._scan_thread: threading.Thread | None = None
|
self._scan_thread: threading.Thread | None = None
|
||||||
self._scan_stop_event = threading.Event()
|
self._scan_stop_event = threading.Event()
|
||||||
self._scan_completed = threading.Event()
|
self._scan_completed = threading.Event()
|
||||||
@@ -853,7 +860,7 @@ class StrixTUIApp(App): # type: ignore[misc]
|
|||||||
|
|
||||||
self._start_scan_thread()
|
self._start_scan_thread()
|
||||||
|
|
||||||
self.set_interval(0.25, self._update_ui_from_tracer)
|
self.set_interval(0.35, self._update_ui_from_tracer)
|
||||||
|
|
||||||
def _update_ui_from_tracer(self) -> None:
|
def _update_ui_from_tracer(self) -> None:
|
||||||
if self.show_splash:
|
if self.show_splash:
|
||||||
@@ -946,11 +953,17 @@ class StrixTUIApp(App): # type: ignore[misc]
|
|||||||
)
|
)
|
||||||
|
|
||||||
current_event_ids = [e["id"] for e in events]
|
current_event_ids = [e["id"] for e in events]
|
||||||
|
current_streaming_len = len(streaming) if streaming else 0
|
||||||
|
last_streaming_len = self._last_streaming_len.get(self.selected_agent_id, 0)
|
||||||
|
|
||||||
if not streaming and current_event_ids == self._displayed_events:
|
if (
|
||||||
|
current_event_ids == self._displayed_events
|
||||||
|
and current_streaming_len == last_streaming_len
|
||||||
|
):
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
self._displayed_events = current_event_ids
|
self._displayed_events = current_event_ids
|
||||||
|
self._last_streaming_len[self.selected_agent_id] = current_streaming_len
|
||||||
return self._get_rendered_events_content(events), "chat-content"
|
return self._get_rendered_events_content(events), "chat-content"
|
||||||
|
|
||||||
def _update_chat_view(self) -> None:
|
def _update_chat_view(self) -> None:
|
||||||
@@ -1025,18 +1038,20 @@ class StrixTUIApp(App): # type: ignore[misc]
|
|||||||
|
|
||||||
return Group(*renderables)
|
return Group(*renderables)
|
||||||
|
|
||||||
def _render_streaming_content(self, content: str) -> Any:
|
def _render_streaming_content(self, content: str, agent_id: str | None = None) -> Any:
|
||||||
from strix.interface.streaming_parser import parse_streaming_content
|
cache_key = agent_id or self.selected_agent_id or ""
|
||||||
|
content_len = len(content)
|
||||||
|
|
||||||
|
if cache_key in self._streaming_render_cache:
|
||||||
|
cached_len, cached_output = self._streaming_render_cache[cache_key]
|
||||||
|
if cached_len == content_len:
|
||||||
|
return cached_output
|
||||||
|
|
||||||
renderables: list[Any] = []
|
renderables: list[Any] = []
|
||||||
segments = parse_streaming_content(content)
|
segments = parse_streaming_content(content)
|
||||||
|
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
if segment.type == "text":
|
if segment.type == "text":
|
||||||
from strix.interface.tool_components.agent_message_renderer import (
|
|
||||||
AgentMessageRenderer,
|
|
||||||
)
|
|
||||||
|
|
||||||
text_content = AgentMessageRenderer.render_simple(segment.content)
|
text_content = AgentMessageRenderer.render_simple(segment.content)
|
||||||
if renderables:
|
if renderables:
|
||||||
renderables.append(Text(""))
|
renderables.append(Text(""))
|
||||||
@@ -1053,18 +1068,18 @@ class StrixTUIApp(App): # type: ignore[misc]
|
|||||||
renderables.append(tool_renderable)
|
renderables.append(tool_renderable)
|
||||||
|
|
||||||
if not renderables:
|
if not renderables:
|
||||||
return Text()
|
result = Text()
|
||||||
|
elif len(renderables) == 1:
|
||||||
|
result = renderables[0]
|
||||||
|
else:
|
||||||
|
result = Group(*renderables)
|
||||||
|
|
||||||
if len(renderables) == 1:
|
self._streaming_render_cache[cache_key] = (content_len, result)
|
||||||
return renderables[0]
|
return result
|
||||||
|
|
||||||
return Group(*renderables)
|
|
||||||
|
|
||||||
def _render_streaming_tool(
|
def _render_streaming_tool(
|
||||||
self, tool_name: str, args: dict[str, str], is_complete: bool
|
self, tool_name: str, args: dict[str, str], is_complete: bool
|
||||||
) -> Any:
|
) -> Any:
|
||||||
from strix.interface.tool_components.registry import get_tool_renderer
|
|
||||||
|
|
||||||
tool_data = {
|
tool_data = {
|
||||||
"tool_name": tool_name,
|
"tool_name": tool_name,
|
||||||
"args": args,
|
"args": args,
|
||||||
@@ -1395,6 +1410,8 @@ class StrixTUIApp(App): # type: ignore[misc]
|
|||||||
return
|
return
|
||||||
|
|
||||||
self._displayed_events.clear()
|
self._displayed_events.clear()
|
||||||
|
self._streaming_render_cache.clear()
|
||||||
|
self._last_streaming_len.clear()
|
||||||
|
|
||||||
self.call_later(self._update_chat_view)
|
self.call_later(self._update_chat_view)
|
||||||
self._update_agent_status_display()
|
self._update_agent_status_display()
|
||||||
@@ -1589,8 +1606,6 @@ class StrixTUIApp(App): # type: ignore[misc]
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if role == "user":
|
if role == "user":
|
||||||
from strix.interface.tool_components.user_message_renderer import UserMessageRenderer
|
|
||||||
|
|
||||||
return UserMessageRenderer.render_simple(content)
|
return UserMessageRenderer.render_simple(content)
|
||||||
|
|
||||||
if metadata.get("interrupted"):
|
if metadata.get("interrupted"):
|
||||||
@@ -1601,8 +1616,6 @@ class StrixTUIApp(App): # type: ignore[misc]
|
|||||||
interrupted_text.append("Interrupted by user", style="yellow dim")
|
interrupted_text.append("Interrupted by user", style="yellow dim")
|
||||||
return Group(streaming_result, interrupted_text)
|
return Group(streaming_result, interrupted_text)
|
||||||
|
|
||||||
from strix.interface.tool_components.agent_message_renderer import AgentMessageRenderer
|
|
||||||
|
|
||||||
return AgentMessageRenderer.render_simple(content)
|
return AgentMessageRenderer.render_simple(content)
|
||||||
|
|
||||||
def _render_tool_content_simple(self, tool_data: dict[str, Any]) -> Any:
|
def _render_tool_content_simple(self, tool_data: dict[str, Any]) -> Any:
|
||||||
@@ -1611,8 +1624,6 @@ class StrixTUIApp(App): # type: ignore[misc]
|
|||||||
status = tool_data.get("status", "unknown")
|
status = tool_data.get("status", "unknown")
|
||||||
result = tool_data.get("result")
|
result = tool_data.get("result")
|
||||||
|
|
||||||
from strix.interface.tool_components.registry import get_tool_renderer
|
|
||||||
|
|
||||||
renderer = get_tool_renderer(tool_name)
|
renderer = get_tool_renderer(tool_name)
|
||||||
|
|
||||||
if renderer:
|
if renderer:
|
||||||
|
|||||||
Reference in New Issue
Block a user