Files
strix/strix/llm/memory_compressor.py
2025-08-08 20:36:44 -07:00

207 lines
6.7 KiB
Python

import logging
import os
from typing import Any
import litellm
logger = logging.getLogger(__name__)
MAX_TOTAL_TOKENS = 100_000
MIN_RECENT_MESSAGES = 15
SUMMARY_PROMPT_TEMPLATE = """You are an agent performing context
condensation for a security agent. Your job is to compress scan data while preserving
ALL operationally critical information for continuing the security assessment.
CRITICAL ELEMENTS TO PRESERVE:
- Discovered vulnerabilities and potential attack vectors
- Scan results and tool outputs (compressed but maintaining key findings)
- Access credentials, tokens, or authentication details found
- System architecture insights and potential weak points
- Progress made in the assessment
- Failed attempts and dead ends (to avoid duplication)
- Any decisions made about the testing approach
COMPRESSION GUIDELINES:
- Preserve exact technical details (URLs, paths, parameters, payloads)
- Summarize verbose tool outputs while keeping critical findings
- Maintain version numbers, specific technologies identified
- Keep exact error messages that might indicate vulnerabilities
- Compress repetitive or similar findings into consolidated form
Remember: Another security agent will use this summary to continue the assessment.
They must be able to pick up exactly where you left off without losing any
operational advantage or context needed to find vulnerabilities.
CONVERSATION SEGMENT TO SUMMARIZE:
{conversation}
Provide a technically precise summary that preserves all operational security context while
keeping the summary concise and to the point."""
def _count_tokens(text: str, model: str) -> int:
try:
count = litellm.token_counter(model=model, text=text)
return int(count)
except Exception:
logger.exception("Failed to count tokens")
return len(text) // 4 # Rough estimate
def _get_message_tokens(msg: dict[str, Any], model: str) -> int:
content = msg.get("content", "")
if isinstance(content, str):
return _count_tokens(content, model)
if isinstance(content, list):
return sum(
_count_tokens(item.get("text", ""), model)
for item in content
if isinstance(item, dict) and item.get("type") == "text"
)
return 0
def _extract_message_text(msg: dict[str, Any]) -> str:
content = msg.get("content", "")
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict):
if item.get("type") == "text":
parts.append(item.get("text", ""))
elif item.get("type") == "image_url":
parts.append("[IMAGE]")
return " ".join(parts)
return str(content)
def _summarize_messages(
messages: list[dict[str, Any]],
model: str,
) -> dict[str, Any]:
if not messages:
empty_summary = "<context_summary message_count='0'>{text}</context_summary>"
return {
"role": "assistant",
"content": empty_summary.format(text="No messages to summarize"),
}
formatted = []
for msg in messages:
role = msg.get("role", "unknown")
text = _extract_message_text(msg)
formatted.append(f"{role}: {text}")
conversation = "\n".join(formatted)
prompt = SUMMARY_PROMPT_TEMPLATE.format(conversation=conversation)
try:
completion_args = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
}
response = litellm.completion(**completion_args)
summary = response.choices[0].message.content
summary_msg = "<context_summary message_count='{count}'>{text}</context_summary>"
return {
"role": "assistant",
"content": summary_msg.format(count=len(messages), text=summary),
}
except Exception:
logger.exception("Failed to summarize messages")
return messages[0]
def _handle_images(messages: list[dict[str, Any]], max_images: int) -> None:
image_count = 0
for msg in reversed(messages):
content = msg.get("content", [])
if isinstance(content, list):
for item in content:
if isinstance(item, dict) and item.get("type") == "image_url":
if image_count >= max_images:
item.update(
{
"type": "text",
"text": "[Previously attached image removed to preserve context]",
}
)
else:
image_count += 1
class MemoryCompressor:
def __init__(
self,
max_images: int = 3,
model_name: str | None = None,
):
self.max_images = max_images
self.model_name = model_name or os.getenv("STRIX_LLM", "anthropic/claude-sonnet-4-20250514")
if not self.model_name:
raise ValueError("STRIX_LLM environment variable must be set and not empty")
def compress_history(
self,
messages: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Compress conversation history to stay within token limits.
Strategy:
1. Handle image limits first
2. Keep all system messages
3. Keep minimum recent messages
4. Summarize older messages when total tokens exceed limit
The compression preserves:
- All system messages unchanged
- Most recent messages intact
- Critical security context in summaries
- Recent images for visual context
- Technical details and findings
"""
if not messages:
return messages
_handle_images(messages, self.max_images)
system_msgs = []
regular_msgs = []
for msg in messages:
if msg.get("role") == "system":
system_msgs.append(msg)
else:
regular_msgs.append(msg)
recent_msgs = regular_msgs[-MIN_RECENT_MESSAGES:]
old_msgs = regular_msgs[:-MIN_RECENT_MESSAGES]
# Type assertion since we ensure model_name is not None in __init__
model_name: str = self.model_name # type: ignore[assignment]
total_tokens = sum(
_get_message_tokens(msg, model_name) for msg in system_msgs + regular_msgs
)
if total_tokens <= MAX_TOTAL_TOKENS * 0.9:
return messages
compressed = []
chunk_size = 10
for i in range(0, len(old_msgs), chunk_size):
chunk = old_msgs[i : i + chunk_size]
summary = _summarize_messages(chunk, model_name)
if summary:
compressed.append(summary)
return system_msgs + compressed + recent_msgs