207 lines
6.7 KiB
Python
207 lines
6.7 KiB
Python
import logging
|
|
import os
|
|
from typing import Any
|
|
|
|
import litellm
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
MAX_TOTAL_TOKENS = 100_000
|
|
MIN_RECENT_MESSAGES = 15
|
|
|
|
SUMMARY_PROMPT_TEMPLATE = """You are an agent performing context
|
|
condensation for a security agent. Your job is to compress scan data while preserving
|
|
ALL operationally critical information for continuing the security assessment.
|
|
|
|
CRITICAL ELEMENTS TO PRESERVE:
|
|
- Discovered vulnerabilities and potential attack vectors
|
|
- Scan results and tool outputs (compressed but maintaining key findings)
|
|
- Access credentials, tokens, or authentication details found
|
|
- System architecture insights and potential weak points
|
|
- Progress made in the assessment
|
|
- Failed attempts and dead ends (to avoid duplication)
|
|
- Any decisions made about the testing approach
|
|
|
|
COMPRESSION GUIDELINES:
|
|
- Preserve exact technical details (URLs, paths, parameters, payloads)
|
|
- Summarize verbose tool outputs while keeping critical findings
|
|
- Maintain version numbers, specific technologies identified
|
|
- Keep exact error messages that might indicate vulnerabilities
|
|
- Compress repetitive or similar findings into consolidated form
|
|
|
|
Remember: Another security agent will use this summary to continue the assessment.
|
|
They must be able to pick up exactly where you left off without losing any
|
|
operational advantage or context needed to find vulnerabilities.
|
|
|
|
CONVERSATION SEGMENT TO SUMMARIZE:
|
|
{conversation}
|
|
|
|
Provide a technically precise summary that preserves all operational security context while
|
|
keeping the summary concise and to the point."""
|
|
|
|
|
|
def _count_tokens(text: str, model: str) -> int:
|
|
try:
|
|
count = litellm.token_counter(model=model, text=text)
|
|
return int(count)
|
|
except Exception:
|
|
logger.exception("Failed to count tokens")
|
|
return len(text) // 4 # Rough estimate
|
|
|
|
|
|
def _get_message_tokens(msg: dict[str, Any], model: str) -> int:
|
|
content = msg.get("content", "")
|
|
if isinstance(content, str):
|
|
return _count_tokens(content, model)
|
|
if isinstance(content, list):
|
|
return sum(
|
|
_count_tokens(item.get("text", ""), model)
|
|
for item in content
|
|
if isinstance(item, dict) and item.get("type") == "text"
|
|
)
|
|
return 0
|
|
|
|
|
|
def _extract_message_text(msg: dict[str, Any]) -> str:
|
|
content = msg.get("content", "")
|
|
if isinstance(content, str):
|
|
return content
|
|
|
|
if isinstance(content, list):
|
|
parts = []
|
|
for item in content:
|
|
if isinstance(item, dict):
|
|
if item.get("type") == "text":
|
|
parts.append(item.get("text", ""))
|
|
elif item.get("type") == "image_url":
|
|
parts.append("[IMAGE]")
|
|
return " ".join(parts)
|
|
|
|
return str(content)
|
|
|
|
|
|
def _summarize_messages(
|
|
messages: list[dict[str, Any]],
|
|
model: str,
|
|
) -> dict[str, Any]:
|
|
if not messages:
|
|
empty_summary = "<context_summary message_count='0'>{text}</context_summary>"
|
|
return {
|
|
"role": "assistant",
|
|
"content": empty_summary.format(text="No messages to summarize"),
|
|
}
|
|
|
|
formatted = []
|
|
for msg in messages:
|
|
role = msg.get("role", "unknown")
|
|
text = _extract_message_text(msg)
|
|
formatted.append(f"{role}: {text}")
|
|
|
|
conversation = "\n".join(formatted)
|
|
prompt = SUMMARY_PROMPT_TEMPLATE.format(conversation=conversation)
|
|
|
|
try:
|
|
completion_args = {
|
|
"model": model,
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
}
|
|
|
|
response = litellm.completion(**completion_args)
|
|
summary = response.choices[0].message.content
|
|
summary_msg = "<context_summary message_count='{count}'>{text}</context_summary>"
|
|
return {
|
|
"role": "assistant",
|
|
"content": summary_msg.format(count=len(messages), text=summary),
|
|
}
|
|
except Exception:
|
|
logger.exception("Failed to summarize messages")
|
|
return messages[0]
|
|
|
|
|
|
def _handle_images(messages: list[dict[str, Any]], max_images: int) -> None:
|
|
image_count = 0
|
|
for msg in reversed(messages):
|
|
content = msg.get("content", [])
|
|
if isinstance(content, list):
|
|
for item in content:
|
|
if isinstance(item, dict) and item.get("type") == "image_url":
|
|
if image_count >= max_images:
|
|
item.update(
|
|
{
|
|
"type": "text",
|
|
"text": "[Previously attached image removed to preserve context]",
|
|
}
|
|
)
|
|
else:
|
|
image_count += 1
|
|
|
|
|
|
class MemoryCompressor:
|
|
def __init__(
|
|
self,
|
|
max_images: int = 3,
|
|
model_name: str | None = None,
|
|
):
|
|
self.max_images = max_images
|
|
self.model_name = model_name or os.getenv("STRIX_LLM", "anthropic/claude-sonnet-4-20250514")
|
|
|
|
if not self.model_name:
|
|
raise ValueError("STRIX_LLM environment variable must be set and not empty")
|
|
|
|
def compress_history(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
) -> list[dict[str, Any]]:
|
|
"""Compress conversation history to stay within token limits.
|
|
|
|
Strategy:
|
|
1. Handle image limits first
|
|
2. Keep all system messages
|
|
3. Keep minimum recent messages
|
|
4. Summarize older messages when total tokens exceed limit
|
|
|
|
The compression preserves:
|
|
- All system messages unchanged
|
|
- Most recent messages intact
|
|
- Critical security context in summaries
|
|
- Recent images for visual context
|
|
- Technical details and findings
|
|
"""
|
|
if not messages:
|
|
return messages
|
|
|
|
_handle_images(messages, self.max_images)
|
|
|
|
system_msgs = []
|
|
regular_msgs = []
|
|
for msg in messages:
|
|
if msg.get("role") == "system":
|
|
system_msgs.append(msg)
|
|
else:
|
|
regular_msgs.append(msg)
|
|
|
|
recent_msgs = regular_msgs[-MIN_RECENT_MESSAGES:]
|
|
old_msgs = regular_msgs[:-MIN_RECENT_MESSAGES]
|
|
|
|
# Type assertion since we ensure model_name is not None in __init__
|
|
model_name: str = self.model_name # type: ignore[assignment]
|
|
|
|
total_tokens = sum(
|
|
_get_message_tokens(msg, model_name) for msg in system_msgs + regular_msgs
|
|
)
|
|
|
|
if total_tokens <= MAX_TOTAL_TOKENS * 0.9:
|
|
return messages
|
|
|
|
compressed = []
|
|
chunk_size = 10
|
|
for i in range(0, len(old_msgs), chunk_size):
|
|
chunk = old_msgs[i : i + chunk_size]
|
|
summary = _summarize_messages(chunk, model_name)
|
|
if summary:
|
|
compressed.append(summary)
|
|
|
|
return system_msgs + compressed + recent_msgs
|