feat(tui): add real-time streaming LLM output with full content display
- Convert LiteLLM requests to streaming mode with stream_request() - Add streaming parser to handle live LLM output segments - Update TUI for real-time streaming content rendering - Add tracer methods for streaming content tracking - Clean function tags from streamed content to prevent display - Remove all truncation from tool renderers for full content visibility
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from fnmatch import fnmatch
|
||||
@@ -12,7 +13,7 @@ from jinja2 import (
|
||||
FileSystemLoader,
|
||||
select_autoescape,
|
||||
)
|
||||
from litellm import ModelResponse, completion_cost
|
||||
from litellm import completion_cost, stream_chunk_builder
|
||||
from litellm.utils import supports_prompt_caching, supports_vision
|
||||
|
||||
from strix.llm.config import LLMConfig
|
||||
@@ -276,7 +277,7 @@ class LLM:
|
||||
conversation_history: list[dict[str, Any]],
|
||||
scan_id: str | None = None,
|
||||
step_number: int = 1,
|
||||
) -> LLMResponse:
|
||||
) -> AsyncIterator[LLMResponse]:
|
||||
messages = [{"role": "system", "content": self.system_prompt}]
|
||||
|
||||
identity_message = self._build_identity_message()
|
||||
@@ -292,30 +293,43 @@ class LLM:
|
||||
cached_messages = self._prepare_cached_messages(messages)
|
||||
|
||||
try:
|
||||
response = await self._make_request(cached_messages)
|
||||
self._update_usage_stats(response)
|
||||
accumulated_content = ""
|
||||
chunks: list[Any] = []
|
||||
|
||||
content = ""
|
||||
if (
|
||||
response.choices
|
||||
and hasattr(response.choices[0], "message")
|
||||
and response.choices[0].message
|
||||
):
|
||||
content = getattr(response.choices[0].message, "content", "") or ""
|
||||
async for chunk in self._stream_request(cached_messages):
|
||||
chunks.append(chunk)
|
||||
delta = self._extract_chunk_delta(chunk)
|
||||
if delta:
|
||||
accumulated_content += delta
|
||||
|
||||
content = _truncate_to_first_function(content)
|
||||
if "</function>" in accumulated_content:
|
||||
function_end = accumulated_content.find("</function>") + len("</function>")
|
||||
accumulated_content = accumulated_content[:function_end]
|
||||
|
||||
if "</function>" in content:
|
||||
function_end_index = content.find("</function>") + len("</function>")
|
||||
content = content[:function_end_index]
|
||||
yield LLMResponse(
|
||||
scan_id=scan_id,
|
||||
step_number=step_number,
|
||||
role=StepRole.AGENT,
|
||||
content=accumulated_content,
|
||||
tool_invocations=None,
|
||||
)
|
||||
|
||||
tool_invocations = parse_tool_invocations(content)
|
||||
if chunks:
|
||||
complete_response = stream_chunk_builder(chunks)
|
||||
self._update_usage_stats(complete_response)
|
||||
|
||||
return LLMResponse(
|
||||
accumulated_content = _truncate_to_first_function(accumulated_content)
|
||||
if "</function>" in accumulated_content:
|
||||
function_end = accumulated_content.find("</function>") + len("</function>")
|
||||
accumulated_content = accumulated_content[:function_end]
|
||||
|
||||
tool_invocations = parse_tool_invocations(accumulated_content)
|
||||
|
||||
yield LLMResponse(
|
||||
scan_id=scan_id,
|
||||
step_number=step_number,
|
||||
role=StepRole.AGENT,
|
||||
content=content,
|
||||
content=accumulated_content,
|
||||
tool_invocations=tool_invocations if tool_invocations else None,
|
||||
)
|
||||
|
||||
@@ -364,6 +378,12 @@ class LLM:
|
||||
except Exception as e:
|
||||
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
|
||||
|
||||
def _extract_chunk_delta(self, chunk: Any) -> str:
|
||||
if chunk.choices and hasattr(chunk.choices[0], "delta"):
|
||||
delta = chunk.choices[0].delta
|
||||
return getattr(delta, "content", "") or ""
|
||||
return ""
|
||||
|
||||
@property
|
||||
def usage_stats(self) -> dict[str, dict[str, int | float]]:
|
||||
return {
|
||||
@@ -436,10 +456,10 @@ class LLM:
|
||||
filtered_messages.append(updated_msg)
|
||||
return filtered_messages
|
||||
|
||||
async def _make_request(
|
||||
async def _stream_request(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
) -> ModelResponse:
|
||||
) -> AsyncIterator[Any]:
|
||||
if not self._model_supports_vision():
|
||||
messages = self._filter_images_from_messages(messages)
|
||||
|
||||
@@ -447,6 +467,7 @@ class LLM:
|
||||
"model": self.config.model_name,
|
||||
"messages": messages,
|
||||
"timeout": self.config.timeout,
|
||||
"stream_options": {"include_usage": True},
|
||||
}
|
||||
|
||||
if _LLM_API_KEY:
|
||||
@@ -461,14 +482,13 @@ class LLM:
|
||||
completion_args["reasoning_effort"] = "high"
|
||||
|
||||
queue = get_global_queue()
|
||||
response = await queue.make_request(completion_args)
|
||||
|
||||
self._total_stats.requests += 1
|
||||
self._last_request_stats = RequestStats(requests=1)
|
||||
|
||||
return response
|
||||
async for chunk in queue.stream_request(completion_args):
|
||||
yield chunk
|
||||
|
||||
def _update_usage_stats(self, response: ModelResponse) -> None:
|
||||
def _update_usage_stats(self, response: Any) -> None:
|
||||
try:
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
input_tokens = getattr(response.usage, "prompt_tokens", 0)
|
||||
|
||||
Reference in New Issue
Block a user