fix(agent): fix agent loop hanging and simplify LLM module

- Fix agent loop getting stuck by adding hard stop mechanism
- Add _force_stop flag for immediate task cancellation across threads
- Use thread-safe loop.call_soon_threadsafe for cross-thread cancellation
- Remove request_queue.py (eliminated threading/queue complexity causing hangs)
- Simplify llm.py: direct acompletion calls, cleaner streaming
- Reduce retry wait times to prevent long hangs during retries
- Make timeouts configurable (llm_max_retries, memory_compressor_timeout, sandbox_execution_timeout)
- Keep essential token tracking (input/output/cached tokens, cost, requests)
- Maintain Anthropic prompt caching for system messages
This commit is contained in:
0xallam
2026-01-14 18:44:04 -08:00
committed by Ahmed Allam
parent 47faeb1ef3
commit 56526cbf90
8 changed files with 210 additions and 485 deletions

View File

@@ -111,7 +111,6 @@ hiddenimports = [
'strix.llm.llm',
'strix.llm.config',
'strix.llm.utils',
'strix.llm.request_queue',
'strix.llm.memory_compressor',
'strix.runtime',
'strix.runtime.runtime',

View File

@@ -79,6 +79,7 @@ class BaseAgent(metaclass=AgentMeta):
with contextlib.suppress(Exception):
self.llm.set_agent_identity(self.state.agent_name, self.state.agent_id)
self._current_task: asyncio.Task[Any] | None = None
self._force_stop = False
from strix.telemetry.tracer import get_global_tracer
@@ -156,6 +157,11 @@ class BaseAgent(metaclass=AgentMeta):
return self._handle_sandbox_error(e, tracer)
while True:
if self._force_stop:
self._force_stop = False
await self._enter_waiting_state(tracer, was_cancelled=True)
continue
self._check_agent_messages(self.state)
if self.state.is_waiting_for_input():
@@ -246,7 +252,8 @@ class BaseAgent(metaclass=AgentMeta):
continue
async def _wait_for_input(self) -> None:
import asyncio
if self._force_stop:
return
if self.state.has_waiting_timeout():
self.state.resume_from_waiting()
@@ -339,6 +346,7 @@ class BaseAgent(metaclass=AgentMeta):
async def _process_iteration(self, tracer: Optional["Tracer"]) -> bool:
final_response = None
async for response in self.llm.generate(self.state.get_conversation_history()):
final_response = response
if tracer and response.content:
@@ -584,6 +592,11 @@ class BaseAgent(metaclass=AgentMeta):
return True
def cancel_current_execution(self) -> None:
self._force_stop = True
if self._current_task and not self._current_task.done():
self._current_task.cancel()
try:
loop = self._current_task.get_loop()
loop.call_soon_threadsafe(self._current_task.cancel)
except RuntimeError:
self._current_task.cancel()
self._current_task = None

View File

@@ -16,9 +16,9 @@ class Config:
litellm_base_url = None
ollama_api_base = None
strix_reasoning_effort = "high"
strix_llm_max_retries = "5"
strix_memory_compressor_timeout = "30"
llm_timeout = "300"
llm_rate_limit_delay = "4.0"
llm_rate_limit_concurrent = "1"
# Tool & Feature Configuration
perplexity_api_key = None
@@ -27,7 +27,7 @@ class Config:
# Runtime Configuration
strix_image = "ghcr.io/usestrix/strix-sandbox:0.1.10"
strix_runtime_backend = "docker"
strix_sandbox_execution_timeout = "500"
strix_sandbox_execution_timeout = "120"
strix_sandbox_connect_timeout = "10"
# Telemetry

View File

@@ -1,23 +1,16 @@
import asyncio
import logging
from collections.abc import AsyncIterator
from dataclasses import dataclass
from enum import Enum
from typing import Any
import litellm
from jinja2 import (
Environment,
FileSystemLoader,
select_autoescape,
)
from litellm import completion_cost, stream_chunk_builder, supports_reasoning
from jinja2 import Environment, FileSystemLoader, select_autoescape
from litellm import acompletion, completion_cost, stream_chunk_builder, supports_reasoning
from litellm.utils import supports_prompt_caching, supports_vision
from strix.config import Config
from strix.llm.config import LLMConfig
from strix.llm.memory_compressor import MemoryCompressor
from strix.llm.request_queue import get_global_queue
from strix.llm.utils import (
_truncate_to_first_function,
fix_incomplete_tool_call,
@@ -28,37 +21,9 @@ from strix.tools import get_tools_prompt
from strix.utils.resource_paths import get_strix_resource_path
MAX_RETRIES = 5
RETRY_MULTIPLIER = 8
RETRY_MIN = 8
RETRY_MAX = 64
def _should_retry(exception: Exception) -> bool:
status_code = None
if hasattr(exception, "status_code"):
status_code = exception.status_code
elif hasattr(exception, "response") and hasattr(exception.response, "status_code"):
status_code = exception.response.status_code
if status_code is not None:
return bool(litellm._should_retry(status_code))
return True
logger = logging.getLogger(__name__)
litellm.drop_params = True
litellm.modify_params = True
_LLM_API_KEY = Config.get("llm_api_key")
_LLM_API_BASE = (
Config.get("llm_api_base")
or Config.get("openai_api_base")
or Config.get("litellm_base_url")
or Config.get("ollama_api_base")
)
_STRIX_REASONING_EFFORT = Config.get("strix_reasoning_effort")
class LLMRequestFailedError(Exception):
def __init__(self, message: str, details: str | None = None):
@@ -67,20 +32,11 @@ class LLMRequestFailedError(Exception):
self.details = details
class StepRole(str, Enum):
AGENT = "agent"
USER = "user"
SYSTEM = "system"
@dataclass
class LLMResponse:
content: str
tool_invocations: list[dict[str, Any]] | None = None
scan_id: str | None = None
step_number: int = 1
role: StepRole = StepRole.AGENT
thinking_blocks: list[dict[str, Any]] | None = None # For reasoning models.
thinking_blocks: list[dict[str, Any]] | None = None
@dataclass
@@ -88,76 +44,63 @@ class RequestStats:
input_tokens: int = 0
output_tokens: int = 0
cached_tokens: int = 0
cache_creation_tokens: int = 0
cost: float = 0.0
requests: int = 0
failed_requests: int = 0
def to_dict(self) -> dict[str, int | float]:
return {
"input_tokens": self.input_tokens,
"output_tokens": self.output_tokens,
"cached_tokens": self.cached_tokens,
"cache_creation_tokens": self.cache_creation_tokens,
"cost": round(self.cost, 4),
"requests": self.requests,
"failed_requests": self.failed_requests,
}
class LLM:
def __init__(
self, config: LLMConfig, agent_name: str | None = None, agent_id: str | None = None
):
def __init__(self, config: LLMConfig, agent_name: str | None = None):
self.config = config
self.agent_name = agent_name
self.agent_id = agent_id
self.agent_id: str | None = None
self._total_stats = RequestStats()
self._last_request_stats = RequestStats()
self.memory_compressor = MemoryCompressor(model_name=config.model_name)
self.system_prompt = self._load_system_prompt(agent_name)
if _STRIX_REASONING_EFFORT:
self._reasoning_effort = _STRIX_REASONING_EFFORT
elif self.config.scan_mode == "quick":
reasoning = Config.get("strix_reasoning_effort")
if reasoning:
self._reasoning_effort = reasoning
elif config.scan_mode == "quick":
self._reasoning_effort = "medium"
else:
self._reasoning_effort = "high"
self.memory_compressor = MemoryCompressor(
model_name=self.config.model_name,
timeout=self.config.timeout,
)
def _load_system_prompt(self, agent_name: str | None) -> str:
if not agent_name:
return ""
if agent_name:
try:
prompt_dir = get_strix_resource_path("agents", agent_name)
skills_dir = get_strix_resource_path("skills")
loader = FileSystemLoader([prompt_dir, skills_dir])
self.jinja_env = Environment(
loader=loader,
env = Environment(
loader=FileSystemLoader([prompt_dir, skills_dir]),
autoescape=select_autoescape(enabled_extensions=(), default_for_string=False),
)
try:
skills_to_load = list(self.config.skills or [])
skills_to_load.append(f"scan_modes/{self.config.scan_mode}")
skills_to_load = [
*list(self.config.skills or []),
f"scan_modes/{self.config.scan_mode}",
]
skill_content = load_skills(skills_to_load, env)
env.globals["get_skill"] = lambda name: skill_content.get(name, "")
skill_content = load_skills(skills_to_load, self.jinja_env)
def get_skill(name: str) -> str:
return skill_content.get(name, "")
self.jinja_env.globals["get_skill"] = get_skill
self.system_prompt = self.jinja_env.get_template("system_prompt.jinja").render(
get_tools_prompt=get_tools_prompt,
loaded_skill_names=list(skill_content.keys()),
**skill_content,
)
except (FileNotFoundError, OSError, ValueError) as e:
logger.warning(f"Failed to load system prompt for {agent_name}: {e}")
self.system_prompt = "You are a helpful AI assistant."
else:
self.system_prompt = "You are a helpful AI assistant."
result = env.get_template("system_prompt.jinja").render(
get_tools_prompt=get_tools_prompt,
loaded_skill_names=list(skill_content.keys()),
**skill_content,
)
return str(result)
except Exception: # noqa: BLE001
return ""
def set_agent_identity(self, agent_name: str | None, agent_id: str | None) -> None:
if agent_name:
@@ -165,330 +108,119 @@ class LLM:
if agent_id:
self.agent_id = agent_id
def _build_identity_message(self) -> dict[str, Any] | None:
if not (self.agent_name and str(self.agent_name).strip()):
return None
identity_name = self.agent_name
identity_id = self.agent_id
content = (
"\n\n"
"<agent_identity>\n"
"<meta>Internal metadata: do not echo or reference; "
"not part of history or tool calls.</meta>\n"
"<note>You are now assuming the role of this agent. "
"Act strictly as this agent and maintain self-identity for this step. "
"Now go answer the next needed step!</note>\n"
f"<agent_name>{identity_name}</agent_name>\n"
f"<agent_id>{identity_id}</agent_id>\n"
"</agent_identity>\n\n"
async def generate(
self, conversation_history: list[dict[str, Any]]
) -> AsyncIterator[LLMResponse]:
messages = self._prepare_messages(conversation_history)
max_retries = int(Config.get("strix_llm_max_retries") or "5")
for attempt in range(max_retries + 1):
try:
async for response in self._stream(messages):
yield response
return # noqa: TRY300
except Exception as e: # noqa: BLE001
if attempt >= max_retries or not self._should_retry(e):
self._raise_error(e)
wait = min(10, 2 * (2**attempt))
await asyncio.sleep(wait)
async def _stream(self, messages: list[dict[str, Any]]) -> AsyncIterator[LLMResponse]:
accumulated = ""
chunks: list[Any] = []
self._total_stats.requests += 1
response = await acompletion(**self._build_completion_args(messages), stream=True)
async for chunk in response:
chunks.append(chunk)
delta = self._get_chunk_content(chunk)
if delta:
accumulated += delta
if "</function>" in accumulated:
accumulated = accumulated[
: accumulated.find("</function>") + len("</function>")
]
yield LLMResponse(content=accumulated)
if chunks:
self._update_usage_stats(stream_chunk_builder(chunks))
accumulated = fix_incomplete_tool_call(_truncate_to_first_function(accumulated))
yield LLMResponse(
content=accumulated,
tool_invocations=parse_tool_invocations(accumulated),
thinking_blocks=self._extract_thinking(chunks),
)
return {"role": "user", "content": content}
def _add_cache_control_to_content(
self, content: str | list[dict[str, Any]]
) -> str | list[dict[str, Any]]:
if isinstance(content, str):
return [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}]
if isinstance(content, list) and content:
last_item = content[-1]
if isinstance(last_item, dict) and last_item.get("type") == "text":
return content[:-1] + [{**last_item, "cache_control": {"type": "ephemeral"}}]
return content
def _is_anthropic_model(self) -> bool:
if not self.config.model_name:
return False
model_lower = self.config.model_name.lower()
return any(provider in model_lower for provider in ["anthropic/", "claude"])
def _calculate_cache_interval(self, total_messages: int) -> int:
if total_messages <= 1:
return 10
max_cached_messages = 3
non_system_messages = total_messages - 1
interval = 10
while non_system_messages // interval > max_cached_messages:
interval += 10
return interval
def _prepare_cached_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
if (
not self.config.enable_prompt_caching
or not supports_prompt_caching(self.config.model_name)
or not messages
):
return messages
if not self._is_anthropic_model():
return messages
cached_messages = list(messages)
if cached_messages and cached_messages[0].get("role") == "system":
system_message = cached_messages[0].copy()
system_message["content"] = self._add_cache_control_to_content(
system_message["content"]
)
cached_messages[0] = system_message
total_messages = len(cached_messages)
if total_messages > 1:
interval = self._calculate_cache_interval(total_messages)
cached_count = 0
for i in range(interval, total_messages, interval):
if cached_count >= 3:
break
if i < len(cached_messages):
message = cached_messages[i].copy()
message["content"] = self._add_cache_control_to_content(message["content"])
cached_messages[i] = message
cached_count += 1
return cached_messages
def _prepare_messages(self, conversation_history: list[dict[str, Any]]) -> list[dict[str, Any]]:
messages = [{"role": "system", "content": self.system_prompt}]
identity_message = self._build_identity_message()
if identity_message:
messages.append(identity_message)
compressed_history = list(self.memory_compressor.compress_history(conversation_history))
if self.agent_name:
messages.append(
{
"role": "user",
"content": (
f"\n\n<agent_identity>\n"
f"<meta>Internal metadata: do not echo or reference.</meta>\n"
f"<agent_name>{self.agent_name}</agent_name>\n"
f"<agent_id>{self.agent_id}</agent_id>\n"
f"</agent_identity>\n\n"
),
}
)
compressed = list(self.memory_compressor.compress_history(conversation_history))
conversation_history.clear()
conversation_history.extend(compressed_history)
messages.extend(compressed_history)
conversation_history.extend(compressed)
messages.extend(compressed)
return self._prepare_cached_messages(messages)
if self._is_anthropic() and self.config.enable_prompt_caching:
messages = self._add_cache_control(messages)
async def _stream_and_accumulate(
self,
messages: list[dict[str, Any]],
scan_id: str | None,
step_number: int,
) -> AsyncIterator[LLMResponse]:
accumulated_content = ""
chunks: list[Any] = []
return messages
async for chunk in self._stream_request(messages):
chunks.append(chunk)
delta = self._extract_chunk_delta(chunk)
if delta:
accumulated_content += delta
def _build_completion_args(self, messages: list[dict[str, Any]]) -> dict[str, Any]:
if not self._supports_vision():
messages = self._strip_images(messages)
if "</function>" in accumulated_content:
function_end = accumulated_content.find("</function>") + len("</function>")
accumulated_content = accumulated_content[:function_end]
yield LLMResponse(
scan_id=scan_id,
step_number=step_number,
role=StepRole.AGENT,
content=accumulated_content,
tool_invocations=None,
)
if chunks:
complete_response = stream_chunk_builder(chunks)
self._update_usage_stats(complete_response)
accumulated_content = _truncate_to_first_function(accumulated_content)
if "</function>" in accumulated_content:
function_end = accumulated_content.find("</function>") + len("</function>")
accumulated_content = accumulated_content[:function_end]
accumulated_content = fix_incomplete_tool_call(accumulated_content)
tool_invocations = parse_tool_invocations(accumulated_content)
# Extract thinking blocks from the complete response if available
thinking_blocks = None
if chunks and self._should_include_reasoning_effort():
complete_response = stream_chunk_builder(chunks)
if (
hasattr(complete_response, "choices")
and complete_response.choices
and hasattr(complete_response.choices[0], "message")
):
message = complete_response.choices[0].message
if hasattr(message, "thinking_blocks") and message.thinking_blocks:
thinking_blocks = message.thinking_blocks
yield LLMResponse(
scan_id=scan_id,
step_number=step_number,
role=StepRole.AGENT,
content=accumulated_content,
tool_invocations=tool_invocations if tool_invocations else None,
thinking_blocks=thinking_blocks,
)
def _raise_llm_error(self, e: Exception) -> None:
error_map: list[tuple[type, str]] = [
(litellm.RateLimitError, "Rate limit exceeded"),
(litellm.AuthenticationError, "Invalid API key"),
(litellm.NotFoundError, "Model not found"),
(litellm.ContextWindowExceededError, "Context too long"),
(litellm.ContentPolicyViolationError, "Content policy violation"),
(litellm.ServiceUnavailableError, "Service unavailable"),
(litellm.Timeout, "Request timed out"),
(litellm.UnprocessableEntityError, "Unprocessable entity"),
(litellm.InternalServerError, "Internal server error"),
(litellm.APIConnectionError, "Connection error"),
(litellm.UnsupportedParamsError, "Unsupported parameters"),
(litellm.BudgetExceededError, "Budget exceeded"),
(litellm.APIResponseValidationError, "Response validation error"),
(litellm.JSONSchemaValidationError, "JSON schema validation error"),
(litellm.InvalidRequestError, "Invalid request"),
(litellm.BadRequestError, "Bad request"),
(litellm.APIError, "API error"),
(litellm.OpenAIError, "OpenAI error"),
]
from strix.telemetry import posthog
for error_type, message in error_map:
if isinstance(e, error_type):
posthog.error(f"llm_{error_type.__name__}", message)
raise LLMRequestFailedError(f"LLM request failed: {message}", str(e)) from e
posthog.error("llm_unknown_error", type(e).__name__)
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
async def generate(
self,
conversation_history: list[dict[str, Any]],
scan_id: str | None = None,
step_number: int = 1,
) -> AsyncIterator[LLMResponse]:
messages = self._prepare_messages(conversation_history)
last_error: Exception | None = None
for attempt in range(MAX_RETRIES):
try:
async for response in self._stream_and_accumulate(messages, scan_id, step_number):
yield response
return # noqa: TRY300
except Exception as e: # noqa: BLE001
last_error = e
if not _should_retry(e) or attempt == MAX_RETRIES - 1:
break
wait_time = min(RETRY_MAX, RETRY_MULTIPLIER * (2**attempt))
wait_time = max(RETRY_MIN, wait_time)
await asyncio.sleep(wait_time)
if last_error:
self._raise_llm_error(last_error)
def _extract_chunk_delta(self, chunk: Any) -> str:
if chunk.choices and hasattr(chunk.choices[0], "delta"):
delta = chunk.choices[0].delta
return getattr(delta, "content", "") or ""
return ""
@property
def usage_stats(self) -> dict[str, dict[str, int | float]]:
return {
"total": self._total_stats.to_dict(),
"last_request": self._last_request_stats.to_dict(),
}
def get_cache_config(self) -> dict[str, bool]:
return {
"enabled": self.config.enable_prompt_caching,
"supported": supports_prompt_caching(self.config.model_name),
}
def _should_include_reasoning_effort(self) -> bool:
if not self.config.model_name:
return False
try:
return bool(supports_reasoning(model=self.config.model_name))
except Exception: # noqa: BLE001
return False
def _model_supports_vision(self) -> bool:
if not self.config.model_name:
return False
try:
return bool(supports_vision(model=self.config.model_name))
except Exception: # noqa: BLE001
return False
def _filter_images_from_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
filtered_messages = []
for msg in messages:
content = msg.get("content")
updated_msg = msg
if isinstance(content, list):
filtered_content = []
for item in content:
if isinstance(item, dict):
if item.get("type") == "image_url":
filtered_content.append(
{
"type": "text",
"text": "[Screenshot removed - model does not support "
"vision. Use view_source or execute_js instead.]",
}
)
else:
filtered_content.append(item)
else:
filtered_content.append(item)
if filtered_content:
text_parts = [
item.get("text", "") if isinstance(item, dict) else str(item)
for item in filtered_content
]
all_text = all(
isinstance(item, dict) and item.get("type") == "text"
for item in filtered_content
)
if all_text:
updated_msg = {**msg, "content": "\n".join(text_parts)}
else:
updated_msg = {**msg, "content": filtered_content}
else:
updated_msg = {**msg, "content": ""}
filtered_messages.append(updated_msg)
return filtered_messages
async def _stream_request(
self,
messages: list[dict[str, Any]],
) -> AsyncIterator[Any]:
if not self._model_supports_vision():
messages = self._filter_images_from_messages(messages)
completion_args: dict[str, Any] = {
args: dict[str, Any] = {
"model": self.config.model_name,
"messages": messages,
"timeout": self.config.timeout,
"stream_options": {"include_usage": True},
"stop": ["</function>"],
}
if _LLM_API_KEY:
completion_args["api_key"] = _LLM_API_KEY
if _LLM_API_BASE:
completion_args["api_base"] = _LLM_API_BASE
if api_key := Config.get("llm_api_key"):
args["api_key"] = api_key
if api_base := (
Config.get("llm_api_base")
or Config.get("openai_api_base")
or Config.get("litellm_base_url")
):
args["api_base"] = api_base
if self._supports_reasoning():
args["reasoning_effort"] = self._reasoning_effort
completion_args["stop"] = ["</function>"]
return args
if self._should_include_reasoning_effort():
completion_args["reasoning_effort"] = self._reasoning_effort
def _get_chunk_content(self, chunk: Any) -> str:
if chunk.choices and hasattr(chunk.choices[0], "delta"):
return getattr(chunk.choices[0].delta, "content", "") or ""
return ""
queue = get_global_queue()
self._total_stats.requests += 1
self._last_request_stats = RequestStats(requests=1)
async for chunk in queue.stream_request(completion_args):
yield chunk
def _extract_thinking(self, chunks: list[Any]) -> list[dict[str, Any]] | None:
if not chunks or not self._supports_reasoning():
return None
try:
resp = stream_chunk_builder(chunks)
if resp.choices and hasattr(resp.choices[0].message, "thinking_blocks"):
blocks: list[dict[str, Any]] = resp.choices[0].message.thinking_blocks
return blocks
except Exception: # noqa: BLE001, S110 # nosec B110
pass
return None
def _update_usage_stats(self, response: Any) -> None:
try:
@@ -497,45 +229,88 @@ class LLM:
output_tokens = getattr(response.usage, "completion_tokens", 0)
cached_tokens = 0
cache_creation_tokens = 0
if hasattr(response.usage, "prompt_tokens_details"):
prompt_details = response.usage.prompt_tokens_details
if hasattr(prompt_details, "cached_tokens"):
cached_tokens = prompt_details.cached_tokens or 0
if hasattr(response.usage, "cache_creation_input_tokens"):
cache_creation_tokens = response.usage.cache_creation_input_tokens or 0
else:
input_tokens = 0
output_tokens = 0
cached_tokens = 0
cache_creation_tokens = 0
try:
cost = completion_cost(response) or 0.0
except Exception as e: # noqa: BLE001
logger.warning(f"Failed to calculate cost: {e}")
except Exception: # noqa: BLE001
cost = 0.0
self._total_stats.input_tokens += input_tokens
self._total_stats.output_tokens += output_tokens
self._total_stats.cached_tokens += cached_tokens
self._total_stats.cache_creation_tokens += cache_creation_tokens
self._total_stats.cost += cost
self._last_request_stats.input_tokens = input_tokens
self._last_request_stats.output_tokens = output_tokens
self._last_request_stats.cached_tokens = cached_tokens
self._last_request_stats.cache_creation_tokens = cache_creation_tokens
self._last_request_stats.cost = cost
except Exception: # noqa: BLE001, S110 # nosec B110
pass
if cached_tokens > 0:
logger.info(f"Cache hit: {cached_tokens} cached tokens, {input_tokens} new tokens")
if cache_creation_tokens > 0:
logger.info(f"Cache creation: {cache_creation_tokens} tokens written to cache")
def _should_retry(self, e: Exception) -> bool:
code = getattr(e, "status_code", None) or getattr(
getattr(e, "response", None), "status_code", None
)
return code is None or litellm._should_retry(code)
logger.info(f"Usage stats: {self.usage_stats}")
except Exception as e: # noqa: BLE001
logger.warning(f"Failed to update usage stats: {e}")
def _raise_error(self, e: Exception) -> None:
from strix.telemetry import posthog
posthog.error("llm_error", type(e).__name__)
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
def _is_anthropic(self) -> bool:
if not self.config.model_name:
return False
return any(p in self.config.model_name.lower() for p in ["anthropic/", "claude"])
def _supports_vision(self) -> bool:
try:
return bool(supports_vision(model=self.config.model_name))
except Exception: # noqa: BLE001
return False
def _supports_reasoning(self) -> bool:
try:
return bool(supports_reasoning(model=self.config.model_name))
except Exception: # noqa: BLE001
return False
def _strip_images(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
result = []
for msg in messages:
content = msg.get("content")
if isinstance(content, list):
text_parts = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
text_parts.append(item.get("text", ""))
elif isinstance(item, dict) and item.get("type") == "image_url":
text_parts.append("[Image removed - model doesn't support vision]")
result.append({**msg, "content": "\n".join(text_parts)})
else:
result.append(msg)
return result
def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
if not messages or not supports_prompt_caching(self.config.model_name):
return messages
result = list(messages)
if result[0].get("role") == "system":
content = result[0]["content"]
result[0] = {
**result[0],
"content": [
{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}
]
if isinstance(content, str)
else content,
}
return result

View File

@@ -86,7 +86,7 @@ def _extract_message_text(msg: dict[str, Any]) -> str:
def _summarize_messages(
messages: list[dict[str, Any]],
model: str,
timeout: int = 600,
timeout: int = 30,
) -> dict[str, Any]:
if not messages:
empty_summary = "<context_summary message_count='0'>{text}</context_summary>"
@@ -148,11 +148,11 @@ class MemoryCompressor:
self,
max_images: int = 3,
model_name: str | None = None,
timeout: int = 600,
timeout: int | None = None,
):
self.max_images = max_images
self.model_name = model_name or Config.get("strix_llm")
self.timeout = timeout
self.timeout = timeout or int(Config.get("strix_memory_compressor_timeout") or "30")
if not self.model_name:
raise ValueError("STRIX_LLM environment variable must be set and not empty")

View File

@@ -1,58 +0,0 @@
import asyncio
import threading
import time
from collections.abc import AsyncIterator
from typing import Any
from litellm import acompletion
from litellm.types.utils import ModelResponseStream
from strix.config import Config
class LLMRequestQueue:
def __init__(self) -> None:
self.delay_between_requests = float(Config.get("llm_rate_limit_delay") or "4.0")
self.max_concurrent = int(Config.get("llm_rate_limit_concurrent") or "1")
self._semaphore = threading.BoundedSemaphore(self.max_concurrent)
self._last_request_time = 0.0
self._lock = threading.Lock()
async def stream_request(
self, completion_args: dict[str, Any]
) -> AsyncIterator[ModelResponseStream]:
try:
while not self._semaphore.acquire(timeout=0.2):
await asyncio.sleep(0.1)
with self._lock:
now = time.time()
time_since_last = now - self._last_request_time
sleep_needed = max(0, self.delay_between_requests - time_since_last)
self._last_request_time = now + sleep_needed
if sleep_needed > 0:
await asyncio.sleep(sleep_needed)
async for chunk in self._stream_request(completion_args):
yield chunk
finally:
self._semaphore.release()
async def _stream_request(
self, completion_args: dict[str, Any]
) -> AsyncIterator[ModelResponseStream]:
response = await acompletion(**completion_args, stream=True)
async for chunk in response:
yield chunk
_global_queue: LLMRequestQueue | None = None
def get_global_queue() -> LLMRequestQueue:
global _global_queue # noqa: PLW0603
if _global_queue is None:
_global_queue = LLMRequestQueue()
return _global_queue

View File

@@ -430,10 +430,8 @@ class Tracer:
"input_tokens": 0,
"output_tokens": 0,
"cached_tokens": 0,
"cache_creation_tokens": 0,
"cost": 0.0,
"requests": 0,
"failed_requests": 0,
}
for agent_instance in _agent_instances.values():
@@ -442,10 +440,8 @@ class Tracer:
total_stats["input_tokens"] += agent_stats.input_tokens
total_stats["output_tokens"] += agent_stats.output_tokens
total_stats["cached_tokens"] += agent_stats.cached_tokens
total_stats["cache_creation_tokens"] += agent_stats.cache_creation_tokens
total_stats["cost"] += agent_stats.cost
total_stats["requests"] += agent_stats.requests
total_stats["failed_requests"] += agent_stats.failed_requests
total_stats["cost"] = round(total_stats["cost"], 4)

View File

@@ -20,7 +20,7 @@ from .registry import (
)
SANDBOX_EXECUTION_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout") or "500")
SANDBOX_EXECUTION_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout") or "120")
SANDBOX_CONNECT_TIMEOUT = float(Config.get("strix_sandbox_connect_timeout") or "10")