feat(tui): add real-time streaming LLM output with full content display

- Convert LiteLLM requests to streaming mode with stream_request()
- Add streaming parser to handle live LLM output segments
- Update TUI for real-time streaming content rendering
- Add tracer methods for streaming content tracking
- Clean function tags from streamed content to prevent display
- Remove all truncation from tool renderers for full content visibility
This commit is contained in:
0xallam
2026-01-05 09:52:05 -08:00
committed by Ahmed Allam
parent a2142cc985
commit a6dcb7756e
21 changed files with 345 additions and 135 deletions

View File

@@ -351,9 +351,16 @@ class BaseAgent(metaclass=AgentMeta):
self.state.add_message("user", task)
async def _process_iteration(self, tracer: Optional["Tracer"]) -> bool:
response = await self.llm.generate(self.state.get_conversation_history())
final_response = None
async for response in self.llm.generate(self.state.get_conversation_history()):
final_response = response
if tracer and response.content:
tracer.update_streaming_content(self.state.agent_id, response.content)
content_stripped = (response.content or "").strip()
if final_response is None:
return False
content_stripped = (final_response.content or "").strip()
if not content_stripped:
corrective_message = (
@@ -369,17 +376,18 @@ class BaseAgent(metaclass=AgentMeta):
self.state.add_message("user", corrective_message)
return False
self.state.add_message("assistant", response.content)
self.state.add_message("assistant", final_response.content)
if tracer:
tracer.clear_streaming_content(self.state.agent_id)
tracer.log_chat_message(
content=clean_content(response.content),
content=clean_content(final_response.content),
role="assistant",
agent_id=self.state.agent_id,
)
actions = (
response.tool_invocations
if hasattr(response, "tool_invocations") and response.tool_invocations
final_response.tool_invocations
if hasattr(final_response, "tool_invocations") and final_response.tool_invocations
else []
)

View File

@@ -0,0 +1,119 @@
import html
import re
from dataclasses import dataclass
from typing import Literal
_FUNCTION_TAG_PREFIX = "<function="
def _get_safe_content(content: str) -> tuple[str, str]:
if not content:
return "", ""
last_lt = content.rfind("<")
if last_lt == -1:
return content, ""
suffix = content[last_lt:]
target = _FUNCTION_TAG_PREFIX # "<function="
if target.startswith(suffix):
return content[:last_lt], suffix
return content, ""
@dataclass
class StreamSegment:
type: Literal["text", "tool"]
content: str
tool_name: str | None = None
args: dict[str, str] | None = None
is_complete: bool = False
def parse_streaming_content(content: str) -> list[StreamSegment]:
if not content:
return []
segments: list[StreamSegment] = []
func_pattern = r"<function=([^>]+)>"
func_matches = list(re.finditer(func_pattern, content))
if not func_matches:
safe_content, _ = _get_safe_content(content)
text = safe_content.strip()
if text:
segments.append(StreamSegment(type="text", content=text))
return segments
first_func_start = func_matches[0].start()
if first_func_start > 0:
text_before = content[:first_func_start].strip()
if text_before:
segments.append(StreamSegment(type="text", content=text_before))
for i, match in enumerate(func_matches):
tool_name = match.group(1)
func_start = match.end()
func_end_match = re.search(r"</function>", content[func_start:])
if func_end_match:
func_body = content[func_start : func_start + func_end_match.start()]
is_complete = True
end_pos = func_start + func_end_match.end()
else:
if i + 1 < len(func_matches):
next_func_start = func_matches[i + 1].start()
func_body = content[func_start:next_func_start]
else:
func_body = content[func_start:]
is_complete = False
end_pos = len(content)
args = _parse_streaming_params(func_body)
segments.append(
StreamSegment(
type="tool",
content=func_body,
tool_name=tool_name,
args=args,
is_complete=is_complete,
)
)
if is_complete and i + 1 < len(func_matches):
next_start = func_matches[i + 1].start()
text_between = content[end_pos:next_start].strip()
if text_between:
segments.append(StreamSegment(type="text", content=text_between))
return segments
def _parse_streaming_params(func_body: str) -> dict[str, str]:
args: dict[str, str] = {}
complete_pattern = r"<parameter=([^>]+)>(.*?)</parameter>"
complete_matches = list(re.finditer(complete_pattern, func_body, re.DOTALL))
complete_end_pos = 0
for match in complete_matches:
param_name = match.group(1)
param_value = html.unescape(match.group(2).strip())
args[param_name] = param_value
complete_end_pos = max(complete_end_pos, match.end())
remaining = func_body[complete_end_pos:]
incomplete_pattern = r"<parameter=([^>]+)>(.*)$"
incomplete_match = re.search(incomplete_pattern, remaining, re.DOTALL)
if incomplete_match:
param_name = incomplete_match.group(1)
param_value = html.unescape(incomplete_match.group(2).strip())
args[param_name] = param_value
return args

View File

@@ -181,4 +181,10 @@ class AgentMessageRenderer(BaseToolRenderer):
if not content:
return Text()
return _apply_markdown_styles(content)
from strix.llm.utils import clean_content
cleaned = clean_content(content)
if not cleaned:
return Text()
return _apply_markdown_styles(cleaned)

View File

@@ -40,7 +40,7 @@ class CreateAgentRenderer(BaseToolRenderer):
if task:
text.append("\n ")
text.append(cls.truncate(task, 400), style="dim")
text.append(task, style="dim")
else:
text.append("\n ")
text.append("Spawning agent...", style="dim")
@@ -66,7 +66,7 @@ class SendMessageToAgentRenderer(BaseToolRenderer):
if message:
text.append("\n ")
text.append(cls.truncate(message, 400), style="dim")
text.append(message, style="dim")
else:
text.append("\n ")
text.append("Sending...", style="dim")
@@ -129,7 +129,7 @@ class WaitForMessageRenderer(BaseToolRenderer):
if reason:
text.append("\n ")
text.append(cls.truncate(reason, 400), style="dim")
text.append(reason, style="dim")
else:
text.append("\n ")
text.append("Agent paused until message received...", style="dim")

View File

@@ -23,12 +23,6 @@ class BaseToolRenderer(ABC):
css_classes = cls.get_css_classes(status)
return Static(content, classes=css_classes)
@classmethod
def truncate(cls, text: str, max_length: int = 500) -> str:
if len(text) <= max_length:
return text
return text[: max_length - 3] + "..."
@classmethod
def status_icon(cls, status: str) -> tuple[str, str]:
icons = {

View File

@@ -74,7 +74,7 @@ class BrowserRenderer(BaseToolRenderer):
def _build_url_action(cls, text: Text, label: str, url: str | None, suffix: str = "") -> None:
text.append(label, style="#06b6d4")
if url:
text.append(cls.truncate(url, 300), style="#06b6d4")
text.append(url, style="#06b6d4")
if suffix:
text.append(suffix, style="#06b6d4")
@@ -120,7 +120,7 @@ class BrowserRenderer(BaseToolRenderer):
label, value = handlers[action]
text.append(label, style="#06b6d4")
if value:
text.append(cls.truncate(str(value), 200), style="#06b6d4")
text.append(str(value), style="#06b6d4")
return text
if action == "execute_js":
@@ -128,7 +128,7 @@ class BrowserRenderer(BaseToolRenderer):
js_code = args.get("js_code")
if js_code:
text.append("\n")
text.append_text(cls._highlight_js(cls.truncate(js_code, 2000)))
text.append_text(cls._highlight_js(js_code))
return text
text.append(action, style="#06b6d4")

View File

@@ -83,8 +83,7 @@ class StrReplaceEditorRenderer(BaseToolRenderer):
if command == "str_replace" and (old_str or new_str):
if old_str:
old_display = cls.truncate(old_str, 1000)
highlighted_old = cls._highlight_code(old_display, path)
highlighted_old = cls._highlight_code(old_str, path)
for line in highlighted_old.plain.split("\n"):
text.append("\n")
text.append("-", style="#ef4444")
@@ -92,8 +91,7 @@ class StrReplaceEditorRenderer(BaseToolRenderer):
text.append(line)
if new_str:
new_display = cls.truncate(new_str, 1000)
highlighted_new = cls._highlight_code(new_display, path)
highlighted_new = cls._highlight_code(new_str, path)
for line in highlighted_new.plain.split("\n"):
text.append("\n")
text.append("+", style="#22c55e")
@@ -101,13 +99,11 @@ class StrReplaceEditorRenderer(BaseToolRenderer):
text.append(line)
elif command == "create" and file_text:
text_display = cls.truncate(file_text, 1500)
text.append("\n")
text.append_text(cls._highlight_code(text_display, path))
text.append_text(cls._highlight_code(file_text, path))
elif command == "insert" and new_str:
new_display = cls.truncate(new_str, 1000)
highlighted_new = cls._highlight_code(new_display, path)
highlighted_new = cls._highlight_code(new_str, path)
for line in highlighted_new.plain.split("\n"):
text.append("\n")
text.append("+", style="#22c55e")
@@ -164,19 +160,15 @@ class SearchFilesRenderer(BaseToolRenderer):
text.append(" ")
if path and regex:
path_display = path[-30:] if len(path) > 30 else path
regex_display = regex[:30] if len(regex) > 30 else regex
text.append(path_display, style="dim")
text.append(path, style="dim")
text.append(" for '", style="dim")
text.append(regex_display, style="dim")
text.append(regex, style="dim")
text.append("'", style="dim")
elif path:
path_display = path[-60:] if len(path) > 60 else path
text.append(path_display, style="dim")
text.append(path, style="dim")
elif regex:
regex_display = regex[:60] if len(regex) > 60 else regex
text.append("'", style="dim")
text.append(regex_display, style="dim")
text.append(regex, style="dim")
text.append("'", style="dim")
else:
text.append("Searching...", style="dim")

View File

@@ -28,11 +28,11 @@ class CreateNoteRenderer(BaseToolRenderer):
if title:
text.append("\n ")
text.append(cls.truncate(title.strip(), 300))
text.append(title.strip())
if content:
text.append("\n ")
text.append(cls.truncate(content.strip(), 800), style="dim")
text.append(content.strip(), style="dim")
if not title and not content:
text.append("\n ")
@@ -75,11 +75,11 @@ class UpdateNoteRenderer(BaseToolRenderer):
if title:
text.append("\n ")
text.append(cls.truncate(title, 300))
text.append(title)
if content:
text.append("\n ")
text.append(cls.truncate(content.strip(), 800), style="dim")
text.append(content.strip(), style="dim")
if not title and not content:
text.append("\n ")
@@ -110,23 +110,18 @@ class ListNotesRenderer(BaseToolRenderer):
text.append("\n ")
text.append("No notes", style="dim")
else:
for note in notes[:5]:
for note in notes:
title = note.get("title", "").strip() or "(untitled)"
category = note.get("category", "general")
note_content = note.get("content", "").strip()
text.append("\n - ")
text.append(cls.truncate(title, 300))
text.append(title)
text.append(f" ({category})", style="dim")
if note_content:
text.append("\n ")
text.append(cls.truncate(note_content, 400), style="dim")
remaining = max(count - 5, 0)
if remaining:
text.append("\n ")
text.append(f"... +{remaining} more", style="dim")
text.append(note_content, style="dim")
else:
text.append("\n ")
text.append("Loading...", style="dim")

View File

@@ -26,7 +26,7 @@ class ListRequestsRenderer(BaseToolRenderer):
if result and isinstance(result, dict) and "requests" in result:
requests = result["requests"]
if isinstance(requests, list) and requests:
for req in requests[:3]:
for req in requests:
if isinstance(req, dict):
method = req.get("method", "?")
path = req.get("path", "?")
@@ -34,16 +34,12 @@ class ListRequestsRenderer(BaseToolRenderer):
status = response.get("statusCode", "?")
text.append("\n ")
text.append(f"{method} {path}{status}", style="dim")
if len(requests) > 3:
text.append("\n ")
text.append(f"... +{len(requests) - 3} more", style="dim")
else:
text.append("\n ")
text.append("No requests found", style="dim")
elif httpql_filter:
text.append("\n ")
text.append(cls.truncate(httpql_filter, 300), style="dim")
text.append(httpql_filter, style="dim")
else:
text.append("\n ")
text.append("All requests", style="dim")
@@ -72,17 +68,15 @@ class ViewRequestRenderer(BaseToolRenderer):
if "content" in result:
content = result["content"]
text.append("\n ")
text.append(cls.truncate(content, 500), style="dim")
text.append(content, style="dim")
elif "matches" in result:
matches = result["matches"]
if isinstance(matches, list) and matches:
for match in matches[:3]:
for match in matches:
if isinstance(match, dict) and "match" in match:
text.append("\n ")
text.append(match["match"], style="dim")
if len(matches) > 3:
text.append("\n ")
text.append(f"... +{len(matches) - 3} more matches", style="dim")
else:
text.append("\n ")
text.append("No matches found", style="dim")
@@ -123,13 +117,13 @@ class SendRequestRenderer(BaseToolRenderer):
text.append(f"Status: {status_code}", style="dim")
if response_body:
text.append("\n ")
text.append(cls.truncate(response_body, 300), style="dim")
text.append(response_body, style="dim")
else:
text.append("\n ")
text.append("Response received", style="dim")
elif url:
text.append("\n ")
text.append(cls.truncate(url, 400), style="dim")
text.append(url, style="dim")
else:
text.append("\n ")
text.append("Sending...", style="dim")
@@ -163,13 +157,13 @@ class RepeatRequestRenderer(BaseToolRenderer):
text.append(f"Status: {status_code}", style="dim")
if response_body:
text.append("\n ")
text.append(cls.truncate(response_body, 300), style="dim")
text.append(response_body, style="dim")
else:
text.append("\n ")
text.append("Response received", style="dim")
elif modifications:
text.append("\n ")
text.append(cls.truncate(str(modifications), 400), style="dim")
text.append(str(modifications), style="dim")
else:
text.append("\n ")
text.append("No modifications", style="dim")
@@ -211,16 +205,13 @@ class ListSitemapRenderer(BaseToolRenderer):
if result and isinstance(result, dict) and "entries" in result:
entries = result["entries"]
if isinstance(entries, list) and entries:
for entry in entries[:4]:
for entry in entries:
if isinstance(entry, dict):
label = entry.get("label", "?")
kind = entry.get("kind", "?")
text.append("\n ")
text.append(f"{kind}: {label}", style="dim")
if len(entries) > 4:
text.append("\n ")
text.append(f"... +{len(entries) - 4} more", style="dim")
else:
text.append("\n ")
text.append("No entries found", style="dim")

View File

@@ -56,8 +56,7 @@ class PythonRenderer(BaseToolRenderer):
text.append("\n")
if code and action in ["new_session", "execute"]:
code_display = cls.truncate(code, 2000)
text.append_text(cls._highlight_python(code_display))
text.append_text(cls._highlight_python(code))
elif action == "close":
text.append(" ")
text.append("Closing session...", style="dim")

View File

@@ -59,10 +59,8 @@ def _render_default_tool_widget(tool_data: dict[str, Any]) -> Static:
text.append(tool_name, style="bold blue")
text.append("\n")
for k, v in list(args.items())[:2]:
for k, v in list(args.items()):
str_v = str(v)
if len(str_v) > 80:
str_v = str_v[:77] + "..."
text.append(" ")
text.append(k, style="dim")
text.append(": ")
@@ -71,8 +69,6 @@ def _render_default_tool_widget(tool_data: dict[str, Any]) -> Static:
if status in ["completed", "failed", "error"] and result is not None:
result_str = str(result)
if len(result_str) > 150:
result_str = result_str[:147] + "..."
text.append("Result: ", style="bold")
text.append(result_str)
else:

View File

@@ -154,6 +154,4 @@ class TerminalRenderer(BaseToolRenderer):
@classmethod
def _format_command(cls, command: str) -> Text:
if len(command) > 2000:
command = command[:2000] + "..."
return cls._highlight_bash(command)

View File

@@ -23,8 +23,7 @@ class ThinkRenderer(BaseToolRenderer):
text.append("\n ")
if thought:
thought_display = cls.truncate(thought, 600)
text.append(thought_display, style="italic dim")
text.append(thought, style="italic dim")
else:
text.append("Thinking...", style="italic dim")

View File

@@ -14,29 +14,18 @@ STATUS_MARKERS: dict[str, str] = {
}
def _format_todo_lines(text: Text, result: dict[str, Any], limit: int = 25) -> None:
def _format_todo_lines(text: Text, result: dict[str, Any]) -> None:
todos = result.get("todos")
if not isinstance(todos, list) or not todos:
text.append("\n ")
text.append("No todos", style="dim")
return
total = len(todos)
for index, todo in enumerate(todos):
if index >= limit:
remaining = total - limit
if remaining > 0:
text.append("\n ")
text.append(f"... +{remaining} more", style="dim")
break
for todo in todos:
status = todo.get("status", "pending")
marker = STATUS_MARKERS.get(status, STATUS_MARKERS["pending"])
title = todo.get("title", "").strip() or "(untitled)"
if len(title) > 90:
title = title[:87] + "..."
text.append("\n ")
text.append(marker)

View File

@@ -34,9 +34,6 @@ class UserMessageRenderer(BaseToolRenderer):
def _format_user_message(cls, content: str) -> Text:
text = Text()
if len(content) > 300:
content = content[:297] + "..."
text.append("", style="#3b82f6")
text.append(" ")
text.append("You:", style="bold")

View File

@@ -23,7 +23,7 @@ class WebSearchRenderer(BaseToolRenderer):
if query:
text.append("\n ")
text.append(cls.truncate(query, 100), style="dim")
text.append(query, style="dim")
css_classes = cls.get_css_classes("completed")
return Static(text, classes=css_classes)

View File

@@ -491,7 +491,7 @@ class StrixTUIApp(App): # type: ignore[misc]
self._start_scan_thread()
self.set_interval(0.5, self._update_ui_from_tracer)
self.set_interval(0.1, self._update_ui_from_tracer)
def _update_ui_from_tracer(self) -> None:
if self.show_splash:
@@ -596,13 +596,14 @@ class StrixTUIApp(App): # type: ignore[misc]
)
else:
events = self._gather_agent_events(self.selected_agent_id)
if not events:
streaming = self.tracer.get_streaming_content(self.selected_agent_id)
if not events and not streaming:
content, css_class = self._get_chat_placeholder_content(
"Starting agent...", "placeholder-no-activity"
)
else:
current_event_ids = [e["id"] for e in events]
if current_event_ids == self._displayed_events:
if current_event_ids == self._displayed_events and not streaming:
return
content = self._get_rendered_events_content(events)
css_class = "chat-content"
@@ -644,8 +645,92 @@ class StrixTUIApp(App): # type: ignore[misc]
result.append_text(content)
first = False
if self.selected_agent_id:
streaming = self.tracer.get_streaming_content(self.selected_agent_id)
if streaming:
streaming_text = self._render_streaming_content(streaming)
if streaming_text:
if not first:
result.append("\n\n")
result.append_text(streaming_text)
return result
def _render_streaming_content(self, content: str) -> Text:
from strix.interface.streaming_parser import parse_streaming_content
result = Text()
segments = parse_streaming_content(content)
for i, segment in enumerate(segments):
if i > 0:
result.append("\n\n")
if segment.type == "text":
from strix.interface.tool_components.agent_message_renderer import (
AgentMessageRenderer,
)
text_content = AgentMessageRenderer.render_simple(segment.content)
result.append_text(text_content)
elif segment.type == "tool":
tool_text = self._render_streaming_tool(
segment.tool_name or "unknown",
segment.args or {},
segment.is_complete,
)
result.append_text(tool_text)
return result
def _render_streaming_tool(
self, tool_name: str, args: dict[str, str], is_complete: bool
) -> Text:
from strix.interface.tool_components.registry import get_tool_renderer
tool_data = {
"tool_name": tool_name,
"args": args,
"status": "completed" if is_complete else "running",
"result": None,
}
renderer = get_tool_renderer(tool_name)
if renderer:
widget = renderer.render(tool_data)
renderable = widget.renderable
if isinstance(renderable, Text):
return renderable
text = Text()
text.append(str(renderable))
return text
return self._render_default_streaming_tool(tool_name, args, is_complete)
def _render_default_streaming_tool(
self, tool_name: str, args: dict[str, str], is_complete: bool
) -> Text:
text = Text()
if is_complete:
text.append("", style="green")
else:
text.append("", style="yellow")
text.append("Using tool ", style="dim")
text.append(tool_name, style="bold blue")
if args:
for key, value in list(args.items())[:3]:
text.append("\n ")
text.append(key, style="dim")
text.append(": ")
display_value = value if len(value) <= 100 else value[:97] + "..."
text.append(display_value, style="italic" if not is_complete else None)
return text
def _get_status_display_content(
self, agent_id: str, agent_data: dict[str, Any]
) -> tuple[Text | None, Text, bool]:

View File

@@ -1,5 +1,6 @@
import logging
import os
from collections.abc import AsyncIterator
from dataclasses import dataclass
from enum import Enum
from fnmatch import fnmatch
@@ -12,7 +13,7 @@ from jinja2 import (
FileSystemLoader,
select_autoescape,
)
from litellm import ModelResponse, completion_cost
from litellm import completion_cost, stream_chunk_builder
from litellm.utils import supports_prompt_caching, supports_vision
from strix.llm.config import LLMConfig
@@ -276,7 +277,7 @@ class LLM:
conversation_history: list[dict[str, Any]],
scan_id: str | None = None,
step_number: int = 1,
) -> LLMResponse:
) -> AsyncIterator[LLMResponse]:
messages = [{"role": "system", "content": self.system_prompt}]
identity_message = self._build_identity_message()
@@ -292,30 +293,43 @@ class LLM:
cached_messages = self._prepare_cached_messages(messages)
try:
response = await self._make_request(cached_messages)
self._update_usage_stats(response)
accumulated_content = ""
chunks: list[Any] = []
content = ""
if (
response.choices
and hasattr(response.choices[0], "message")
and response.choices[0].message
):
content = getattr(response.choices[0].message, "content", "") or ""
async for chunk in self._stream_request(cached_messages):
chunks.append(chunk)
delta = self._extract_chunk_delta(chunk)
if delta:
accumulated_content += delta
content = _truncate_to_first_function(content)
if "</function>" in accumulated_content:
function_end = accumulated_content.find("</function>") + len("</function>")
accumulated_content = accumulated_content[:function_end]
if "</function>" in content:
function_end_index = content.find("</function>") + len("</function>")
content = content[:function_end_index]
yield LLMResponse(
scan_id=scan_id,
step_number=step_number,
role=StepRole.AGENT,
content=accumulated_content,
tool_invocations=None,
)
tool_invocations = parse_tool_invocations(content)
if chunks:
complete_response = stream_chunk_builder(chunks)
self._update_usage_stats(complete_response)
return LLMResponse(
accumulated_content = _truncate_to_first_function(accumulated_content)
if "</function>" in accumulated_content:
function_end = accumulated_content.find("</function>") + len("</function>")
accumulated_content = accumulated_content[:function_end]
tool_invocations = parse_tool_invocations(accumulated_content)
yield LLMResponse(
scan_id=scan_id,
step_number=step_number,
role=StepRole.AGENT,
content=content,
content=accumulated_content,
tool_invocations=tool_invocations if tool_invocations else None,
)
@@ -364,6 +378,12 @@ class LLM:
except Exception as e:
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
def _extract_chunk_delta(self, chunk: Any) -> str:
if chunk.choices and hasattr(chunk.choices[0], "delta"):
delta = chunk.choices[0].delta
return getattr(delta, "content", "") or ""
return ""
@property
def usage_stats(self) -> dict[str, dict[str, int | float]]:
return {
@@ -436,10 +456,10 @@ class LLM:
filtered_messages.append(updated_msg)
return filtered_messages
async def _make_request(
async def _stream_request(
self,
messages: list[dict[str, Any]],
) -> ModelResponse:
) -> AsyncIterator[Any]:
if not self._model_supports_vision():
messages = self._filter_images_from_messages(messages)
@@ -447,6 +467,7 @@ class LLM:
"model": self.config.model_name,
"messages": messages,
"timeout": self.config.timeout,
"stream_options": {"include_usage": True},
}
if _LLM_API_KEY:
@@ -461,14 +482,13 @@ class LLM:
completion_args["reasoning_effort"] = "high"
queue = get_global_queue()
response = await queue.make_request(completion_args)
self._total_stats.requests += 1
self._last_request_stats = RequestStats(requests=1)
return response
async for chunk in queue.stream_request(completion_args):
yield chunk
def _update_usage_stats(self, response: ModelResponse) -> None:
def _update_usage_stats(self, response: Any) -> None:
try:
if hasattr(response, "usage") and response.usage:
input_tokens = getattr(response.usage, "prompt_tokens", 0)

View File

@@ -3,10 +3,12 @@ import logging
import os
import threading
import time
from collections.abc import AsyncIterator
from typing import Any
import litellm
from litellm import ModelResponse, completion
from litellm import completion
from litellm.types.utils import ModelResponseStream
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
@@ -42,7 +44,9 @@ class LLMRequestQueue:
self._last_request_time = 0.0
self._lock = threading.Lock()
async def make_request(self, completion_args: dict[str, Any]) -> ModelResponse:
async def stream_request(
self, completion_args: dict[str, Any]
) -> AsyncIterator[ModelResponseStream]:
try:
while not self._semaphore.acquire(timeout=0.2):
await asyncio.sleep(0.1)
@@ -56,7 +60,8 @@ class LLMRequestQueue:
if sleep_needed > 0:
await asyncio.sleep(sleep_needed)
return await self._reliable_request(completion_args)
async for chunk in self._reliable_stream_request(completion_args):
yield chunk
finally:
self._semaphore.release()
@@ -66,15 +71,12 @@ class LLMRequestQueue:
retry=retry_if_exception(should_retry_exception),
reraise=True,
)
async def _reliable_request(self, completion_args: dict[str, Any]) -> ModelResponse:
response = completion(**completion_args, stream=False)
if isinstance(response, ModelResponse):
return response
self._raise_unexpected_response()
raise RuntimeError("Unreachable code")
def _raise_unexpected_response(self) -> None:
raise RuntimeError("Unexpected response type")
async def _reliable_stream_request(
self, completion_args: dict[str, Any]
) -> AsyncIterator[ModelResponseStream]:
response = await asyncio.to_thread(completion, **completion_args, stream=True)
for chunk in response:
yield chunk
_global_queue: LLMRequestQueue | None = None

View File

@@ -47,10 +47,14 @@ def parse_tool_invocations(content: str) -> list[dict[str, Any]] | None:
def _fix_stopword(content: str) -> str:
if "<function=" in content and content.count("<function=") == 1:
if (
"<function=" in content
and content.count("<function=") == 1
and "</function>" not in content
):
if content.endswith("</"):
content = content.rstrip() + "function>"
elif not content.rstrip().endswith("</function>"):
else:
content = content + "\n</function>"
return content
@@ -75,6 +79,12 @@ def clean_content(content: str) -> str:
tool_pattern = r"<function=[^>]+>.*?</function>"
cleaned = re.sub(tool_pattern, "", content, flags=re.DOTALL)
incomplete_tool_pattern = r"<function=[^>]+>.*$"
cleaned = re.sub(incomplete_tool_pattern, "", cleaned, flags=re.DOTALL)
partial_tag_pattern = r"<f(?:u(?:n(?:c(?:t(?:i(?:o(?:n(?:=(?:[^>]*)?)?)?)?)?)?)?)?)?$"
cleaned = re.sub(partial_tag_pattern, "", cleaned)
hidden_xml_patterns = [
r"<inter_agent_message>.*?</inter_agent_message>",
r"<agent_completion_report>.*?</agent_completion_report>",

View File

@@ -33,6 +33,7 @@ class Tracer:
self.agents: dict[str, dict[str, Any]] = {}
self.tool_executions: dict[int, dict[str, Any]] = {}
self.chat_messages: list[dict[str, Any]] = []
self.streaming_content: dict[str, str] = {}
self.vulnerability_reports: list[dict[str, Any]] = []
self.final_scan_result: str | None = None
@@ -333,5 +334,14 @@ class Tracer:
"total_tokens": total_stats["input_tokens"] + total_stats["output_tokens"],
}
def update_streaming_content(self, agent_id: str, content: str) -> None:
self.streaming_content[agent_id] = content
def clear_streaming_content(self, agent_id: str) -> None:
self.streaming_content.pop(agent_id, None)
def get_streaming_content(self, agent_id: str) -> str | None:
return self.streaming_content.get(agent_id)
def cleanup(self) -> None:
self.save_run_data(mark_complete=True)