Open-source release for Alpha version
This commit is contained in:
310
strix/llm/llm.py
Normal file
310
strix/llm/llm.py
Normal file
@@ -0,0 +1,310 @@
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
from jinja2 import (
|
||||
Environment,
|
||||
FileSystemLoader,
|
||||
select_autoescape,
|
||||
)
|
||||
from litellm import ModelResponse, completion_cost
|
||||
from litellm.utils import supports_prompt_caching
|
||||
|
||||
from strix.llm.config import LLMConfig
|
||||
from strix.llm.memory_compressor import MemoryCompressor
|
||||
from strix.llm.request_queue import get_global_queue
|
||||
from strix.llm.utils import _truncate_to_first_function, parse_tool_invocations
|
||||
from strix.prompts import load_prompt_modules
|
||||
from strix.tools import get_tools_prompt
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
api_key = os.getenv("LLM_API_KEY")
|
||||
if api_key:
|
||||
litellm.api_key = api_key
|
||||
|
||||
|
||||
class StepRole(str, Enum):
|
||||
AGENT = "agent"
|
||||
USER = "user"
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
content: str
|
||||
tool_invocations: list[dict[str, Any]] | None = None
|
||||
scan_id: str | None = None
|
||||
step_number: int = 1
|
||||
role: StepRole = StepRole.AGENT
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestStats:
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cached_tokens: int = 0
|
||||
cache_creation_tokens: int = 0
|
||||
cost: float = 0.0
|
||||
requests: int = 0
|
||||
failed_requests: int = 0
|
||||
|
||||
def to_dict(self) -> dict[str, int | float]:
|
||||
return {
|
||||
"input_tokens": self.input_tokens,
|
||||
"output_tokens": self.output_tokens,
|
||||
"cached_tokens": self.cached_tokens,
|
||||
"cache_creation_tokens": self.cache_creation_tokens,
|
||||
"cost": round(self.cost, 4),
|
||||
"requests": self.requests,
|
||||
"failed_requests": self.failed_requests,
|
||||
}
|
||||
|
||||
|
||||
class LLM:
|
||||
def __init__(self, config: LLMConfig, agent_name: str | None = None):
|
||||
self.config = config
|
||||
self.agent_name = agent_name
|
||||
self._total_stats = RequestStats()
|
||||
self._last_request_stats = RequestStats()
|
||||
|
||||
self.memory_compressor = MemoryCompressor()
|
||||
|
||||
if agent_name:
|
||||
prompt_dir = Path(__file__).parent.parent / "agents" / agent_name
|
||||
prompts_dir = Path(__file__).parent.parent / "prompts"
|
||||
|
||||
loader = FileSystemLoader([prompt_dir, prompts_dir])
|
||||
self.jinja_env = Environment(
|
||||
loader=loader,
|
||||
autoescape=select_autoescape(enabled_extensions=(), default_for_string=False),
|
||||
)
|
||||
|
||||
try:
|
||||
prompt_module_content = load_prompt_modules(
|
||||
self.config.prompt_modules or [], self.jinja_env
|
||||
)
|
||||
|
||||
def get_module(name: str) -> str:
|
||||
return prompt_module_content.get(name, "")
|
||||
|
||||
self.jinja_env.globals["get_module"] = get_module
|
||||
|
||||
self.system_prompt = self.jinja_env.get_template("system_prompt.jinja").render(
|
||||
get_tools_prompt=get_tools_prompt,
|
||||
loaded_module_names=list(prompt_module_content.keys()),
|
||||
**prompt_module_content,
|
||||
)
|
||||
except (FileNotFoundError, OSError, ValueError) as e:
|
||||
logger.warning(f"Failed to load system prompt for {agent_name}: {e}")
|
||||
self.system_prompt = "You are a helpful AI assistant."
|
||||
else:
|
||||
self.system_prompt = "You are a helpful AI assistant."
|
||||
|
||||
def _add_cache_control_to_content(
|
||||
self, content: str | list[dict[str, Any]]
|
||||
) -> str | list[dict[str, Any]]:
|
||||
if isinstance(content, str):
|
||||
return [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}]
|
||||
if isinstance(content, list) and content:
|
||||
last_item = content[-1]
|
||||
if isinstance(last_item, dict) and last_item.get("type") == "text":
|
||||
return content[:-1] + [{**last_item, "cache_control": {"type": "ephemeral"}}]
|
||||
return content
|
||||
|
||||
def _is_anthropic_model(self) -> bool:
|
||||
if not self.config.model_name:
|
||||
return False
|
||||
model_lower = self.config.model_name.lower()
|
||||
return any(provider in model_lower for provider in ["anthropic/", "claude"])
|
||||
|
||||
def _calculate_cache_interval(self, total_messages: int) -> int:
|
||||
if total_messages <= 1:
|
||||
return 10
|
||||
|
||||
max_cached_messages = 3
|
||||
non_system_messages = total_messages - 1
|
||||
|
||||
interval = 10
|
||||
while non_system_messages // interval > max_cached_messages:
|
||||
interval += 10
|
||||
|
||||
return interval
|
||||
|
||||
def _prepare_cached_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
if (
|
||||
not self.config.enable_prompt_caching
|
||||
or not supports_prompt_caching(self.config.model_name)
|
||||
or not messages
|
||||
):
|
||||
return messages
|
||||
|
||||
if not self._is_anthropic_model():
|
||||
return messages
|
||||
|
||||
cached_messages = list(messages)
|
||||
|
||||
if cached_messages and cached_messages[0].get("role") == "system":
|
||||
system_message = cached_messages[0].copy()
|
||||
system_message["content"] = self._add_cache_control_to_content(
|
||||
system_message["content"]
|
||||
)
|
||||
cached_messages[0] = system_message
|
||||
|
||||
total_messages = len(cached_messages)
|
||||
if total_messages > 1:
|
||||
interval = self._calculate_cache_interval(total_messages)
|
||||
|
||||
cached_count = 0
|
||||
for i in range(interval, total_messages, interval):
|
||||
if cached_count >= 3:
|
||||
break
|
||||
|
||||
if i < len(cached_messages):
|
||||
message = cached_messages[i].copy()
|
||||
message["content"] = self._add_cache_control_to_content(message["content"])
|
||||
cached_messages[i] = message
|
||||
cached_count += 1
|
||||
|
||||
return cached_messages
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
conversation_history: list[dict[str, Any]],
|
||||
scan_id: str | None = None,
|
||||
step_number: int = 1,
|
||||
) -> LLMResponse:
|
||||
messages = [{"role": "system", "content": self.system_prompt}]
|
||||
|
||||
compressed_history = list(self.memory_compressor.compress_history(conversation_history))
|
||||
|
||||
conversation_history.clear()
|
||||
conversation_history.extend(compressed_history)
|
||||
messages.extend(compressed_history)
|
||||
|
||||
cached_messages = self._prepare_cached_messages(messages)
|
||||
|
||||
try:
|
||||
response = await self._make_request(cached_messages)
|
||||
self._update_usage_stats(response)
|
||||
|
||||
content = ""
|
||||
if (
|
||||
response.choices
|
||||
and hasattr(response.choices[0], "message")
|
||||
and response.choices[0].message
|
||||
):
|
||||
content = getattr(response.choices[0].message, "content", "") or ""
|
||||
|
||||
content = _truncate_to_first_function(content)
|
||||
|
||||
if "</function>" in content:
|
||||
function_end_index = content.find("</function>") + len("</function>")
|
||||
content = content[:function_end_index]
|
||||
|
||||
tool_invocations = parse_tool_invocations(content)
|
||||
|
||||
return LLMResponse(
|
||||
scan_id=scan_id,
|
||||
step_number=step_number,
|
||||
role=StepRole.AGENT,
|
||||
content=content,
|
||||
tool_invocations=tool_invocations if tool_invocations else None,
|
||||
)
|
||||
|
||||
except (ValueError, TypeError, RuntimeError):
|
||||
logger.exception("Error in LLM generation")
|
||||
return LLMResponse(
|
||||
scan_id=scan_id,
|
||||
step_number=step_number,
|
||||
role=StepRole.AGENT,
|
||||
content="An error occurred while generating the response",
|
||||
tool_invocations=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def usage_stats(self) -> dict[str, dict[str, int | float]]:
|
||||
return {
|
||||
"total": self._total_stats.to_dict(),
|
||||
"last_request": self._last_request_stats.to_dict(),
|
||||
}
|
||||
|
||||
def get_cache_config(self) -> dict[str, bool]:
|
||||
return {
|
||||
"enabled": self.config.enable_prompt_caching,
|
||||
"supported": supports_prompt_caching(self.config.model_name),
|
||||
}
|
||||
|
||||
async def _make_request(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
) -> ModelResponse:
|
||||
completion_args = {
|
||||
"model": self.config.model_name,
|
||||
"messages": messages,
|
||||
"temperature": self.config.temperature,
|
||||
"stop": ["</function>"],
|
||||
}
|
||||
|
||||
queue = get_global_queue()
|
||||
response = await queue.make_request(completion_args)
|
||||
|
||||
self._total_stats.requests += 1
|
||||
self._last_request_stats = RequestStats(requests=1)
|
||||
|
||||
return response
|
||||
|
||||
def _update_usage_stats(self, response: ModelResponse) -> None:
|
||||
try:
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
input_tokens = getattr(response.usage, "prompt_tokens", 0)
|
||||
output_tokens = getattr(response.usage, "completion_tokens", 0)
|
||||
|
||||
cached_tokens = 0
|
||||
cache_creation_tokens = 0
|
||||
|
||||
if hasattr(response.usage, "prompt_tokens_details"):
|
||||
prompt_details = response.usage.prompt_tokens_details
|
||||
if hasattr(prompt_details, "cached_tokens"):
|
||||
cached_tokens = prompt_details.cached_tokens or 0
|
||||
|
||||
if hasattr(response.usage, "cache_creation_input_tokens"):
|
||||
cache_creation_tokens = response.usage.cache_creation_input_tokens or 0
|
||||
|
||||
else:
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
cached_tokens = 0
|
||||
cache_creation_tokens = 0
|
||||
|
||||
try:
|
||||
cost = completion_cost(response) or 0.0
|
||||
except (ValueError, TypeError, RuntimeError) as e:
|
||||
logger.warning(f"Failed to calculate cost: {e}")
|
||||
cost = 0.0
|
||||
|
||||
self._total_stats.input_tokens += input_tokens
|
||||
self._total_stats.output_tokens += output_tokens
|
||||
self._total_stats.cached_tokens += cached_tokens
|
||||
self._total_stats.cache_creation_tokens += cache_creation_tokens
|
||||
self._total_stats.cost += cost
|
||||
|
||||
self._last_request_stats.input_tokens = input_tokens
|
||||
self._last_request_stats.output_tokens = output_tokens
|
||||
self._last_request_stats.cached_tokens = cached_tokens
|
||||
self._last_request_stats.cache_creation_tokens = cache_creation_tokens
|
||||
self._last_request_stats.cost = cost
|
||||
|
||||
if cached_tokens > 0:
|
||||
logger.info(f"Cache hit: {cached_tokens} cached tokens, {input_tokens} new tokens")
|
||||
if cache_creation_tokens > 0:
|
||||
logger.info(f"Cache creation: {cache_creation_tokens} tokens written to cache")
|
||||
|
||||
logger.info(f"Usage stats: {self.usage_stats}")
|
||||
except (AttributeError, TypeError, ValueError) as e:
|
||||
logger.warning(f"Failed to update usage stats: {e}")
|
||||
Reference in New Issue
Block a user