feat(tui): add real-time streaming LLM output with full content display

- Convert LiteLLM requests to streaming mode with stream_request()
- Add streaming parser to handle live LLM output segments
- Update TUI for real-time streaming content rendering
- Add tracer methods for streaming content tracking
- Clean function tags from streamed content to prevent display
- Remove all truncation from tool renderers for full content visibility
This commit is contained in:
0xallam
2026-01-05 09:52:05 -08:00
committed by Ahmed Allam
parent a2142cc985
commit a6dcb7756e
21 changed files with 345 additions and 135 deletions

View File

@@ -1,5 +1,6 @@
import logging
import os
from collections.abc import AsyncIterator
from dataclasses import dataclass
from enum import Enum
from fnmatch import fnmatch
@@ -12,7 +13,7 @@ from jinja2 import (
FileSystemLoader,
select_autoescape,
)
from litellm import ModelResponse, completion_cost
from litellm import completion_cost, stream_chunk_builder
from litellm.utils import supports_prompt_caching, supports_vision
from strix.llm.config import LLMConfig
@@ -276,7 +277,7 @@ class LLM:
conversation_history: list[dict[str, Any]],
scan_id: str | None = None,
step_number: int = 1,
) -> LLMResponse:
) -> AsyncIterator[LLMResponse]:
messages = [{"role": "system", "content": self.system_prompt}]
identity_message = self._build_identity_message()
@@ -292,30 +293,43 @@ class LLM:
cached_messages = self._prepare_cached_messages(messages)
try:
response = await self._make_request(cached_messages)
self._update_usage_stats(response)
accumulated_content = ""
chunks: list[Any] = []
content = ""
if (
response.choices
and hasattr(response.choices[0], "message")
and response.choices[0].message
):
content = getattr(response.choices[0].message, "content", "") or ""
async for chunk in self._stream_request(cached_messages):
chunks.append(chunk)
delta = self._extract_chunk_delta(chunk)
if delta:
accumulated_content += delta
content = _truncate_to_first_function(content)
if "</function>" in accumulated_content:
function_end = accumulated_content.find("</function>") + len("</function>")
accumulated_content = accumulated_content[:function_end]
if "</function>" in content:
function_end_index = content.find("</function>") + len("</function>")
content = content[:function_end_index]
yield LLMResponse(
scan_id=scan_id,
step_number=step_number,
role=StepRole.AGENT,
content=accumulated_content,
tool_invocations=None,
)
tool_invocations = parse_tool_invocations(content)
if chunks:
complete_response = stream_chunk_builder(chunks)
self._update_usage_stats(complete_response)
return LLMResponse(
accumulated_content = _truncate_to_first_function(accumulated_content)
if "</function>" in accumulated_content:
function_end = accumulated_content.find("</function>") + len("</function>")
accumulated_content = accumulated_content[:function_end]
tool_invocations = parse_tool_invocations(accumulated_content)
yield LLMResponse(
scan_id=scan_id,
step_number=step_number,
role=StepRole.AGENT,
content=content,
content=accumulated_content,
tool_invocations=tool_invocations if tool_invocations else None,
)
@@ -364,6 +378,12 @@ class LLM:
except Exception as e:
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
def _extract_chunk_delta(self, chunk: Any) -> str:
if chunk.choices and hasattr(chunk.choices[0], "delta"):
delta = chunk.choices[0].delta
return getattr(delta, "content", "") or ""
return ""
@property
def usage_stats(self) -> dict[str, dict[str, int | float]]:
return {
@@ -436,10 +456,10 @@ class LLM:
filtered_messages.append(updated_msg)
return filtered_messages
async def _make_request(
async def _stream_request(
self,
messages: list[dict[str, Any]],
) -> ModelResponse:
) -> AsyncIterator[Any]:
if not self._model_supports_vision():
messages = self._filter_images_from_messages(messages)
@@ -447,6 +467,7 @@ class LLM:
"model": self.config.model_name,
"messages": messages,
"timeout": self.config.timeout,
"stream_options": {"include_usage": True},
}
if _LLM_API_KEY:
@@ -461,14 +482,13 @@ class LLM:
completion_args["reasoning_effort"] = "high"
queue = get_global_queue()
response = await queue.make_request(completion_args)
self._total_stats.requests += 1
self._last_request_stats = RequestStats(requests=1)
return response
async for chunk in queue.stream_request(completion_args):
yield chunk
def _update_usage_stats(self, response: ModelResponse) -> None:
def _update_usage_stats(self, response: Any) -> None:
try:
if hasattr(response, "usage") and response.usage:
input_tokens = getattr(response.usage, "prompt_tokens", 0)