diff --git a/strix.spec b/strix.spec
index 03dbf86..fbe2d90 100644
--- a/strix.spec
+++ b/strix.spec
@@ -111,7 +111,6 @@ hiddenimports = [
'strix.llm.llm',
'strix.llm.config',
'strix.llm.utils',
- 'strix.llm.request_queue',
'strix.llm.memory_compressor',
'strix.runtime',
'strix.runtime.runtime',
diff --git a/strix/agents/base_agent.py b/strix/agents/base_agent.py
index 8905197..f955892 100644
--- a/strix/agents/base_agent.py
+++ b/strix/agents/base_agent.py
@@ -79,6 +79,7 @@ class BaseAgent(metaclass=AgentMeta):
with contextlib.suppress(Exception):
self.llm.set_agent_identity(self.state.agent_name, self.state.agent_id)
self._current_task: asyncio.Task[Any] | None = None
+ self._force_stop = False
from strix.telemetry.tracer import get_global_tracer
@@ -156,6 +157,11 @@ class BaseAgent(metaclass=AgentMeta):
return self._handle_sandbox_error(e, tracer)
while True:
+ if self._force_stop:
+ self._force_stop = False
+ await self._enter_waiting_state(tracer, was_cancelled=True)
+ continue
+
self._check_agent_messages(self.state)
if self.state.is_waiting_for_input():
@@ -246,7 +252,8 @@ class BaseAgent(metaclass=AgentMeta):
continue
async def _wait_for_input(self) -> None:
- import asyncio
+ if self._force_stop:
+ return
if self.state.has_waiting_timeout():
self.state.resume_from_waiting()
@@ -339,6 +346,7 @@ class BaseAgent(metaclass=AgentMeta):
async def _process_iteration(self, tracer: Optional["Tracer"]) -> bool:
final_response = None
+
async for response in self.llm.generate(self.state.get_conversation_history()):
final_response = response
if tracer and response.content:
@@ -584,6 +592,11 @@ class BaseAgent(metaclass=AgentMeta):
return True
def cancel_current_execution(self) -> None:
+ self._force_stop = True
if self._current_task and not self._current_task.done():
- self._current_task.cancel()
+ try:
+ loop = self._current_task.get_loop()
+ loop.call_soon_threadsafe(self._current_task.cancel)
+ except RuntimeError:
+ self._current_task.cancel()
self._current_task = None
diff --git a/strix/config/config.py b/strix/config/config.py
index 0e0d16f..a602658 100644
--- a/strix/config/config.py
+++ b/strix/config/config.py
@@ -16,9 +16,9 @@ class Config:
litellm_base_url = None
ollama_api_base = None
strix_reasoning_effort = "high"
+ strix_llm_max_retries = "5"
+ strix_memory_compressor_timeout = "30"
llm_timeout = "300"
- llm_rate_limit_delay = "4.0"
- llm_rate_limit_concurrent = "1"
# Tool & Feature Configuration
perplexity_api_key = None
@@ -27,7 +27,7 @@ class Config:
# Runtime Configuration
strix_image = "ghcr.io/usestrix/strix-sandbox:0.1.10"
strix_runtime_backend = "docker"
- strix_sandbox_execution_timeout = "500"
+ strix_sandbox_execution_timeout = "120"
strix_sandbox_connect_timeout = "10"
# Telemetry
diff --git a/strix/llm/llm.py b/strix/llm/llm.py
index 34132a8..f4e831e 100644
--- a/strix/llm/llm.py
+++ b/strix/llm/llm.py
@@ -1,23 +1,16 @@
import asyncio
-import logging
from collections.abc import AsyncIterator
from dataclasses import dataclass
-from enum import Enum
from typing import Any
import litellm
-from jinja2 import (
- Environment,
- FileSystemLoader,
- select_autoescape,
-)
-from litellm import completion_cost, stream_chunk_builder, supports_reasoning
+from jinja2 import Environment, FileSystemLoader, select_autoescape
+from litellm import acompletion, completion_cost, stream_chunk_builder, supports_reasoning
from litellm.utils import supports_prompt_caching, supports_vision
from strix.config import Config
from strix.llm.config import LLMConfig
from strix.llm.memory_compressor import MemoryCompressor
-from strix.llm.request_queue import get_global_queue
from strix.llm.utils import (
_truncate_to_first_function,
fix_incomplete_tool_call,
@@ -28,37 +21,9 @@ from strix.tools import get_tools_prompt
from strix.utils.resource_paths import get_strix_resource_path
-MAX_RETRIES = 5
-RETRY_MULTIPLIER = 8
-RETRY_MIN = 8
-RETRY_MAX = 64
-
-
-def _should_retry(exception: Exception) -> bool:
- status_code = None
- if hasattr(exception, "status_code"):
- status_code = exception.status_code
- elif hasattr(exception, "response") and hasattr(exception.response, "status_code"):
- status_code = exception.response.status_code
- if status_code is not None:
- return bool(litellm._should_retry(status_code))
- return True
-
-
-logger = logging.getLogger(__name__)
-
litellm.drop_params = True
litellm.modify_params = True
-_LLM_API_KEY = Config.get("llm_api_key")
-_LLM_API_BASE = (
- Config.get("llm_api_base")
- or Config.get("openai_api_base")
- or Config.get("litellm_base_url")
- or Config.get("ollama_api_base")
-)
-_STRIX_REASONING_EFFORT = Config.get("strix_reasoning_effort")
-
class LLMRequestFailedError(Exception):
def __init__(self, message: str, details: str | None = None):
@@ -67,20 +32,11 @@ class LLMRequestFailedError(Exception):
self.details = details
-class StepRole(str, Enum):
- AGENT = "agent"
- USER = "user"
- SYSTEM = "system"
-
-
@dataclass
class LLMResponse:
content: str
tool_invocations: list[dict[str, Any]] | None = None
- scan_id: str | None = None
- step_number: int = 1
- role: StepRole = StepRole.AGENT
- thinking_blocks: list[dict[str, Any]] | None = None # For reasoning models.
+ thinking_blocks: list[dict[str, Any]] | None = None
@dataclass
@@ -88,76 +44,63 @@ class RequestStats:
input_tokens: int = 0
output_tokens: int = 0
cached_tokens: int = 0
- cache_creation_tokens: int = 0
cost: float = 0.0
requests: int = 0
- failed_requests: int = 0
def to_dict(self) -> dict[str, int | float]:
return {
"input_tokens": self.input_tokens,
"output_tokens": self.output_tokens,
"cached_tokens": self.cached_tokens,
- "cache_creation_tokens": self.cache_creation_tokens,
"cost": round(self.cost, 4),
"requests": self.requests,
- "failed_requests": self.failed_requests,
}
class LLM:
- def __init__(
- self, config: LLMConfig, agent_name: str | None = None, agent_id: str | None = None
- ):
+ def __init__(self, config: LLMConfig, agent_name: str | None = None):
self.config = config
self.agent_name = agent_name
- self.agent_id = agent_id
+ self.agent_id: str | None = None
self._total_stats = RequestStats()
- self._last_request_stats = RequestStats()
+ self.memory_compressor = MemoryCompressor(model_name=config.model_name)
+ self.system_prompt = self._load_system_prompt(agent_name)
- if _STRIX_REASONING_EFFORT:
- self._reasoning_effort = _STRIX_REASONING_EFFORT
- elif self.config.scan_mode == "quick":
+ reasoning = Config.get("strix_reasoning_effort")
+ if reasoning:
+ self._reasoning_effort = reasoning
+ elif config.scan_mode == "quick":
self._reasoning_effort = "medium"
else:
self._reasoning_effort = "high"
- self.memory_compressor = MemoryCompressor(
- model_name=self.config.model_name,
- timeout=self.config.timeout,
- )
+ def _load_system_prompt(self, agent_name: str | None) -> str:
+ if not agent_name:
+ return ""
- if agent_name:
+ try:
prompt_dir = get_strix_resource_path("agents", agent_name)
skills_dir = get_strix_resource_path("skills")
-
- loader = FileSystemLoader([prompt_dir, skills_dir])
- self.jinja_env = Environment(
- loader=loader,
+ env = Environment(
+ loader=FileSystemLoader([prompt_dir, skills_dir]),
autoescape=select_autoescape(enabled_extensions=(), default_for_string=False),
)
- try:
- skills_to_load = list(self.config.skills or [])
- skills_to_load.append(f"scan_modes/{self.config.scan_mode}")
+ skills_to_load = [
+ *list(self.config.skills or []),
+ f"scan_modes/{self.config.scan_mode}",
+ ]
+ skill_content = load_skills(skills_to_load, env)
+ env.globals["get_skill"] = lambda name: skill_content.get(name, "")
- skill_content = load_skills(skills_to_load, self.jinja_env)
-
- def get_skill(name: str) -> str:
- return skill_content.get(name, "")
-
- self.jinja_env.globals["get_skill"] = get_skill
-
- self.system_prompt = self.jinja_env.get_template("system_prompt.jinja").render(
- get_tools_prompt=get_tools_prompt,
- loaded_skill_names=list(skill_content.keys()),
- **skill_content,
- )
- except (FileNotFoundError, OSError, ValueError) as e:
- logger.warning(f"Failed to load system prompt for {agent_name}: {e}")
- self.system_prompt = "You are a helpful AI assistant."
- else:
- self.system_prompt = "You are a helpful AI assistant."
+ result = env.get_template("system_prompt.jinja").render(
+ get_tools_prompt=get_tools_prompt,
+ loaded_skill_names=list(skill_content.keys()),
+ **skill_content,
+ )
+ return str(result)
+ except Exception: # noqa: BLE001
+ return ""
def set_agent_identity(self, agent_name: str | None, agent_id: str | None) -> None:
if agent_name:
@@ -165,330 +108,119 @@ class LLM:
if agent_id:
self.agent_id = agent_id
- def _build_identity_message(self) -> dict[str, Any] | None:
- if not (self.agent_name and str(self.agent_name).strip()):
- return None
- identity_name = self.agent_name
- identity_id = self.agent_id
- content = (
- "\n\n"
- "\n"
- "Internal metadata: do not echo or reference; "
- "not part of history or tool calls.\n"
- "You are now assuming the role of this agent. "
- "Act strictly as this agent and maintain self-identity for this step. "
- "Now go answer the next needed step!\n"
- f"{identity_name}\n"
- f"{identity_id}\n"
- "\n\n"
+ async def generate(
+ self, conversation_history: list[dict[str, Any]]
+ ) -> AsyncIterator[LLMResponse]:
+ messages = self._prepare_messages(conversation_history)
+ max_retries = int(Config.get("strix_llm_max_retries") or "5")
+
+ for attempt in range(max_retries + 1):
+ try:
+ async for response in self._stream(messages):
+ yield response
+ return # noqa: TRY300
+ except Exception as e: # noqa: BLE001
+ if attempt >= max_retries or not self._should_retry(e):
+ self._raise_error(e)
+ wait = min(10, 2 * (2**attempt))
+ await asyncio.sleep(wait)
+
+ async def _stream(self, messages: list[dict[str, Any]]) -> AsyncIterator[LLMResponse]:
+ accumulated = ""
+ chunks: list[Any] = []
+
+ self._total_stats.requests += 1
+ response = await acompletion(**self._build_completion_args(messages), stream=True)
+
+ async for chunk in response:
+ chunks.append(chunk)
+ delta = self._get_chunk_content(chunk)
+ if delta:
+ accumulated += delta
+ if "" in accumulated:
+ accumulated = accumulated[
+ : accumulated.find("") + len("")
+ ]
+ yield LLMResponse(content=accumulated)
+
+ if chunks:
+ self._update_usage_stats(stream_chunk_builder(chunks))
+
+ accumulated = fix_incomplete_tool_call(_truncate_to_first_function(accumulated))
+ yield LLMResponse(
+ content=accumulated,
+ tool_invocations=parse_tool_invocations(accumulated),
+ thinking_blocks=self._extract_thinking(chunks),
)
- return {"role": "user", "content": content}
-
- def _add_cache_control_to_content(
- self, content: str | list[dict[str, Any]]
- ) -> str | list[dict[str, Any]]:
- if isinstance(content, str):
- return [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}]
- if isinstance(content, list) and content:
- last_item = content[-1]
- if isinstance(last_item, dict) and last_item.get("type") == "text":
- return content[:-1] + [{**last_item, "cache_control": {"type": "ephemeral"}}]
- return content
-
- def _is_anthropic_model(self) -> bool:
- if not self.config.model_name:
- return False
- model_lower = self.config.model_name.lower()
- return any(provider in model_lower for provider in ["anthropic/", "claude"])
-
- def _calculate_cache_interval(self, total_messages: int) -> int:
- if total_messages <= 1:
- return 10
-
- max_cached_messages = 3
- non_system_messages = total_messages - 1
-
- interval = 10
- while non_system_messages // interval > max_cached_messages:
- interval += 10
-
- return interval
-
- def _prepare_cached_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
- if (
- not self.config.enable_prompt_caching
- or not supports_prompt_caching(self.config.model_name)
- or not messages
- ):
- return messages
-
- if not self._is_anthropic_model():
- return messages
-
- cached_messages = list(messages)
-
- if cached_messages and cached_messages[0].get("role") == "system":
- system_message = cached_messages[0].copy()
- system_message["content"] = self._add_cache_control_to_content(
- system_message["content"]
- )
- cached_messages[0] = system_message
-
- total_messages = len(cached_messages)
- if total_messages > 1:
- interval = self._calculate_cache_interval(total_messages)
-
- cached_count = 0
- for i in range(interval, total_messages, interval):
- if cached_count >= 3:
- break
-
- if i < len(cached_messages):
- message = cached_messages[i].copy()
- message["content"] = self._add_cache_control_to_content(message["content"])
- cached_messages[i] = message
- cached_count += 1
-
- return cached_messages
def _prepare_messages(self, conversation_history: list[dict[str, Any]]) -> list[dict[str, Any]]:
messages = [{"role": "system", "content": self.system_prompt}]
- identity_message = self._build_identity_message()
- if identity_message:
- messages.append(identity_message)
-
- compressed_history = list(self.memory_compressor.compress_history(conversation_history))
+ if self.agent_name:
+ messages.append(
+ {
+ "role": "user",
+ "content": (
+ f"\n\n\n"
+ f"Internal metadata: do not echo or reference.\n"
+ f"{self.agent_name}\n"
+ f"{self.agent_id}\n"
+ f"\n\n"
+ ),
+ }
+ )
+ compressed = list(self.memory_compressor.compress_history(conversation_history))
conversation_history.clear()
- conversation_history.extend(compressed_history)
- messages.extend(compressed_history)
+ conversation_history.extend(compressed)
+ messages.extend(compressed)
- return self._prepare_cached_messages(messages)
+ if self._is_anthropic() and self.config.enable_prompt_caching:
+ messages = self._add_cache_control(messages)
- async def _stream_and_accumulate(
- self,
- messages: list[dict[str, Any]],
- scan_id: str | None,
- step_number: int,
- ) -> AsyncIterator[LLMResponse]:
- accumulated_content = ""
- chunks: list[Any] = []
+ return messages
- async for chunk in self._stream_request(messages):
- chunks.append(chunk)
- delta = self._extract_chunk_delta(chunk)
- if delta:
- accumulated_content += delta
+ def _build_completion_args(self, messages: list[dict[str, Any]]) -> dict[str, Any]:
+ if not self._supports_vision():
+ messages = self._strip_images(messages)
- if "" in accumulated_content:
- function_end = accumulated_content.find("") + len("")
- accumulated_content = accumulated_content[:function_end]
-
- yield LLMResponse(
- scan_id=scan_id,
- step_number=step_number,
- role=StepRole.AGENT,
- content=accumulated_content,
- tool_invocations=None,
- )
-
- if chunks:
- complete_response = stream_chunk_builder(chunks)
- self._update_usage_stats(complete_response)
-
- accumulated_content = _truncate_to_first_function(accumulated_content)
- if "" in accumulated_content:
- function_end = accumulated_content.find("") + len("")
- accumulated_content = accumulated_content[:function_end]
-
- accumulated_content = fix_incomplete_tool_call(accumulated_content)
-
- tool_invocations = parse_tool_invocations(accumulated_content)
-
- # Extract thinking blocks from the complete response if available
- thinking_blocks = None
- if chunks and self._should_include_reasoning_effort():
- complete_response = stream_chunk_builder(chunks)
- if (
- hasattr(complete_response, "choices")
- and complete_response.choices
- and hasattr(complete_response.choices[0], "message")
- ):
- message = complete_response.choices[0].message
- if hasattr(message, "thinking_blocks") and message.thinking_blocks:
- thinking_blocks = message.thinking_blocks
-
- yield LLMResponse(
- scan_id=scan_id,
- step_number=step_number,
- role=StepRole.AGENT,
- content=accumulated_content,
- tool_invocations=tool_invocations if tool_invocations else None,
- thinking_blocks=thinking_blocks,
- )
-
- def _raise_llm_error(self, e: Exception) -> None:
- error_map: list[tuple[type, str]] = [
- (litellm.RateLimitError, "Rate limit exceeded"),
- (litellm.AuthenticationError, "Invalid API key"),
- (litellm.NotFoundError, "Model not found"),
- (litellm.ContextWindowExceededError, "Context too long"),
- (litellm.ContentPolicyViolationError, "Content policy violation"),
- (litellm.ServiceUnavailableError, "Service unavailable"),
- (litellm.Timeout, "Request timed out"),
- (litellm.UnprocessableEntityError, "Unprocessable entity"),
- (litellm.InternalServerError, "Internal server error"),
- (litellm.APIConnectionError, "Connection error"),
- (litellm.UnsupportedParamsError, "Unsupported parameters"),
- (litellm.BudgetExceededError, "Budget exceeded"),
- (litellm.APIResponseValidationError, "Response validation error"),
- (litellm.JSONSchemaValidationError, "JSON schema validation error"),
- (litellm.InvalidRequestError, "Invalid request"),
- (litellm.BadRequestError, "Bad request"),
- (litellm.APIError, "API error"),
- (litellm.OpenAIError, "OpenAI error"),
- ]
-
- from strix.telemetry import posthog
-
- for error_type, message in error_map:
- if isinstance(e, error_type):
- posthog.error(f"llm_{error_type.__name__}", message)
- raise LLMRequestFailedError(f"LLM request failed: {message}", str(e)) from e
-
- posthog.error("llm_unknown_error", type(e).__name__)
- raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
-
- async def generate(
- self,
- conversation_history: list[dict[str, Any]],
- scan_id: str | None = None,
- step_number: int = 1,
- ) -> AsyncIterator[LLMResponse]:
- messages = self._prepare_messages(conversation_history)
-
- last_error: Exception | None = None
- for attempt in range(MAX_RETRIES):
- try:
- async for response in self._stream_and_accumulate(messages, scan_id, step_number):
- yield response
- return # noqa: TRY300
- except Exception as e: # noqa: BLE001
- last_error = e
- if not _should_retry(e) or attempt == MAX_RETRIES - 1:
- break
- wait_time = min(RETRY_MAX, RETRY_MULTIPLIER * (2**attempt))
- wait_time = max(RETRY_MIN, wait_time)
- await asyncio.sleep(wait_time)
-
- if last_error:
- self._raise_llm_error(last_error)
-
- def _extract_chunk_delta(self, chunk: Any) -> str:
- if chunk.choices and hasattr(chunk.choices[0], "delta"):
- delta = chunk.choices[0].delta
- return getattr(delta, "content", "") or ""
- return ""
-
- @property
- def usage_stats(self) -> dict[str, dict[str, int | float]]:
- return {
- "total": self._total_stats.to_dict(),
- "last_request": self._last_request_stats.to_dict(),
- }
-
- def get_cache_config(self) -> dict[str, bool]:
- return {
- "enabled": self.config.enable_prompt_caching,
- "supported": supports_prompt_caching(self.config.model_name),
- }
-
- def _should_include_reasoning_effort(self) -> bool:
- if not self.config.model_name:
- return False
- try:
- return bool(supports_reasoning(model=self.config.model_name))
- except Exception: # noqa: BLE001
- return False
-
- def _model_supports_vision(self) -> bool:
- if not self.config.model_name:
- return False
- try:
- return bool(supports_vision(model=self.config.model_name))
- except Exception: # noqa: BLE001
- return False
-
- def _filter_images_from_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
- filtered_messages = []
- for msg in messages:
- content = msg.get("content")
- updated_msg = msg
- if isinstance(content, list):
- filtered_content = []
- for item in content:
- if isinstance(item, dict):
- if item.get("type") == "image_url":
- filtered_content.append(
- {
- "type": "text",
- "text": "[Screenshot removed - model does not support "
- "vision. Use view_source or execute_js instead.]",
- }
- )
- else:
- filtered_content.append(item)
- else:
- filtered_content.append(item)
- if filtered_content:
- text_parts = [
- item.get("text", "") if isinstance(item, dict) else str(item)
- for item in filtered_content
- ]
- all_text = all(
- isinstance(item, dict) and item.get("type") == "text"
- for item in filtered_content
- )
- if all_text:
- updated_msg = {**msg, "content": "\n".join(text_parts)}
- else:
- updated_msg = {**msg, "content": filtered_content}
- else:
- updated_msg = {**msg, "content": ""}
- filtered_messages.append(updated_msg)
- return filtered_messages
-
- async def _stream_request(
- self,
- messages: list[dict[str, Any]],
- ) -> AsyncIterator[Any]:
- if not self._model_supports_vision():
- messages = self._filter_images_from_messages(messages)
-
- completion_args: dict[str, Any] = {
+ args: dict[str, Any] = {
"model": self.config.model_name,
"messages": messages,
"timeout": self.config.timeout,
"stream_options": {"include_usage": True},
+ "stop": [""],
}
- if _LLM_API_KEY:
- completion_args["api_key"] = _LLM_API_KEY
- if _LLM_API_BASE:
- completion_args["api_base"] = _LLM_API_BASE
+ if api_key := Config.get("llm_api_key"):
+ args["api_key"] = api_key
+ if api_base := (
+ Config.get("llm_api_base")
+ or Config.get("openai_api_base")
+ or Config.get("litellm_base_url")
+ ):
+ args["api_base"] = api_base
+ if self._supports_reasoning():
+ args["reasoning_effort"] = self._reasoning_effort
- completion_args["stop"] = [""]
+ return args
- if self._should_include_reasoning_effort():
- completion_args["reasoning_effort"] = self._reasoning_effort
+ def _get_chunk_content(self, chunk: Any) -> str:
+ if chunk.choices and hasattr(chunk.choices[0], "delta"):
+ return getattr(chunk.choices[0].delta, "content", "") or ""
+ return ""
- queue = get_global_queue()
- self._total_stats.requests += 1
- self._last_request_stats = RequestStats(requests=1)
-
- async for chunk in queue.stream_request(completion_args):
- yield chunk
+ def _extract_thinking(self, chunks: list[Any]) -> list[dict[str, Any]] | None:
+ if not chunks or not self._supports_reasoning():
+ return None
+ try:
+ resp = stream_chunk_builder(chunks)
+ if resp.choices and hasattr(resp.choices[0].message, "thinking_blocks"):
+ blocks: list[dict[str, Any]] = resp.choices[0].message.thinking_blocks
+ return blocks
+ except Exception: # noqa: BLE001, S110 # nosec B110
+ pass
+ return None
def _update_usage_stats(self, response: Any) -> None:
try:
@@ -497,45 +229,88 @@ class LLM:
output_tokens = getattr(response.usage, "completion_tokens", 0)
cached_tokens = 0
- cache_creation_tokens = 0
-
if hasattr(response.usage, "prompt_tokens_details"):
prompt_details = response.usage.prompt_tokens_details
if hasattr(prompt_details, "cached_tokens"):
cached_tokens = prompt_details.cached_tokens or 0
- if hasattr(response.usage, "cache_creation_input_tokens"):
- cache_creation_tokens = response.usage.cache_creation_input_tokens or 0
-
else:
input_tokens = 0
output_tokens = 0
cached_tokens = 0
- cache_creation_tokens = 0
try:
cost = completion_cost(response) or 0.0
- except Exception as e: # noqa: BLE001
- logger.warning(f"Failed to calculate cost: {e}")
+ except Exception: # noqa: BLE001
cost = 0.0
self._total_stats.input_tokens += input_tokens
self._total_stats.output_tokens += output_tokens
self._total_stats.cached_tokens += cached_tokens
- self._total_stats.cache_creation_tokens += cache_creation_tokens
self._total_stats.cost += cost
- self._last_request_stats.input_tokens = input_tokens
- self._last_request_stats.output_tokens = output_tokens
- self._last_request_stats.cached_tokens = cached_tokens
- self._last_request_stats.cache_creation_tokens = cache_creation_tokens
- self._last_request_stats.cost = cost
+ except Exception: # noqa: BLE001, S110 # nosec B110
+ pass
- if cached_tokens > 0:
- logger.info(f"Cache hit: {cached_tokens} cached tokens, {input_tokens} new tokens")
- if cache_creation_tokens > 0:
- logger.info(f"Cache creation: {cache_creation_tokens} tokens written to cache")
+ def _should_retry(self, e: Exception) -> bool:
+ code = getattr(e, "status_code", None) or getattr(
+ getattr(e, "response", None), "status_code", None
+ )
+ return code is None or litellm._should_retry(code)
- logger.info(f"Usage stats: {self.usage_stats}")
- except Exception as e: # noqa: BLE001
- logger.warning(f"Failed to update usage stats: {e}")
+ def _raise_error(self, e: Exception) -> None:
+ from strix.telemetry import posthog
+
+ posthog.error("llm_error", type(e).__name__)
+ raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
+
+ def _is_anthropic(self) -> bool:
+ if not self.config.model_name:
+ return False
+ return any(p in self.config.model_name.lower() for p in ["anthropic/", "claude"])
+
+ def _supports_vision(self) -> bool:
+ try:
+ return bool(supports_vision(model=self.config.model_name))
+ except Exception: # noqa: BLE001
+ return False
+
+ def _supports_reasoning(self) -> bool:
+ try:
+ return bool(supports_reasoning(model=self.config.model_name))
+ except Exception: # noqa: BLE001
+ return False
+
+ def _strip_images(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
+ result = []
+ for msg in messages:
+ content = msg.get("content")
+ if isinstance(content, list):
+ text_parts = []
+ for item in content:
+ if isinstance(item, dict) and item.get("type") == "text":
+ text_parts.append(item.get("text", ""))
+ elif isinstance(item, dict) and item.get("type") == "image_url":
+ text_parts.append("[Image removed - model doesn't support vision]")
+ result.append({**msg, "content": "\n".join(text_parts)})
+ else:
+ result.append(msg)
+ return result
+
+ def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
+ if not messages or not supports_prompt_caching(self.config.model_name):
+ return messages
+
+ result = list(messages)
+
+ if result[0].get("role") == "system":
+ content = result[0]["content"]
+ result[0] = {
+ **result[0],
+ "content": [
+ {"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}
+ ]
+ if isinstance(content, str)
+ else content,
+ }
+ return result
diff --git a/strix/llm/memory_compressor.py b/strix/llm/memory_compressor.py
index bfc6480..a9532f8 100644
--- a/strix/llm/memory_compressor.py
+++ b/strix/llm/memory_compressor.py
@@ -86,7 +86,7 @@ def _extract_message_text(msg: dict[str, Any]) -> str:
def _summarize_messages(
messages: list[dict[str, Any]],
model: str,
- timeout: int = 600,
+ timeout: int = 30,
) -> dict[str, Any]:
if not messages:
empty_summary = "{text}"
@@ -148,11 +148,11 @@ class MemoryCompressor:
self,
max_images: int = 3,
model_name: str | None = None,
- timeout: int = 600,
+ timeout: int | None = None,
):
self.max_images = max_images
self.model_name = model_name or Config.get("strix_llm")
- self.timeout = timeout
+ self.timeout = timeout or int(Config.get("strix_memory_compressor_timeout") or "30")
if not self.model_name:
raise ValueError("STRIX_LLM environment variable must be set and not empty")
diff --git a/strix/llm/request_queue.py b/strix/llm/request_queue.py
deleted file mode 100644
index 35c9725..0000000
--- a/strix/llm/request_queue.py
+++ /dev/null
@@ -1,58 +0,0 @@
-import asyncio
-import threading
-import time
-from collections.abc import AsyncIterator
-from typing import Any
-
-from litellm import acompletion
-from litellm.types.utils import ModelResponseStream
-
-from strix.config import Config
-
-
-class LLMRequestQueue:
- def __init__(self) -> None:
- self.delay_between_requests = float(Config.get("llm_rate_limit_delay") or "4.0")
- self.max_concurrent = int(Config.get("llm_rate_limit_concurrent") or "1")
- self._semaphore = threading.BoundedSemaphore(self.max_concurrent)
- self._last_request_time = 0.0
- self._lock = threading.Lock()
-
- async def stream_request(
- self, completion_args: dict[str, Any]
- ) -> AsyncIterator[ModelResponseStream]:
- try:
- while not self._semaphore.acquire(timeout=0.2):
- await asyncio.sleep(0.1)
-
- with self._lock:
- now = time.time()
- time_since_last = now - self._last_request_time
- sleep_needed = max(0, self.delay_between_requests - time_since_last)
- self._last_request_time = now + sleep_needed
-
- if sleep_needed > 0:
- await asyncio.sleep(sleep_needed)
-
- async for chunk in self._stream_request(completion_args):
- yield chunk
- finally:
- self._semaphore.release()
-
- async def _stream_request(
- self, completion_args: dict[str, Any]
- ) -> AsyncIterator[ModelResponseStream]:
- response = await acompletion(**completion_args, stream=True)
-
- async for chunk in response:
- yield chunk
-
-
-_global_queue: LLMRequestQueue | None = None
-
-
-def get_global_queue() -> LLMRequestQueue:
- global _global_queue # noqa: PLW0603
- if _global_queue is None:
- _global_queue = LLMRequestQueue()
- return _global_queue
diff --git a/strix/telemetry/tracer.py b/strix/telemetry/tracer.py
index af299a9..25af62c 100644
--- a/strix/telemetry/tracer.py
+++ b/strix/telemetry/tracer.py
@@ -430,10 +430,8 @@ class Tracer:
"input_tokens": 0,
"output_tokens": 0,
"cached_tokens": 0,
- "cache_creation_tokens": 0,
"cost": 0.0,
"requests": 0,
- "failed_requests": 0,
}
for agent_instance in _agent_instances.values():
@@ -442,10 +440,8 @@ class Tracer:
total_stats["input_tokens"] += agent_stats.input_tokens
total_stats["output_tokens"] += agent_stats.output_tokens
total_stats["cached_tokens"] += agent_stats.cached_tokens
- total_stats["cache_creation_tokens"] += agent_stats.cache_creation_tokens
total_stats["cost"] += agent_stats.cost
total_stats["requests"] += agent_stats.requests
- total_stats["failed_requests"] += agent_stats.failed_requests
total_stats["cost"] = round(total_stats["cost"], 4)
diff --git a/strix/tools/executor.py b/strix/tools/executor.py
index ad0aeef..db06477 100644
--- a/strix/tools/executor.py
+++ b/strix/tools/executor.py
@@ -20,7 +20,7 @@ from .registry import (
)
-SANDBOX_EXECUTION_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout") or "500")
+SANDBOX_EXECUTION_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout") or "120")
SANDBOX_CONNECT_TIMEOUT = float(Config.get("strix_sandbox_connect_timeout") or "10")