diff --git a/strix.spec b/strix.spec index 03dbf86..fbe2d90 100644 --- a/strix.spec +++ b/strix.spec @@ -111,7 +111,6 @@ hiddenimports = [ 'strix.llm.llm', 'strix.llm.config', 'strix.llm.utils', - 'strix.llm.request_queue', 'strix.llm.memory_compressor', 'strix.runtime', 'strix.runtime.runtime', diff --git a/strix/agents/base_agent.py b/strix/agents/base_agent.py index 8905197..f955892 100644 --- a/strix/agents/base_agent.py +++ b/strix/agents/base_agent.py @@ -79,6 +79,7 @@ class BaseAgent(metaclass=AgentMeta): with contextlib.suppress(Exception): self.llm.set_agent_identity(self.state.agent_name, self.state.agent_id) self._current_task: asyncio.Task[Any] | None = None + self._force_stop = False from strix.telemetry.tracer import get_global_tracer @@ -156,6 +157,11 @@ class BaseAgent(metaclass=AgentMeta): return self._handle_sandbox_error(e, tracer) while True: + if self._force_stop: + self._force_stop = False + await self._enter_waiting_state(tracer, was_cancelled=True) + continue + self._check_agent_messages(self.state) if self.state.is_waiting_for_input(): @@ -246,7 +252,8 @@ class BaseAgent(metaclass=AgentMeta): continue async def _wait_for_input(self) -> None: - import asyncio + if self._force_stop: + return if self.state.has_waiting_timeout(): self.state.resume_from_waiting() @@ -339,6 +346,7 @@ class BaseAgent(metaclass=AgentMeta): async def _process_iteration(self, tracer: Optional["Tracer"]) -> bool: final_response = None + async for response in self.llm.generate(self.state.get_conversation_history()): final_response = response if tracer and response.content: @@ -584,6 +592,11 @@ class BaseAgent(metaclass=AgentMeta): return True def cancel_current_execution(self) -> None: + self._force_stop = True if self._current_task and not self._current_task.done(): - self._current_task.cancel() + try: + loop = self._current_task.get_loop() + loop.call_soon_threadsafe(self._current_task.cancel) + except RuntimeError: + self._current_task.cancel() self._current_task = None diff --git a/strix/config/config.py b/strix/config/config.py index 0e0d16f..a602658 100644 --- a/strix/config/config.py +++ b/strix/config/config.py @@ -16,9 +16,9 @@ class Config: litellm_base_url = None ollama_api_base = None strix_reasoning_effort = "high" + strix_llm_max_retries = "5" + strix_memory_compressor_timeout = "30" llm_timeout = "300" - llm_rate_limit_delay = "4.0" - llm_rate_limit_concurrent = "1" # Tool & Feature Configuration perplexity_api_key = None @@ -27,7 +27,7 @@ class Config: # Runtime Configuration strix_image = "ghcr.io/usestrix/strix-sandbox:0.1.10" strix_runtime_backend = "docker" - strix_sandbox_execution_timeout = "500" + strix_sandbox_execution_timeout = "120" strix_sandbox_connect_timeout = "10" # Telemetry diff --git a/strix/llm/llm.py b/strix/llm/llm.py index 34132a8..f4e831e 100644 --- a/strix/llm/llm.py +++ b/strix/llm/llm.py @@ -1,23 +1,16 @@ import asyncio -import logging from collections.abc import AsyncIterator from dataclasses import dataclass -from enum import Enum from typing import Any import litellm -from jinja2 import ( - Environment, - FileSystemLoader, - select_autoescape, -) -from litellm import completion_cost, stream_chunk_builder, supports_reasoning +from jinja2 import Environment, FileSystemLoader, select_autoescape +from litellm import acompletion, completion_cost, stream_chunk_builder, supports_reasoning from litellm.utils import supports_prompt_caching, supports_vision from strix.config import Config from strix.llm.config import LLMConfig from strix.llm.memory_compressor import MemoryCompressor -from strix.llm.request_queue import get_global_queue from strix.llm.utils import ( _truncate_to_first_function, fix_incomplete_tool_call, @@ -28,37 +21,9 @@ from strix.tools import get_tools_prompt from strix.utils.resource_paths import get_strix_resource_path -MAX_RETRIES = 5 -RETRY_MULTIPLIER = 8 -RETRY_MIN = 8 -RETRY_MAX = 64 - - -def _should_retry(exception: Exception) -> bool: - status_code = None - if hasattr(exception, "status_code"): - status_code = exception.status_code - elif hasattr(exception, "response") and hasattr(exception.response, "status_code"): - status_code = exception.response.status_code - if status_code is not None: - return bool(litellm._should_retry(status_code)) - return True - - -logger = logging.getLogger(__name__) - litellm.drop_params = True litellm.modify_params = True -_LLM_API_KEY = Config.get("llm_api_key") -_LLM_API_BASE = ( - Config.get("llm_api_base") - or Config.get("openai_api_base") - or Config.get("litellm_base_url") - or Config.get("ollama_api_base") -) -_STRIX_REASONING_EFFORT = Config.get("strix_reasoning_effort") - class LLMRequestFailedError(Exception): def __init__(self, message: str, details: str | None = None): @@ -67,20 +32,11 @@ class LLMRequestFailedError(Exception): self.details = details -class StepRole(str, Enum): - AGENT = "agent" - USER = "user" - SYSTEM = "system" - - @dataclass class LLMResponse: content: str tool_invocations: list[dict[str, Any]] | None = None - scan_id: str | None = None - step_number: int = 1 - role: StepRole = StepRole.AGENT - thinking_blocks: list[dict[str, Any]] | None = None # For reasoning models. + thinking_blocks: list[dict[str, Any]] | None = None @dataclass @@ -88,76 +44,63 @@ class RequestStats: input_tokens: int = 0 output_tokens: int = 0 cached_tokens: int = 0 - cache_creation_tokens: int = 0 cost: float = 0.0 requests: int = 0 - failed_requests: int = 0 def to_dict(self) -> dict[str, int | float]: return { "input_tokens": self.input_tokens, "output_tokens": self.output_tokens, "cached_tokens": self.cached_tokens, - "cache_creation_tokens": self.cache_creation_tokens, "cost": round(self.cost, 4), "requests": self.requests, - "failed_requests": self.failed_requests, } class LLM: - def __init__( - self, config: LLMConfig, agent_name: str | None = None, agent_id: str | None = None - ): + def __init__(self, config: LLMConfig, agent_name: str | None = None): self.config = config self.agent_name = agent_name - self.agent_id = agent_id + self.agent_id: str | None = None self._total_stats = RequestStats() - self._last_request_stats = RequestStats() + self.memory_compressor = MemoryCompressor(model_name=config.model_name) + self.system_prompt = self._load_system_prompt(agent_name) - if _STRIX_REASONING_EFFORT: - self._reasoning_effort = _STRIX_REASONING_EFFORT - elif self.config.scan_mode == "quick": + reasoning = Config.get("strix_reasoning_effort") + if reasoning: + self._reasoning_effort = reasoning + elif config.scan_mode == "quick": self._reasoning_effort = "medium" else: self._reasoning_effort = "high" - self.memory_compressor = MemoryCompressor( - model_name=self.config.model_name, - timeout=self.config.timeout, - ) + def _load_system_prompt(self, agent_name: str | None) -> str: + if not agent_name: + return "" - if agent_name: + try: prompt_dir = get_strix_resource_path("agents", agent_name) skills_dir = get_strix_resource_path("skills") - - loader = FileSystemLoader([prompt_dir, skills_dir]) - self.jinja_env = Environment( - loader=loader, + env = Environment( + loader=FileSystemLoader([prompt_dir, skills_dir]), autoescape=select_autoescape(enabled_extensions=(), default_for_string=False), ) - try: - skills_to_load = list(self.config.skills or []) - skills_to_load.append(f"scan_modes/{self.config.scan_mode}") + skills_to_load = [ + *list(self.config.skills or []), + f"scan_modes/{self.config.scan_mode}", + ] + skill_content = load_skills(skills_to_load, env) + env.globals["get_skill"] = lambda name: skill_content.get(name, "") - skill_content = load_skills(skills_to_load, self.jinja_env) - - def get_skill(name: str) -> str: - return skill_content.get(name, "") - - self.jinja_env.globals["get_skill"] = get_skill - - self.system_prompt = self.jinja_env.get_template("system_prompt.jinja").render( - get_tools_prompt=get_tools_prompt, - loaded_skill_names=list(skill_content.keys()), - **skill_content, - ) - except (FileNotFoundError, OSError, ValueError) as e: - logger.warning(f"Failed to load system prompt for {agent_name}: {e}") - self.system_prompt = "You are a helpful AI assistant." - else: - self.system_prompt = "You are a helpful AI assistant." + result = env.get_template("system_prompt.jinja").render( + get_tools_prompt=get_tools_prompt, + loaded_skill_names=list(skill_content.keys()), + **skill_content, + ) + return str(result) + except Exception: # noqa: BLE001 + return "" def set_agent_identity(self, agent_name: str | None, agent_id: str | None) -> None: if agent_name: @@ -165,330 +108,119 @@ class LLM: if agent_id: self.agent_id = agent_id - def _build_identity_message(self) -> dict[str, Any] | None: - if not (self.agent_name and str(self.agent_name).strip()): - return None - identity_name = self.agent_name - identity_id = self.agent_id - content = ( - "\n\n" - "\n" - "Internal metadata: do not echo or reference; " - "not part of history or tool calls.\n" - "You are now assuming the role of this agent. " - "Act strictly as this agent and maintain self-identity for this step. " - "Now go answer the next needed step!\n" - f"{identity_name}\n" - f"{identity_id}\n" - "\n\n" + async def generate( + self, conversation_history: list[dict[str, Any]] + ) -> AsyncIterator[LLMResponse]: + messages = self._prepare_messages(conversation_history) + max_retries = int(Config.get("strix_llm_max_retries") or "5") + + for attempt in range(max_retries + 1): + try: + async for response in self._stream(messages): + yield response + return # noqa: TRY300 + except Exception as e: # noqa: BLE001 + if attempt >= max_retries or not self._should_retry(e): + self._raise_error(e) + wait = min(10, 2 * (2**attempt)) + await asyncio.sleep(wait) + + async def _stream(self, messages: list[dict[str, Any]]) -> AsyncIterator[LLMResponse]: + accumulated = "" + chunks: list[Any] = [] + + self._total_stats.requests += 1 + response = await acompletion(**self._build_completion_args(messages), stream=True) + + async for chunk in response: + chunks.append(chunk) + delta = self._get_chunk_content(chunk) + if delta: + accumulated += delta + if "" in accumulated: + accumulated = accumulated[ + : accumulated.find("") + len("") + ] + yield LLMResponse(content=accumulated) + + if chunks: + self._update_usage_stats(stream_chunk_builder(chunks)) + + accumulated = fix_incomplete_tool_call(_truncate_to_first_function(accumulated)) + yield LLMResponse( + content=accumulated, + tool_invocations=parse_tool_invocations(accumulated), + thinking_blocks=self._extract_thinking(chunks), ) - return {"role": "user", "content": content} - - def _add_cache_control_to_content( - self, content: str | list[dict[str, Any]] - ) -> str | list[dict[str, Any]]: - if isinstance(content, str): - return [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}] - if isinstance(content, list) and content: - last_item = content[-1] - if isinstance(last_item, dict) and last_item.get("type") == "text": - return content[:-1] + [{**last_item, "cache_control": {"type": "ephemeral"}}] - return content - - def _is_anthropic_model(self) -> bool: - if not self.config.model_name: - return False - model_lower = self.config.model_name.lower() - return any(provider in model_lower for provider in ["anthropic/", "claude"]) - - def _calculate_cache_interval(self, total_messages: int) -> int: - if total_messages <= 1: - return 10 - - max_cached_messages = 3 - non_system_messages = total_messages - 1 - - interval = 10 - while non_system_messages // interval > max_cached_messages: - interval += 10 - - return interval - - def _prepare_cached_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: - if ( - not self.config.enable_prompt_caching - or not supports_prompt_caching(self.config.model_name) - or not messages - ): - return messages - - if not self._is_anthropic_model(): - return messages - - cached_messages = list(messages) - - if cached_messages and cached_messages[0].get("role") == "system": - system_message = cached_messages[0].copy() - system_message["content"] = self._add_cache_control_to_content( - system_message["content"] - ) - cached_messages[0] = system_message - - total_messages = len(cached_messages) - if total_messages > 1: - interval = self._calculate_cache_interval(total_messages) - - cached_count = 0 - for i in range(interval, total_messages, interval): - if cached_count >= 3: - break - - if i < len(cached_messages): - message = cached_messages[i].copy() - message["content"] = self._add_cache_control_to_content(message["content"]) - cached_messages[i] = message - cached_count += 1 - - return cached_messages def _prepare_messages(self, conversation_history: list[dict[str, Any]]) -> list[dict[str, Any]]: messages = [{"role": "system", "content": self.system_prompt}] - identity_message = self._build_identity_message() - if identity_message: - messages.append(identity_message) - - compressed_history = list(self.memory_compressor.compress_history(conversation_history)) + if self.agent_name: + messages.append( + { + "role": "user", + "content": ( + f"\n\n\n" + f"Internal metadata: do not echo or reference.\n" + f"{self.agent_name}\n" + f"{self.agent_id}\n" + f"\n\n" + ), + } + ) + compressed = list(self.memory_compressor.compress_history(conversation_history)) conversation_history.clear() - conversation_history.extend(compressed_history) - messages.extend(compressed_history) + conversation_history.extend(compressed) + messages.extend(compressed) - return self._prepare_cached_messages(messages) + if self._is_anthropic() and self.config.enable_prompt_caching: + messages = self._add_cache_control(messages) - async def _stream_and_accumulate( - self, - messages: list[dict[str, Any]], - scan_id: str | None, - step_number: int, - ) -> AsyncIterator[LLMResponse]: - accumulated_content = "" - chunks: list[Any] = [] + return messages - async for chunk in self._stream_request(messages): - chunks.append(chunk) - delta = self._extract_chunk_delta(chunk) - if delta: - accumulated_content += delta + def _build_completion_args(self, messages: list[dict[str, Any]]) -> dict[str, Any]: + if not self._supports_vision(): + messages = self._strip_images(messages) - if "" in accumulated_content: - function_end = accumulated_content.find("") + len("") - accumulated_content = accumulated_content[:function_end] - - yield LLMResponse( - scan_id=scan_id, - step_number=step_number, - role=StepRole.AGENT, - content=accumulated_content, - tool_invocations=None, - ) - - if chunks: - complete_response = stream_chunk_builder(chunks) - self._update_usage_stats(complete_response) - - accumulated_content = _truncate_to_first_function(accumulated_content) - if "" in accumulated_content: - function_end = accumulated_content.find("") + len("") - accumulated_content = accumulated_content[:function_end] - - accumulated_content = fix_incomplete_tool_call(accumulated_content) - - tool_invocations = parse_tool_invocations(accumulated_content) - - # Extract thinking blocks from the complete response if available - thinking_blocks = None - if chunks and self._should_include_reasoning_effort(): - complete_response = stream_chunk_builder(chunks) - if ( - hasattr(complete_response, "choices") - and complete_response.choices - and hasattr(complete_response.choices[0], "message") - ): - message = complete_response.choices[0].message - if hasattr(message, "thinking_blocks") and message.thinking_blocks: - thinking_blocks = message.thinking_blocks - - yield LLMResponse( - scan_id=scan_id, - step_number=step_number, - role=StepRole.AGENT, - content=accumulated_content, - tool_invocations=tool_invocations if tool_invocations else None, - thinking_blocks=thinking_blocks, - ) - - def _raise_llm_error(self, e: Exception) -> None: - error_map: list[tuple[type, str]] = [ - (litellm.RateLimitError, "Rate limit exceeded"), - (litellm.AuthenticationError, "Invalid API key"), - (litellm.NotFoundError, "Model not found"), - (litellm.ContextWindowExceededError, "Context too long"), - (litellm.ContentPolicyViolationError, "Content policy violation"), - (litellm.ServiceUnavailableError, "Service unavailable"), - (litellm.Timeout, "Request timed out"), - (litellm.UnprocessableEntityError, "Unprocessable entity"), - (litellm.InternalServerError, "Internal server error"), - (litellm.APIConnectionError, "Connection error"), - (litellm.UnsupportedParamsError, "Unsupported parameters"), - (litellm.BudgetExceededError, "Budget exceeded"), - (litellm.APIResponseValidationError, "Response validation error"), - (litellm.JSONSchemaValidationError, "JSON schema validation error"), - (litellm.InvalidRequestError, "Invalid request"), - (litellm.BadRequestError, "Bad request"), - (litellm.APIError, "API error"), - (litellm.OpenAIError, "OpenAI error"), - ] - - from strix.telemetry import posthog - - for error_type, message in error_map: - if isinstance(e, error_type): - posthog.error(f"llm_{error_type.__name__}", message) - raise LLMRequestFailedError(f"LLM request failed: {message}", str(e)) from e - - posthog.error("llm_unknown_error", type(e).__name__) - raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e - - async def generate( - self, - conversation_history: list[dict[str, Any]], - scan_id: str | None = None, - step_number: int = 1, - ) -> AsyncIterator[LLMResponse]: - messages = self._prepare_messages(conversation_history) - - last_error: Exception | None = None - for attempt in range(MAX_RETRIES): - try: - async for response in self._stream_and_accumulate(messages, scan_id, step_number): - yield response - return # noqa: TRY300 - except Exception as e: # noqa: BLE001 - last_error = e - if not _should_retry(e) or attempt == MAX_RETRIES - 1: - break - wait_time = min(RETRY_MAX, RETRY_MULTIPLIER * (2**attempt)) - wait_time = max(RETRY_MIN, wait_time) - await asyncio.sleep(wait_time) - - if last_error: - self._raise_llm_error(last_error) - - def _extract_chunk_delta(self, chunk: Any) -> str: - if chunk.choices and hasattr(chunk.choices[0], "delta"): - delta = chunk.choices[0].delta - return getattr(delta, "content", "") or "" - return "" - - @property - def usage_stats(self) -> dict[str, dict[str, int | float]]: - return { - "total": self._total_stats.to_dict(), - "last_request": self._last_request_stats.to_dict(), - } - - def get_cache_config(self) -> dict[str, bool]: - return { - "enabled": self.config.enable_prompt_caching, - "supported": supports_prompt_caching(self.config.model_name), - } - - def _should_include_reasoning_effort(self) -> bool: - if not self.config.model_name: - return False - try: - return bool(supports_reasoning(model=self.config.model_name)) - except Exception: # noqa: BLE001 - return False - - def _model_supports_vision(self) -> bool: - if not self.config.model_name: - return False - try: - return bool(supports_vision(model=self.config.model_name)) - except Exception: # noqa: BLE001 - return False - - def _filter_images_from_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: - filtered_messages = [] - for msg in messages: - content = msg.get("content") - updated_msg = msg - if isinstance(content, list): - filtered_content = [] - for item in content: - if isinstance(item, dict): - if item.get("type") == "image_url": - filtered_content.append( - { - "type": "text", - "text": "[Screenshot removed - model does not support " - "vision. Use view_source or execute_js instead.]", - } - ) - else: - filtered_content.append(item) - else: - filtered_content.append(item) - if filtered_content: - text_parts = [ - item.get("text", "") if isinstance(item, dict) else str(item) - for item in filtered_content - ] - all_text = all( - isinstance(item, dict) and item.get("type") == "text" - for item in filtered_content - ) - if all_text: - updated_msg = {**msg, "content": "\n".join(text_parts)} - else: - updated_msg = {**msg, "content": filtered_content} - else: - updated_msg = {**msg, "content": ""} - filtered_messages.append(updated_msg) - return filtered_messages - - async def _stream_request( - self, - messages: list[dict[str, Any]], - ) -> AsyncIterator[Any]: - if not self._model_supports_vision(): - messages = self._filter_images_from_messages(messages) - - completion_args: dict[str, Any] = { + args: dict[str, Any] = { "model": self.config.model_name, "messages": messages, "timeout": self.config.timeout, "stream_options": {"include_usage": True}, + "stop": [""], } - if _LLM_API_KEY: - completion_args["api_key"] = _LLM_API_KEY - if _LLM_API_BASE: - completion_args["api_base"] = _LLM_API_BASE + if api_key := Config.get("llm_api_key"): + args["api_key"] = api_key + if api_base := ( + Config.get("llm_api_base") + or Config.get("openai_api_base") + or Config.get("litellm_base_url") + ): + args["api_base"] = api_base + if self._supports_reasoning(): + args["reasoning_effort"] = self._reasoning_effort - completion_args["stop"] = [""] + return args - if self._should_include_reasoning_effort(): - completion_args["reasoning_effort"] = self._reasoning_effort + def _get_chunk_content(self, chunk: Any) -> str: + if chunk.choices and hasattr(chunk.choices[0], "delta"): + return getattr(chunk.choices[0].delta, "content", "") or "" + return "" - queue = get_global_queue() - self._total_stats.requests += 1 - self._last_request_stats = RequestStats(requests=1) - - async for chunk in queue.stream_request(completion_args): - yield chunk + def _extract_thinking(self, chunks: list[Any]) -> list[dict[str, Any]] | None: + if not chunks or not self._supports_reasoning(): + return None + try: + resp = stream_chunk_builder(chunks) + if resp.choices and hasattr(resp.choices[0].message, "thinking_blocks"): + blocks: list[dict[str, Any]] = resp.choices[0].message.thinking_blocks + return blocks + except Exception: # noqa: BLE001, S110 # nosec B110 + pass + return None def _update_usage_stats(self, response: Any) -> None: try: @@ -497,45 +229,88 @@ class LLM: output_tokens = getattr(response.usage, "completion_tokens", 0) cached_tokens = 0 - cache_creation_tokens = 0 - if hasattr(response.usage, "prompt_tokens_details"): prompt_details = response.usage.prompt_tokens_details if hasattr(prompt_details, "cached_tokens"): cached_tokens = prompt_details.cached_tokens or 0 - if hasattr(response.usage, "cache_creation_input_tokens"): - cache_creation_tokens = response.usage.cache_creation_input_tokens or 0 - else: input_tokens = 0 output_tokens = 0 cached_tokens = 0 - cache_creation_tokens = 0 try: cost = completion_cost(response) or 0.0 - except Exception as e: # noqa: BLE001 - logger.warning(f"Failed to calculate cost: {e}") + except Exception: # noqa: BLE001 cost = 0.0 self._total_stats.input_tokens += input_tokens self._total_stats.output_tokens += output_tokens self._total_stats.cached_tokens += cached_tokens - self._total_stats.cache_creation_tokens += cache_creation_tokens self._total_stats.cost += cost - self._last_request_stats.input_tokens = input_tokens - self._last_request_stats.output_tokens = output_tokens - self._last_request_stats.cached_tokens = cached_tokens - self._last_request_stats.cache_creation_tokens = cache_creation_tokens - self._last_request_stats.cost = cost + except Exception: # noqa: BLE001, S110 # nosec B110 + pass - if cached_tokens > 0: - logger.info(f"Cache hit: {cached_tokens} cached tokens, {input_tokens} new tokens") - if cache_creation_tokens > 0: - logger.info(f"Cache creation: {cache_creation_tokens} tokens written to cache") + def _should_retry(self, e: Exception) -> bool: + code = getattr(e, "status_code", None) or getattr( + getattr(e, "response", None), "status_code", None + ) + return code is None or litellm._should_retry(code) - logger.info(f"Usage stats: {self.usage_stats}") - except Exception as e: # noqa: BLE001 - logger.warning(f"Failed to update usage stats: {e}") + def _raise_error(self, e: Exception) -> None: + from strix.telemetry import posthog + + posthog.error("llm_error", type(e).__name__) + raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e + + def _is_anthropic(self) -> bool: + if not self.config.model_name: + return False + return any(p in self.config.model_name.lower() for p in ["anthropic/", "claude"]) + + def _supports_vision(self) -> bool: + try: + return bool(supports_vision(model=self.config.model_name)) + except Exception: # noqa: BLE001 + return False + + def _supports_reasoning(self) -> bool: + try: + return bool(supports_reasoning(model=self.config.model_name)) + except Exception: # noqa: BLE001 + return False + + def _strip_images(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + result = [] + for msg in messages: + content = msg.get("content") + if isinstance(content, list): + text_parts = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + text_parts.append(item.get("text", "")) + elif isinstance(item, dict) and item.get("type") == "image_url": + text_parts.append("[Image removed - model doesn't support vision]") + result.append({**msg, "content": "\n".join(text_parts)}) + else: + result.append(msg) + return result + + def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + if not messages or not supports_prompt_caching(self.config.model_name): + return messages + + result = list(messages) + + if result[0].get("role") == "system": + content = result[0]["content"] + result[0] = { + **result[0], + "content": [ + {"type": "text", "text": content, "cache_control": {"type": "ephemeral"}} + ] + if isinstance(content, str) + else content, + } + return result diff --git a/strix/llm/memory_compressor.py b/strix/llm/memory_compressor.py index bfc6480..a9532f8 100644 --- a/strix/llm/memory_compressor.py +++ b/strix/llm/memory_compressor.py @@ -86,7 +86,7 @@ def _extract_message_text(msg: dict[str, Any]) -> str: def _summarize_messages( messages: list[dict[str, Any]], model: str, - timeout: int = 600, + timeout: int = 30, ) -> dict[str, Any]: if not messages: empty_summary = "{text}" @@ -148,11 +148,11 @@ class MemoryCompressor: self, max_images: int = 3, model_name: str | None = None, - timeout: int = 600, + timeout: int | None = None, ): self.max_images = max_images self.model_name = model_name or Config.get("strix_llm") - self.timeout = timeout + self.timeout = timeout or int(Config.get("strix_memory_compressor_timeout") or "30") if not self.model_name: raise ValueError("STRIX_LLM environment variable must be set and not empty") diff --git a/strix/llm/request_queue.py b/strix/llm/request_queue.py deleted file mode 100644 index 35c9725..0000000 --- a/strix/llm/request_queue.py +++ /dev/null @@ -1,58 +0,0 @@ -import asyncio -import threading -import time -from collections.abc import AsyncIterator -from typing import Any - -from litellm import acompletion -from litellm.types.utils import ModelResponseStream - -from strix.config import Config - - -class LLMRequestQueue: - def __init__(self) -> None: - self.delay_between_requests = float(Config.get("llm_rate_limit_delay") or "4.0") - self.max_concurrent = int(Config.get("llm_rate_limit_concurrent") or "1") - self._semaphore = threading.BoundedSemaphore(self.max_concurrent) - self._last_request_time = 0.0 - self._lock = threading.Lock() - - async def stream_request( - self, completion_args: dict[str, Any] - ) -> AsyncIterator[ModelResponseStream]: - try: - while not self._semaphore.acquire(timeout=0.2): - await asyncio.sleep(0.1) - - with self._lock: - now = time.time() - time_since_last = now - self._last_request_time - sleep_needed = max(0, self.delay_between_requests - time_since_last) - self._last_request_time = now + sleep_needed - - if sleep_needed > 0: - await asyncio.sleep(sleep_needed) - - async for chunk in self._stream_request(completion_args): - yield chunk - finally: - self._semaphore.release() - - async def _stream_request( - self, completion_args: dict[str, Any] - ) -> AsyncIterator[ModelResponseStream]: - response = await acompletion(**completion_args, stream=True) - - async for chunk in response: - yield chunk - - -_global_queue: LLMRequestQueue | None = None - - -def get_global_queue() -> LLMRequestQueue: - global _global_queue # noqa: PLW0603 - if _global_queue is None: - _global_queue = LLMRequestQueue() - return _global_queue diff --git a/strix/telemetry/tracer.py b/strix/telemetry/tracer.py index af299a9..25af62c 100644 --- a/strix/telemetry/tracer.py +++ b/strix/telemetry/tracer.py @@ -430,10 +430,8 @@ class Tracer: "input_tokens": 0, "output_tokens": 0, "cached_tokens": 0, - "cache_creation_tokens": 0, "cost": 0.0, "requests": 0, - "failed_requests": 0, } for agent_instance in _agent_instances.values(): @@ -442,10 +440,8 @@ class Tracer: total_stats["input_tokens"] += agent_stats.input_tokens total_stats["output_tokens"] += agent_stats.output_tokens total_stats["cached_tokens"] += agent_stats.cached_tokens - total_stats["cache_creation_tokens"] += agent_stats.cache_creation_tokens total_stats["cost"] += agent_stats.cost total_stats["requests"] += agent_stats.requests - total_stats["failed_requests"] += agent_stats.failed_requests total_stats["cost"] = round(total_stats["cost"], 4) diff --git a/strix/tools/executor.py b/strix/tools/executor.py index ad0aeef..db06477 100644 --- a/strix/tools/executor.py +++ b/strix/tools/executor.py @@ -20,7 +20,7 @@ from .registry import ( ) -SANDBOX_EXECUTION_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout") or "500") +SANDBOX_EXECUTION_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout") or "120") SANDBOX_CONNECT_TIMEOUT = float(Config.get("strix_sandbox_connect_timeout") or "10")