fix(agent): fix agent loop hanging and simplify LLM module
- Fix agent loop getting stuck by adding hard stop mechanism - Add _force_stop flag for immediate task cancellation across threads - Use thread-safe loop.call_soon_threadsafe for cross-thread cancellation - Remove request_queue.py (eliminated threading/queue complexity causing hangs) - Simplify llm.py: direct acompletion calls, cleaner streaming - Reduce retry wait times to prevent long hangs during retries - Make timeouts configurable (llm_max_retries, memory_compressor_timeout, sandbox_execution_timeout) - Keep essential token tracking (input/output/cached tokens, cost, requests) - Maintain Anthropic prompt caching for system messages
This commit is contained in:
@@ -111,7 +111,6 @@ hiddenimports = [
|
|||||||
'strix.llm.llm',
|
'strix.llm.llm',
|
||||||
'strix.llm.config',
|
'strix.llm.config',
|
||||||
'strix.llm.utils',
|
'strix.llm.utils',
|
||||||
'strix.llm.request_queue',
|
|
||||||
'strix.llm.memory_compressor',
|
'strix.llm.memory_compressor',
|
||||||
'strix.runtime',
|
'strix.runtime',
|
||||||
'strix.runtime.runtime',
|
'strix.runtime.runtime',
|
||||||
|
|||||||
@@ -79,6 +79,7 @@ class BaseAgent(metaclass=AgentMeta):
|
|||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
self.llm.set_agent_identity(self.state.agent_name, self.state.agent_id)
|
self.llm.set_agent_identity(self.state.agent_name, self.state.agent_id)
|
||||||
self._current_task: asyncio.Task[Any] | None = None
|
self._current_task: asyncio.Task[Any] | None = None
|
||||||
|
self._force_stop = False
|
||||||
|
|
||||||
from strix.telemetry.tracer import get_global_tracer
|
from strix.telemetry.tracer import get_global_tracer
|
||||||
|
|
||||||
@@ -156,6 +157,11 @@ class BaseAgent(metaclass=AgentMeta):
|
|||||||
return self._handle_sandbox_error(e, tracer)
|
return self._handle_sandbox_error(e, tracer)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
if self._force_stop:
|
||||||
|
self._force_stop = False
|
||||||
|
await self._enter_waiting_state(tracer, was_cancelled=True)
|
||||||
|
continue
|
||||||
|
|
||||||
self._check_agent_messages(self.state)
|
self._check_agent_messages(self.state)
|
||||||
|
|
||||||
if self.state.is_waiting_for_input():
|
if self.state.is_waiting_for_input():
|
||||||
@@ -246,7 +252,8 @@ class BaseAgent(metaclass=AgentMeta):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
async def _wait_for_input(self) -> None:
|
async def _wait_for_input(self) -> None:
|
||||||
import asyncio
|
if self._force_stop:
|
||||||
|
return
|
||||||
|
|
||||||
if self.state.has_waiting_timeout():
|
if self.state.has_waiting_timeout():
|
||||||
self.state.resume_from_waiting()
|
self.state.resume_from_waiting()
|
||||||
@@ -339,6 +346,7 @@ class BaseAgent(metaclass=AgentMeta):
|
|||||||
|
|
||||||
async def _process_iteration(self, tracer: Optional["Tracer"]) -> bool:
|
async def _process_iteration(self, tracer: Optional["Tracer"]) -> bool:
|
||||||
final_response = None
|
final_response = None
|
||||||
|
|
||||||
async for response in self.llm.generate(self.state.get_conversation_history()):
|
async for response in self.llm.generate(self.state.get_conversation_history()):
|
||||||
final_response = response
|
final_response = response
|
||||||
if tracer and response.content:
|
if tracer and response.content:
|
||||||
@@ -584,6 +592,11 @@ class BaseAgent(metaclass=AgentMeta):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def cancel_current_execution(self) -> None:
|
def cancel_current_execution(self) -> None:
|
||||||
|
self._force_stop = True
|
||||||
if self._current_task and not self._current_task.done():
|
if self._current_task and not self._current_task.done():
|
||||||
self._current_task.cancel()
|
try:
|
||||||
|
loop = self._current_task.get_loop()
|
||||||
|
loop.call_soon_threadsafe(self._current_task.cancel)
|
||||||
|
except RuntimeError:
|
||||||
|
self._current_task.cancel()
|
||||||
self._current_task = None
|
self._current_task = None
|
||||||
|
|||||||
@@ -16,9 +16,9 @@ class Config:
|
|||||||
litellm_base_url = None
|
litellm_base_url = None
|
||||||
ollama_api_base = None
|
ollama_api_base = None
|
||||||
strix_reasoning_effort = "high"
|
strix_reasoning_effort = "high"
|
||||||
|
strix_llm_max_retries = "5"
|
||||||
|
strix_memory_compressor_timeout = "30"
|
||||||
llm_timeout = "300"
|
llm_timeout = "300"
|
||||||
llm_rate_limit_delay = "4.0"
|
|
||||||
llm_rate_limit_concurrent = "1"
|
|
||||||
|
|
||||||
# Tool & Feature Configuration
|
# Tool & Feature Configuration
|
||||||
perplexity_api_key = None
|
perplexity_api_key = None
|
||||||
@@ -27,7 +27,7 @@ class Config:
|
|||||||
# Runtime Configuration
|
# Runtime Configuration
|
||||||
strix_image = "ghcr.io/usestrix/strix-sandbox:0.1.10"
|
strix_image = "ghcr.io/usestrix/strix-sandbox:0.1.10"
|
||||||
strix_runtime_backend = "docker"
|
strix_runtime_backend = "docker"
|
||||||
strix_sandbox_execution_timeout = "500"
|
strix_sandbox_execution_timeout = "120"
|
||||||
strix_sandbox_connect_timeout = "10"
|
strix_sandbox_connect_timeout = "10"
|
||||||
|
|
||||||
# Telemetry
|
# Telemetry
|
||||||
|
|||||||
601
strix/llm/llm.py
601
strix/llm/llm.py
@@ -1,23 +1,16 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from jinja2 import (
|
from jinja2 import Environment, FileSystemLoader, select_autoescape
|
||||||
Environment,
|
from litellm import acompletion, completion_cost, stream_chunk_builder, supports_reasoning
|
||||||
FileSystemLoader,
|
|
||||||
select_autoescape,
|
|
||||||
)
|
|
||||||
from litellm import completion_cost, stream_chunk_builder, supports_reasoning
|
|
||||||
from litellm.utils import supports_prompt_caching, supports_vision
|
from litellm.utils import supports_prompt_caching, supports_vision
|
||||||
|
|
||||||
from strix.config import Config
|
from strix.config import Config
|
||||||
from strix.llm.config import LLMConfig
|
from strix.llm.config import LLMConfig
|
||||||
from strix.llm.memory_compressor import MemoryCompressor
|
from strix.llm.memory_compressor import MemoryCompressor
|
||||||
from strix.llm.request_queue import get_global_queue
|
|
||||||
from strix.llm.utils import (
|
from strix.llm.utils import (
|
||||||
_truncate_to_first_function,
|
_truncate_to_first_function,
|
||||||
fix_incomplete_tool_call,
|
fix_incomplete_tool_call,
|
||||||
@@ -28,37 +21,9 @@ from strix.tools import get_tools_prompt
|
|||||||
from strix.utils.resource_paths import get_strix_resource_path
|
from strix.utils.resource_paths import get_strix_resource_path
|
||||||
|
|
||||||
|
|
||||||
MAX_RETRIES = 5
|
|
||||||
RETRY_MULTIPLIER = 8
|
|
||||||
RETRY_MIN = 8
|
|
||||||
RETRY_MAX = 64
|
|
||||||
|
|
||||||
|
|
||||||
def _should_retry(exception: Exception) -> bool:
|
|
||||||
status_code = None
|
|
||||||
if hasattr(exception, "status_code"):
|
|
||||||
status_code = exception.status_code
|
|
||||||
elif hasattr(exception, "response") and hasattr(exception.response, "status_code"):
|
|
||||||
status_code = exception.response.status_code
|
|
||||||
if status_code is not None:
|
|
||||||
return bool(litellm._should_retry(status_code))
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
litellm.modify_params = True
|
litellm.modify_params = True
|
||||||
|
|
||||||
_LLM_API_KEY = Config.get("llm_api_key")
|
|
||||||
_LLM_API_BASE = (
|
|
||||||
Config.get("llm_api_base")
|
|
||||||
or Config.get("openai_api_base")
|
|
||||||
or Config.get("litellm_base_url")
|
|
||||||
or Config.get("ollama_api_base")
|
|
||||||
)
|
|
||||||
_STRIX_REASONING_EFFORT = Config.get("strix_reasoning_effort")
|
|
||||||
|
|
||||||
|
|
||||||
class LLMRequestFailedError(Exception):
|
class LLMRequestFailedError(Exception):
|
||||||
def __init__(self, message: str, details: str | None = None):
|
def __init__(self, message: str, details: str | None = None):
|
||||||
@@ -67,20 +32,11 @@ class LLMRequestFailedError(Exception):
|
|||||||
self.details = details
|
self.details = details
|
||||||
|
|
||||||
|
|
||||||
class StepRole(str, Enum):
|
|
||||||
AGENT = "agent"
|
|
||||||
USER = "user"
|
|
||||||
SYSTEM = "system"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LLMResponse:
|
class LLMResponse:
|
||||||
content: str
|
content: str
|
||||||
tool_invocations: list[dict[str, Any]] | None = None
|
tool_invocations: list[dict[str, Any]] | None = None
|
||||||
scan_id: str | None = None
|
thinking_blocks: list[dict[str, Any]] | None = None
|
||||||
step_number: int = 1
|
|
||||||
role: StepRole = StepRole.AGENT
|
|
||||||
thinking_blocks: list[dict[str, Any]] | None = None # For reasoning models.
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -88,76 +44,63 @@ class RequestStats:
|
|||||||
input_tokens: int = 0
|
input_tokens: int = 0
|
||||||
output_tokens: int = 0
|
output_tokens: int = 0
|
||||||
cached_tokens: int = 0
|
cached_tokens: int = 0
|
||||||
cache_creation_tokens: int = 0
|
|
||||||
cost: float = 0.0
|
cost: float = 0.0
|
||||||
requests: int = 0
|
requests: int = 0
|
||||||
failed_requests: int = 0
|
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, int | float]:
|
def to_dict(self) -> dict[str, int | float]:
|
||||||
return {
|
return {
|
||||||
"input_tokens": self.input_tokens,
|
"input_tokens": self.input_tokens,
|
||||||
"output_tokens": self.output_tokens,
|
"output_tokens": self.output_tokens,
|
||||||
"cached_tokens": self.cached_tokens,
|
"cached_tokens": self.cached_tokens,
|
||||||
"cache_creation_tokens": self.cache_creation_tokens,
|
|
||||||
"cost": round(self.cost, 4),
|
"cost": round(self.cost, 4),
|
||||||
"requests": self.requests,
|
"requests": self.requests,
|
||||||
"failed_requests": self.failed_requests,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class LLM:
|
class LLM:
|
||||||
def __init__(
|
def __init__(self, config: LLMConfig, agent_name: str | None = None):
|
||||||
self, config: LLMConfig, agent_name: str | None = None, agent_id: str | None = None
|
|
||||||
):
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.agent_name = agent_name
|
self.agent_name = agent_name
|
||||||
self.agent_id = agent_id
|
self.agent_id: str | None = None
|
||||||
self._total_stats = RequestStats()
|
self._total_stats = RequestStats()
|
||||||
self._last_request_stats = RequestStats()
|
self.memory_compressor = MemoryCompressor(model_name=config.model_name)
|
||||||
|
self.system_prompt = self._load_system_prompt(agent_name)
|
||||||
|
|
||||||
if _STRIX_REASONING_EFFORT:
|
reasoning = Config.get("strix_reasoning_effort")
|
||||||
self._reasoning_effort = _STRIX_REASONING_EFFORT
|
if reasoning:
|
||||||
elif self.config.scan_mode == "quick":
|
self._reasoning_effort = reasoning
|
||||||
|
elif config.scan_mode == "quick":
|
||||||
self._reasoning_effort = "medium"
|
self._reasoning_effort = "medium"
|
||||||
else:
|
else:
|
||||||
self._reasoning_effort = "high"
|
self._reasoning_effort = "high"
|
||||||
|
|
||||||
self.memory_compressor = MemoryCompressor(
|
def _load_system_prompt(self, agent_name: str | None) -> str:
|
||||||
model_name=self.config.model_name,
|
if not agent_name:
|
||||||
timeout=self.config.timeout,
|
return ""
|
||||||
)
|
|
||||||
|
|
||||||
if agent_name:
|
try:
|
||||||
prompt_dir = get_strix_resource_path("agents", agent_name)
|
prompt_dir = get_strix_resource_path("agents", agent_name)
|
||||||
skills_dir = get_strix_resource_path("skills")
|
skills_dir = get_strix_resource_path("skills")
|
||||||
|
env = Environment(
|
||||||
loader = FileSystemLoader([prompt_dir, skills_dir])
|
loader=FileSystemLoader([prompt_dir, skills_dir]),
|
||||||
self.jinja_env = Environment(
|
|
||||||
loader=loader,
|
|
||||||
autoescape=select_autoescape(enabled_extensions=(), default_for_string=False),
|
autoescape=select_autoescape(enabled_extensions=(), default_for_string=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
skills_to_load = [
|
||||||
skills_to_load = list(self.config.skills or [])
|
*list(self.config.skills or []),
|
||||||
skills_to_load.append(f"scan_modes/{self.config.scan_mode}")
|
f"scan_modes/{self.config.scan_mode}",
|
||||||
|
]
|
||||||
|
skill_content = load_skills(skills_to_load, env)
|
||||||
|
env.globals["get_skill"] = lambda name: skill_content.get(name, "")
|
||||||
|
|
||||||
skill_content = load_skills(skills_to_load, self.jinja_env)
|
result = env.get_template("system_prompt.jinja").render(
|
||||||
|
get_tools_prompt=get_tools_prompt,
|
||||||
def get_skill(name: str) -> str:
|
loaded_skill_names=list(skill_content.keys()),
|
||||||
return skill_content.get(name, "")
|
**skill_content,
|
||||||
|
)
|
||||||
self.jinja_env.globals["get_skill"] = get_skill
|
return str(result)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
self.system_prompt = self.jinja_env.get_template("system_prompt.jinja").render(
|
return ""
|
||||||
get_tools_prompt=get_tools_prompt,
|
|
||||||
loaded_skill_names=list(skill_content.keys()),
|
|
||||||
**skill_content,
|
|
||||||
)
|
|
||||||
except (FileNotFoundError, OSError, ValueError) as e:
|
|
||||||
logger.warning(f"Failed to load system prompt for {agent_name}: {e}")
|
|
||||||
self.system_prompt = "You are a helpful AI assistant."
|
|
||||||
else:
|
|
||||||
self.system_prompt = "You are a helpful AI assistant."
|
|
||||||
|
|
||||||
def set_agent_identity(self, agent_name: str | None, agent_id: str | None) -> None:
|
def set_agent_identity(self, agent_name: str | None, agent_id: str | None) -> None:
|
||||||
if agent_name:
|
if agent_name:
|
||||||
@@ -165,330 +108,119 @@ class LLM:
|
|||||||
if agent_id:
|
if agent_id:
|
||||||
self.agent_id = agent_id
|
self.agent_id = agent_id
|
||||||
|
|
||||||
def _build_identity_message(self) -> dict[str, Any] | None:
|
async def generate(
|
||||||
if not (self.agent_name and str(self.agent_name).strip()):
|
self, conversation_history: list[dict[str, Any]]
|
||||||
return None
|
) -> AsyncIterator[LLMResponse]:
|
||||||
identity_name = self.agent_name
|
messages = self._prepare_messages(conversation_history)
|
||||||
identity_id = self.agent_id
|
max_retries = int(Config.get("strix_llm_max_retries") or "5")
|
||||||
content = (
|
|
||||||
"\n\n"
|
for attempt in range(max_retries + 1):
|
||||||
"<agent_identity>\n"
|
try:
|
||||||
"<meta>Internal metadata: do not echo or reference; "
|
async for response in self._stream(messages):
|
||||||
"not part of history or tool calls.</meta>\n"
|
yield response
|
||||||
"<note>You are now assuming the role of this agent. "
|
return # noqa: TRY300
|
||||||
"Act strictly as this agent and maintain self-identity for this step. "
|
except Exception as e: # noqa: BLE001
|
||||||
"Now go answer the next needed step!</note>\n"
|
if attempt >= max_retries or not self._should_retry(e):
|
||||||
f"<agent_name>{identity_name}</agent_name>\n"
|
self._raise_error(e)
|
||||||
f"<agent_id>{identity_id}</agent_id>\n"
|
wait = min(10, 2 * (2**attempt))
|
||||||
"</agent_identity>\n\n"
|
await asyncio.sleep(wait)
|
||||||
|
|
||||||
|
async def _stream(self, messages: list[dict[str, Any]]) -> AsyncIterator[LLMResponse]:
|
||||||
|
accumulated = ""
|
||||||
|
chunks: list[Any] = []
|
||||||
|
|
||||||
|
self._total_stats.requests += 1
|
||||||
|
response = await acompletion(**self._build_completion_args(messages), stream=True)
|
||||||
|
|
||||||
|
async for chunk in response:
|
||||||
|
chunks.append(chunk)
|
||||||
|
delta = self._get_chunk_content(chunk)
|
||||||
|
if delta:
|
||||||
|
accumulated += delta
|
||||||
|
if "</function>" in accumulated:
|
||||||
|
accumulated = accumulated[
|
||||||
|
: accumulated.find("</function>") + len("</function>")
|
||||||
|
]
|
||||||
|
yield LLMResponse(content=accumulated)
|
||||||
|
|
||||||
|
if chunks:
|
||||||
|
self._update_usage_stats(stream_chunk_builder(chunks))
|
||||||
|
|
||||||
|
accumulated = fix_incomplete_tool_call(_truncate_to_first_function(accumulated))
|
||||||
|
yield LLMResponse(
|
||||||
|
content=accumulated,
|
||||||
|
tool_invocations=parse_tool_invocations(accumulated),
|
||||||
|
thinking_blocks=self._extract_thinking(chunks),
|
||||||
)
|
)
|
||||||
return {"role": "user", "content": content}
|
|
||||||
|
|
||||||
def _add_cache_control_to_content(
|
|
||||||
self, content: str | list[dict[str, Any]]
|
|
||||||
) -> str | list[dict[str, Any]]:
|
|
||||||
if isinstance(content, str):
|
|
||||||
return [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}]
|
|
||||||
if isinstance(content, list) and content:
|
|
||||||
last_item = content[-1]
|
|
||||||
if isinstance(last_item, dict) and last_item.get("type") == "text":
|
|
||||||
return content[:-1] + [{**last_item, "cache_control": {"type": "ephemeral"}}]
|
|
||||||
return content
|
|
||||||
|
|
||||||
def _is_anthropic_model(self) -> bool:
|
|
||||||
if not self.config.model_name:
|
|
||||||
return False
|
|
||||||
model_lower = self.config.model_name.lower()
|
|
||||||
return any(provider in model_lower for provider in ["anthropic/", "claude"])
|
|
||||||
|
|
||||||
def _calculate_cache_interval(self, total_messages: int) -> int:
|
|
||||||
if total_messages <= 1:
|
|
||||||
return 10
|
|
||||||
|
|
||||||
max_cached_messages = 3
|
|
||||||
non_system_messages = total_messages - 1
|
|
||||||
|
|
||||||
interval = 10
|
|
||||||
while non_system_messages // interval > max_cached_messages:
|
|
||||||
interval += 10
|
|
||||||
|
|
||||||
return interval
|
|
||||||
|
|
||||||
def _prepare_cached_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
||||||
if (
|
|
||||||
not self.config.enable_prompt_caching
|
|
||||||
or not supports_prompt_caching(self.config.model_name)
|
|
||||||
or not messages
|
|
||||||
):
|
|
||||||
return messages
|
|
||||||
|
|
||||||
if not self._is_anthropic_model():
|
|
||||||
return messages
|
|
||||||
|
|
||||||
cached_messages = list(messages)
|
|
||||||
|
|
||||||
if cached_messages and cached_messages[0].get("role") == "system":
|
|
||||||
system_message = cached_messages[0].copy()
|
|
||||||
system_message["content"] = self._add_cache_control_to_content(
|
|
||||||
system_message["content"]
|
|
||||||
)
|
|
||||||
cached_messages[0] = system_message
|
|
||||||
|
|
||||||
total_messages = len(cached_messages)
|
|
||||||
if total_messages > 1:
|
|
||||||
interval = self._calculate_cache_interval(total_messages)
|
|
||||||
|
|
||||||
cached_count = 0
|
|
||||||
for i in range(interval, total_messages, interval):
|
|
||||||
if cached_count >= 3:
|
|
||||||
break
|
|
||||||
|
|
||||||
if i < len(cached_messages):
|
|
||||||
message = cached_messages[i].copy()
|
|
||||||
message["content"] = self._add_cache_control_to_content(message["content"])
|
|
||||||
cached_messages[i] = message
|
|
||||||
cached_count += 1
|
|
||||||
|
|
||||||
return cached_messages
|
|
||||||
|
|
||||||
def _prepare_messages(self, conversation_history: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
def _prepare_messages(self, conversation_history: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
messages = [{"role": "system", "content": self.system_prompt}]
|
messages = [{"role": "system", "content": self.system_prompt}]
|
||||||
|
|
||||||
identity_message = self._build_identity_message()
|
if self.agent_name:
|
||||||
if identity_message:
|
messages.append(
|
||||||
messages.append(identity_message)
|
{
|
||||||
|
"role": "user",
|
||||||
compressed_history = list(self.memory_compressor.compress_history(conversation_history))
|
"content": (
|
||||||
|
f"\n\n<agent_identity>\n"
|
||||||
|
f"<meta>Internal metadata: do not echo or reference.</meta>\n"
|
||||||
|
f"<agent_name>{self.agent_name}</agent_name>\n"
|
||||||
|
f"<agent_id>{self.agent_id}</agent_id>\n"
|
||||||
|
f"</agent_identity>\n\n"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
compressed = list(self.memory_compressor.compress_history(conversation_history))
|
||||||
conversation_history.clear()
|
conversation_history.clear()
|
||||||
conversation_history.extend(compressed_history)
|
conversation_history.extend(compressed)
|
||||||
messages.extend(compressed_history)
|
messages.extend(compressed)
|
||||||
|
|
||||||
return self._prepare_cached_messages(messages)
|
if self._is_anthropic() and self.config.enable_prompt_caching:
|
||||||
|
messages = self._add_cache_control(messages)
|
||||||
|
|
||||||
async def _stream_and_accumulate(
|
return messages
|
||||||
self,
|
|
||||||
messages: list[dict[str, Any]],
|
|
||||||
scan_id: str | None,
|
|
||||||
step_number: int,
|
|
||||||
) -> AsyncIterator[LLMResponse]:
|
|
||||||
accumulated_content = ""
|
|
||||||
chunks: list[Any] = []
|
|
||||||
|
|
||||||
async for chunk in self._stream_request(messages):
|
def _build_completion_args(self, messages: list[dict[str, Any]]) -> dict[str, Any]:
|
||||||
chunks.append(chunk)
|
if not self._supports_vision():
|
||||||
delta = self._extract_chunk_delta(chunk)
|
messages = self._strip_images(messages)
|
||||||
if delta:
|
|
||||||
accumulated_content += delta
|
|
||||||
|
|
||||||
if "</function>" in accumulated_content:
|
args: dict[str, Any] = {
|
||||||
function_end = accumulated_content.find("</function>") + len("</function>")
|
|
||||||
accumulated_content = accumulated_content[:function_end]
|
|
||||||
|
|
||||||
yield LLMResponse(
|
|
||||||
scan_id=scan_id,
|
|
||||||
step_number=step_number,
|
|
||||||
role=StepRole.AGENT,
|
|
||||||
content=accumulated_content,
|
|
||||||
tool_invocations=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if chunks:
|
|
||||||
complete_response = stream_chunk_builder(chunks)
|
|
||||||
self._update_usage_stats(complete_response)
|
|
||||||
|
|
||||||
accumulated_content = _truncate_to_first_function(accumulated_content)
|
|
||||||
if "</function>" in accumulated_content:
|
|
||||||
function_end = accumulated_content.find("</function>") + len("</function>")
|
|
||||||
accumulated_content = accumulated_content[:function_end]
|
|
||||||
|
|
||||||
accumulated_content = fix_incomplete_tool_call(accumulated_content)
|
|
||||||
|
|
||||||
tool_invocations = parse_tool_invocations(accumulated_content)
|
|
||||||
|
|
||||||
# Extract thinking blocks from the complete response if available
|
|
||||||
thinking_blocks = None
|
|
||||||
if chunks and self._should_include_reasoning_effort():
|
|
||||||
complete_response = stream_chunk_builder(chunks)
|
|
||||||
if (
|
|
||||||
hasattr(complete_response, "choices")
|
|
||||||
and complete_response.choices
|
|
||||||
and hasattr(complete_response.choices[0], "message")
|
|
||||||
):
|
|
||||||
message = complete_response.choices[0].message
|
|
||||||
if hasattr(message, "thinking_blocks") and message.thinking_blocks:
|
|
||||||
thinking_blocks = message.thinking_blocks
|
|
||||||
|
|
||||||
yield LLMResponse(
|
|
||||||
scan_id=scan_id,
|
|
||||||
step_number=step_number,
|
|
||||||
role=StepRole.AGENT,
|
|
||||||
content=accumulated_content,
|
|
||||||
tool_invocations=tool_invocations if tool_invocations else None,
|
|
||||||
thinking_blocks=thinking_blocks,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _raise_llm_error(self, e: Exception) -> None:
|
|
||||||
error_map: list[tuple[type, str]] = [
|
|
||||||
(litellm.RateLimitError, "Rate limit exceeded"),
|
|
||||||
(litellm.AuthenticationError, "Invalid API key"),
|
|
||||||
(litellm.NotFoundError, "Model not found"),
|
|
||||||
(litellm.ContextWindowExceededError, "Context too long"),
|
|
||||||
(litellm.ContentPolicyViolationError, "Content policy violation"),
|
|
||||||
(litellm.ServiceUnavailableError, "Service unavailable"),
|
|
||||||
(litellm.Timeout, "Request timed out"),
|
|
||||||
(litellm.UnprocessableEntityError, "Unprocessable entity"),
|
|
||||||
(litellm.InternalServerError, "Internal server error"),
|
|
||||||
(litellm.APIConnectionError, "Connection error"),
|
|
||||||
(litellm.UnsupportedParamsError, "Unsupported parameters"),
|
|
||||||
(litellm.BudgetExceededError, "Budget exceeded"),
|
|
||||||
(litellm.APIResponseValidationError, "Response validation error"),
|
|
||||||
(litellm.JSONSchemaValidationError, "JSON schema validation error"),
|
|
||||||
(litellm.InvalidRequestError, "Invalid request"),
|
|
||||||
(litellm.BadRequestError, "Bad request"),
|
|
||||||
(litellm.APIError, "API error"),
|
|
||||||
(litellm.OpenAIError, "OpenAI error"),
|
|
||||||
]
|
|
||||||
|
|
||||||
from strix.telemetry import posthog
|
|
||||||
|
|
||||||
for error_type, message in error_map:
|
|
||||||
if isinstance(e, error_type):
|
|
||||||
posthog.error(f"llm_{error_type.__name__}", message)
|
|
||||||
raise LLMRequestFailedError(f"LLM request failed: {message}", str(e)) from e
|
|
||||||
|
|
||||||
posthog.error("llm_unknown_error", type(e).__name__)
|
|
||||||
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
|
|
||||||
|
|
||||||
async def generate(
|
|
||||||
self,
|
|
||||||
conversation_history: list[dict[str, Any]],
|
|
||||||
scan_id: str | None = None,
|
|
||||||
step_number: int = 1,
|
|
||||||
) -> AsyncIterator[LLMResponse]:
|
|
||||||
messages = self._prepare_messages(conversation_history)
|
|
||||||
|
|
||||||
last_error: Exception | None = None
|
|
||||||
for attempt in range(MAX_RETRIES):
|
|
||||||
try:
|
|
||||||
async for response in self._stream_and_accumulate(messages, scan_id, step_number):
|
|
||||||
yield response
|
|
||||||
return # noqa: TRY300
|
|
||||||
except Exception as e: # noqa: BLE001
|
|
||||||
last_error = e
|
|
||||||
if not _should_retry(e) or attempt == MAX_RETRIES - 1:
|
|
||||||
break
|
|
||||||
wait_time = min(RETRY_MAX, RETRY_MULTIPLIER * (2**attempt))
|
|
||||||
wait_time = max(RETRY_MIN, wait_time)
|
|
||||||
await asyncio.sleep(wait_time)
|
|
||||||
|
|
||||||
if last_error:
|
|
||||||
self._raise_llm_error(last_error)
|
|
||||||
|
|
||||||
def _extract_chunk_delta(self, chunk: Any) -> str:
|
|
||||||
if chunk.choices and hasattr(chunk.choices[0], "delta"):
|
|
||||||
delta = chunk.choices[0].delta
|
|
||||||
return getattr(delta, "content", "") or ""
|
|
||||||
return ""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def usage_stats(self) -> dict[str, dict[str, int | float]]:
|
|
||||||
return {
|
|
||||||
"total": self._total_stats.to_dict(),
|
|
||||||
"last_request": self._last_request_stats.to_dict(),
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_cache_config(self) -> dict[str, bool]:
|
|
||||||
return {
|
|
||||||
"enabled": self.config.enable_prompt_caching,
|
|
||||||
"supported": supports_prompt_caching(self.config.model_name),
|
|
||||||
}
|
|
||||||
|
|
||||||
def _should_include_reasoning_effort(self) -> bool:
|
|
||||||
if not self.config.model_name:
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
return bool(supports_reasoning(model=self.config.model_name))
|
|
||||||
except Exception: # noqa: BLE001
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _model_supports_vision(self) -> bool:
|
|
||||||
if not self.config.model_name:
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
return bool(supports_vision(model=self.config.model_name))
|
|
||||||
except Exception: # noqa: BLE001
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _filter_images_from_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
||||||
filtered_messages = []
|
|
||||||
for msg in messages:
|
|
||||||
content = msg.get("content")
|
|
||||||
updated_msg = msg
|
|
||||||
if isinstance(content, list):
|
|
||||||
filtered_content = []
|
|
||||||
for item in content:
|
|
||||||
if isinstance(item, dict):
|
|
||||||
if item.get("type") == "image_url":
|
|
||||||
filtered_content.append(
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "[Screenshot removed - model does not support "
|
|
||||||
"vision. Use view_source or execute_js instead.]",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
filtered_content.append(item)
|
|
||||||
else:
|
|
||||||
filtered_content.append(item)
|
|
||||||
if filtered_content:
|
|
||||||
text_parts = [
|
|
||||||
item.get("text", "") if isinstance(item, dict) else str(item)
|
|
||||||
for item in filtered_content
|
|
||||||
]
|
|
||||||
all_text = all(
|
|
||||||
isinstance(item, dict) and item.get("type") == "text"
|
|
||||||
for item in filtered_content
|
|
||||||
)
|
|
||||||
if all_text:
|
|
||||||
updated_msg = {**msg, "content": "\n".join(text_parts)}
|
|
||||||
else:
|
|
||||||
updated_msg = {**msg, "content": filtered_content}
|
|
||||||
else:
|
|
||||||
updated_msg = {**msg, "content": ""}
|
|
||||||
filtered_messages.append(updated_msg)
|
|
||||||
return filtered_messages
|
|
||||||
|
|
||||||
async def _stream_request(
|
|
||||||
self,
|
|
||||||
messages: list[dict[str, Any]],
|
|
||||||
) -> AsyncIterator[Any]:
|
|
||||||
if not self._model_supports_vision():
|
|
||||||
messages = self._filter_images_from_messages(messages)
|
|
||||||
|
|
||||||
completion_args: dict[str, Any] = {
|
|
||||||
"model": self.config.model_name,
|
"model": self.config.model_name,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"timeout": self.config.timeout,
|
"timeout": self.config.timeout,
|
||||||
"stream_options": {"include_usage": True},
|
"stream_options": {"include_usage": True},
|
||||||
|
"stop": ["</function>"],
|
||||||
}
|
}
|
||||||
|
|
||||||
if _LLM_API_KEY:
|
if api_key := Config.get("llm_api_key"):
|
||||||
completion_args["api_key"] = _LLM_API_KEY
|
args["api_key"] = api_key
|
||||||
if _LLM_API_BASE:
|
if api_base := (
|
||||||
completion_args["api_base"] = _LLM_API_BASE
|
Config.get("llm_api_base")
|
||||||
|
or Config.get("openai_api_base")
|
||||||
|
or Config.get("litellm_base_url")
|
||||||
|
):
|
||||||
|
args["api_base"] = api_base
|
||||||
|
if self._supports_reasoning():
|
||||||
|
args["reasoning_effort"] = self._reasoning_effort
|
||||||
|
|
||||||
completion_args["stop"] = ["</function>"]
|
return args
|
||||||
|
|
||||||
if self._should_include_reasoning_effort():
|
def _get_chunk_content(self, chunk: Any) -> str:
|
||||||
completion_args["reasoning_effort"] = self._reasoning_effort
|
if chunk.choices and hasattr(chunk.choices[0], "delta"):
|
||||||
|
return getattr(chunk.choices[0].delta, "content", "") or ""
|
||||||
|
return ""
|
||||||
|
|
||||||
queue = get_global_queue()
|
def _extract_thinking(self, chunks: list[Any]) -> list[dict[str, Any]] | None:
|
||||||
self._total_stats.requests += 1
|
if not chunks or not self._supports_reasoning():
|
||||||
self._last_request_stats = RequestStats(requests=1)
|
return None
|
||||||
|
try:
|
||||||
async for chunk in queue.stream_request(completion_args):
|
resp = stream_chunk_builder(chunks)
|
||||||
yield chunk
|
if resp.choices and hasattr(resp.choices[0].message, "thinking_blocks"):
|
||||||
|
blocks: list[dict[str, Any]] = resp.choices[0].message.thinking_blocks
|
||||||
|
return blocks
|
||||||
|
except Exception: # noqa: BLE001, S110 # nosec B110
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
def _update_usage_stats(self, response: Any) -> None:
|
def _update_usage_stats(self, response: Any) -> None:
|
||||||
try:
|
try:
|
||||||
@@ -497,45 +229,88 @@ class LLM:
|
|||||||
output_tokens = getattr(response.usage, "completion_tokens", 0)
|
output_tokens = getattr(response.usage, "completion_tokens", 0)
|
||||||
|
|
||||||
cached_tokens = 0
|
cached_tokens = 0
|
||||||
cache_creation_tokens = 0
|
|
||||||
|
|
||||||
if hasattr(response.usage, "prompt_tokens_details"):
|
if hasattr(response.usage, "prompt_tokens_details"):
|
||||||
prompt_details = response.usage.prompt_tokens_details
|
prompt_details = response.usage.prompt_tokens_details
|
||||||
if hasattr(prompt_details, "cached_tokens"):
|
if hasattr(prompt_details, "cached_tokens"):
|
||||||
cached_tokens = prompt_details.cached_tokens or 0
|
cached_tokens = prompt_details.cached_tokens or 0
|
||||||
|
|
||||||
if hasattr(response.usage, "cache_creation_input_tokens"):
|
|
||||||
cache_creation_tokens = response.usage.cache_creation_input_tokens or 0
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
input_tokens = 0
|
input_tokens = 0
|
||||||
output_tokens = 0
|
output_tokens = 0
|
||||||
cached_tokens = 0
|
cached_tokens = 0
|
||||||
cache_creation_tokens = 0
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cost = completion_cost(response) or 0.0
|
cost = completion_cost(response) or 0.0
|
||||||
except Exception as e: # noqa: BLE001
|
except Exception: # noqa: BLE001
|
||||||
logger.warning(f"Failed to calculate cost: {e}")
|
|
||||||
cost = 0.0
|
cost = 0.0
|
||||||
|
|
||||||
self._total_stats.input_tokens += input_tokens
|
self._total_stats.input_tokens += input_tokens
|
||||||
self._total_stats.output_tokens += output_tokens
|
self._total_stats.output_tokens += output_tokens
|
||||||
self._total_stats.cached_tokens += cached_tokens
|
self._total_stats.cached_tokens += cached_tokens
|
||||||
self._total_stats.cache_creation_tokens += cache_creation_tokens
|
|
||||||
self._total_stats.cost += cost
|
self._total_stats.cost += cost
|
||||||
|
|
||||||
self._last_request_stats.input_tokens = input_tokens
|
except Exception: # noqa: BLE001, S110 # nosec B110
|
||||||
self._last_request_stats.output_tokens = output_tokens
|
pass
|
||||||
self._last_request_stats.cached_tokens = cached_tokens
|
|
||||||
self._last_request_stats.cache_creation_tokens = cache_creation_tokens
|
|
||||||
self._last_request_stats.cost = cost
|
|
||||||
|
|
||||||
if cached_tokens > 0:
|
def _should_retry(self, e: Exception) -> bool:
|
||||||
logger.info(f"Cache hit: {cached_tokens} cached tokens, {input_tokens} new tokens")
|
code = getattr(e, "status_code", None) or getattr(
|
||||||
if cache_creation_tokens > 0:
|
getattr(e, "response", None), "status_code", None
|
||||||
logger.info(f"Cache creation: {cache_creation_tokens} tokens written to cache")
|
)
|
||||||
|
return code is None or litellm._should_retry(code)
|
||||||
|
|
||||||
logger.info(f"Usage stats: {self.usage_stats}")
|
def _raise_error(self, e: Exception) -> None:
|
||||||
except Exception as e: # noqa: BLE001
|
from strix.telemetry import posthog
|
||||||
logger.warning(f"Failed to update usage stats: {e}")
|
|
||||||
|
posthog.error("llm_error", type(e).__name__)
|
||||||
|
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
|
||||||
|
|
||||||
|
def _is_anthropic(self) -> bool:
|
||||||
|
if not self.config.model_name:
|
||||||
|
return False
|
||||||
|
return any(p in self.config.model_name.lower() for p in ["anthropic/", "claude"])
|
||||||
|
|
||||||
|
def _supports_vision(self) -> bool:
|
||||||
|
try:
|
||||||
|
return bool(supports_vision(model=self.config.model_name))
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _supports_reasoning(self) -> bool:
|
||||||
|
try:
|
||||||
|
return bool(supports_reasoning(model=self.config.model_name))
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _strip_images(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
result = []
|
||||||
|
for msg in messages:
|
||||||
|
content = msg.get("content")
|
||||||
|
if isinstance(content, list):
|
||||||
|
text_parts = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, dict) and item.get("type") == "text":
|
||||||
|
text_parts.append(item.get("text", ""))
|
||||||
|
elif isinstance(item, dict) and item.get("type") == "image_url":
|
||||||
|
text_parts.append("[Image removed - model doesn't support vision]")
|
||||||
|
result.append({**msg, "content": "\n".join(text_parts)})
|
||||||
|
else:
|
||||||
|
result.append(msg)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
if not messages or not supports_prompt_caching(self.config.model_name):
|
||||||
|
return messages
|
||||||
|
|
||||||
|
result = list(messages)
|
||||||
|
|
||||||
|
if result[0].get("role") == "system":
|
||||||
|
content = result[0]["content"]
|
||||||
|
result[0] = {
|
||||||
|
**result[0],
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}
|
||||||
|
]
|
||||||
|
if isinstance(content, str)
|
||||||
|
else content,
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ def _extract_message_text(msg: dict[str, Any]) -> str:
|
|||||||
def _summarize_messages(
|
def _summarize_messages(
|
||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
model: str,
|
model: str,
|
||||||
timeout: int = 600,
|
timeout: int = 30,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
if not messages:
|
if not messages:
|
||||||
empty_summary = "<context_summary message_count='0'>{text}</context_summary>"
|
empty_summary = "<context_summary message_count='0'>{text}</context_summary>"
|
||||||
@@ -148,11 +148,11 @@ class MemoryCompressor:
|
|||||||
self,
|
self,
|
||||||
max_images: int = 3,
|
max_images: int = 3,
|
||||||
model_name: str | None = None,
|
model_name: str | None = None,
|
||||||
timeout: int = 600,
|
timeout: int | None = None,
|
||||||
):
|
):
|
||||||
self.max_images = max_images
|
self.max_images = max_images
|
||||||
self.model_name = model_name or Config.get("strix_llm")
|
self.model_name = model_name or Config.get("strix_llm")
|
||||||
self.timeout = timeout
|
self.timeout = timeout or int(Config.get("strix_memory_compressor_timeout") or "30")
|
||||||
|
|
||||||
if not self.model_name:
|
if not self.model_name:
|
||||||
raise ValueError("STRIX_LLM environment variable must be set and not empty")
|
raise ValueError("STRIX_LLM environment variable must be set and not empty")
|
||||||
|
|||||||
@@ -1,58 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from collections.abc import AsyncIterator
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from litellm import acompletion
|
|
||||||
from litellm.types.utils import ModelResponseStream
|
|
||||||
|
|
||||||
from strix.config import Config
|
|
||||||
|
|
||||||
|
|
||||||
class LLMRequestQueue:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.delay_between_requests = float(Config.get("llm_rate_limit_delay") or "4.0")
|
|
||||||
self.max_concurrent = int(Config.get("llm_rate_limit_concurrent") or "1")
|
|
||||||
self._semaphore = threading.BoundedSemaphore(self.max_concurrent)
|
|
||||||
self._last_request_time = 0.0
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
|
|
||||||
async def stream_request(
|
|
||||||
self, completion_args: dict[str, Any]
|
|
||||||
) -> AsyncIterator[ModelResponseStream]:
|
|
||||||
try:
|
|
||||||
while not self._semaphore.acquire(timeout=0.2):
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
|
||||||
with self._lock:
|
|
||||||
now = time.time()
|
|
||||||
time_since_last = now - self._last_request_time
|
|
||||||
sleep_needed = max(0, self.delay_between_requests - time_since_last)
|
|
||||||
self._last_request_time = now + sleep_needed
|
|
||||||
|
|
||||||
if sleep_needed > 0:
|
|
||||||
await asyncio.sleep(sleep_needed)
|
|
||||||
|
|
||||||
async for chunk in self._stream_request(completion_args):
|
|
||||||
yield chunk
|
|
||||||
finally:
|
|
||||||
self._semaphore.release()
|
|
||||||
|
|
||||||
async def _stream_request(
|
|
||||||
self, completion_args: dict[str, Any]
|
|
||||||
) -> AsyncIterator[ModelResponseStream]:
|
|
||||||
response = await acompletion(**completion_args, stream=True)
|
|
||||||
|
|
||||||
async for chunk in response:
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
|
|
||||||
_global_queue: LLMRequestQueue | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_global_queue() -> LLMRequestQueue:
|
|
||||||
global _global_queue # noqa: PLW0603
|
|
||||||
if _global_queue is None:
|
|
||||||
_global_queue = LLMRequestQueue()
|
|
||||||
return _global_queue
|
|
||||||
@@ -430,10 +430,8 @@ class Tracer:
|
|||||||
"input_tokens": 0,
|
"input_tokens": 0,
|
||||||
"output_tokens": 0,
|
"output_tokens": 0,
|
||||||
"cached_tokens": 0,
|
"cached_tokens": 0,
|
||||||
"cache_creation_tokens": 0,
|
|
||||||
"cost": 0.0,
|
"cost": 0.0,
|
||||||
"requests": 0,
|
"requests": 0,
|
||||||
"failed_requests": 0,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for agent_instance in _agent_instances.values():
|
for agent_instance in _agent_instances.values():
|
||||||
@@ -442,10 +440,8 @@ class Tracer:
|
|||||||
total_stats["input_tokens"] += agent_stats.input_tokens
|
total_stats["input_tokens"] += agent_stats.input_tokens
|
||||||
total_stats["output_tokens"] += agent_stats.output_tokens
|
total_stats["output_tokens"] += agent_stats.output_tokens
|
||||||
total_stats["cached_tokens"] += agent_stats.cached_tokens
|
total_stats["cached_tokens"] += agent_stats.cached_tokens
|
||||||
total_stats["cache_creation_tokens"] += agent_stats.cache_creation_tokens
|
|
||||||
total_stats["cost"] += agent_stats.cost
|
total_stats["cost"] += agent_stats.cost
|
||||||
total_stats["requests"] += agent_stats.requests
|
total_stats["requests"] += agent_stats.requests
|
||||||
total_stats["failed_requests"] += agent_stats.failed_requests
|
|
||||||
|
|
||||||
total_stats["cost"] = round(total_stats["cost"], 4)
|
total_stats["cost"] = round(total_stats["cost"], 4)
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from .registry import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
SANDBOX_EXECUTION_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout") or "500")
|
SANDBOX_EXECUTION_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout") or "120")
|
||||||
SANDBOX_CONNECT_TIMEOUT = float(Config.get("strix_sandbox_connect_timeout") or "10")
|
SANDBOX_CONNECT_TIMEOUT = float(Config.get("strix_sandbox_connect_timeout") or "10")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user