diff --git a/strix/interface/tool_components/python_renderer.py b/strix/interface/tool_components/python_renderer.py index c61f3df..e784989 100644 --- a/strix/interface/tool_components/python_renderer.py +++ b/strix/interface/tool_components/python_renderer.py @@ -14,6 +14,8 @@ from .registry import register_tool_renderer MAX_OUTPUT_LINES = 50 MAX_LINE_LENGTH = 200 +ANSI_PATTERN = re.compile(r"\x1b(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~]|\][^\x07]*\x07)") + STRIP_PATTERNS = [ r"\.\.\. \[(stdout|stderr|result|output|error) truncated at \d+k? chars\]", ] @@ -25,31 +27,32 @@ def _get_style_colors() -> dict[Any, str]: return {token: f"#{style_def['color']}" for token, style_def in style if style_def["color"]} +@cache +def _get_lexer() -> PythonLexer: + return PythonLexer() + + +@cache +def _get_token_color(token_type: Any) -> str | None: + colors = _get_style_colors() + while token_type: + if token_type in colors: + return colors[token_type] + token_type = token_type.parent + return None + + @register_tool_renderer class PythonRenderer(BaseToolRenderer): tool_name: ClassVar[str] = "python_action" css_classes: ClassVar[list[str]] = ["tool-call", "python-tool"] - @classmethod - def _get_token_color(cls, token_type: Any) -> str | None: - colors = _get_style_colors() - while token_type: - if token_type in colors: - return colors[token_type] - token_type = token_type.parent - return None - @classmethod def _highlight_python(cls, code: str) -> Text: - lexer = PythonLexer() text = Text() - - for token_type, token_value in lexer.get_tokens(code): - if not token_value: - continue - color = cls._get_token_color(token_type) - text.append(token_value, style=color) - + for token_type, token_value in _get_lexer().get_tokens(code): + if token_value: + text.append(token_value, style=_get_token_color(token_type)) return text @classmethod @@ -59,11 +62,16 @@ class PythonRenderer(BaseToolRenderer): cleaned = re.sub(pattern, "", cleaned) return cleaned.strip() + @classmethod + def _strip_ansi(cls, text: str) -> str: + return ANSI_PATTERN.sub("", text) + @classmethod def _truncate_line(cls, line: str) -> str: - if len(line) > MAX_LINE_LENGTH: - return line[: MAX_LINE_LENGTH - 3] + "..." - return line + clean_line = cls._strip_ansi(line) + if len(clean_line) > MAX_LINE_LENGTH: + return clean_line[: MAX_LINE_LENGTH - 3] + "..." + return clean_line @classmethod def _format_output(cls, output: str) -> Text: