fix(agent): fix agent loop hanging and simplify LLM module

- Fix agent loop getting stuck by adding hard stop mechanism
- Add _force_stop flag for immediate task cancellation across threads
- Use thread-safe loop.call_soon_threadsafe for cross-thread cancellation
- Remove request_queue.py (eliminated threading/queue complexity causing hangs)
- Simplify llm.py: direct acompletion calls, cleaner streaming
- Reduce retry wait times to prevent long hangs during retries
- Make timeouts configurable (llm_max_retries, memory_compressor_timeout, sandbox_execution_timeout)
- Keep essential token tracking (input/output/cached tokens, cost, requests)
- Maintain Anthropic prompt caching for system messages
This commit is contained in:
0xallam
2026-01-14 18:44:04 -08:00
committed by Ahmed Allam
parent 47faeb1ef3
commit 56526cbf90
8 changed files with 210 additions and 485 deletions

View File

@@ -111,7 +111,6 @@ hiddenimports = [
'strix.llm.llm', 'strix.llm.llm',
'strix.llm.config', 'strix.llm.config',
'strix.llm.utils', 'strix.llm.utils',
'strix.llm.request_queue',
'strix.llm.memory_compressor', 'strix.llm.memory_compressor',
'strix.runtime', 'strix.runtime',
'strix.runtime.runtime', 'strix.runtime.runtime',

View File

@@ -79,6 +79,7 @@ class BaseAgent(metaclass=AgentMeta):
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
self.llm.set_agent_identity(self.state.agent_name, self.state.agent_id) self.llm.set_agent_identity(self.state.agent_name, self.state.agent_id)
self._current_task: asyncio.Task[Any] | None = None self._current_task: asyncio.Task[Any] | None = None
self._force_stop = False
from strix.telemetry.tracer import get_global_tracer from strix.telemetry.tracer import get_global_tracer
@@ -156,6 +157,11 @@ class BaseAgent(metaclass=AgentMeta):
return self._handle_sandbox_error(e, tracer) return self._handle_sandbox_error(e, tracer)
while True: while True:
if self._force_stop:
self._force_stop = False
await self._enter_waiting_state(tracer, was_cancelled=True)
continue
self._check_agent_messages(self.state) self._check_agent_messages(self.state)
if self.state.is_waiting_for_input(): if self.state.is_waiting_for_input():
@@ -246,7 +252,8 @@ class BaseAgent(metaclass=AgentMeta):
continue continue
async def _wait_for_input(self) -> None: async def _wait_for_input(self) -> None:
import asyncio if self._force_stop:
return
if self.state.has_waiting_timeout(): if self.state.has_waiting_timeout():
self.state.resume_from_waiting() self.state.resume_from_waiting()
@@ -339,6 +346,7 @@ class BaseAgent(metaclass=AgentMeta):
async def _process_iteration(self, tracer: Optional["Tracer"]) -> bool: async def _process_iteration(self, tracer: Optional["Tracer"]) -> bool:
final_response = None final_response = None
async for response in self.llm.generate(self.state.get_conversation_history()): async for response in self.llm.generate(self.state.get_conversation_history()):
final_response = response final_response = response
if tracer and response.content: if tracer and response.content:
@@ -584,6 +592,11 @@ class BaseAgent(metaclass=AgentMeta):
return True return True
def cancel_current_execution(self) -> None: def cancel_current_execution(self) -> None:
self._force_stop = True
if self._current_task and not self._current_task.done(): if self._current_task and not self._current_task.done():
try:
loop = self._current_task.get_loop()
loop.call_soon_threadsafe(self._current_task.cancel)
except RuntimeError:
self._current_task.cancel() self._current_task.cancel()
self._current_task = None self._current_task = None

View File

@@ -16,9 +16,9 @@ class Config:
litellm_base_url = None litellm_base_url = None
ollama_api_base = None ollama_api_base = None
strix_reasoning_effort = "high" strix_reasoning_effort = "high"
strix_llm_max_retries = "5"
strix_memory_compressor_timeout = "30"
llm_timeout = "300" llm_timeout = "300"
llm_rate_limit_delay = "4.0"
llm_rate_limit_concurrent = "1"
# Tool & Feature Configuration # Tool & Feature Configuration
perplexity_api_key = None perplexity_api_key = None
@@ -27,7 +27,7 @@ class Config:
# Runtime Configuration # Runtime Configuration
strix_image = "ghcr.io/usestrix/strix-sandbox:0.1.10" strix_image = "ghcr.io/usestrix/strix-sandbox:0.1.10"
strix_runtime_backend = "docker" strix_runtime_backend = "docker"
strix_sandbox_execution_timeout = "500" strix_sandbox_execution_timeout = "120"
strix_sandbox_connect_timeout = "10" strix_sandbox_connect_timeout = "10"
# Telemetry # Telemetry

View File

@@ -1,23 +1,16 @@
import asyncio import asyncio
import logging
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum
from typing import Any from typing import Any
import litellm import litellm
from jinja2 import ( from jinja2 import Environment, FileSystemLoader, select_autoescape
Environment, from litellm import acompletion, completion_cost, stream_chunk_builder, supports_reasoning
FileSystemLoader,
select_autoescape,
)
from litellm import completion_cost, stream_chunk_builder, supports_reasoning
from litellm.utils import supports_prompt_caching, supports_vision from litellm.utils import supports_prompt_caching, supports_vision
from strix.config import Config from strix.config import Config
from strix.llm.config import LLMConfig from strix.llm.config import LLMConfig
from strix.llm.memory_compressor import MemoryCompressor from strix.llm.memory_compressor import MemoryCompressor
from strix.llm.request_queue import get_global_queue
from strix.llm.utils import ( from strix.llm.utils import (
_truncate_to_first_function, _truncate_to_first_function,
fix_incomplete_tool_call, fix_incomplete_tool_call,
@@ -28,37 +21,9 @@ from strix.tools import get_tools_prompt
from strix.utils.resource_paths import get_strix_resource_path from strix.utils.resource_paths import get_strix_resource_path
MAX_RETRIES = 5
RETRY_MULTIPLIER = 8
RETRY_MIN = 8
RETRY_MAX = 64
def _should_retry(exception: Exception) -> bool:
status_code = None
if hasattr(exception, "status_code"):
status_code = exception.status_code
elif hasattr(exception, "response") and hasattr(exception.response, "status_code"):
status_code = exception.response.status_code
if status_code is not None:
return bool(litellm._should_retry(status_code))
return True
logger = logging.getLogger(__name__)
litellm.drop_params = True litellm.drop_params = True
litellm.modify_params = True litellm.modify_params = True
_LLM_API_KEY = Config.get("llm_api_key")
_LLM_API_BASE = (
Config.get("llm_api_base")
or Config.get("openai_api_base")
or Config.get("litellm_base_url")
or Config.get("ollama_api_base")
)
_STRIX_REASONING_EFFORT = Config.get("strix_reasoning_effort")
class LLMRequestFailedError(Exception): class LLMRequestFailedError(Exception):
def __init__(self, message: str, details: str | None = None): def __init__(self, message: str, details: str | None = None):
@@ -67,20 +32,11 @@ class LLMRequestFailedError(Exception):
self.details = details self.details = details
class StepRole(str, Enum):
AGENT = "agent"
USER = "user"
SYSTEM = "system"
@dataclass @dataclass
class LLMResponse: class LLMResponse:
content: str content: str
tool_invocations: list[dict[str, Any]] | None = None tool_invocations: list[dict[str, Any]] | None = None
scan_id: str | None = None thinking_blocks: list[dict[str, Any]] | None = None
step_number: int = 1
role: StepRole = StepRole.AGENT
thinking_blocks: list[dict[str, Any]] | None = None # For reasoning models.
@dataclass @dataclass
@@ -88,76 +44,63 @@ class RequestStats:
input_tokens: int = 0 input_tokens: int = 0
output_tokens: int = 0 output_tokens: int = 0
cached_tokens: int = 0 cached_tokens: int = 0
cache_creation_tokens: int = 0
cost: float = 0.0 cost: float = 0.0
requests: int = 0 requests: int = 0
failed_requests: int = 0
def to_dict(self) -> dict[str, int | float]: def to_dict(self) -> dict[str, int | float]:
return { return {
"input_tokens": self.input_tokens, "input_tokens": self.input_tokens,
"output_tokens": self.output_tokens, "output_tokens": self.output_tokens,
"cached_tokens": self.cached_tokens, "cached_tokens": self.cached_tokens,
"cache_creation_tokens": self.cache_creation_tokens,
"cost": round(self.cost, 4), "cost": round(self.cost, 4),
"requests": self.requests, "requests": self.requests,
"failed_requests": self.failed_requests,
} }
class LLM: class LLM:
def __init__( def __init__(self, config: LLMConfig, agent_name: str | None = None):
self, config: LLMConfig, agent_name: str | None = None, agent_id: str | None = None
):
self.config = config self.config = config
self.agent_name = agent_name self.agent_name = agent_name
self.agent_id = agent_id self.agent_id: str | None = None
self._total_stats = RequestStats() self._total_stats = RequestStats()
self._last_request_stats = RequestStats() self.memory_compressor = MemoryCompressor(model_name=config.model_name)
self.system_prompt = self._load_system_prompt(agent_name)
if _STRIX_REASONING_EFFORT: reasoning = Config.get("strix_reasoning_effort")
self._reasoning_effort = _STRIX_REASONING_EFFORT if reasoning:
elif self.config.scan_mode == "quick": self._reasoning_effort = reasoning
elif config.scan_mode == "quick":
self._reasoning_effort = "medium" self._reasoning_effort = "medium"
else: else:
self._reasoning_effort = "high" self._reasoning_effort = "high"
self.memory_compressor = MemoryCompressor( def _load_system_prompt(self, agent_name: str | None) -> str:
model_name=self.config.model_name, if not agent_name:
timeout=self.config.timeout, return ""
)
if agent_name: try:
prompt_dir = get_strix_resource_path("agents", agent_name) prompt_dir = get_strix_resource_path("agents", agent_name)
skills_dir = get_strix_resource_path("skills") skills_dir = get_strix_resource_path("skills")
env = Environment(
loader = FileSystemLoader([prompt_dir, skills_dir]) loader=FileSystemLoader([prompt_dir, skills_dir]),
self.jinja_env = Environment(
loader=loader,
autoescape=select_autoescape(enabled_extensions=(), default_for_string=False), autoescape=select_autoescape(enabled_extensions=(), default_for_string=False),
) )
try: skills_to_load = [
skills_to_load = list(self.config.skills or []) *list(self.config.skills or []),
skills_to_load.append(f"scan_modes/{self.config.scan_mode}") f"scan_modes/{self.config.scan_mode}",
]
skill_content = load_skills(skills_to_load, env)
env.globals["get_skill"] = lambda name: skill_content.get(name, "")
skill_content = load_skills(skills_to_load, self.jinja_env) result = env.get_template("system_prompt.jinja").render(
def get_skill(name: str) -> str:
return skill_content.get(name, "")
self.jinja_env.globals["get_skill"] = get_skill
self.system_prompt = self.jinja_env.get_template("system_prompt.jinja").render(
get_tools_prompt=get_tools_prompt, get_tools_prompt=get_tools_prompt,
loaded_skill_names=list(skill_content.keys()), loaded_skill_names=list(skill_content.keys()),
**skill_content, **skill_content,
) )
except (FileNotFoundError, OSError, ValueError) as e: return str(result)
logger.warning(f"Failed to load system prompt for {agent_name}: {e}") except Exception: # noqa: BLE001
self.system_prompt = "You are a helpful AI assistant." return ""
else:
self.system_prompt = "You are a helpful AI assistant."
def set_agent_identity(self, agent_name: str | None, agent_id: str | None) -> None: def set_agent_identity(self, agent_name: str | None, agent_id: str | None) -> None:
if agent_name: if agent_name:
@@ -165,330 +108,119 @@ class LLM:
if agent_id: if agent_id:
self.agent_id = agent_id self.agent_id = agent_id
def _build_identity_message(self) -> dict[str, Any] | None: async def generate(
if not (self.agent_name and str(self.agent_name).strip()): self, conversation_history: list[dict[str, Any]]
return None ) -> AsyncIterator[LLMResponse]:
identity_name = self.agent_name messages = self._prepare_messages(conversation_history)
identity_id = self.agent_id max_retries = int(Config.get("strix_llm_max_retries") or "5")
content = (
"\n\n" for attempt in range(max_retries + 1):
"<agent_identity>\n" try:
"<meta>Internal metadata: do not echo or reference; " async for response in self._stream(messages):
"not part of history or tool calls.</meta>\n" yield response
"<note>You are now assuming the role of this agent. " return # noqa: TRY300
"Act strictly as this agent and maintain self-identity for this step. " except Exception as e: # noqa: BLE001
"Now go answer the next needed step!</note>\n" if attempt >= max_retries or not self._should_retry(e):
f"<agent_name>{identity_name}</agent_name>\n" self._raise_error(e)
f"<agent_id>{identity_id}</agent_id>\n" wait = min(10, 2 * (2**attempt))
"</agent_identity>\n\n" await asyncio.sleep(wait)
async def _stream(self, messages: list[dict[str, Any]]) -> AsyncIterator[LLMResponse]:
accumulated = ""
chunks: list[Any] = []
self._total_stats.requests += 1
response = await acompletion(**self._build_completion_args(messages), stream=True)
async for chunk in response:
chunks.append(chunk)
delta = self._get_chunk_content(chunk)
if delta:
accumulated += delta
if "</function>" in accumulated:
accumulated = accumulated[
: accumulated.find("</function>") + len("</function>")
]
yield LLMResponse(content=accumulated)
if chunks:
self._update_usage_stats(stream_chunk_builder(chunks))
accumulated = fix_incomplete_tool_call(_truncate_to_first_function(accumulated))
yield LLMResponse(
content=accumulated,
tool_invocations=parse_tool_invocations(accumulated),
thinking_blocks=self._extract_thinking(chunks),
) )
return {"role": "user", "content": content}
def _add_cache_control_to_content(
self, content: str | list[dict[str, Any]]
) -> str | list[dict[str, Any]]:
if isinstance(content, str):
return [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}]
if isinstance(content, list) and content:
last_item = content[-1]
if isinstance(last_item, dict) and last_item.get("type") == "text":
return content[:-1] + [{**last_item, "cache_control": {"type": "ephemeral"}}]
return content
def _is_anthropic_model(self) -> bool:
if not self.config.model_name:
return False
model_lower = self.config.model_name.lower()
return any(provider in model_lower for provider in ["anthropic/", "claude"])
def _calculate_cache_interval(self, total_messages: int) -> int:
if total_messages <= 1:
return 10
max_cached_messages = 3
non_system_messages = total_messages - 1
interval = 10
while non_system_messages // interval > max_cached_messages:
interval += 10
return interval
def _prepare_cached_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
if (
not self.config.enable_prompt_caching
or not supports_prompt_caching(self.config.model_name)
or not messages
):
return messages
if not self._is_anthropic_model():
return messages
cached_messages = list(messages)
if cached_messages and cached_messages[0].get("role") == "system":
system_message = cached_messages[0].copy()
system_message["content"] = self._add_cache_control_to_content(
system_message["content"]
)
cached_messages[0] = system_message
total_messages = len(cached_messages)
if total_messages > 1:
interval = self._calculate_cache_interval(total_messages)
cached_count = 0
for i in range(interval, total_messages, interval):
if cached_count >= 3:
break
if i < len(cached_messages):
message = cached_messages[i].copy()
message["content"] = self._add_cache_control_to_content(message["content"])
cached_messages[i] = message
cached_count += 1
return cached_messages
def _prepare_messages(self, conversation_history: list[dict[str, Any]]) -> list[dict[str, Any]]: def _prepare_messages(self, conversation_history: list[dict[str, Any]]) -> list[dict[str, Any]]:
messages = [{"role": "system", "content": self.system_prompt}] messages = [{"role": "system", "content": self.system_prompt}]
identity_message = self._build_identity_message() if self.agent_name:
if identity_message: messages.append(
messages.append(identity_message)
compressed_history = list(self.memory_compressor.compress_history(conversation_history))
conversation_history.clear()
conversation_history.extend(compressed_history)
messages.extend(compressed_history)
return self._prepare_cached_messages(messages)
async def _stream_and_accumulate(
self,
messages: list[dict[str, Any]],
scan_id: str | None,
step_number: int,
) -> AsyncIterator[LLMResponse]:
accumulated_content = ""
chunks: list[Any] = []
async for chunk in self._stream_request(messages):
chunks.append(chunk)
delta = self._extract_chunk_delta(chunk)
if delta:
accumulated_content += delta
if "</function>" in accumulated_content:
function_end = accumulated_content.find("</function>") + len("</function>")
accumulated_content = accumulated_content[:function_end]
yield LLMResponse(
scan_id=scan_id,
step_number=step_number,
role=StepRole.AGENT,
content=accumulated_content,
tool_invocations=None,
)
if chunks:
complete_response = stream_chunk_builder(chunks)
self._update_usage_stats(complete_response)
accumulated_content = _truncate_to_first_function(accumulated_content)
if "</function>" in accumulated_content:
function_end = accumulated_content.find("</function>") + len("</function>")
accumulated_content = accumulated_content[:function_end]
accumulated_content = fix_incomplete_tool_call(accumulated_content)
tool_invocations = parse_tool_invocations(accumulated_content)
# Extract thinking blocks from the complete response if available
thinking_blocks = None
if chunks and self._should_include_reasoning_effort():
complete_response = stream_chunk_builder(chunks)
if (
hasattr(complete_response, "choices")
and complete_response.choices
and hasattr(complete_response.choices[0], "message")
):
message = complete_response.choices[0].message
if hasattr(message, "thinking_blocks") and message.thinking_blocks:
thinking_blocks = message.thinking_blocks
yield LLMResponse(
scan_id=scan_id,
step_number=step_number,
role=StepRole.AGENT,
content=accumulated_content,
tool_invocations=tool_invocations if tool_invocations else None,
thinking_blocks=thinking_blocks,
)
def _raise_llm_error(self, e: Exception) -> None:
error_map: list[tuple[type, str]] = [
(litellm.RateLimitError, "Rate limit exceeded"),
(litellm.AuthenticationError, "Invalid API key"),
(litellm.NotFoundError, "Model not found"),
(litellm.ContextWindowExceededError, "Context too long"),
(litellm.ContentPolicyViolationError, "Content policy violation"),
(litellm.ServiceUnavailableError, "Service unavailable"),
(litellm.Timeout, "Request timed out"),
(litellm.UnprocessableEntityError, "Unprocessable entity"),
(litellm.InternalServerError, "Internal server error"),
(litellm.APIConnectionError, "Connection error"),
(litellm.UnsupportedParamsError, "Unsupported parameters"),
(litellm.BudgetExceededError, "Budget exceeded"),
(litellm.APIResponseValidationError, "Response validation error"),
(litellm.JSONSchemaValidationError, "JSON schema validation error"),
(litellm.InvalidRequestError, "Invalid request"),
(litellm.BadRequestError, "Bad request"),
(litellm.APIError, "API error"),
(litellm.OpenAIError, "OpenAI error"),
]
from strix.telemetry import posthog
for error_type, message in error_map:
if isinstance(e, error_type):
posthog.error(f"llm_{error_type.__name__}", message)
raise LLMRequestFailedError(f"LLM request failed: {message}", str(e)) from e
posthog.error("llm_unknown_error", type(e).__name__)
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
async def generate(
self,
conversation_history: list[dict[str, Any]],
scan_id: str | None = None,
step_number: int = 1,
) -> AsyncIterator[LLMResponse]:
messages = self._prepare_messages(conversation_history)
last_error: Exception | None = None
for attempt in range(MAX_RETRIES):
try:
async for response in self._stream_and_accumulate(messages, scan_id, step_number):
yield response
return # noqa: TRY300
except Exception as e: # noqa: BLE001
last_error = e
if not _should_retry(e) or attempt == MAX_RETRIES - 1:
break
wait_time = min(RETRY_MAX, RETRY_MULTIPLIER * (2**attempt))
wait_time = max(RETRY_MIN, wait_time)
await asyncio.sleep(wait_time)
if last_error:
self._raise_llm_error(last_error)
def _extract_chunk_delta(self, chunk: Any) -> str:
if chunk.choices and hasattr(chunk.choices[0], "delta"):
delta = chunk.choices[0].delta
return getattr(delta, "content", "") or ""
return ""
@property
def usage_stats(self) -> dict[str, dict[str, int | float]]:
return {
"total": self._total_stats.to_dict(),
"last_request": self._last_request_stats.to_dict(),
}
def get_cache_config(self) -> dict[str, bool]:
return {
"enabled": self.config.enable_prompt_caching,
"supported": supports_prompt_caching(self.config.model_name),
}
def _should_include_reasoning_effort(self) -> bool:
if not self.config.model_name:
return False
try:
return bool(supports_reasoning(model=self.config.model_name))
except Exception: # noqa: BLE001
return False
def _model_supports_vision(self) -> bool:
if not self.config.model_name:
return False
try:
return bool(supports_vision(model=self.config.model_name))
except Exception: # noqa: BLE001
return False
def _filter_images_from_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
filtered_messages = []
for msg in messages:
content = msg.get("content")
updated_msg = msg
if isinstance(content, list):
filtered_content = []
for item in content:
if isinstance(item, dict):
if item.get("type") == "image_url":
filtered_content.append(
{ {
"type": "text", "role": "user",
"text": "[Screenshot removed - model does not support " "content": (
"vision. Use view_source or execute_js instead.]", f"\n\n<agent_identity>\n"
f"<meta>Internal metadata: do not echo or reference.</meta>\n"
f"<agent_name>{self.agent_name}</agent_name>\n"
f"<agent_id>{self.agent_id}</agent_id>\n"
f"</agent_identity>\n\n"
),
} }
) )
else:
filtered_content.append(item)
else:
filtered_content.append(item)
if filtered_content:
text_parts = [
item.get("text", "") if isinstance(item, dict) else str(item)
for item in filtered_content
]
all_text = all(
isinstance(item, dict) and item.get("type") == "text"
for item in filtered_content
)
if all_text:
updated_msg = {**msg, "content": "\n".join(text_parts)}
else:
updated_msg = {**msg, "content": filtered_content}
else:
updated_msg = {**msg, "content": ""}
filtered_messages.append(updated_msg)
return filtered_messages
async def _stream_request( compressed = list(self.memory_compressor.compress_history(conversation_history))
self, conversation_history.clear()
messages: list[dict[str, Any]], conversation_history.extend(compressed)
) -> AsyncIterator[Any]: messages.extend(compressed)
if not self._model_supports_vision():
messages = self._filter_images_from_messages(messages)
completion_args: dict[str, Any] = { if self._is_anthropic() and self.config.enable_prompt_caching:
messages = self._add_cache_control(messages)
return messages
def _build_completion_args(self, messages: list[dict[str, Any]]) -> dict[str, Any]:
if not self._supports_vision():
messages = self._strip_images(messages)
args: dict[str, Any] = {
"model": self.config.model_name, "model": self.config.model_name,
"messages": messages, "messages": messages,
"timeout": self.config.timeout, "timeout": self.config.timeout,
"stream_options": {"include_usage": True}, "stream_options": {"include_usage": True},
"stop": ["</function>"],
} }
if _LLM_API_KEY: if api_key := Config.get("llm_api_key"):
completion_args["api_key"] = _LLM_API_KEY args["api_key"] = api_key
if _LLM_API_BASE: if api_base := (
completion_args["api_base"] = _LLM_API_BASE Config.get("llm_api_base")
or Config.get("openai_api_base")
or Config.get("litellm_base_url")
):
args["api_base"] = api_base
if self._supports_reasoning():
args["reasoning_effort"] = self._reasoning_effort
completion_args["stop"] = ["</function>"] return args
if self._should_include_reasoning_effort(): def _get_chunk_content(self, chunk: Any) -> str:
completion_args["reasoning_effort"] = self._reasoning_effort if chunk.choices and hasattr(chunk.choices[0], "delta"):
return getattr(chunk.choices[0].delta, "content", "") or ""
return ""
queue = get_global_queue() def _extract_thinking(self, chunks: list[Any]) -> list[dict[str, Any]] | None:
self._total_stats.requests += 1 if not chunks or not self._supports_reasoning():
self._last_request_stats = RequestStats(requests=1) return None
try:
async for chunk in queue.stream_request(completion_args): resp = stream_chunk_builder(chunks)
yield chunk if resp.choices and hasattr(resp.choices[0].message, "thinking_blocks"):
blocks: list[dict[str, Any]] = resp.choices[0].message.thinking_blocks
return blocks
except Exception: # noqa: BLE001, S110 # nosec B110
pass
return None
def _update_usage_stats(self, response: Any) -> None: def _update_usage_stats(self, response: Any) -> None:
try: try:
@@ -497,45 +229,88 @@ class LLM:
output_tokens = getattr(response.usage, "completion_tokens", 0) output_tokens = getattr(response.usage, "completion_tokens", 0)
cached_tokens = 0 cached_tokens = 0
cache_creation_tokens = 0
if hasattr(response.usage, "prompt_tokens_details"): if hasattr(response.usage, "prompt_tokens_details"):
prompt_details = response.usage.prompt_tokens_details prompt_details = response.usage.prompt_tokens_details
if hasattr(prompt_details, "cached_tokens"): if hasattr(prompt_details, "cached_tokens"):
cached_tokens = prompt_details.cached_tokens or 0 cached_tokens = prompt_details.cached_tokens or 0
if hasattr(response.usage, "cache_creation_input_tokens"):
cache_creation_tokens = response.usage.cache_creation_input_tokens or 0
else: else:
input_tokens = 0 input_tokens = 0
output_tokens = 0 output_tokens = 0
cached_tokens = 0 cached_tokens = 0
cache_creation_tokens = 0
try: try:
cost = completion_cost(response) or 0.0 cost = completion_cost(response) or 0.0
except Exception as e: # noqa: BLE001 except Exception: # noqa: BLE001
logger.warning(f"Failed to calculate cost: {e}")
cost = 0.0 cost = 0.0
self._total_stats.input_tokens += input_tokens self._total_stats.input_tokens += input_tokens
self._total_stats.output_tokens += output_tokens self._total_stats.output_tokens += output_tokens
self._total_stats.cached_tokens += cached_tokens self._total_stats.cached_tokens += cached_tokens
self._total_stats.cache_creation_tokens += cache_creation_tokens
self._total_stats.cost += cost self._total_stats.cost += cost
self._last_request_stats.input_tokens = input_tokens except Exception: # noqa: BLE001, S110 # nosec B110
self._last_request_stats.output_tokens = output_tokens pass
self._last_request_stats.cached_tokens = cached_tokens
self._last_request_stats.cache_creation_tokens = cache_creation_tokens
self._last_request_stats.cost = cost
if cached_tokens > 0: def _should_retry(self, e: Exception) -> bool:
logger.info(f"Cache hit: {cached_tokens} cached tokens, {input_tokens} new tokens") code = getattr(e, "status_code", None) or getattr(
if cache_creation_tokens > 0: getattr(e, "response", None), "status_code", None
logger.info(f"Cache creation: {cache_creation_tokens} tokens written to cache") )
return code is None or litellm._should_retry(code)
logger.info(f"Usage stats: {self.usage_stats}") def _raise_error(self, e: Exception) -> None:
except Exception as e: # noqa: BLE001 from strix.telemetry import posthog
logger.warning(f"Failed to update usage stats: {e}")
posthog.error("llm_error", type(e).__name__)
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
def _is_anthropic(self) -> bool:
if not self.config.model_name:
return False
return any(p in self.config.model_name.lower() for p in ["anthropic/", "claude"])
def _supports_vision(self) -> bool:
try:
return bool(supports_vision(model=self.config.model_name))
except Exception: # noqa: BLE001
return False
def _supports_reasoning(self) -> bool:
try:
return bool(supports_reasoning(model=self.config.model_name))
except Exception: # noqa: BLE001
return False
def _strip_images(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
result = []
for msg in messages:
content = msg.get("content")
if isinstance(content, list):
text_parts = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
text_parts.append(item.get("text", ""))
elif isinstance(item, dict) and item.get("type") == "image_url":
text_parts.append("[Image removed - model doesn't support vision]")
result.append({**msg, "content": "\n".join(text_parts)})
else:
result.append(msg)
return result
def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
if not messages or not supports_prompt_caching(self.config.model_name):
return messages
result = list(messages)
if result[0].get("role") == "system":
content = result[0]["content"]
result[0] = {
**result[0],
"content": [
{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}
]
if isinstance(content, str)
else content,
}
return result

View File

@@ -86,7 +86,7 @@ def _extract_message_text(msg: dict[str, Any]) -> str:
def _summarize_messages( def _summarize_messages(
messages: list[dict[str, Any]], messages: list[dict[str, Any]],
model: str, model: str,
timeout: int = 600, timeout: int = 30,
) -> dict[str, Any]: ) -> dict[str, Any]:
if not messages: if not messages:
empty_summary = "<context_summary message_count='0'>{text}</context_summary>" empty_summary = "<context_summary message_count='0'>{text}</context_summary>"
@@ -148,11 +148,11 @@ class MemoryCompressor:
self, self,
max_images: int = 3, max_images: int = 3,
model_name: str | None = None, model_name: str | None = None,
timeout: int = 600, timeout: int | None = None,
): ):
self.max_images = max_images self.max_images = max_images
self.model_name = model_name or Config.get("strix_llm") self.model_name = model_name or Config.get("strix_llm")
self.timeout = timeout self.timeout = timeout or int(Config.get("strix_memory_compressor_timeout") or "30")
if not self.model_name: if not self.model_name:
raise ValueError("STRIX_LLM environment variable must be set and not empty") raise ValueError("STRIX_LLM environment variable must be set and not empty")

View File

@@ -1,58 +0,0 @@
import asyncio
import threading
import time
from collections.abc import AsyncIterator
from typing import Any
from litellm import acompletion
from litellm.types.utils import ModelResponseStream
from strix.config import Config
class LLMRequestQueue:
def __init__(self) -> None:
self.delay_between_requests = float(Config.get("llm_rate_limit_delay") or "4.0")
self.max_concurrent = int(Config.get("llm_rate_limit_concurrent") or "1")
self._semaphore = threading.BoundedSemaphore(self.max_concurrent)
self._last_request_time = 0.0
self._lock = threading.Lock()
async def stream_request(
self, completion_args: dict[str, Any]
) -> AsyncIterator[ModelResponseStream]:
try:
while not self._semaphore.acquire(timeout=0.2):
await asyncio.sleep(0.1)
with self._lock:
now = time.time()
time_since_last = now - self._last_request_time
sleep_needed = max(0, self.delay_between_requests - time_since_last)
self._last_request_time = now + sleep_needed
if sleep_needed > 0:
await asyncio.sleep(sleep_needed)
async for chunk in self._stream_request(completion_args):
yield chunk
finally:
self._semaphore.release()
async def _stream_request(
self, completion_args: dict[str, Any]
) -> AsyncIterator[ModelResponseStream]:
response = await acompletion(**completion_args, stream=True)
async for chunk in response:
yield chunk
_global_queue: LLMRequestQueue | None = None
def get_global_queue() -> LLMRequestQueue:
global _global_queue # noqa: PLW0603
if _global_queue is None:
_global_queue = LLMRequestQueue()
return _global_queue

View File

@@ -430,10 +430,8 @@ class Tracer:
"input_tokens": 0, "input_tokens": 0,
"output_tokens": 0, "output_tokens": 0,
"cached_tokens": 0, "cached_tokens": 0,
"cache_creation_tokens": 0,
"cost": 0.0, "cost": 0.0,
"requests": 0, "requests": 0,
"failed_requests": 0,
} }
for agent_instance in _agent_instances.values(): for agent_instance in _agent_instances.values():
@@ -442,10 +440,8 @@ class Tracer:
total_stats["input_tokens"] += agent_stats.input_tokens total_stats["input_tokens"] += agent_stats.input_tokens
total_stats["output_tokens"] += agent_stats.output_tokens total_stats["output_tokens"] += agent_stats.output_tokens
total_stats["cached_tokens"] += agent_stats.cached_tokens total_stats["cached_tokens"] += agent_stats.cached_tokens
total_stats["cache_creation_tokens"] += agent_stats.cache_creation_tokens
total_stats["cost"] += agent_stats.cost total_stats["cost"] += agent_stats.cost
total_stats["requests"] += agent_stats.requests total_stats["requests"] += agent_stats.requests
total_stats["failed_requests"] += agent_stats.failed_requests
total_stats["cost"] = round(total_stats["cost"], 4) total_stats["cost"] = round(total_stats["cost"], 4)

View File

@@ -20,7 +20,7 @@ from .registry import (
) )
SANDBOX_EXECUTION_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout") or "500") SANDBOX_EXECUTION_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout") or "120")
SANDBOX_CONNECT_TIMEOUT = float(Config.get("strix_sandbox_connect_timeout") or "10") SANDBOX_CONNECT_TIMEOUT = float(Config.get("strix_sandbox_connect_timeout") or "10")