From 878d6ebf57503cec56fe8bcba8c36ea44cc8a69d Mon Sep 17 00:00:00 2001 From: 0xallam Date: Mon, 5 Jan 2026 16:22:51 -0800 Subject: [PATCH] refactor(tui): improve agent node expansion handling and add tree node selection functionality --- strix/interface/tui.py | 27 +++++++++++++++++++++++++-- strix/llm/__init__.py | 3 +++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/strix/interface/tui.py b/strix/interface/tui.py index 4d98114..2c927f7 100644 --- a/strix/interface/tui.py +++ b/strix/interface/tui.py @@ -522,7 +522,7 @@ class StrixTUIApp(App): # type: ignore[misc] agent_updates = True if agent_updates: - self._expand_all_agent_nodes() + self._expand_new_agent_nodes() self._update_chat_view() @@ -1077,6 +1077,13 @@ class StrixTUIApp(App): # type: ignore[misc] logging.warning(f"Failed to add agent node {agent_id}: {e}") + def _expand_new_agent_nodes(self) -> None: + if len(self.screen_stack) > 1 or self.show_splash: + return + + if not self.is_mounted: + return + def _expand_all_agent_nodes(self) -> None: if len(self.screen_stack) > 1 or self.show_splash: return @@ -1156,7 +1163,7 @@ class StrixTUIApp(App): # type: ignore[misc] old_node.remove() parent_node.allow_expand = True - self._expand_all_agent_nodes() + parent_node.expand() def _render_chat_content(self, msg_data: dict[str, Any]) -> Text | None: role = msg_data.get("role") @@ -1266,6 +1273,22 @@ class StrixTUIApp(App): # type: ignore[misc] if agent_id: self.selected_agent_id = agent_id + @on(Tree.NodeSelected) # type: ignore[misc] + def handle_tree_node_selected(self, event: Tree.NodeSelected) -> None: + if len(self.screen_stack) > 1 or self.show_splash: + return + + if not self.is_mounted: + return + + node = event.node + + if node.allow_expand: + if node.is_expanded: + node.collapse() + else: + node.expand() + def _send_user_message(self, message: str) -> None: if not self.selected_agent_id: return diff --git a/strix/llm/__init__.py b/strix/llm/__init__.py index f3f8b67..a971dc0 100644 --- a/strix/llm/__init__.py +++ b/strix/llm/__init__.py @@ -1,3 +1,5 @@ +import logging + import litellm from .config import LLMConfig @@ -11,3 +13,4 @@ __all__ = [ ] litellm._logging._disable_debugging() +logging.getLogger("aiohttp").setLevel(logging.CRITICAL)