Open-source release for Alpha version

This commit is contained in:
Ahmed Allam
2025-08-08 20:36:44 -07:00
commit 81ac98e8b9
105 changed files with 22125 additions and 0 deletions

64
strix/tools/__init__.py Normal file
View File

@@ -0,0 +1,64 @@
import os
from .executor import (
execute_tool,
execute_tool_invocation,
execute_tool_with_validation,
extract_screenshot_from_result,
process_tool_invocations,
remove_screenshot_from_result,
validate_tool_availability,
)
from .registry import (
ImplementedInClientSideOnlyError,
get_tool_by_name,
get_tool_names,
get_tools_prompt,
needs_agent_state,
register_tool,
tools,
)
SANDBOX_MODE = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
HAS_PERPLEXITY_API = bool(os.getenv("PERPLEXITY_API_KEY"))
if not SANDBOX_MODE:
from .agents_graph import * # noqa: F403
from .browser import * # noqa: F403
from .file_edit import * # noqa: F403
from .finish import * # noqa: F403
from .notes import * # noqa: F403
from .proxy import * # noqa: F403
from .python import * # noqa: F403
from .reporting import * # noqa: F403
from .terminal import * # noqa: F403
from .thinking import * # noqa: F403
if HAS_PERPLEXITY_API:
from .web_search import * # noqa: F403
else:
from .browser import * # noqa: F403
from .file_edit import * # noqa: F403
from .notes import * # noqa: F403
from .proxy import * # noqa: F403
from .python import * # noqa: F403
from .terminal import * # noqa: F403
__all__ = [
"ImplementedInClientSideOnlyError",
"execute_tool",
"execute_tool_invocation",
"execute_tool_with_validation",
"extract_screenshot_from_result",
"get_tool_by_name",
"get_tool_names",
"get_tools_prompt",
"needs_agent_state",
"process_tool_invocations",
"register_tool",
"remove_screenshot_from_result",
"tools",
"validate_tool_availability",
]

View File

@@ -0,0 +1,16 @@
from .agents_graph_actions import (
agent_finish,
create_agent,
send_message_to_agent,
view_agent_graph,
wait_for_message,
)
__all__ = [
"agent_finish",
"create_agent",
"send_message_to_agent",
"view_agent_graph",
"wait_for_message",
]

View File

@@ -0,0 +1,610 @@
import threading
from datetime import UTC, datetime
from typing import Any, Literal
from strix.tools.registry import register_tool
_agent_graph: dict[str, Any] = {
"nodes": {},
"edges": [],
}
_root_agent_id: str | None = None
_agent_messages: dict[str, list[dict[str, Any]]] = {}
_running_agents: dict[str, threading.Thread] = {}
_agent_instances: dict[str, Any] = {}
_agent_states: dict[str, Any] = {}
def _run_agent_in_thread(
agent: Any, state: Any, inherited_messages: list[dict[str, Any]]
) -> dict[str, Any]:
try:
if inherited_messages:
state.add_message("user", "<inherited_context_from_parent>")
for msg in inherited_messages:
state.add_message(msg["role"], msg["content"])
state.add_message("user", "</inherited_context_from_parent>")
parent_info = _agent_graph["nodes"].get(state.parent_id, {})
parent_name = parent_info.get("name", "Unknown Parent")
context_status = (
"inherited conversation context from your parent for background understanding"
if inherited_messages
else "started with a fresh context"
)
task_xml = f"""<agent_delegation>
<identity>
⚠️ You are NOT your parent agent. You are a NEW, SEPARATE sub-agent (not root).
Your Info: {state.agent_name} ({state.agent_id})
Parent Info: {parent_name} ({state.parent_id})
</identity>
<your_task>{state.task}</your_task>
<instructions>
- You have {context_status}
- Inherited context is for BACKGROUND ONLY - don't continue parent's work
- Focus EXCLUSIVELY on your delegated task above
- Work independently with your own approach
- Use agent_finish when complete to report back to parent
- You are a SPECIALIST for this specific task
</instructions>
</agent_delegation>"""
state.add_message("user", task_xml)
_agent_states[state.agent_id] = state
_agent_graph["nodes"][state.agent_id]["state"] = state.model_dump()
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(agent.agent_loop(state.task))
finally:
loop.close()
except Exception as e:
_agent_graph["nodes"][state.agent_id]["status"] = "error"
_agent_graph["nodes"][state.agent_id]["finished_at"] = datetime.now(UTC).isoformat()
_agent_graph["nodes"][state.agent_id]["result"] = {"error": str(e)}
_running_agents.pop(state.agent_id, None)
_agent_instances.pop(state.agent_id, None)
raise
else:
if state.stop_requested:
_agent_graph["nodes"][state.agent_id]["status"] = "stopped"
else:
_agent_graph["nodes"][state.agent_id]["status"] = "completed"
_agent_graph["nodes"][state.agent_id]["finished_at"] = datetime.now(UTC).isoformat()
_agent_graph["nodes"][state.agent_id]["result"] = result
_running_agents.pop(state.agent_id, None)
_agent_instances.pop(state.agent_id, None)
return {"result": result}
@register_tool(sandbox_execution=False)
def view_agent_graph(agent_state: Any) -> dict[str, Any]:
try:
structure_lines = ["=== AGENT GRAPH STRUCTURE ==="]
def _build_tree(agent_id: str, depth: int = 0) -> None:
node = _agent_graph["nodes"][agent_id]
indent = " " * depth
you_indicator = " ← This is you" if agent_id == agent_state.agent_id else ""
structure_lines.append(f"{indent}* {node['name']} ({agent_id}){you_indicator}")
structure_lines.append(f"{indent} Task: {node['task']}")
structure_lines.append(f"{indent} Status: {node['status']}")
children = [
edge["to"]
for edge in _agent_graph["edges"]
if edge["from"] == agent_id and edge["type"] == "delegation"
]
if children:
structure_lines.append(f"{indent} Children:")
for child_id in children:
_build_tree(child_id, depth + 2)
root_agent_id = _root_agent_id
if not root_agent_id and _agent_graph["nodes"]:
for agent_id, node in _agent_graph["nodes"].items():
if node.get("parent_id") is None:
root_agent_id = agent_id
break
if not root_agent_id:
root_agent_id = next(iter(_agent_graph["nodes"].keys()))
if root_agent_id and root_agent_id in _agent_graph["nodes"]:
_build_tree(root_agent_id)
else:
structure_lines.append("No agents in the graph yet")
graph_structure = "\n".join(structure_lines)
total_nodes = len(_agent_graph["nodes"])
running_count = sum(
1 for node in _agent_graph["nodes"].values() if node["status"] == "running"
)
waiting_count = sum(
1 for node in _agent_graph["nodes"].values() if node["status"] == "waiting"
)
stopping_count = sum(
1 for node in _agent_graph["nodes"].values() if node["status"] == "stopping"
)
completed_count = sum(
1 for node in _agent_graph["nodes"].values() if node["status"] == "completed"
)
stopped_count = sum(
1 for node in _agent_graph["nodes"].values() if node["status"] == "stopped"
)
failed_count = sum(
1 for node in _agent_graph["nodes"].values() if node["status"] in ["failed", "error"]
)
except Exception as e: # noqa: BLE001
return {
"error": f"Failed to view agent graph: {e}",
"graph_structure": "Error retrieving graph structure",
}
else:
return {
"graph_structure": graph_structure,
"summary": {
"total_agents": total_nodes,
"running": running_count,
"waiting": waiting_count,
"stopping": stopping_count,
"completed": completed_count,
"stopped": stopped_count,
"failed": failed_count,
},
}
@register_tool(sandbox_execution=False)
def create_agent(
agent_state: Any,
task: str,
name: str,
inherit_context: bool = True,
prompt_modules: str | None = None,
) -> dict[str, Any]:
try:
parent_id = agent_state.agent_id
module_list = []
if prompt_modules:
module_list = [m.strip() for m in prompt_modules.split(",") if m.strip()]
if "root_agent" in module_list:
return {
"success": False,
"error": (
"The 'root_agent' module is reserved for the main agent "
"and cannot be used by sub-agents"
),
"agent_id": None,
}
if len(module_list) > 3:
return {
"success": False,
"error": (
"Cannot specify more than 3 prompt modules for an agent "
"(use comma-separated format)"
),
"agent_id": None,
}
if module_list:
from strix.prompts import get_all_module_names, validate_module_names
validation = validate_module_names(module_list)
if validation["invalid"]:
available_modules = list(get_all_module_names())
return {
"success": False,
"error": (
f"Invalid prompt modules: {validation['invalid']}. "
f"Available modules: {', '.join(available_modules)}"
),
"agent_id": None,
}
from strix.agents import StrixAgent
from strix.agents.state import AgentState
from strix.llm.config import LLMConfig
state = AgentState(task=task, agent_name=name, parent_id=parent_id, max_iterations=200)
llm_config = LLMConfig(prompt_modules=module_list)
agent = StrixAgent(
{
"llm_config": llm_config,
"state": state,
}
)
inherited_messages = []
if inherit_context:
inherited_messages = agent_state.get_conversation_history()
_agent_instances[state.agent_id] = agent
thread = threading.Thread(
target=_run_agent_in_thread,
args=(agent, state, inherited_messages),
daemon=True,
name=f"Agent-{name}-{state.agent_id}",
)
thread.start()
_running_agents[state.agent_id] = thread
except Exception as e: # noqa: BLE001
return {"success": False, "error": f"Failed to create agent: {e}", "agent_id": None}
else:
return {
"success": True,
"agent_id": state.agent_id,
"message": f"Agent '{name}' created and started asynchronously",
"agent_info": {
"id": state.agent_id,
"name": name,
"status": "running",
"parent_id": parent_id,
},
}
@register_tool(sandbox_execution=False)
def send_message_to_agent(
agent_state: Any,
target_agent_id: str,
message: str,
message_type: Literal["query", "instruction", "information"] = "information",
priority: Literal["low", "normal", "high", "urgent"] = "normal",
) -> dict[str, Any]:
try:
if target_agent_id not in _agent_graph["nodes"]:
return {
"success": False,
"error": f"Target agent '{target_agent_id}' not found in graph",
"message_id": None,
}
sender_id = agent_state.agent_id
from uuid import uuid4
message_id = f"msg_{uuid4().hex[:8]}"
message_data = {
"id": message_id,
"from": sender_id,
"to": target_agent_id,
"content": message,
"message_type": message_type,
"priority": priority,
"timestamp": datetime.now(UTC).isoformat(),
"delivered": False,
"read": False,
}
if target_agent_id not in _agent_messages:
_agent_messages[target_agent_id] = []
_agent_messages[target_agent_id].append(message_data)
_agent_graph["edges"].append(
{
"from": sender_id,
"to": target_agent_id,
"type": "message",
"message_id": message_id,
"message_type": message_type,
"priority": priority,
"created_at": datetime.now(UTC).isoformat(),
}
)
message_data["delivered"] = True
target_name = _agent_graph["nodes"][target_agent_id]["name"]
sender_name = _agent_graph["nodes"][sender_id]["name"]
return {
"success": True,
"message_id": message_id,
"message": f"Message sent from '{sender_name}' to '{target_name}'",
"delivery_status": "delivered",
"target_agent": {
"id": target_agent_id,
"name": target_name,
"status": _agent_graph["nodes"][target_agent_id]["status"],
},
}
except Exception as e: # noqa: BLE001
return {"success": False, "error": f"Failed to send message: {e}", "message_id": None}
@register_tool(sandbox_execution=False)
def agent_finish(
agent_state: Any,
result_summary: str,
findings: list[str] | None = None,
success: bool = True,
report_to_parent: bool = True,
final_recommendations: list[str] | None = None,
) -> dict[str, Any]:
try:
if not hasattr(agent_state, "parent_id") or agent_state.parent_id is None:
return {
"agent_completed": False,
"error": (
"This tool can only be used by subagents. "
"Root/main agents must use finish_scan instead."
),
"parent_notified": False,
}
agent_id = agent_state.agent_id
if agent_id not in _agent_graph["nodes"]:
return {"agent_completed": False, "error": "Current agent not found in graph"}
agent_node = _agent_graph["nodes"][agent_id]
agent_node["status"] = "finished" if success else "failed"
agent_node["finished_at"] = datetime.now(UTC).isoformat()
agent_node["result"] = {
"summary": result_summary,
"findings": findings or [],
"success": success,
"recommendations": final_recommendations or [],
}
parent_notified = False
if report_to_parent and agent_node["parent_id"]:
parent_id = agent_node["parent_id"]
if parent_id in _agent_graph["nodes"]:
findings_xml = "\n".join(
f" <finding>{finding}</finding>" for finding in (findings or [])
)
recommendations_xml = "\n".join(
f" <recommendation>{rec}</recommendation>"
for rec in (final_recommendations or [])
)
report_message = f"""<agent_completion_report>
<agent_info>
<agent_name>{agent_node["name"]}</agent_name>
<agent_id>{agent_id}</agent_id>
<task>{agent_node["task"]}</task>
<status>{"SUCCESS" if success else "FAILED"}</status>
<completion_time>{agent_node["finished_at"]}</completion_time>
</agent_info>
<results>
<summary>{result_summary}</summary>
<findings>
{findings_xml}
</findings>
<recommendations>
{recommendations_xml}
</recommendations>
</results>
</agent_completion_report>"""
if parent_id not in _agent_messages:
_agent_messages[parent_id] = []
from uuid import uuid4
_agent_messages[parent_id].append(
{
"id": f"report_{uuid4().hex[:8]}",
"from": agent_id,
"to": parent_id,
"content": report_message,
"message_type": "information",
"priority": "high",
"timestamp": datetime.now(UTC).isoformat(),
"delivered": True,
"read": False,
}
)
parent_notified = True
_running_agents.pop(agent_id, None)
return {
"agent_completed": True,
"parent_notified": parent_notified,
"completion_summary": {
"agent_id": agent_id,
"agent_name": agent_node["name"],
"task": agent_node["task"],
"success": success,
"findings_count": len(findings or []),
"has_recommendations": bool(final_recommendations),
"finished_at": agent_node["finished_at"],
},
}
except Exception as e: # noqa: BLE001
return {
"agent_completed": False,
"error": f"Failed to complete agent: {e}",
"parent_notified": False,
}
def stop_agent(agent_id: str) -> dict[str, Any]:
try:
if agent_id not in _agent_graph["nodes"]:
return {
"success": False,
"error": f"Agent '{agent_id}' not found in graph",
"agent_id": agent_id,
}
agent_node = _agent_graph["nodes"][agent_id]
if agent_node["status"] in ["completed", "error", "failed", "stopped"]:
return {
"success": True,
"message": f"Agent '{agent_node['name']}' was already stopped",
"agent_id": agent_id,
"previous_status": agent_node["status"],
}
if agent_id in _agent_states:
agent_state = _agent_states[agent_id]
agent_state.request_stop()
if agent_id in _agent_instances:
agent_instance = _agent_instances[agent_id]
if hasattr(agent_instance, "state"):
agent_instance.state.request_stop()
if hasattr(agent_instance, "cancel_current_execution"):
agent_instance.cancel_current_execution()
agent_node["status"] = "stopping"
try:
from strix.cli.tracer import get_global_tracer
tracer = get_global_tracer()
if tracer:
tracer.update_agent_status(agent_id, "stopping")
except (ImportError, AttributeError):
pass
agent_node["result"] = {
"summary": "Agent stop requested by user",
"success": False,
"stopped_by_user": True,
}
return {
"success": True,
"message": f"Stop request sent to agent '{agent_node['name']}'",
"agent_id": agent_id,
"agent_name": agent_node["name"],
"note": "Agent will stop gracefully after current iteration",
}
except Exception as e: # noqa: BLE001
return {
"success": False,
"error": f"Failed to stop agent: {e}",
"agent_id": agent_id,
}
def send_user_message_to_agent(agent_id: str, message: str) -> dict[str, Any]:
try:
if agent_id not in _agent_graph["nodes"]:
return {
"success": False,
"error": f"Agent '{agent_id}' not found in graph",
"agent_id": agent_id,
}
agent_node = _agent_graph["nodes"][agent_id]
if agent_id not in _agent_messages:
_agent_messages[agent_id] = []
from uuid import uuid4
message_data = {
"id": f"user_msg_{uuid4().hex[:8]}",
"from": "user",
"to": agent_id,
"content": message,
"message_type": "instruction",
"priority": "high",
"timestamp": datetime.now(UTC).isoformat(),
"delivered": True,
"read": False,
}
_agent_messages[agent_id].append(message_data)
return {
"success": True,
"message": f"Message sent to agent '{agent_node['name']}'",
"agent_id": agent_id,
"agent_name": agent_node["name"],
}
except Exception as e: # noqa: BLE001
return {
"success": False,
"error": f"Failed to send message to agent: {e}",
"agent_id": agent_id,
}
@register_tool(sandbox_execution=False)
def wait_for_message(
agent_state: Any,
reason: str = "Waiting for messages from other agents or user input",
) -> dict[str, Any]:
try:
agent_id = agent_state.agent_id
agent_name = agent_state.agent_name
agent_state.enter_waiting_state()
if agent_id in _agent_graph["nodes"]:
_agent_graph["nodes"][agent_id]["status"] = "waiting"
_agent_graph["nodes"][agent_id]["waiting_reason"] = reason
try:
from strix.cli.tracer import get_global_tracer
tracer = get_global_tracer()
if tracer:
tracer.update_agent_status(agent_id, "waiting")
except (ImportError, AttributeError):
pass
except Exception as e: # noqa: BLE001
return {"success": False, "error": f"Failed to enter waiting state: {e}", "status": "error"}
else:
return {
"success": True,
"status": "waiting",
"message": f"Agent '{agent_name}' is now waiting for messages",
"reason": reason,
"agent_info": {
"id": agent_id,
"name": agent_name,
"status": "waiting",
},
"resume_conditions": [
"Message from another agent",
"Message from user",
"Direct communication",
],
}

View File

@@ -0,0 +1,223 @@
<tools>
<tool name="agent_finish">
<description>Mark a subagent's task as completed and optionally report results to parent agent.
IMPORTANT: This tool can ONLY be used by subagents (agents with a parent).
Root/main agents must use finish_scan instead.
This tool should be called when a subagent completes its assigned subtask to:
- Mark the subagent's task as completed
- Report findings back to the parent agent
Use this tool when:
- You are a subagent working on a specific subtask
- You have completed your assigned task
- You want to report your findings to the parent agent
- You are ready to terminate this subagent's execution</description>
<details>This replaces the previous finish_scan tool and handles both sub-agent completion
and main agent completion. When a sub-agent finishes, it can report its findings
back to the parent agent for coordination.</details>
<parameters>
<parameter name="result_summary" type="string" required="true">
<description>Summary of what the agent accomplished and discovered</description>
</parameter>
<parameter name="findings" type="string" required="false">
<description>List of specific findings, vulnerabilities, or discoveries</description>
</parameter>
<parameter name="success" type="boolean" required="false">
<description>Whether the agent's task completed successfully</description>
</parameter>
<parameter name="report_to_parent" type="boolean" required="false">
<description>Whether to send results back to the parent agent</description>
</parameter>
<parameter name="final_recommendations" type="string" required="false">
<description>Recommendations for next steps or follow-up actions</description>
</parameter>
</parameters>
<returns type="Dict[str, Any]">
<description>Response containing: - agent_completed: Whether the agent was marked as completed - parent_notified: Whether parent was notified (if applicable) - completion_summary: Summary of completion status</description>
</returns>
<examples>
# Sub-agent completing subdomain enumeration task
<function=agent_finish>
<parameter=result_summary>Completed comprehensive subdomain enumeration for target.com.
Discovered 47 subdomains including several interesting ones with admin/dev
in the name. Found 3 subdomains with exposed services on non-standard
ports.</parameter>
<parameter=findings>["admin.target.com - exposed phpMyAdmin",
"dev-api.target.com - unauth API endpoints",
"staging.target.com - directory listing enabled",
"mail.target.com - POP3/IMAP services"]</parameter>
<parameter=success>true</parameter>
<parameter=report_to_parent>true</parameter>
<parameter=final_recommendations>["Prioritize testing admin.target.com for default creds",
"Enumerate dev-api.target.com API endpoints",
"Check staging.target.com for sensitive files"]</parameter>
</function>
</examples>
</tool>
<tool name="create_agent">
<description>Create and spawn a new agent to handle a specific subtask.
MANDATORY REQUIREMENT: You MUST call view_agent_graph FIRST before creating any new agent to check if there is already an agent working on the same or similar task. Only create a new agent if no existing agent is handling the specific task.</description>
<details>The new agent inherits the parent's conversation history and context up to the point
of creation, then continues with its assigned subtask. This enables decomposition
of complex penetration testing tasks into specialized sub-agents.
The agent runs asynchronously and independently, allowing the parent to continue
immediately while the new agent executes its task in the background.
CRITICAL: Before calling this tool, you MUST first use view_agent_graph to:
- Examine all existing agents and their current tasks
- Verify no agent is already working on the same or similar objective
- Avoid duplication of effort and resource waste
- Ensure efficient coordination across the multi-agent system
If you as a parent agent don't absolutely have anything to do while your subagents are running, you can use wait_for_message tool. The subagent will continue to run in the background, and update you when it's done.
</details>
<parameters>
<parameter name="task" type="string" required="true">
<description>The specific task/objective for the new agent to accomplish</description>
</parameter>
<parameter name="name" type="string" required="true">
<description>Human-readable name for the agent (for tracking purposes)</description>
</parameter>
<parameter name="inherit_context" type="boolean" required="false">
<description>Whether the new agent should inherit parent's conversation history and context</description>
</parameter>
<parameter name="prompt_modules" type="string" required="false">
<description>Comma-separated list of prompt modules to use for the agent. Most agents should have at least one module in order to be useful. {{DYNAMIC_MODULES_DESCRIPTION}}</description>
</parameter>
</parameters>
<returns type="Dict[str, Any]">
<description>Response containing: - agent_id: Unique identifier for the created agent - success: Whether the agent was created successfully - message: Status message - agent_info: Details about the created agent</description>
</returns>
<examples>
# REQUIRED: First check agent graph before creating any new agent
<function=view_agent_graph>
</function>
# REQUIRED: Check agent graph again before creating another agent
<function=view_agent_graph>
</function>
# After confirming no SQL testing agent exists, create agent for vulnerability validation
<function=create_agent>
<parameter=task>Validate and exploit the suspected SQL injection vulnerability found in
the login form. Confirm exploitability and document proof of concept.</parameter>
<parameter=name>SQLi Validator</parameter>
<parameter=prompt_modules>sql_injection</parameter>
</function>
# Create specialized authentication testing agent with multiple modules (comma-separated)
<function=create_agent>
<parameter=task>Test authentication mechanisms, JWT implementation, and session management
for security vulnerabilities and bypass techniques.</parameter>
<parameter=name>Auth Specialist</parameter>
<parameter=prompt_modules>authentication_jwt, business_logic</parameter>
</function>
</examples>
</tool>
<tool name="send_message_to_agent">
<description>Send a message to another agent in the graph for coordination and communication.</description>
<details>This enables agents to communicate with each other during execution for:
- Sharing discovered information or findings
- Asking questions or requesting assistance
- Providing instructions or coordination
- Reporting status or results</details>
<parameters>
<parameter name="target_agent_id" type="string" required="true">
<description>ID of the agent to send the message to</description>
</parameter>
<parameter name="message" type="string" required="true">
<description>The message content to send</description>
</parameter>
<parameter name="message_type" type="string" required="false">
<description>Type of message being sent: - "query": Question requiring a response - "instruction": Command or directive for the target agent - "information": Informational message (findings, status, etc.)</description>
</parameter>
<parameter name="priority" type="string" required="false">
<description>Priority level of the message</description>
</parameter>
</parameters>
<returns type="Dict[str, Any]">
<description>Response containing: - success: Whether the message was sent successfully - message_id: Unique identifier for the message - delivery_status: Status of message delivery</description>
</returns>
<examples>
# Share discovered vulnerability information
<function=send_message_to_agent>
<parameter=target_agent_id>agent_abc123</parameter>
<parameter=message>Found SQL injection vulnerability in /login.php parameter 'username'.
Payload: admin' OR '1'='1' -- successfully bypassed authentication.
You should focus your testing on the authenticated areas of the
application.</parameter>
<parameter=message_type>information</parameter>
<parameter=priority>high</parameter>
</function>
# Request assistance from specialist agent
<function=send_message_to_agent>
<parameter=target_agent_id>agent_def456</parameter>
<parameter=message>I've identified what appears to be a custom encryption implementation
in the API responses. Can you analyze the cryptographic strength and look
for potential weaknesses?</parameter>
<parameter=message_type>query</parameter>
<parameter=priority>normal</parameter>
</function>
</examples>
</tool>
<tool name="view_agent_graph">
<description>View the current agent graph showing all agents, their relationships, and status.</description>
<details>This provides a comprehensive overview of the multi-agent system including:
- All agent nodes with their tasks, status, and metadata
- Parent-child relationships between agents
- Message communication patterns
- Current execution state</details>
<returns type="Dict[str, Any]">
<description>Response containing: - graph_structure: Human-readable representation of the agent graph - summary: High-level statistics about the graph</description>
</returns>
</tool>
<tool name="wait_for_message">
<description>Pause the agent loop indefinitely until receiving a message from another agent or user.
This tool puts the agent into a waiting state where it remains idle until it receives any form of communication. The agent will automatically resume execution when a message arrives.
IMPORTANT: This tool causes the agent to stop all activity until a message is received. Use it when you need to:
- Wait for subagent completion reports
- Coordinate with other agents before proceeding
- Pause for user input or decisions
- Synchronize multi-agent workflows
NOTE: If you are waiting for an agent that is NOT your subagent, you first tell it to message you with updates before waiting for it. Otherwise, you will wait forever!
</description>
<details>When this tool is called, the agent enters a waiting state and will not continue execution until:
- Another agent sends it a message via send_message_to_agent
- A user sends it a direct message through the CLI
- Any other form of inter-agent or user communication occurs
The agent will automatically resume from where it left off once a message is received.
This is particularly useful for parent agents waiting for subagent results or for coordination points in multi-agent workflows.</details>
<parameters>
<parameter name="reason" type="string" required="false">
<description>Explanation for why the agent is waiting (for logging and monitoring purposes)</description>
</parameter>
</parameters>
<returns type="Dict[str, Any]">
<description>Response containing: - success: Whether the agent successfully entered waiting state - status: Current agent status ("waiting") - reason: The reason for waiting - agent_info: Details about the waiting agent - resume_conditions: List of conditions that will resume the agent</description>
</returns>
<examples>
# Wait for subagents to complete their tasks
<function=wait_for_message>
<parameter=reason>Waiting for subdomain enumeration and port scanning subagents to complete their tasks and report findings</parameter>
</function>
# Wait for user input on next steps
<function=wait_for_message>
<parameter=reason>Waiting for user decision on whether to proceed with exploitation of discovered SQL injection vulnerability</parameter>
</function>
# Coordinate with other agents
<function=wait_for_message>
<parameter=reason>Waiting for vulnerability assessment agent to share discovered attack vectors before proceeding with exploitation phase</parameter>
</function>
</examples>
</tool>
</tools>

View File

@@ -0,0 +1,120 @@
import contextlib
import inspect
import json
from collections.abc import Callable
from typing import Any, Union, get_args, get_origin
class ArgumentConversionError(Exception):
def __init__(self, message: str, param_name: str | None = None) -> None:
self.param_name = param_name
super().__init__(message)
def convert_arguments(func: Callable[..., Any], kwargs: dict[str, Any]) -> dict[str, Any]:
try:
sig = inspect.signature(func)
converted = {}
for param_name, value in kwargs.items():
if param_name not in sig.parameters:
converted[param_name] = value
continue
param = sig.parameters[param_name]
param_type = param.annotation
if param_type == inspect.Parameter.empty or value is None:
converted[param_name] = value
continue
if not isinstance(value, str):
converted[param_name] = value
continue
try:
converted[param_name] = convert_string_to_type(value, param_type)
except (ValueError, TypeError, json.JSONDecodeError) as e:
raise ArgumentConversionError(
f"Failed to convert argument '{param_name}' to type {param_type}: {e}",
param_name=param_name,
) from e
except (ValueError, TypeError, AttributeError) as e:
raise ArgumentConversionError(f"Failed to process function arguments: {e}") from e
return converted
def convert_string_to_type(value: str, param_type: Any) -> Any:
origin = get_origin(param_type)
if origin is Union or origin is type(str | None):
args = get_args(param_type)
for arg_type in args:
if arg_type is not type(None):
with contextlib.suppress(ValueError, TypeError, json.JSONDecodeError):
return convert_string_to_type(value, arg_type)
return value
if hasattr(param_type, "__args__"):
args = getattr(param_type, "__args__", ())
if len(args) == 2 and type(None) in args:
non_none_type = args[0] if args[1] is type(None) else args[1]
with contextlib.suppress(ValueError, TypeError, json.JSONDecodeError):
return convert_string_to_type(value, non_none_type)
return value
return _convert_basic_types(value, param_type, origin)
def _convert_basic_types(value: str, param_type: Any, origin: Any = None) -> Any:
basic_type_converters: dict[Any, Callable[[str], Any]] = {
int: int,
float: float,
bool: _convert_to_bool,
str: str,
}
if param_type in basic_type_converters:
return basic_type_converters[param_type](value)
if list in (origin, param_type):
return _convert_to_list(value)
if dict in (origin, param_type):
return _convert_to_dict(value)
with contextlib.suppress(json.JSONDecodeError):
return json.loads(value)
return value
def _convert_to_bool(value: str) -> bool:
if value.lower() in ("true", "1", "yes", "on"):
return True
if value.lower() in ("false", "0", "no", "off"):
return False
return bool(value)
def _convert_to_list(value: str) -> list[Any]:
try:
parsed = json.loads(value)
if isinstance(parsed, list):
return parsed
except json.JSONDecodeError:
if "," in value:
return [item.strip() for item in value.split(",")]
return [value]
else:
return [parsed]
def _convert_to_dict(value: str) -> dict[str, Any]:
try:
parsed = json.loads(value)
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
return {}
else:
return {}

View File

@@ -0,0 +1,4 @@
from .browser_actions import browser_action
__all__ = ["browser_action"]

View File

@@ -0,0 +1,236 @@
from typing import Any, Literal, NoReturn
from strix.tools.registry import register_tool
from .tab_manager import BrowserTabManager, get_browser_tab_manager
BrowserAction = Literal[
"launch",
"goto",
"click",
"type",
"scroll_down",
"scroll_up",
"back",
"forward",
"new_tab",
"switch_tab",
"close_tab",
"wait",
"execute_js",
"double_click",
"hover",
"press_key",
"save_pdf",
"get_console_logs",
"view_source",
"close",
"list_tabs",
]
def _validate_url(action_name: str, url: str | None) -> None:
if not url:
raise ValueError(f"url parameter is required for {action_name} action")
def _validate_coordinate(action_name: str, coordinate: str | None) -> None:
if not coordinate:
raise ValueError(f"coordinate parameter is required for {action_name} action")
def _validate_text(action_name: str, text: str | None) -> None:
if not text:
raise ValueError(f"text parameter is required for {action_name} action")
def _validate_tab_id(action_name: str, tab_id: str | None) -> None:
if not tab_id:
raise ValueError(f"tab_id parameter is required for {action_name} action")
def _validate_js_code(action_name: str, js_code: str | None) -> None:
if not js_code:
raise ValueError(f"js_code parameter is required for {action_name} action")
def _validate_duration(action_name: str, duration: float | None) -> None:
if duration is None:
raise ValueError(f"duration parameter is required for {action_name} action")
def _validate_key(action_name: str, key: str | None) -> None:
if not key:
raise ValueError(f"key parameter is required for {action_name} action")
def _validate_file_path(action_name: str, file_path: str | None) -> None:
if not file_path:
raise ValueError(f"file_path parameter is required for {action_name} action")
def _handle_navigation_actions(
manager: BrowserTabManager,
action: str,
url: str | None = None,
tab_id: str | None = None,
) -> dict[str, Any]:
if action == "launch":
return manager.launch_browser(url)
if action == "goto":
_validate_url(action, url)
assert url is not None
return manager.goto_url(url, tab_id)
if action == "back":
return manager.back(tab_id)
if action == "forward":
return manager.forward(tab_id)
raise ValueError(f"Unknown navigation action: {action}")
def _handle_interaction_actions(
manager: BrowserTabManager,
action: str,
coordinate: str | None = None,
text: str | None = None,
key: str | None = None,
tab_id: str | None = None,
) -> dict[str, Any]:
if action in {"click", "double_click", "hover"}:
_validate_coordinate(action, coordinate)
assert coordinate is not None
action_map = {
"click": manager.click,
"double_click": manager.double_click,
"hover": manager.hover,
}
return action_map[action](coordinate, tab_id)
if action in {"scroll_down", "scroll_up"}:
direction = "down" if action == "scroll_down" else "up"
return manager.scroll(direction, tab_id)
if action == "type":
_validate_text(action, text)
assert text is not None
return manager.type_text(text, tab_id)
if action == "press_key":
_validate_key(action, key)
assert key is not None
return manager.press_key(key, tab_id)
raise ValueError(f"Unknown interaction action: {action}")
def _raise_unknown_action(action: str) -> NoReturn:
raise ValueError(f"Unknown action: {action}")
def _handle_tab_actions(
manager: BrowserTabManager,
action: str,
url: str | None = None,
tab_id: str | None = None,
) -> dict[str, Any]:
if action == "new_tab":
return manager.new_tab(url)
if action == "switch_tab":
_validate_tab_id(action, tab_id)
assert tab_id is not None
return manager.switch_tab(tab_id)
if action == "close_tab":
_validate_tab_id(action, tab_id)
assert tab_id is not None
return manager.close_tab(tab_id)
if action == "list_tabs":
return manager.list_tabs()
raise ValueError(f"Unknown tab action: {action}")
def _handle_utility_actions(
manager: BrowserTabManager,
action: str,
duration: float | None = None,
js_code: str | None = None,
file_path: str | None = None,
tab_id: str | None = None,
clear: bool = False,
) -> dict[str, Any]:
if action == "wait":
_validate_duration(action, duration)
assert duration is not None
return manager.wait_browser(duration, tab_id)
if action == "execute_js":
_validate_js_code(action, js_code)
assert js_code is not None
return manager.execute_js(js_code, tab_id)
if action == "save_pdf":
_validate_file_path(action, file_path)
assert file_path is not None
return manager.save_pdf(file_path, tab_id)
if action == "get_console_logs":
return manager.get_console_logs(tab_id, clear)
if action == "view_source":
return manager.view_source(tab_id)
if action == "close":
return manager.close_browser()
raise ValueError(f"Unknown utility action: {action}")
@register_tool
def browser_action(
action: BrowserAction,
url: str | None = None,
coordinate: str | None = None,
text: str | None = None,
tab_id: str | None = None,
js_code: str | None = None,
duration: float | None = None,
key: str | None = None,
file_path: str | None = None,
clear: bool = False,
) -> dict[str, Any]:
manager = get_browser_tab_manager()
try:
navigation_actions = {"launch", "goto", "back", "forward"}
interaction_actions = {
"click",
"type",
"double_click",
"hover",
"press_key",
"scroll_down",
"scroll_up",
}
tab_actions = {"new_tab", "switch_tab", "close_tab", "list_tabs"}
utility_actions = {
"wait",
"execute_js",
"save_pdf",
"get_console_logs",
"view_source",
"close",
}
if action in navigation_actions:
return _handle_navigation_actions(manager, action, url, tab_id)
if action in interaction_actions:
return _handle_interaction_actions(manager, action, coordinate, text, key, tab_id)
if action in tab_actions:
return _handle_tab_actions(manager, action, url, tab_id)
if action in utility_actions:
return _handle_utility_actions(
manager, action, duration, js_code, file_path, tab_id, clear
)
_raise_unknown_action(action)
except (ValueError, RuntimeError) as e:
return {
"error": str(e),
"tab_id": tab_id,
"screenshot": "",
"is_running": False,
}

View File

@@ -0,0 +1,183 @@
<?xml version="1.0" ?>
<tools>
<tool name="browser_action">
<description>Perform browser actions using a Playwright-controlled browser with multiple tabs.
The browser is PERSISTENT and remains active until explicitly closed, allowing for
multi-step workflows and long-running processes across multiple tabs.</description>
<parameters>
<parameter name="action" type="string" required="true">
</parameter>
<parameter name="url" type="string" required="false">
<description>Required for 'launch', 'goto', and optionally for 'new_tab' actions. The URL to launch the browser at, navigate to, or load in new tab. Must include appropriate protocol (e.g., http://, https://, file://).</description>
</parameter>
<parameter name="coordinate" type="string" required="false">
<description>Required for 'click', 'double_click', and 'hover' actions. Format: "x,y" (e.g., "432,321"). Coordinates should target the center of elements (buttons, links, etc.). Must be within the browser viewport resolution. Be very careful to calculate the coordinates correctly based on the previous screenshot.</description>
</parameter>
<parameter name="text" type="string" required="false">
<description>Required for 'type' action. The text to type in the field.</description>
</parameter>
<parameter name="tab_id" type="string" required="false">
<description>Required for 'switch_tab' and 'close_tab' actions. Optional for other actions to specify which tab to operate on. The ID of the tab to operate on. The first tab created during 'launch' has ID "tab_1". If not provided, actions will operate on the currently active tab.</description>
</parameter>
<parameter name="js_code" type="string" required="false">
<description>Required for 'execute_js' action. JavaScript code to execute in the page context. The code runs in the context of the current page and has access to the DOM and all page-defined variables and functions. The last evaluated expression's value is returned in the response.</description>
</parameter>
<parameter name="duration" type="string" required="false">
<description>Required for 'wait' action. Number of seconds to pause execution. Can be fractional (e.g., 0.5 for half a second).</description>
</parameter>
<parameter name="key" type="string" required="false">
<description>Required for 'press_key' action. The key to press. Valid values include: - Single characters: 'a'-'z', 'A'-'Z', '0'-'9' - Special keys: 'Enter', 'Escape', 'ArrowLeft', 'ArrowRight', etc. - Modifier keys: 'Shift', 'Control', 'Alt', 'Meta' - Function keys: 'F1'-'F12'</description>
</parameter>
<parameter name="file_path" type="string" required="false">
<description>Required for 'save_pdf' action. The file path where to save the PDF.</description>
</parameter>
<parameter name="clear" type="boolean" required="false">
<description>For 'get_console_logs' action: whether to clear console logs after retrieving them. Default is False (keep logs).</description>
</parameter>
</parameters>
<returns type="Dict[str, Any]">
<description>Response containing: - screenshot: Base64 encoded PNG of the current page state - url: Current page URL - title: Current page title - viewport: Current browser viewport dimensions - tab_id: ID of the current active tab - all_tabs: Dict of all open tab IDs and their URLs - message: Status message about the action performed - js_result: Result of JavaScript execution (for execute_js action) - pdf_saved: File path of saved PDF (for save_pdf action) - console_logs: Array of console messages (for get_console_logs action) Limited to 50KB total and 200 most recent logs. Individual messages truncated at 1KB. - page_source: HTML source code (for view_source action) Large pages are truncated to 100KB (keeping beginning and end sections).</description>
</returns>
<notes>
Important usage rules:
1. PERSISTENCE: The browser remains active and maintains its state until
explicitly closed with the 'close' action. This allows for multi-step workflows
across multiple tool calls and tabs.
2. Browser interaction MUST start with 'launch' and end with 'close'.
3. Only one action can be performed per call.
4. To visit a new URL not reachable from current page, either:
- Use 'goto' action
- Open a new tab with the URL
- Close browser and relaunch
5. Click coordinates must be derived from the most recent screenshot.
6. You MUST click on the center of the element, not the edge. You MUST calculate
the coordinates correctly based on the previous screenshot, otherwise the click
will fail. After clicking, check the new screenshot to verify the click was
successful.
7. Tab management:
- First tab from 'launch' is "tab_1"
- New tabs are numbered sequentially ("tab_2", "tab_3", etc.)
- Must have at least one tab open at all times
- Actions affect the currently active tab unless tab_id is specified
8. JavaScript execution (following Playwright evaluation patterns):
- Code runs in the browser page context, not the tool context
- Has access to DOM (document, window, etc.) and page variables/functions
- The LAST EVALUATED EXPRESSION is automatically returned - no return statement needed
- For simple values: document.title (returns the title)
- For objects: {title: document.title, url: location.href} (returns the object)
- For async operations: Use await and the promise result will be returned
- AVOID explicit return statements - they can break evaluation
- object literals must be wrapped in paranthesis when they are the final expression
- Variables from tool context are NOT available - pass data as parameters if needed
- Examples of correct patterns:
* Single value: document.querySelectorAll('img').length
* Object result: {images: document.images.length, links: document.links.length}
* Async operation: await fetch(location.href).then(r => r.status)
* DOM manipulation: document.body.style.backgroundColor = 'red'; 'background changed'
9. Wait action:
- Time is specified in seconds
- Can be used to wait for page loads, animations, etc.
- Can be fractional (e.g., 0.5 seconds)
- Screenshot is captured after the wait
10. The browser can operate concurrently with other tools. You may invoke
terminal, python, or other tools (in separate assistant messages) while maintaining
the active browser session, enabling sophisticated multi-tool workflows.
11. Keyboard actions:
- Use press_key for individual key presses
- Use type for typing regular text
- Some keys have special names based on Playwright's key documentation
12. All code in the js_code parameter is executed as-is - there's no need to
escape special characters or worry about formatting. Just write your JavaScript
code normally. It can be single line or multi-line.
13. For form filling, click on the field first, then use 'type' to enter text.
14. The browser runs in headless mode using Chrome engine for security and performance.
</notes>
<examples>
# Launch browser at URL (creates tab_1)
<function=browser_action>
<parameter=action>launch</parameter>
<parameter=url>https://example.com</parameter>
</function>
# Navigate to different URL
<function=browser_action>
<parameter=action>goto</parameter>
<parameter=url>https://github.com</parameter>
</function>
# Open new tab with different URL
<function=browser_action>
<parameter=action>new_tab</parameter>
<parameter=url>https://another-site.com</parameter>
</function>
# Wait for page load
<function=browser_action>
<parameter=action>wait</parameter>
<parameter=duration>2.5</parameter>
</function>
# Click login button at coordinates from screenshot
<function=browser_action>
<parameter=action>click</parameter>
<parameter=coordinate>450,300</parameter>
</function>
# Click username field and type
<function=browser_action>
<parameter=action>click</parameter>
<parameter=coordinate>400,200</parameter>
</function>
<function=browser_action>
<parameter=action>type</parameter>
<parameter=text>user@example.com</parameter>
</function>
# Click password field and type
<function=browser_action>
<parameter=action>click</parameter>
<parameter=coordinate>400,250</parameter>
</function>
<function=browser_action>
<parameter=action>type</parameter>
<parameter=text>mypassword123</parameter>
</function>
# Press Enter key
<function=browser_action>
<parameter=action>press_key</parameter>
<parameter=key>Enter</parameter>
</function>
# Execute JavaScript to get page stats (correct pattern - no return statement)
<function=browser_action>
<parameter=action>execute_js</parameter>
<parameter=js_code>const images = document.querySelectorAll('img');
const links = document.querySelectorAll('a');
{
images: images.length,
links: links.length,
title: document.title
}</parameter>
</function>
# Scroll down
<function=browser_action>
<parameter=action>scroll_down</parameter>
</function>
# Get console logs
<function=browser_action>
<parameter=action>get_console_logs</parameter>
</function>
# View page source
<function=browser_action>
<parameter=action>view_source</parameter>
</function>
</examples>
</tool>
</tools>

View File

@@ -0,0 +1,533 @@
import asyncio
import base64
import logging
import threading
from pathlib import Path
from typing import Any, cast
from playwright.async_api import Browser, BrowserContext, Page, Playwright, async_playwright
logger = logging.getLogger(__name__)
MAX_PAGE_SOURCE_LENGTH = 20_000
MAX_CONSOLE_LOG_LENGTH = 30_000
MAX_INDIVIDUAL_LOG_LENGTH = 1_000
MAX_CONSOLE_LOGS_COUNT = 200
MAX_JS_RESULT_LENGTH = 5_000
class BrowserInstance:
def __init__(self) -> None:
self.is_running = True
self._execution_lock = threading.Lock()
self.playwright: Playwright | None = None
self.browser: Browser | None = None
self.context: BrowserContext | None = None
self.pages: dict[str, Page] = {}
self.current_page_id: str | None = None
self._next_tab_id = 1
self.console_logs: dict[str, list[dict[str, Any]]] = {}
self._loop: asyncio.AbstractEventLoop | None = None
self._loop_thread: threading.Thread | None = None
self._start_event_loop()
def _start_event_loop(self) -> None:
def run_loop() -> None:
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._loop.run_forever()
self._loop_thread = threading.Thread(target=run_loop, daemon=True)
self._loop_thread.start()
while self._loop is None:
threading.Event().wait(0.01)
def _run_async(self, coro: Any) -> dict[str, Any]:
if not self._loop or not self.is_running:
raise RuntimeError("Browser instance is not running")
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
return cast("dict[str, Any]", future.result(timeout=30)) # 30 second timeout
async def _setup_console_logging(self, page: Page, tab_id: str) -> None:
self.console_logs[tab_id] = []
def handle_console(msg: Any) -> None:
text = msg.text
if len(text) > MAX_INDIVIDUAL_LOG_LENGTH:
text = text[:MAX_INDIVIDUAL_LOG_LENGTH] + "... [TRUNCATED]"
log_entry = {
"type": msg.type,
"text": text,
"location": msg.location,
"timestamp": asyncio.get_event_loop().time(),
}
self.console_logs[tab_id].append(log_entry)
if len(self.console_logs[tab_id]) > MAX_CONSOLE_LOGS_COUNT:
self.console_logs[tab_id] = self.console_logs[tab_id][-MAX_CONSOLE_LOGS_COUNT:]
page.on("console", handle_console)
async def _launch_browser(self, url: str | None = None) -> dict[str, Any]:
self.playwright = await async_playwright().start()
self.browser = await self.playwright.chromium.launch(
headless=True,
args=[
"--no-sandbox",
"--disable-dev-shm-usage",
"--disable-gpu",
"--disable-web-security",
"--disable-features=VizDisplayCompositor",
],
)
self.context = await self.browser.new_context(
viewport={"width": 1280, "height": 720},
user_agent=(
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 "
"(KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
),
)
page = await self.context.new_page()
tab_id = f"tab_{self._next_tab_id}"
self._next_tab_id += 1
self.pages[tab_id] = page
self.current_page_id = tab_id
await self._setup_console_logging(page, tab_id)
if url:
await page.goto(url, wait_until="domcontentloaded")
return await self._get_page_state(tab_id)
async def _get_page_state(self, tab_id: str | None = None) -> dict[str, Any]:
if not tab_id:
tab_id = self.current_page_id
if not tab_id or tab_id not in self.pages:
raise ValueError(f"Tab '{tab_id}' not found")
page = self.pages[tab_id]
await asyncio.sleep(2)
screenshot_bytes = await page.screenshot(type="png", full_page=False)
screenshot_b64 = base64.b64encode(screenshot_bytes).decode("utf-8")
url = page.url
title = await page.title()
viewport = page.viewport_size
all_tabs = {}
for tid, tab_page in self.pages.items():
all_tabs[tid] = {
"url": tab_page.url,
"title": await tab_page.title() if not tab_page.is_closed() else "Closed",
}
return {
"screenshot": screenshot_b64,
"url": url,
"title": title,
"viewport": viewport,
"tab_id": tab_id,
"all_tabs": all_tabs,
}
def launch(self, url: str | None = None) -> dict[str, Any]:
with self._execution_lock:
if self.browser is not None:
raise ValueError("Browser is already launched")
return self._run_async(self._launch_browser(url))
def goto(self, url: str, tab_id: str | None = None) -> dict[str, Any]:
with self._execution_lock:
return self._run_async(self._goto(url, tab_id))
async def _goto(self, url: str, tab_id: str | None = None) -> dict[str, Any]:
if not tab_id:
tab_id = self.current_page_id
if not tab_id or tab_id not in self.pages:
raise ValueError(f"Tab '{tab_id}' not found")
page = self.pages[tab_id]
await page.goto(url, wait_until="domcontentloaded")
return await self._get_page_state(tab_id)
def click(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
with self._execution_lock:
return self._run_async(self._click(coordinate, tab_id))
async def _click(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
if not tab_id:
tab_id = self.current_page_id
if not tab_id or tab_id not in self.pages:
raise ValueError(f"Tab '{tab_id}' not found")
try:
x, y = map(int, coordinate.split(","))
except ValueError as e:
raise ValueError(f"Invalid coordinate format: {coordinate}. Use 'x,y'") from e
page = self.pages[tab_id]
await page.mouse.click(x, y)
return await self._get_page_state(tab_id)
def type_text(self, text: str, tab_id: str | None = None) -> dict[str, Any]:
with self._execution_lock:
return self._run_async(self._type_text(text, tab_id))
async def _type_text(self, text: str, tab_id: str | None = None) -> dict[str, Any]:
if not tab_id:
tab_id = self.current_page_id
if not tab_id or tab_id not in self.pages:
raise ValueError(f"Tab '{tab_id}' not found")
page = self.pages[tab_id]
await page.keyboard.type(text)
return await self._get_page_state(tab_id)
def scroll(self, direction: str, tab_id: str | None = None) -> dict[str, Any]:
with self._execution_lock:
return self._run_async(self._scroll(direction, tab_id))
async def _scroll(self, direction: str, tab_id: str | None = None) -> dict[str, Any]:
if not tab_id:
tab_id = self.current_page_id
if not tab_id or tab_id not in self.pages:
raise ValueError(f"Tab '{tab_id}' not found")
page = self.pages[tab_id]
if direction == "down":
await page.keyboard.press("PageDown")
elif direction == "up":
await page.keyboard.press("PageUp")
else:
raise ValueError(f"Invalid scroll direction: {direction}")
return await self._get_page_state(tab_id)
def back(self, tab_id: str | None = None) -> dict[str, Any]:
with self._execution_lock:
return self._run_async(self._back(tab_id))
async def _back(self, tab_id: str | None = None) -> dict[str, Any]:
if not tab_id:
tab_id = self.current_page_id
if not tab_id or tab_id not in self.pages:
raise ValueError(f"Tab '{tab_id}' not found")
page = self.pages[tab_id]
await page.go_back(wait_until="domcontentloaded")
return await self._get_page_state(tab_id)
def forward(self, tab_id: str | None = None) -> dict[str, Any]:
with self._execution_lock:
return self._run_async(self._forward(tab_id))
async def _forward(self, tab_id: str | None = None) -> dict[str, Any]:
if not tab_id:
tab_id = self.current_page_id
if not tab_id or tab_id not in self.pages:
raise ValueError(f"Tab '{tab_id}' not found")
page = self.pages[tab_id]
await page.go_forward(wait_until="domcontentloaded")
return await self._get_page_state(tab_id)
def new_tab(self, url: str | None = None) -> dict[str, Any]:
with self._execution_lock:
return self._run_async(self._new_tab(url))
async def _new_tab(self, url: str | None = None) -> dict[str, Any]:
if not self.context:
raise ValueError("Browser not launched")
page = await self.context.new_page()
tab_id = f"tab_{self._next_tab_id}"
self._next_tab_id += 1
self.pages[tab_id] = page
self.current_page_id = tab_id
await self._setup_console_logging(page, tab_id)
if url:
await page.goto(url, wait_until="domcontentloaded")
return await self._get_page_state(tab_id)
def switch_tab(self, tab_id: str) -> dict[str, Any]:
with self._execution_lock:
return self._run_async(self._switch_tab(tab_id))
async def _switch_tab(self, tab_id: str) -> dict[str, Any]:
if tab_id not in self.pages:
raise ValueError(f"Tab '{tab_id}' not found")
self.current_page_id = tab_id
return await self._get_page_state(tab_id)
def close_tab(self, tab_id: str) -> dict[str, Any]:
with self._execution_lock:
return self._run_async(self._close_tab(tab_id))
async def _close_tab(self, tab_id: str) -> dict[str, Any]:
if tab_id not in self.pages:
raise ValueError(f"Tab '{tab_id}' not found")
if len(self.pages) == 1:
raise ValueError("Cannot close the last tab")
page = self.pages.pop(tab_id)
await page.close()
if tab_id in self.console_logs:
del self.console_logs[tab_id]
if self.current_page_id == tab_id:
self.current_page_id = next(iter(self.pages.keys()))
return await self._get_page_state(self.current_page_id)
def wait(self, duration: float, tab_id: str | None = None) -> dict[str, Any]:
with self._execution_lock:
return self._run_async(self._wait(duration, tab_id))
async def _wait(self, duration: float, tab_id: str | None = None) -> dict[str, Any]:
await asyncio.sleep(duration)
return await self._get_page_state(tab_id)
def execute_js(self, js_code: str, tab_id: str | None = None) -> dict[str, Any]:
with self._execution_lock:
return self._run_async(self._execute_js(js_code, tab_id))
async def _execute_js(self, js_code: str, tab_id: str | None = None) -> dict[str, Any]:
if not tab_id:
tab_id = self.current_page_id
if not tab_id or tab_id not in self.pages:
raise ValueError(f"Tab '{tab_id}' not found")
page = self.pages[tab_id]
try:
result = await page.evaluate(js_code)
except Exception as e: # noqa: BLE001
result = {
"error": True,
"error_type": type(e).__name__,
"error_message": str(e),
}
result_str = str(result)
if len(result_str) > MAX_JS_RESULT_LENGTH:
result = result_str[:MAX_JS_RESULT_LENGTH] + "... [JS result truncated at 5k chars]"
state = await self._get_page_state(tab_id)
state["js_result"] = result
return state
def get_console_logs(self, tab_id: str | None = None, clear: bool = False) -> dict[str, Any]:
with self._execution_lock:
return self._run_async(self._get_console_logs(tab_id, clear))
async def _get_console_logs(
self, tab_id: str | None = None, clear: bool = False
) -> dict[str, Any]:
if not tab_id:
tab_id = self.current_page_id
if not tab_id or tab_id not in self.pages:
raise ValueError(f"Tab '{tab_id}' not found")
logs = self.console_logs.get(tab_id, [])
total_length = sum(len(str(log)) for log in logs)
if total_length > MAX_CONSOLE_LOG_LENGTH:
truncated_logs: list[dict[str, Any]] = []
current_length = 0
for log in reversed(logs):
log_length = len(str(log))
if current_length + log_length <= MAX_CONSOLE_LOG_LENGTH:
truncated_logs.insert(0, log)
current_length += log_length
else:
break
if len(truncated_logs) < len(logs):
truncation_notice = {
"type": "info",
"text": (
f"[TRUNCATED: {len(logs) - len(truncated_logs)} older logs "
f"removed to stay within {MAX_CONSOLE_LOG_LENGTH} character limit]"
),
"location": {},
"timestamp": 0,
}
truncated_logs.insert(0, truncation_notice)
logs = truncated_logs
if clear:
self.console_logs[tab_id] = []
state = await self._get_page_state(tab_id)
state["console_logs"] = logs
return state
def view_source(self, tab_id: str | None = None) -> dict[str, Any]:
with self._execution_lock:
return self._run_async(self._view_source(tab_id))
async def _view_source(self, tab_id: str | None = None) -> dict[str, Any]:
if not tab_id:
tab_id = self.current_page_id
if not tab_id or tab_id not in self.pages:
raise ValueError(f"Tab '{tab_id}' not found")
page = self.pages[tab_id]
source = await page.content()
original_length = len(source)
if original_length > MAX_PAGE_SOURCE_LENGTH:
truncation_message = (
f"\n\n<!-- [TRUNCATED: {original_length - MAX_PAGE_SOURCE_LENGTH} "
"characters removed] -->\n\n"
)
available_space = MAX_PAGE_SOURCE_LENGTH - len(truncation_message)
truncate_point = available_space // 2
source = source[:truncate_point] + truncation_message + source[-truncate_point:]
state = await self._get_page_state(tab_id)
state["page_source"] = source
return state
def double_click(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
with self._execution_lock:
return self._run_async(self._double_click(coordinate, tab_id))
async def _double_click(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
if not tab_id:
tab_id = self.current_page_id
if not tab_id or tab_id not in self.pages:
raise ValueError(f"Tab '{tab_id}' not found")
try:
x, y = map(int, coordinate.split(","))
except ValueError as e:
raise ValueError(f"Invalid coordinate format: {coordinate}. Use 'x,y'") from e
page = self.pages[tab_id]
await page.mouse.dblclick(x, y)
return await self._get_page_state(tab_id)
def hover(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
with self._execution_lock:
return self._run_async(self._hover(coordinate, tab_id))
async def _hover(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
if not tab_id:
tab_id = self.current_page_id
if not tab_id or tab_id not in self.pages:
raise ValueError(f"Tab '{tab_id}' not found")
try:
x, y = map(int, coordinate.split(","))
except ValueError as e:
raise ValueError(f"Invalid coordinate format: {coordinate}. Use 'x,y'") from e
page = self.pages[tab_id]
await page.mouse.move(x, y)
return await self._get_page_state(tab_id)
def press_key(self, key: str, tab_id: str | None = None) -> dict[str, Any]:
with self._execution_lock:
return self._run_async(self._press_key(key, tab_id))
async def _press_key(self, key: str, tab_id: str | None = None) -> dict[str, Any]:
if not tab_id:
tab_id = self.current_page_id
if not tab_id or tab_id not in self.pages:
raise ValueError(f"Tab '{tab_id}' not found")
page = self.pages[tab_id]
await page.keyboard.press(key)
return await self._get_page_state(tab_id)
def save_pdf(self, file_path: str, tab_id: str | None = None) -> dict[str, Any]:
with self._execution_lock:
return self._run_async(self._save_pdf(file_path, tab_id))
async def _save_pdf(self, file_path: str, tab_id: str | None = None) -> dict[str, Any]:
if not tab_id:
tab_id = self.current_page_id
if not tab_id or tab_id not in self.pages:
raise ValueError(f"Tab '{tab_id}' not found")
if not Path(file_path).is_absolute():
file_path = str(Path("/workspace") / file_path)
page = self.pages[tab_id]
await page.pdf(path=file_path)
state = await self._get_page_state(tab_id)
state["pdf_saved"] = file_path
return state
def close(self) -> None:
with self._execution_lock:
self.is_running = False
if self._loop:
asyncio.run_coroutine_threadsafe(self._close_browser(), self._loop)
self._loop.call_soon_threadsafe(self._loop.stop)
if self._loop_thread:
self._loop_thread.join(timeout=5)
async def _close_browser(self) -> None:
try:
if self.browser:
await self.browser.close()
if self.playwright:
await self.playwright.stop()
except (OSError, RuntimeError) as e:
logger.warning(f"Error closing browser: {e}")
def is_alive(self) -> bool:
return self.is_running and self.browser is not None and self.browser.is_connected()

View File

@@ -0,0 +1,342 @@
import atexit
import contextlib
import signal
import sys
import threading
from typing import Any
from .browser_instance import BrowserInstance
class BrowserTabManager:
def __init__(self) -> None:
self.browser_instance: BrowserInstance | None = None
self._lock = threading.Lock()
self._register_cleanup_handlers()
def launch_browser(self, url: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is not None:
raise ValueError("Browser is already launched")
try:
self.browser_instance = BrowserInstance()
result = self.browser_instance.launch(url)
result["message"] = "Browser launched successfully"
except (OSError, ValueError, RuntimeError) as e:
if self.browser_instance:
self.browser_instance = None
raise RuntimeError(f"Failed to launch browser: {e}") from e
else:
return result
def goto_url(self, url: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.goto(url, tab_id)
result["message"] = f"Navigated to {url}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to navigate to URL: {e}") from e
else:
return result
def click(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.click(coordinate, tab_id)
result["message"] = f"Clicked at {coordinate}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to click: {e}") from e
else:
return result
def type_text(self, text: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.type_text(text, tab_id)
result["message"] = f"Typed text: {text[:50]}{'...' if len(text) > 50 else ''}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to type text: {e}") from e
else:
return result
def scroll(self, direction: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.scroll(direction, tab_id)
result["message"] = f"Scrolled {direction}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to scroll: {e}") from e
else:
return result
def back(self, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.back(tab_id)
result["message"] = "Navigated back"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to go back: {e}") from e
else:
return result
def forward(self, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.forward(tab_id)
result["message"] = "Navigated forward"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to go forward: {e}") from e
else:
return result
def new_tab(self, url: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.new_tab(url)
result["message"] = f"Created new tab {result.get('tab_id', '')}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to create new tab: {e}") from e
else:
return result
def switch_tab(self, tab_id: str) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.switch_tab(tab_id)
result["message"] = f"Switched to tab {tab_id}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to switch tab: {e}") from e
else:
return result
def close_tab(self, tab_id: str) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.close_tab(tab_id)
result["message"] = f"Closed tab {tab_id}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to close tab: {e}") from e
else:
return result
def wait_browser(self, duration: float, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.wait(duration, tab_id)
result["message"] = f"Waited {duration}s"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to wait: {e}") from e
else:
return result
def execute_js(self, js_code: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.execute_js(js_code, tab_id)
result["message"] = "JavaScript executed successfully"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to execute JavaScript: {e}") from e
else:
return result
def double_click(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.double_click(coordinate, tab_id)
result["message"] = f"Double clicked at {coordinate}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to double click: {e}") from e
else:
return result
def hover(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.hover(coordinate, tab_id)
result["message"] = f"Hovered at {coordinate}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to hover: {e}") from e
else:
return result
def press_key(self, key: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.press_key(key, tab_id)
result["message"] = f"Pressed key {key}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to press key: {e}") from e
else:
return result
def save_pdf(self, file_path: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.save_pdf(file_path, tab_id)
result["message"] = f"Page saved as PDF: {file_path}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to save PDF: {e}") from e
else:
return result
def get_console_logs(self, tab_id: str | None = None, clear: bool = False) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.get_console_logs(tab_id, clear)
action_text = "cleared and retrieved" if clear else "retrieved"
logs = result.get("console_logs", [])
truncated = any(log.get("text", "").startswith("[TRUNCATED:") for log in logs)
truncated_text = " (truncated)" if truncated else ""
result["message"] = (
f"Console logs {action_text} for tab "
f"{result.get('tab_id', 'current')}{truncated_text}"
)
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to get console logs: {e}") from e
else:
return result
def view_source(self, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.view_source(tab_id)
result["message"] = "Page source retrieved"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to get page source: {e}") from e
else:
return result
def list_tabs(self) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
return {"tabs": {}, "total_count": 0, "current_tab": None}
try:
tab_info = {}
for tid, tab_page in self.browser_instance.pages.items():
try:
tab_info[tid] = {
"url": tab_page.url,
"title": "Unknown" if tab_page.is_closed() else "Active",
"is_current": tid == self.browser_instance.current_page_id,
}
except (AttributeError, RuntimeError):
tab_info[tid] = {
"url": "Unknown",
"title": "Closed",
"is_current": False,
}
return {
"tabs": tab_info,
"total_count": len(tab_info),
"current_tab": self.browser_instance.current_page_id,
}
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to list tabs: {e}") from e
def close_browser(self) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
try:
self.browser_instance.close()
self.browser_instance = None
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to close browser: {e}") from e
else:
return {
"message": "Browser closed successfully",
"screenshot": "",
"is_running": False,
}
def cleanup_dead_browser(self) -> None:
with self._lock:
if self.browser_instance and not self.browser_instance.is_alive():
with contextlib.suppress(Exception):
self.browser_instance.close()
self.browser_instance = None
def close_all(self) -> None:
with self._lock:
if self.browser_instance:
with contextlib.suppress(Exception):
self.browser_instance.close()
self.browser_instance = None
def _register_cleanup_handlers(self) -> None:
atexit.register(self.close_all)
signal.signal(signal.SIGTERM, self._signal_handler)
signal.signal(signal.SIGINT, self._signal_handler)
if hasattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, self._signal_handler)
def _signal_handler(self, _signum: int, _frame: Any) -> None:
self.close_all()
sys.exit(0)
_browser_tab_manager = BrowserTabManager()
def get_browser_tab_manager() -> BrowserTabManager:
return _browser_tab_manager

302
strix/tools/executor.py Normal file
View File

@@ -0,0 +1,302 @@
import inspect
import os
from typing import Any
import httpx
if os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "false":
from strix.runtime import get_runtime
from .argument_parser import convert_arguments
from .registry import (
get_tool_by_name,
get_tool_names,
needs_agent_state,
should_execute_in_sandbox,
)
async def execute_tool(tool_name: str, agent_state: Any | None = None, **kwargs: Any) -> Any:
execute_in_sandbox = should_execute_in_sandbox(tool_name)
sandbox_mode = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
if execute_in_sandbox and not sandbox_mode:
return await _execute_tool_in_sandbox(tool_name, agent_state, **kwargs)
return await _execute_tool_locally(tool_name, agent_state, **kwargs)
async def _execute_tool_in_sandbox(tool_name: str, agent_state: Any, **kwargs: Any) -> Any:
if not hasattr(agent_state, "sandbox_id") or not agent_state.sandbox_id:
raise ValueError("Agent state with a valid sandbox_id is required for sandbox execution.")
if not hasattr(agent_state, "sandbox_token") or not agent_state.sandbox_token:
raise ValueError(
"Agent state with a valid sandbox_token is required for sandbox execution."
)
if (
not hasattr(agent_state, "sandbox_info")
or "tool_server_port" not in agent_state.sandbox_info
):
raise ValueError(
"Agent state with a valid sandbox_info containing tool_server_port is required."
)
runtime = get_runtime()
tool_server_port = agent_state.sandbox_info["tool_server_port"]
server_url = await runtime.get_sandbox_url(agent_state.sandbox_id, tool_server_port)
request_url = f"{server_url}/execute"
request_data = {
"tool_name": tool_name,
"kwargs": kwargs,
}
headers = {
"Authorization": f"Bearer {agent_state.sandbox_token}",
"Content-Type": "application/json",
}
async with httpx.AsyncClient() as client:
try:
response = await client.post(
request_url, json=request_data, headers=headers, timeout=None
)
response.raise_for_status()
response_data = response.json()
if response_data.get("error"):
raise RuntimeError(f"Sandbox execution error: {response_data['error']}")
return response_data.get("result")
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise RuntimeError("Authentication failed: Invalid or missing sandbox token") from e
raise RuntimeError(f"HTTP error calling tool server: {e.response.status_code}") from e
except httpx.RequestError as e:
raise RuntimeError(f"Request error calling tool server: {e}") from e
async def _execute_tool_locally(tool_name: str, agent_state: Any | None, **kwargs: Any) -> Any:
tool_func = get_tool_by_name(tool_name)
if not tool_func:
raise ValueError(f"Tool '{tool_name}' not found")
converted_kwargs = convert_arguments(tool_func, kwargs)
if needs_agent_state(tool_name):
if agent_state is None:
raise ValueError(f"Tool '{tool_name}' requires agent_state but none was provided.")
result = tool_func(agent_state=agent_state, **converted_kwargs)
else:
result = tool_func(**converted_kwargs)
return await result if inspect.isawaitable(result) else result
def validate_tool_availability(tool_name: str | None) -> tuple[bool, str]:
if tool_name is None:
return False, "Tool name is missing"
if tool_name not in get_tool_names():
return False, f"Tool '{tool_name}' is not available"
return True, ""
async def execute_tool_with_validation(
tool_name: str | None, agent_state: Any | None = None, **kwargs: Any
) -> Any:
is_valid, error_msg = validate_tool_availability(tool_name)
if not is_valid:
return f"Error: {error_msg}"
assert tool_name is not None
try:
result = await execute_tool(tool_name, agent_state, **kwargs)
except Exception as e: # noqa: BLE001
error_str = str(e)
if len(error_str) > 500:
error_str = error_str[:500] + "... [truncated]"
return f"Error executing {tool_name}: {error_str}"
else:
return result
async def execute_tool_invocation(tool_inv: dict[str, Any], agent_state: Any | None = None) -> Any:
tool_name = tool_inv.get("toolName")
tool_args = tool_inv.get("args", {})
return await execute_tool_with_validation(tool_name, agent_state, **tool_args)
def _check_error_result(result: Any) -> tuple[bool, Any]:
is_error = False
error_payload: Any = None
if (isinstance(result, dict) and "error" in result) or (
isinstance(result, str) and result.strip().lower().startswith("error:")
):
is_error = True
error_payload = result
return is_error, error_payload
def _update_tracer_with_result(
tracer: Any, execution_id: Any, is_error: bool, result: Any, error_payload: Any
) -> None:
if not tracer or not execution_id:
return
try:
if is_error:
tracer.update_tool_execution(execution_id, "error", error_payload)
else:
tracer.update_tool_execution(execution_id, "completed", result)
except (ConnectionError, RuntimeError) as e:
error_msg = str(e)
if tracer and execution_id:
tracer.update_tool_execution(execution_id, "error", error_msg)
raise
def _format_tool_result(tool_name: str, result: Any) -> tuple[str, list[dict[str, Any]]]:
images: list[dict[str, Any]] = []
screenshot_data = extract_screenshot_from_result(result)
if screenshot_data:
images.append(
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{screenshot_data}"},
}
)
result_str = remove_screenshot_from_result(result)
else:
result_str = result
if result_str is None:
final_result_str = f"Tool {tool_name} executed successfully"
else:
final_result_str = str(result_str)
if len(final_result_str) > 10000:
start_part = final_result_str[:4000]
end_part = final_result_str[-4000:]
final_result_str = start_part + "\n\n... [middle content truncated] ...\n\n" + end_part
observation_xml = (
f"<tool_result>\n<tool_name>{tool_name}</tool_name>\n"
f"<result>{final_result_str}</result>\n</tool_result>"
)
return observation_xml, images
async def _execute_single_tool(
tool_inv: dict[str, Any],
agent_state: Any | None,
tracer: Any | None,
agent_id: str,
) -> tuple[str, list[dict[str, Any]], bool]:
tool_name = tool_inv.get("toolName", "unknown")
args = tool_inv.get("args", {})
execution_id = None
should_agent_finish = False
if tracer:
execution_id = tracer.log_tool_execution_start(agent_id, tool_name, args)
try:
result = await execute_tool_invocation(tool_inv, agent_state)
is_error, error_payload = _check_error_result(result)
if (
tool_name in ("finish_scan", "agent_finish")
and not is_error
and isinstance(result, dict)
):
if tool_name == "finish_scan":
should_agent_finish = result.get("scan_completed", False)
elif tool_name == "agent_finish":
should_agent_finish = result.get("agent_completed", False)
_update_tracer_with_result(tracer, execution_id, is_error, result, error_payload)
except (ConnectionError, RuntimeError, ValueError, TypeError, OSError) as e:
error_msg = str(e)
if tracer and execution_id:
tracer.update_tool_execution(execution_id, "error", error_msg)
raise
observation_xml, images = _format_tool_result(tool_name, result)
return observation_xml, images, should_agent_finish
def _get_tracer_and_agent_id(agent_state: Any | None) -> tuple[Any | None, str]:
try:
from strix.cli.tracer import get_global_tracer
tracer = get_global_tracer()
agent_id = agent_state.agent_id if agent_state else "unknown_agent"
except (ImportError, AttributeError):
tracer = None
agent_id = "unknown_agent"
return tracer, agent_id
async def process_tool_invocations(
tool_invocations: list[dict[str, Any]],
conversation_history: list[dict[str, Any]],
agent_state: Any | None = None,
) -> bool:
observation_parts: list[str] = []
all_images: list[dict[str, Any]] = []
should_agent_finish = False
tracer, agent_id = _get_tracer_and_agent_id(agent_state)
for tool_inv in tool_invocations:
observation_xml, images, tool_should_finish = await _execute_single_tool(
tool_inv, agent_state, tracer, agent_id
)
observation_parts.append(observation_xml)
all_images.extend(images)
if tool_should_finish:
should_agent_finish = True
if all_images:
content = [{"type": "text", "text": "Tool Results:\n\n" + "\n\n".join(observation_parts)}]
content.extend(all_images)
conversation_history.append({"role": "user", "content": content})
else:
observation_content = "Tool Results:\n\n" + "\n\n".join(observation_parts)
conversation_history.append({"role": "user", "content": observation_content})
return should_agent_finish
def extract_screenshot_from_result(result: Any) -> str | None:
if not isinstance(result, dict):
return None
screenshot = result.get("screenshot")
if isinstance(screenshot, str) and screenshot:
return screenshot
return None
def remove_screenshot_from_result(result: Any) -> Any:
if not isinstance(result, dict):
return result
result_copy = result.copy()
if "screenshot" in result_copy:
result_copy["screenshot"] = "[Image data extracted - see attached image]"
return result_copy

View File

@@ -0,0 +1,4 @@
from .file_edit_actions import list_files, search_files, str_replace_editor
__all__ = ["list_files", "search_files", "str_replace_editor"]

View File

@@ -0,0 +1,141 @@
import json
import re
from pathlib import Path
from typing import Any, cast
from openhands_aci import file_editor
from openhands_aci.utils.shell import run_shell_cmd
from strix.tools.registry import register_tool
def _parse_file_editor_output(output: str) -> dict[str, Any]:
try:
pattern = r"<oh_aci_output_[^>]+>\n(.*?)\n</oh_aci_output_[^>]+>"
match = re.search(pattern, output, re.DOTALL)
if match:
json_str = match.group(1)
data = json.loads(json_str)
return cast("dict[str, Any]", data)
return {"output": output, "error": None}
except (json.JSONDecodeError, AttributeError):
return {"output": output, "error": None}
@register_tool
def str_replace_editor(
command: str,
path: str,
file_text: str | None = None,
view_range: list[int] | None = None,
old_str: str | None = None,
new_str: str | None = None,
insert_line: int | None = None,
) -> dict[str, Any]:
try:
path_obj = Path(path)
if not path_obj.is_absolute():
path = str(Path("/workspace") / path_obj)
result = file_editor(
command=command,
path=path,
file_text=file_text,
view_range=view_range,
old_str=old_str,
new_str=new_str,
insert_line=insert_line,
)
parsed = _parse_file_editor_output(result)
if parsed.get("error"):
return {"error": parsed["error"]}
return {"content": parsed.get("output", result)}
except (OSError, ValueError) as e:
return {"error": f"Error in {command} operation: {e!s}"}
@register_tool
def list_files(
path: str,
recursive: bool = False,
) -> dict[str, Any]:
try:
path_obj = Path(path)
if not path_obj.is_absolute():
path = str(Path("/workspace") / path_obj)
path_obj = Path(path)
if not path_obj.exists():
return {"error": f"Directory not found: {path}"}
if not path_obj.is_dir():
return {"error": f"Path is not a directory: {path}"}
cmd = f"find '{path}' -type f -o -type d | head -500" if recursive else f"ls -1a '{path}'"
exit_code, stdout, stderr = run_shell_cmd(cmd)
if exit_code != 0:
return {"error": f"Error listing directory: {stderr}"}
items = stdout.strip().split("\n") if stdout.strip() else []
files = []
dirs = []
for item in items:
item_path = item if recursive else str(Path(path) / item)
item_path_obj = Path(item_path)
if item_path_obj.is_file():
files.append(item)
elif item_path_obj.is_dir():
dirs.append(item)
return {
"files": sorted(files),
"directories": sorted(dirs),
"total_files": len(files),
"total_dirs": len(dirs),
"path": path,
"recursive": recursive,
}
except (OSError, ValueError) as e:
return {"error": f"Error listing directory: {e!s}"}
@register_tool
def search_files(
path: str,
regex: str,
file_pattern: str = "*",
) -> dict[str, Any]:
try:
path_obj = Path(path)
if not path_obj.is_absolute():
path = str(Path("/workspace") / path_obj)
if not Path(path).exists():
return {"error": f"Directory not found: {path}"}
escaped_regex = regex.replace("'", "'\"'\"'")
cmd = f"rg --line-number --glob '{file_pattern}' '{escaped_regex}' '{path}'"
exit_code, stdout, stderr = run_shell_cmd(cmd)
if exit_code not in {0, 1}:
return {"error": f"Error searching files: {stderr}"}
return {"output": stdout if stdout else "No matches found"}
except (OSError, ValueError) as e:
return {"error": f"Error searching files: {e!s}"}
# ruff: noqa: TRY300

View File

@@ -0,0 +1,128 @@
<tools>
<tool name="list_files">
<description>List files and directories within the specified directory.</description>
<parameters>
<parameter name="path" type="string" required="true">
<description>Directory path to list</description>
</parameter>
<parameter name="recursive" type="boolean" required="false">
<description>Whether to list files recursively</description>
</parameter>
</parameters>
<returns type="Dict[str, Any]">
<description>Response containing: - files: List of files and directories - total_files: Total number of files found - total_dirs: Total number of directories found</description>
</returns>
<notes>
- Lists contents alphabetically
- Returns maximum 500 results to avoid overwhelming output
</notes>
<examples>
# List directory contents
<function=list_files>
<parameter=path>/home/user/project/src</parameter>
</function>
# Recursive listing
<function=list_files>
<parameter=path>/home/user/project/src</parameter>
<parameter=recursive>true</parameter>
</function>
</examples>
</tool>
<tool name="search_files">
<description>Perform a regex search across files in a directory.</description>
<parameters>
<parameter name="path" type="string" required="true">
<description>Directory path to search</description>
</parameter>
<parameter name="regex" type="string" required="true">
<description>Regular expression pattern to search for</description>
</parameter>
<parameter name="file_pattern" type="string" required="false">
<description>File pattern to filter (e.g., "*.py", "*.js")</description>
</parameter>
</parameters>
<returns type="Dict[str, Any]">
<description>Response containing: - output: The search results as a string</description>
</returns>
<notes>
- Searches recursively through subdirectories
- Uses ripgrep for fast searching
</notes>
<examples>
# Search Python files for a pattern
<function=search_files>
<parameter=path>/home/user/project/src</parameter>
<parameter=regex>def\s+process_data</parameter>
<parameter=file_pattern>*.py</parameter>
</function>
</examples>
</tool>
<tool name="str_replace_editor">
<description>A text editor tool for viewing, creating and editing files.</description>
<parameters>
<parameter name="command" type="string" required="true">
<description>Editor command to execute</description>
</parameter>
<parameter name="path" type="string" required="true">
<description>Path to the file to edit</description>
</parameter>
<parameter name="file_text" type="string" required="false">
<description>Required parameter of create command, with the content of the file to be created</description>
</parameter>
<parameter name="view_range" type="string" required="false">
<description>Optional parameter of view command when path points to a file. If none is given, the full file is shown. If provided, the file will be shown in the indicated line number range, e.g. [11, 12] will show lines 11 and 12. Indexing at 1 to start. Setting [start_line, -1] shows all lines from start_line to the end of the file</description>
</parameter>
<parameter name="old_str" type="string" required="false">
<description>Required parameter of str_replace command containing the string in path to replace</description>
</parameter>
<parameter name="new_str" type="string" required="false">
<description>Optional parameter of str_replace command containing the new string (if not given, no string will be added). Required parameter of insert command containing the string to insert</description>
</parameter>
<parameter name="insert_line" type="string" required="false">
<description>Required parameter of insert command. The new_str will be inserted AFTER the line insert_line of path</description>
</parameter>
</parameters>
<returns type="Dict[str, Any]">
<description>Response containing the result of the operation</description>
</returns>
<notes>
Command details:
- view: Show file contents, optionally with line range
- create: Create a new file with given content
- str_replace: Replace old_str with new_str in file
- insert: Insert new_str after the specified line number
- undo_edit: Revert the last edit made to the file
</notes>
<examples>
# View a file
<function=str_replace_editor>
<parameter=command>view</parameter>
<parameter=path>/home/user/project/file.py</parameter>
</function>
# Create a file
<function=str_replace_editor>
<parameter=command>create</parameter>
<parameter=path>/home/user/project/new_file.py</parameter>
<parameter=file_text>print("Hello World")</parameter>
</function>
# Replace text in file
<function=str_replace_editor>
<parameter=command>str_replace</parameter>
<parameter=path>/home/user/project/file.py</parameter>
<parameter=old_str>old_function()</parameter>
<parameter=new_str>new_function()</parameter>
</function>
# Insert text after line 10
<function=str_replace_editor>
<parameter=command>insert</parameter>
<parameter=path>/home/user/project/file.py</parameter>
<parameter=insert_line>10</parameter>
<parameter=new_str>print("Inserted line")</parameter>
</function>
</examples>
</tool>
</tools>

View File

@@ -0,0 +1,4 @@
from .finish_actions import finish_scan
__all__ = ["finish_scan"]

View File

@@ -0,0 +1,174 @@
from typing import Any
from strix.tools.registry import register_tool
def _validate_root_agent(agent_state: Any) -> dict[str, Any] | None:
if (
agent_state is not None
and hasattr(agent_state, "parent_id")
and agent_state.parent_id is not None
):
return {
"success": False,
"message": (
"This tool can only be used by the root/main agent. "
"Subagents must use agent_finish instead."
),
}
return None
def _validate_content(content: str) -> dict[str, Any] | None:
if not content or not content.strip():
return {"success": False, "message": "Content cannot be empty"}
return None
def _check_active_agents(agent_state: Any = None) -> dict[str, Any] | None:
try:
from strix.tools.agents_graph.agents_graph_actions import _agent_graph
current_agent_id = None
if agent_state and hasattr(agent_state, "agent_id"):
current_agent_id = agent_state.agent_id
running_agents = []
stopping_agents = []
for agent_id, node in _agent_graph.get("nodes", {}).items():
if agent_id == current_agent_id:
continue
status = node.get("status", "")
if status == "running":
running_agents.append(
{
"id": agent_id,
"name": node.get("name", "Unknown"),
"task": node.get("task", "No task description"),
}
)
elif status == "stopping":
stopping_agents.append(
{
"id": agent_id,
"name": node.get("name", "Unknown"),
}
)
if running_agents or stopping_agents:
message_parts = ["Cannot finish scan while other agents are still active:"]
if running_agents:
message_parts.append("\n\nRunning agents:")
message_parts.extend(
[
f" - {agent['name']} ({agent['id']}): {agent['task']}"
for agent in running_agents
]
)
if stopping_agents:
message_parts.append("\n\nStopping agents:")
message_parts.extend(
[f" - {agent['name']} ({agent['id']})" for agent in stopping_agents]
)
message_parts.extend(
[
"\n\nSuggested actions:",
"1. Use wait_for_message to wait for all agents to complete",
"2. Send messages to agents asking them to finish if urgent",
"3. Use view_agent_graph to monitor agent status",
]
)
return {
"success": False,
"message": "\n".join(message_parts),
"active_agents": {
"running": len(running_agents),
"stopping": len(stopping_agents),
"details": {
"running": running_agents,
"stopping": stopping_agents,
},
},
}
except ImportError:
import logging
logging.warning("Could not check agent graph status - agents_graph module unavailable")
return None
def _finalize_with_tracer(content: str, success: bool) -> dict[str, Any]:
try:
from strix.cli.tracer import get_global_tracer
tracer = get_global_tracer()
if tracer:
tracer.set_final_scan_result(
content=content.strip(),
success=success,
)
return {
"success": True,
"scan_completed": True,
"message": "Scan completed successfully"
if success
else "Scan completed with errors",
"vulnerabilities_found": len(tracer.vulnerability_reports),
}
import logging
logging.warning("Global tracer not available - final scan result not stored")
return { # noqa: TRY300
"success": True,
"scan_completed": True,
"message": "Scan completed successfully (not persisted)"
if success
else "Scan completed with errors (not persisted)",
"warning": "Final result could not be persisted - tracer unavailable",
}
except ImportError:
return {
"success": True,
"scan_completed": True,
"message": "Scan completed successfully (not persisted)"
if success
else "Scan completed with errors (not persisted)",
"warning": "Final result could not be persisted - tracer module unavailable",
}
@register_tool(sandbox_execution=False)
def finish_scan(
content: str,
success: bool = True,
agent_state: Any = None,
) -> dict[str, Any]:
try:
validation_error = _validate_root_agent(agent_state)
if validation_error:
return validation_error
validation_error = _validate_content(content)
if validation_error:
return validation_error
active_agents_error = _check_active_agents(agent_state)
if active_agents_error:
return active_agents_error
return _finalize_with_tracer(content, success)
except (ValueError, TypeError, KeyError) as e:
return {"success": False, "message": f"Failed to complete scan: {e!s}"}

View File

@@ -0,0 +1,45 @@
<tools>
<tool name="finish_scan">
<description>Complete the main security scan and generate final report.
IMPORTANT: This tool can ONLY be used by the root/main agent.
Subagents must use agent_finish from agents_graph tool instead.
IMPORTANT: This tool will NOT allow finishing if any agents are still running or stopping.
You must wait for all agents to complete before using this tool.
This tool MUST be called at the very end of the security assessment to:
- Verify all agents have completed their tasks
- Generate the final comprehensive scan report
- Mark the entire scan as completed
- Stop the agent from running
Use this tool when:
- You are the main/root agent conducting the security assessment
- ALL subagents have completed their tasks (no agents are "running" or "stopping")
- You have completed all testing phases
- You are ready to conclude the entire security assessment
IMPORTANT: Calling this tool multiple times will OVERWRITE any previous scan report.
Make sure you include ALL findings and details in a single comprehensive report.
If agents are still running, this tool will:
- Show you which agents are still active
- Suggest using wait_for_message to wait for completion
- Suggest messaging agents if immediate completion is needed
Put ALL details in the content - methodology, tools used, vulnerability counts, key findings, recommendations,
compliance notes, risk assessments, next steps, etc. Be comprehensive and include everything relevant.</description>
<parameters>
<parameter name="content" type="string" required="true">
<description>Complete scan report including executive summary, methodology, findings, vulnerability details, recommendations, compliance notes, risk assessment, and conclusions. Include everything relevant to the assessment.</description>
</parameter>
<parameter name="success" type="boolean" required="false">
<description>Whether the scan completed successfully without critical errors</description>
</parameter>
</parameters>
<returns type="Dict[str, Any]">
<description>Response containing success status and completion message. If agents are still running, returns details about active agents and suggested actions.</description>
</returns>
</tool>
</tools>

View File

@@ -0,0 +1,14 @@
from .notes_actions import (
create_note,
delete_note,
list_notes,
update_note,
)
__all__ = [
"create_note",
"delete_note",
"list_notes",
"update_note",
]

View File

@@ -0,0 +1,191 @@
import uuid
from datetime import UTC, datetime
from typing import Any
from strix.tools.registry import register_tool
_notes_storage: dict[str, dict[str, Any]] = {}
def _filter_notes(
category: str | None = None,
tags: list[str] | None = None,
priority: str | None = None,
search_query: str | None = None,
) -> list[dict[str, Any]]:
filtered_notes = []
for note_id, note in _notes_storage.items():
if category and note.get("category") != category:
continue
if priority and note.get("priority") != priority:
continue
if tags:
note_tags = note.get("tags", [])
if not any(tag in note_tags for tag in tags):
continue
if search_query:
search_lower = search_query.lower()
title_match = search_lower in note.get("title", "").lower()
content_match = search_lower in note.get("content", "").lower()
if not (title_match or content_match):
continue
note_with_id = note.copy()
note_with_id["note_id"] = note_id
filtered_notes.append(note_with_id)
filtered_notes.sort(key=lambda x: x.get("created_at", ""), reverse=True)
return filtered_notes
@register_tool
def create_note(
title: str,
content: str,
category: str = "general",
tags: list[str] | None = None,
priority: str = "normal",
) -> dict[str, Any]:
try:
if not title or not title.strip():
return {"success": False, "error": "Title cannot be empty", "note_id": None}
if not content or not content.strip():
return {"success": False, "error": "Content cannot be empty", "note_id": None}
valid_categories = ["general", "findings", "methodology", "todo", "questions", "plan"]
if category not in valid_categories:
return {
"success": False,
"error": f"Invalid category. Must be one of: {', '.join(valid_categories)}",
"note_id": None,
}
valid_priorities = ["low", "normal", "high", "urgent"]
if priority not in valid_priorities:
return {
"success": False,
"error": f"Invalid priority. Must be one of: {', '.join(valid_priorities)}",
"note_id": None,
}
note_id = str(uuid.uuid4())[:5]
timestamp = datetime.now(UTC).isoformat()
note = {
"title": title.strip(),
"content": content.strip(),
"category": category,
"tags": tags or [],
"priority": priority,
"created_at": timestamp,
"updated_at": timestamp,
}
_notes_storage[note_id] = note
except (ValueError, TypeError) as e:
return {"success": False, "error": f"Failed to create note: {e}", "note_id": None}
else:
return {
"success": True,
"note_id": note_id,
"message": f"Note '{title}' created successfully",
}
@register_tool
def list_notes(
category: str | None = None,
tags: list[str] | None = None,
priority: str | None = None,
search: str | None = None,
) -> dict[str, Any]:
try:
filtered_notes = _filter_notes(
category=category, tags=tags, priority=priority, search_query=search
)
return {
"success": True,
"notes": filtered_notes,
"total_count": len(filtered_notes),
}
except (ValueError, TypeError) as e:
return {
"success": False,
"error": f"Failed to list notes: {e}",
"notes": [],
"total_count": 0,
}
@register_tool
def update_note(
note_id: str,
title: str | None = None,
content: str | None = None,
tags: list[str] | None = None,
priority: str | None = None,
) -> dict[str, Any]:
try:
if note_id not in _notes_storage:
return {"success": False, "error": f"Note with ID '{note_id}' not found"}
note = _notes_storage[note_id]
if title is not None:
if not title.strip():
return {"success": False, "error": "Title cannot be empty"}
note["title"] = title.strip()
if content is not None:
if not content.strip():
return {"success": False, "error": "Content cannot be empty"}
note["content"] = content.strip()
if tags is not None:
note["tags"] = tags
if priority is not None:
valid_priorities = ["low", "normal", "high", "urgent"]
if priority not in valid_priorities:
return {
"success": False,
"error": f"Invalid priority. Must be one of: {', '.join(valid_priorities)}",
}
note["priority"] = priority
note["updated_at"] = datetime.now(UTC).isoformat()
return {
"success": True,
"message": f"Note '{note['title']}' updated successfully",
}
except (ValueError, TypeError) as e:
return {"success": False, "error": f"Failed to update note: {e}"}
@register_tool
def delete_note(note_id: str) -> dict[str, Any]:
try:
if note_id not in _notes_storage:
return {"success": False, "error": f"Note with ID '{note_id}' not found"}
note_title = _notes_storage[note_id]["title"]
del _notes_storage[note_id]
except (ValueError, TypeError) as e:
return {"success": False, "error": f"Failed to delete note: {e}"}
else:
return {
"success": True,
"message": f"Note '{note_title}' deleted successfully",
}

View File

@@ -0,0 +1,150 @@
<tools>
<tool name="create_note">
<description>Create a personal note for TODOs, side notes, plans, and organizational purposes during
the scan.</description>
<details>Use this tool for quick reminders, action items, planning thoughts, and organizational notes
rather than formal vulnerability reports or detailed findings. This is your personal notepad
for keeping track of tasks, ideas, and things to remember or follow up on.</details>
<parameters>
<parameter name="title" type="string" required="true">
<description>Title of the note</description>
</parameter>
<parameter name="content" type="string" required="true">
<description>Content of the note</description>
</parameter>
<parameter name="category" type="string" required="false">
<description>Category to organize the note (default: "general", "findings", "methodology", "todo", "questions", "plan")</description>
</parameter>
<parameter name="tags" type="string" required="false">
<description>Tags for categorization</description>
</parameter>
<parameter name="priority" type="string" required="false">
<description>Priority level of the note ("low", "normal", "high", "urgent")</description>
</parameter>
</parameters>
<returns type="Dict[str, Any]">
<description>Response containing: - note_id: ID of the created note - success: Whether the note was created successfully</description>
</returns>
<examples>
# Create a TODO reminder
<function=create_note>
<parameter=title>TODO: Check SSL Certificate Details</parameter>
<parameter=content>Remember to verify SSL certificate validity and check for weak ciphers
on the HTTPS service discovered on port 443. Also check for certificate
transparency logs.</parameter>
<parameter=category>todo</parameter>
<parameter=tags>["ssl", "certificate", "followup"]</parameter>
<parameter=priority>normal</parameter>
</function>
# Planning note
<function=create_note>
<parameter=title>Scan Strategy Planning</parameter>
<parameter=content>Plan for next phase: 1) Complete subdomain enumeration 2) Test discovered
web apps for OWASP Top 10 3) Check database services for default creds
4) Review any custom applications for business logic flaws</parameter>
<parameter=category>plan</parameter>
<parameter=tags>["planning", "strategy", "next_steps"]</parameter>
</function>
# Side note for later investigation
<function=create_note>
<parameter=title>Interesting Directory Found</parameter>
<parameter=content>Found /backup/ directory that might contain sensitive files. Low priority
for now but worth checking if time permits. Directory listing seems
disabled.</parameter>
<parameter=category>findings</parameter>
<parameter=tags>["directory", "backup", "low_priority"]</parameter>
<parameter=priority>low</parameter>
</function>
</examples>
</tool>
<tool name="delete_note">
<description>Delete a note.</description>
<parameters>
<parameter name="note_id" type="string" required="true">
<description>ID of the note to delete</description>
</parameter>
</parameters>
<returns type="Dict[str, Any]">
<description>Response containing: - success: Whether the note was deleted successfully</description>
</returns>
<examples>
<function=delete_note>
<parameter=note_id>note_123</parameter>
</function>
</examples>
</tool>
<tool name="list_notes">
<description>List existing notes with optional filtering and search.</description>
<parameters>
<parameter name="category" type="string" required="false">
<description>Filter by category</description>
</parameter>
<parameter name="tags" type="string" required="false">
<description>Filter by tags (returns notes with any of these tags)</description>
</parameter>
<parameter name="priority" type="string" required="false">
<description>Filter by priority level</description>
</parameter>
<parameter name="search" type="string" required="false">
<description>Search query to find in note titles and content</description>
</parameter>
</parameters>
<returns type="Dict[str, Any]">
<description>Response containing: - notes: List of matching notes - total_count: Total number of notes found</description>
</returns>
<examples>
# List all findings
<function=list_notes>
<parameter=category>findings</parameter>
</function>
# List high priority items
<function=list_notes>
<parameter=priority>high</parameter>
</function>
# Search for SQL injection related notes
<function=list_notes>
<parameter=search>SQL injection</parameter>
</function>
# Search within a specific category
<function=list_notes>
<parameter=search>admin</parameter>
<parameter=category>findings</parameter>
</function>
</examples>
</tool>
<tool name="update_note">
<description>Update an existing note.</description>
<parameters>
<parameter name="note_id" type="string" required="true">
<description>ID of the note to update</description>
</parameter>
<parameter name="title" type="string" required="false">
<description>New title for the note</description>
</parameter>
<parameter name="content" type="string" required="false">
<description>New content for the note</description>
</parameter>
<parameter name="tags" type="string" required="false">
<description>New tags for the note</description>
</parameter>
<parameter name="priority" type="string" required="false">
<description>New priority level</description>
</parameter>
</parameters>
<returns type="Dict[str, Any]">
<description>Response containing: - success: Whether the note was updated successfully</description>
</returns>
<examples>
<function=update_note>
<parameter=note_id>note_123</parameter>
<parameter=content>Updated content with new findings...</parameter>
<parameter=priority>urgent</parameter>
</function>
</examples>
</tool>
</tools>

View File

@@ -0,0 +1,20 @@
from .proxy_actions import (
list_requests,
list_sitemap,
repeat_request,
scope_rules,
send_request,
view_request,
view_sitemap_entry,
)
__all__ = [
"list_requests",
"list_sitemap",
"repeat_request",
"scope_rules",
"send_request",
"view_request",
"view_sitemap_entry",
]

View File

@@ -0,0 +1,101 @@
from typing import Any, Literal
from strix.tools.registry import register_tool
from .proxy_manager import get_proxy_manager
RequestPart = Literal["request", "response"]
@register_tool
def list_requests(
httpql_filter: str | None = None,
start_page: int = 1,
end_page: int = 1,
page_size: int = 50,
sort_by: Literal[
"timestamp",
"host",
"method",
"path",
"status_code",
"response_time",
"response_size",
"source",
] = "timestamp",
sort_order: Literal["asc", "desc"] = "desc",
scope_id: str | None = None,
) -> dict[str, Any]:
manager = get_proxy_manager()
return manager.list_requests(
httpql_filter, start_page, end_page, page_size, sort_by, sort_order, scope_id
)
@register_tool
def view_request(
request_id: str,
part: RequestPart = "request",
search_pattern: str | None = None,
page: int = 1,
page_size: int = 50,
) -> dict[str, Any]:
manager = get_proxy_manager()
return manager.view_request(request_id, part, search_pattern, page, page_size)
@register_tool
def send_request(
method: str,
url: str,
headers: dict[str, str] | None = None,
body: str = "",
timeout: int = 30,
) -> dict[str, Any]:
if headers is None:
headers = {}
manager = get_proxy_manager()
return manager.send_simple_request(method, url, headers, body, timeout)
@register_tool
def repeat_request(
request_id: str,
modifications: dict[str, Any] | None = None,
) -> dict[str, Any]:
if modifications is None:
modifications = {}
manager = get_proxy_manager()
return manager.repeat_request(request_id, modifications)
@register_tool
def scope_rules(
action: Literal["get", "list", "create", "update", "delete"],
allowlist: list[str] | None = None,
denylist: list[str] | None = None,
scope_id: str | None = None,
scope_name: str | None = None,
) -> dict[str, Any]:
manager = get_proxy_manager()
return manager.scope_rules(action, allowlist, denylist, scope_id, scope_name)
@register_tool
def list_sitemap(
scope_id: str | None = None,
parent_id: str | None = None,
depth: Literal["DIRECT", "ALL"] = "DIRECT",
page: int = 1,
) -> dict[str, Any]:
manager = get_proxy_manager()
return manager.list_sitemap(scope_id, parent_id, depth, page)
@register_tool
def view_sitemap_entry(
entry_id: str,
) -> dict[str, Any]:
manager = get_proxy_manager()
return manager.view_sitemap_entry(entry_id)

View File

@@ -0,0 +1,267 @@
<?xml version="1.0" ?>
<tools>
<tool name="list_requests">
<description>List and filter proxy requests using HTTPQL with pagination.</description>
<parameters>
<parameter name="httpql_filter" type="string" required="false">
<description>HTTPQL filter using Caido's syntax:
Integer fields (port, code, roundtrip, id) - eq, gt, gte, lt, lte, ne:
- resp.code.eq:200, resp.code.gte:400, req.port.eq:443
Text/byte fields (ext, host, method, path, query, raw) - regex:
- req.method.regex:"POST", req.path.regex:"/api/.*", req.host.regex:".*.com"
Date fields (created_at) - gt, lt with ISO formats:
- req.created_at.gt:"2024-01-01T00:00:00Z"
Special: source:intercept, preset:"name"</description>
</parameter>
<parameter name="start_page" type="integer" required="false">
<description>Starting page (1-based)</description>
</parameter>
<parameter name="end_page" type="integer" required="false">
<description>Ending page (1-based, inclusive)</description>
</parameter>
<parameter name="page_size" type="integer" required="false">
<description>Requests per page</description>
</parameter>
<parameter name="sort_by" type="string" required="false">
<description>Sort field from: "timestamp", "host", "status_code", "response_time", "response_size"</description>
</parameter>
<parameter name="sort_order" type="string" required="false">
<description>Sort direction ("asc" or "desc")</description>
</parameter>
<parameter name="scope_id" type="string" required="false">
<description>Scope ID to filter requests (use scope_rules to manage scopes)</description>
</parameter>
</parameters>
<returns type="Dict[str, Any]">
<description>Response containing:
- 'requests': Request objects for page range
- 'total_count': Total matching requests
- 'start_page', 'end_page', 'page_size': Query parameters
- 'returned_count': Requests in response</description>
</returns>
<examples>
# POST requests to API with 200 responses
<function=list_requests>
<parameter=httpql_filter>req.method.eq:"POST" AND req.path.cont:"/api/"</parameter>
<parameter=sort_by>response_time</parameter>
<parameter=scope_id>scope123</parameter>
</function>
# Requests within specific scope
<function=list_requests>
<parameter=scope_id>scope123</parameter>
<parameter=sort_by>timestamp</parameter>
</function>
</examples>
</tool>
<tool name="view_request">
<description>View request/response data with search and pagination.</description>
<parameters>
<parameter name="request_id" type="string" required="true">
<description>Request ID</description>
</parameter>
<parameter name="part" type="string" required="false">
<description>Which part to return ("request" or "response")</description>
</parameter>
<parameter name="search_pattern" type="string" required="false">
<description>Regex pattern to search content. Common patterns:
- API endpoints: r"/api/[a-zA-Z0-9._/-]+"
- URLs: r"https?://[^\\s<>"\']+"
- Parameters: r'[?&][a-zA-Z0-9_]+=([^&\\s<>"\']+)'
- Reflections: input_value in content</description>
</parameter>
<parameter name="page" type="integer" required="false">
<description>Page number for pagination</description>
</parameter>
<parameter name="page_size" type="integer" required="false">
<description>Lines per page</description>
</parameter>
</parameters>
<returns type="Dict[str, Any]">
<description>With search_pattern (COMPACT):
- 'matches': [{match, before, after, position}] - max 20
- 'total_matches': Total found
- 'truncated': If limited to 20
Without search_pattern (PAGINATION):
- 'content': Page content
- 'page': Current page
- 'showing_lines': Range display
- 'has_more': More pages available</description>
</returns>
<examples>
# Find API endpoints in response
<function=view_request>
<parameter=request_id>123</parameter>
<parameter=part>response</parameter>
<parameter=search_pattern>/api/[a-zA-Z0-9._/-]+</parameter>
</function>
</examples>
</tool>
<tool name="send_request">
<description>Send a simple HTTP request through proxy.</description>
<parameters>
<parameter name="method" type="string" required="true">
<description>HTTP method (GET, POST, etc.)</description>
</parameter>
<parameter name="url" type="string" required="true">
<description>Target URL</description>
</parameter>
<parameter name="headers" type="dict" required="false">
<description>Headers as {"key": "value"}</description>
</parameter>
<parameter name="body" type="string" required="false">
<description>Request body</description>
</parameter>
<parameter name="timeout" type="integer" required="false">
<description>Request timeout</description>
</parameter>
</parameters>
</tool>
<tool name="repeat_request">
<description>Repeat an existing proxy request with modifications for pentesting.
PROPER WORKFLOW:
1. Use browser_action to browse the target application
2. Use list_requests() to see captured proxy traffic
3. Use repeat_request() to modify and test specific requests
This mirrors real pentesting: browse → capture → modify → test</description>
<parameters>
<parameter name="request_id" type="string" required="true">
<description>ID of the original request to repeat (from list_requests)</description>
</parameter>
<parameter name="modifications" type="dict" required="false">
<description>Changes to apply to the original request:
- "url": New URL or modify existing one
- "params": Dict to update query parameters
- "headers": Dict to add/update headers
- "body": New request body (replaces original)
- "cookies": Dict to add/update cookies</description>
</parameter>
</parameters>
<returns type="Dict[str, Any]">
<description>Response data with status, headers, body, timing, and request details</description>
</returns>
<examples>
# Modify POST body payload
<function=repeat_request>
<parameter=request_id>req_789</parameter>
<parameter=modifications>{"body": "{\"username\":\"admin\",\"password\":\"admin\"}"}</parameter>
</function>
</examples>
</tool>
<tool name="scope_rules">
<description>Manage proxy scope patterns for domain/file filtering using Caido's scope system.</description>
<parameters>
<parameter name="action" type="string" required="true">
<description>Scope action:
- get: Get specific scope by ID or list all if no ID
- update: Update existing scope (requires scope_id and scope_name)
- list: List all available scopes
- create: Create new scope (requires scope_name)
- delete: Delete scope (requires scope_id)</description>
</parameter>
<parameter name="allowlist" type="list" required="false">
<description>Domain patterns to include. Examples: ["*.example.com", "api.test.com"]</description>
</parameter>
<parameter name="denylist" type="list" required="false">
<description>Patterns to exclude. Some common extensions:
["*.gif", "*.jpg", "*.png", "*.css", "*.js", "*.ico", "*.svg", "*woff*", "*.ttf"]</description>
</parameter>
<parameter name="scope_id" type="string" required="false">
<description>Specific scope ID to operate on (required for get, update, delete)</description>
</parameter>
<parameter name="scope_name" type="string" required="false">
<description>Name for scope (required for create, update)</description>
</parameter>
</parameters>
<returns type="Dict[str, Any]">
<description>Depending on action:
- get: Single scope object or error
- list: {"scopes": [...], "count": N}
- create/update: {"scope": {...}, "message": "..."}
- delete: {"message": "...", "deletedId": "..."}</description>
</returns>
<notes>
- Empty allowlist = allow all domains
- Denylist overrides allowlist
- Glob patterns: * (any), ? (single), [abc] (one of), [a-z] (range), [^abc] (none of)
- Each scope has unique ID and can be used with list_requests(scopeId=...)
</notes>
<examples>
# Create API-only scope
<function=scope_rules>
<parameter=action>create</parameter>
<parameter=scope_name>API Testing</parameter>
<parameter=allowlist>["api.example.com", "*.api.com"]</parameter>
<parameter=denylist>["*.gif", "*.jpg", "*.png", "*.css", "*.js"]</parameter>
</function>
</examples>
</tool>
<tool name="list_sitemap">
<description>View hierarchical sitemap of discovered attack surface from proxied traffic.
Perfect for bug hunters to understand the application structure and identify
interesting endpoints, directories, and entry points discovered during testing.</description>
<parameters>
<parameter name="scope_id" type="string" required="false">
<description>Scope ID to filter sitemap entries (use scope_rules to get/create scope IDs)</description>
</parameter>
<parameter name="parent_id" type="string" required="false">
<description>ID of parent entry to expand. If None, returns root domains.</description>
</parameter>
<parameter name="depth" type="string" required="false">
<description>DIRECT: Only immediate children. ALL: All descendants recursively.</description>
</parameter>
<parameter name="page" type="integer" required="false">
<description>Page number for pagination (30 entries per page)</description>
</parameter>
</parameters>
<returns type="Dict[str, Any]">
<description>Response containing:
- 'entries': List of cleaned sitemap entries
- 'page', 'total_pages', 'total_count': Pagination info
- 'has_more': Whether more pages available
- Each entry: id, kind, label, hasDescendants, request (method/path/status only)</description>
</returns>
<notes>
Entry kinds:
- DOMAIN: Root domains (example.com)
- DIRECTORY: Path directories (/api/, /admin/)
- REQUEST: Individual endpoints
- REQUEST_BODY: POST/PUT body variations
- REQUEST_QUERY: GET parameter variations
Check hasDescendants=true to identify entries worth expanding.
Use parent_id from any entry to drill down into subdirectories.
</notes>
</tool>
<tool name="view_sitemap_entry">
<description>Get detailed information about a specific sitemap entry and related requests.
Perfect for understanding what's been discovered under a specific directory
or endpoint, including all related requests and response codes.</description>
<parameters>
<parameter name="entry_id" type="string" required="true">
<description>ID of the sitemap entry to examine</description>
</parameter>
</parameters>
<returns type="Dict[str, Any]">
<description>Response containing:
- 'entry': Complete entry details including metadata
- Entry contains 'requests' with all related HTTP requests
- Shows request methods, paths, response codes, timing</description>
</returns>
</tool>
</tools>

View File

@@ -0,0 +1,785 @@
import base64
import os
import re
import time
from typing import TYPE_CHECKING, Any
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
import requests
from gql import Client, gql
from gql.transport.exceptions import TransportQueryError
from gql.transport.requests import RequestsHTTPTransport
from requests.exceptions import ProxyError, RequestException, Timeout
if TYPE_CHECKING:
from collections.abc import Callable
class ProxyManager:
def __init__(self, auth_token: str | None = None):
host = "127.0.0.1"
port = os.getenv("CAIDO_PORT", "56789")
self.base_url = f"http://{host}:{port}/graphql"
self.proxies = {"http": f"http://{host}:{port}", "https": f"http://{host}:{port}"}
self.auth_token = auth_token or os.getenv("CAIDO_API_TOKEN")
self.transport = RequestsHTTPTransport(
url=self.base_url, headers={"Authorization": f"Bearer {self.auth_token}"}
)
self.client = Client(transport=self.transport, fetch_schema_from_transport=False)
def list_requests(
self,
httpql_filter: str | None = None,
start_page: int = 1,
end_page: int = 1,
page_size: int = 50,
sort_by: str = "timestamp",
sort_order: str = "desc",
scope_id: str | None = None,
) -> dict[str, Any]:
offset = (start_page - 1) * page_size
limit = (end_page - start_page + 1) * page_size
sort_mapping = {
"timestamp": "CREATED_AT",
"host": "HOST",
"method": "METHOD",
"path": "PATH",
"status_code": "RESP_STATUS_CODE",
"response_time": "RESP_ROUNDTRIP_TIME",
"response_size": "RESP_LENGTH",
"source": "SOURCE",
}
query = gql("""
query GetRequests(
$limit: Int, $offset: Int, $filter: HTTPQL,
$order: RequestResponseOrderInput, $scopeId: ID
) {
requestsByOffset(
limit: $limit, offset: $offset, filter: $filter,
order: $order, scopeId: $scopeId
) {
edges {
node {
id method host path query createdAt length isTls port
source alteration fileExtension
response { id statusCode length roundtripTime createdAt }
}
}
count { value }
}
}
""")
variables = {
"limit": limit,
"offset": offset,
"filter": httpql_filter,
"order": {
"by": sort_mapping.get(sort_by, "CREATED_AT"),
"ordering": sort_order.upper(),
},
"scopeId": scope_id,
}
try:
result = self.client.execute(query, variable_values=variables)
data = result.get("requestsByOffset", {})
nodes = [edge["node"] for edge in data.get("edges", [])]
count_data = data.get("count") or {}
return {
"requests": nodes,
"total_count": count_data.get("value", 0),
"start_page": start_page,
"end_page": end_page,
"page_size": page_size,
"offset": offset,
"returned_count": len(nodes),
"sort_by": sort_by,
"sort_order": sort_order,
}
except (TransportQueryError, ValueError, KeyError) as e:
return {"requests": [], "total_count": 0, "error": f"Error fetching requests: {e}"}
def view_request(
self,
request_id: str,
part: str = "request",
search_pattern: str | None = None,
page: int = 1,
page_size: int = 50,
) -> dict[str, Any]:
queries = {
"request": """query GetRequest($id: ID!) {
request(id: $id) {
id method host path query createdAt length isTls port
source alteration edited raw
}
}""",
"response": """query GetRequest($id: ID!) {
request(id: $id) {
id response {
id statusCode length roundtripTime createdAt raw
}
}
}""",
}
if part not in queries:
return {"error": f"Invalid part '{part}'. Use 'request' or 'response'"}
try:
result = self.client.execute(gql(queries[part]), variable_values={"id": request_id})
request_data = result.get("request", {})
if not request_data:
return {"error": f"Request {request_id} not found"}
if part == "request":
raw_content = request_data.get("raw")
else:
response_data = request_data.get("response") or {}
raw_content = response_data.get("raw")
if not raw_content:
return {"error": "No content available"}
content = base64.b64decode(raw_content).decode("utf-8", errors="replace")
if part == "response":
request_data["response"]["raw"] = content
else:
request_data["raw"] = content
return (
self._search_content(request_data, content, search_pattern)
if search_pattern
else self._paginate_content(request_data, content, page, page_size)
)
except (TransportQueryError, ValueError, KeyError, UnicodeDecodeError) as e:
return {"error": f"Failed to view request: {e}"}
def _search_content(
self, request_data: dict[str, Any], content: str, pattern: str
) -> dict[str, Any]:
try:
regex = re.compile(pattern, re.IGNORECASE | re.MULTILINE | re.DOTALL)
matches = []
for match in regex.finditer(content):
start, end = match.start(), match.end()
context_size = 120
before = re.sub(r"\s+", " ", content[max(0, start - context_size) : start].strip())[
-100:
]
after = re.sub(r"\s+", " ", content[end : end + context_size].strip())[:100]
matches.append(
{"match": match.group(), "before": before, "after": after, "position": start}
)
if len(matches) >= 20:
break
return {
"id": request_data.get("id"),
"matches": matches,
"total_matches": len(matches),
"search_pattern": pattern,
"truncated": len(matches) >= 20,
}
except re.error as e:
return {"error": f"Invalid regex: {e}"}
def _paginate_content(
self, request_data: dict[str, Any], content: str, page: int, page_size: int
) -> dict[str, Any]:
display_lines = []
for line in content.split("\n"):
if len(line) <= 80:
display_lines.append(line)
else:
display_lines.extend(
[
line[i : i + 80] + (" \\" if i + 80 < len(line) else "")
for i in range(0, len(line), 80)
]
)
total_lines = len(display_lines)
total_pages = (total_lines + page_size - 1) // page_size
page = max(1, min(page, total_pages))
start_line = (page - 1) * page_size
end_line = min(total_lines, start_line + page_size)
return {
"id": request_data.get("id"),
"content": "\n".join(display_lines[start_line:end_line]),
"page": page,
"total_pages": total_pages,
"showing_lines": f"{start_line + 1}-{end_line} of {total_lines}",
"has_more": page < total_pages,
}
def send_simple_request(
self,
method: str,
url: str,
headers: dict[str, str] | None = None,
body: str = "",
timeout: int = 30,
) -> dict[str, Any]:
if headers is None:
headers = {}
try:
start_time = time.time()
response = requests.request(
method=method,
url=url,
headers=headers,
data=body or None,
proxies=self.proxies,
timeout=timeout,
verify=False,
)
response_time = int((time.time() - start_time) * 1000)
body_content = response.text
if len(body_content) > 10000:
body_content = body_content[:10000] + "\n... [truncated]"
return {
"status_code": response.status_code,
"headers": dict(response.headers),
"body": body_content,
"response_time_ms": response_time,
"url": response.url,
"message": (
"Request sent through proxy - check list_requests() for captured traffic"
),
}
except (RequestException, ProxyError, Timeout) as e:
return {"error": f"Request failed: {type(e).__name__}", "details": str(e), "url": url}
def repeat_request(
self, request_id: str, modifications: dict[str, Any] | None = None
) -> dict[str, Any]:
if modifications is None:
modifications = {}
original = self.view_request(request_id, "request")
if "error" in original:
return {"error": f"Could not retrieve original request: {original['error']}"}
raw_content = original.get("content", "")
if not raw_content:
return {"error": "No raw request content found"}
request_components = self._parse_http_request(raw_content)
if "error" in request_components:
return request_components
full_url = self._build_full_url(request_components, modifications)
if "error" in full_url:
return full_url
modified_request = self._apply_modifications(
request_components, modifications, full_url["url"]
)
return self._send_modified_request(modified_request, request_id, modifications)
def _parse_http_request(self, raw_content: str) -> dict[str, Any]:
lines = raw_content.split("\n")
request_line = lines[0].strip().split(" ")
if len(request_line) < 2:
return {"error": "Invalid request line format"}
method, url_path = request_line[0], request_line[1]
headers = {}
body_start = 0
for i, line in enumerate(lines[1:], 1):
if line.strip() == "":
body_start = i + 1
break
if ":" in line:
key, value = line.split(":", 1)
headers[key.strip()] = value.strip()
body = "\n".join(lines[body_start:]).strip() if body_start < len(lines) else ""
return {"method": method, "url_path": url_path, "headers": headers, "body": body}
def _build_full_url(
self, components: dict[str, Any], modifications: dict[str, Any]
) -> dict[str, Any]:
headers = components["headers"]
host = headers.get("Host", "")
if not host:
return {"error": "No Host header found"}
protocol = (
"https" if ":443" in host or "https" in headers.get("Referer", "").lower() else "http"
)
full_url = f"{protocol}://{host}{components['url_path']}"
if "url" in modifications:
full_url = modifications["url"]
return {"url": full_url}
def _apply_modifications(
self, components: dict[str, Any], modifications: dict[str, Any], full_url: str
) -> dict[str, Any]:
headers = components["headers"].copy()
body = components["body"]
final_url = full_url
if "params" in modifications:
parsed = urlparse(final_url)
params = {k: v[0] if v else "" for k, v in parse_qs(parsed.query).items()}
params.update(modifications["params"])
final_url = urlunparse(parsed._replace(query=urlencode(params)))
if "headers" in modifications:
headers.update(modifications["headers"])
if "body" in modifications:
body = modifications["body"]
if "cookies" in modifications:
cookies = {}
if headers.get("Cookie"):
for cookie in headers["Cookie"].split(";"):
if "=" in cookie:
k, v = cookie.split("=", 1)
cookies[k.strip()] = v.strip()
cookies.update(modifications["cookies"])
headers["Cookie"] = "; ".join([f"{k}={v}" for k, v in cookies.items()])
return {
"method": components["method"],
"url": final_url,
"headers": headers,
"body": body,
}
def _send_modified_request(
self, request_data: dict[str, Any], request_id: str, modifications: dict[str, Any]
) -> dict[str, Any]:
try:
start_time = time.time()
response = requests.request(
method=request_data["method"],
url=request_data["url"],
headers=request_data["headers"],
data=request_data["body"] or None,
proxies=self.proxies,
timeout=30,
verify=False,
)
response_time = int((time.time() - start_time) * 1000)
response_body = response.text
truncated = len(response_body) > 10000
if truncated:
response_body = response_body[:10000] + "\n... [truncated]"
return {
"status_code": response.status_code,
"status_text": response.reason,
"headers": {
k: v
for k, v in response.headers.items()
if k.lower()
in ["content-type", "content-length", "server", "set-cookie", "location"]
},
"body": response_body,
"body_truncated": truncated,
"body_size": len(response.content),
"response_time_ms": response_time,
"url": response.url,
"original_request_id": request_id,
"modifications_applied": modifications,
"request": {
"method": request_data["method"],
"url": request_data["url"],
"headers": request_data["headers"],
"has_body": bool(request_data["body"]),
},
}
except ProxyError as e:
return {
"error": "Proxy connection failed - is Caido running?",
"details": str(e),
"original_request_id": request_id,
}
except (RequestException, Timeout) as e:
return {
"error": f"Failed to repeat request: {type(e).__name__}",
"details": str(e),
"original_request_id": request_id,
}
def _handle_scope_list(self) -> dict[str, Any]:
result = self.client.execute(gql("query { scopes { id name allowlist denylist indexed } }"))
scopes = result.get("scopes", [])
return {"scopes": scopes, "count": len(scopes)}
def _handle_scope_get(self, scope_id: str | None) -> dict[str, Any]:
if not scope_id:
return self._handle_scope_list()
result = self.client.execute(
gql(
"query GetScope($id: ID!) { scope(id: $id) { id name allowlist denylist indexed } }"
),
variable_values={"id": scope_id},
)
scope = result.get("scope")
if not scope:
return {"error": f"Scope {scope_id} not found"}
return {"scope": scope}
def _handle_scope_create(
self, scope_name: str, allowlist: list[str] | None, denylist: list[str] | None
) -> dict[str, Any]:
if not scope_name:
return {"error": "scope_name required for create"}
mutation = gql("""
mutation CreateScope($input: CreateScopeInput!) {
createScope(input: $input) {
scope { id name allowlist denylist indexed }
error {
... on InvalidGlobTermsUserError { code terms }
... on OtherUserError { code }
}
}
}
""")
result = self.client.execute(
mutation,
variable_values={
"input": {
"name": scope_name,
"allowlist": allowlist or [],
"denylist": denylist or [],
}
},
)
payload = result.get("createScope", {})
if payload.get("error"):
error = payload["error"]
return {"error": f"Invalid glob patterns: {error.get('terms', error.get('code'))}"}
return {"scope": payload.get("scope"), "message": "Scope created successfully"}
def _handle_scope_update(
self,
scope_id: str,
scope_name: str,
allowlist: list[str] | None,
denylist: list[str] | None,
) -> dict[str, Any]:
if not scope_id or not scope_name:
return {"error": "scope_id and scope_name required"}
mutation = gql("""
mutation UpdateScope($id: ID!, $input: UpdateScopeInput!) {
updateScope(id: $id, input: $input) {
scope { id name allowlist denylist indexed }
error {
... on InvalidGlobTermsUserError { code terms }
... on OtherUserError { code }
}
}
}
""")
result = self.client.execute(
mutation,
variable_values={
"id": scope_id,
"input": {
"name": scope_name,
"allowlist": allowlist or [],
"denylist": denylist or [],
},
},
)
payload = result.get("updateScope", {})
if payload.get("error"):
error = payload["error"]
return {"error": f"Invalid glob patterns: {error.get('terms', error.get('code'))}"}
return {"scope": payload.get("scope"), "message": "Scope updated successfully"}
def _handle_scope_delete(self, scope_id: str) -> dict[str, Any]:
if not scope_id:
return {"error": "scope_id required for delete"}
result = self.client.execute(
gql("mutation DeleteScope($id: ID!) { deleteScope(id: $id) { deletedId } }"),
variable_values={"id": scope_id},
)
payload = result.get("deleteScope", {})
if not payload.get("deletedId"):
return {"error": f"Failed to delete scope {scope_id}"}
return {"message": f"Scope {scope_id} deleted", "deletedId": payload["deletedId"]}
def scope_rules(
self,
action: str,
allowlist: list[str] | None = None,
denylist: list[str] | None = None,
scope_id: str | None = None,
scope_name: str | None = None,
) -> dict[str, Any]:
handlers: dict[str, Callable[[], dict[str, Any]]] = {
"list": self._handle_scope_list,
"get": lambda: self._handle_scope_get(scope_id),
"create": lambda: (
{"error": "scope_name required for create"}
if not scope_name
else self._handle_scope_create(scope_name, allowlist, denylist)
),
"update": lambda: (
{"error": "scope_id and scope_name required"}
if not scope_id or not scope_name
else self._handle_scope_update(scope_id, scope_name, allowlist, denylist)
),
"delete": lambda: (
{"error": "scope_id required for delete"}
if not scope_id
else self._handle_scope_delete(scope_id)
),
}
handler = handlers.get(action)
if not handler:
return {
"error": f"Unsupported action: {action}. Use 'get', 'list', 'create', "
f"'update', or 'delete'"
}
try:
result = handler()
except (TransportQueryError, ValueError, KeyError) as e:
return {"error": f"Scope operation failed: {e}"}
else:
return result
def list_sitemap(
self,
scope_id: str | None = None,
parent_id: str | None = None,
depth: str = "DIRECT",
page: int = 1,
page_size: int = 30,
) -> dict[str, Any]:
try:
skip_count = (page - 1) * page_size
if parent_id:
query = gql("""
query GetSitemapDescendants($parentId: ID!, $depth: SitemapDescendantsDepth!) {
sitemapDescendantEntries(parentId: $parentId, depth: $depth) {
edges {
node {
id kind label hasDescendants
request { method path response { statusCode } }
}
}
count { value }
}
}
""")
result = self.client.execute(
query, variable_values={"parentId": parent_id, "depth": depth}
)
data = result.get("sitemapDescendantEntries", {})
else:
query = gql("""
query GetSitemapRoots($scopeId: ID) {
sitemapRootEntries(scopeId: $scopeId) {
edges { node {
id kind label hasDescendants
metadata { ... on SitemapEntryMetadataDomain { isTls port } }
request { method path response { statusCode } }
} }
count { value }
}
}
""")
result = self.client.execute(query, variable_values={"scopeId": scope_id})
data = result.get("sitemapRootEntries", {})
all_nodes = [edge["node"] for edge in data.get("edges", [])]
count_data = data.get("count") or {}
total_count = count_data.get("value", 0)
paginated_nodes = all_nodes[skip_count : skip_count + page_size]
cleaned_nodes = []
for node in paginated_nodes:
cleaned = {
"id": node["id"],
"kind": node["kind"],
"label": node["label"],
"hasDescendants": node["hasDescendants"],
}
if node.get("metadata") and (
node["metadata"].get("isTls") is not None or node["metadata"].get("port")
):
cleaned["metadata"] = node["metadata"]
if node.get("request"):
req = node["request"]
cleaned_req = {}
if req.get("method"):
cleaned_req["method"] = req["method"]
if req.get("path"):
cleaned_req["path"] = req["path"]
response_data = req.get("response") or {}
if response_data.get("statusCode"):
cleaned_req["status"] = response_data["statusCode"]
if cleaned_req:
cleaned["request"] = cleaned_req
cleaned_nodes.append(cleaned)
total_pages = (total_count + page_size - 1) // page_size
return {
"entries": cleaned_nodes,
"page": page,
"page_size": page_size,
"total_pages": total_pages,
"total_count": total_count,
"has_more": page < total_pages,
"showing": (
f"{skip_count + 1}-{min(skip_count + page_size, total_count)} of {total_count}"
),
}
except (TransportQueryError, ValueError, KeyError) as e:
return {"error": f"Failed to fetch sitemap: {e}"}
def _process_sitemap_metadata(self, node: dict[str, Any]) -> dict[str, Any]:
cleaned = {
"id": node["id"],
"kind": node["kind"],
"label": node["label"],
"hasDescendants": node["hasDescendants"],
}
if node.get("metadata") and (
node["metadata"].get("isTls") is not None or node["metadata"].get("port")
):
cleaned["metadata"] = node["metadata"]
return cleaned
def _process_sitemap_request(self, req: dict[str, Any]) -> dict[str, Any] | None:
cleaned_req = {}
if req.get("method"):
cleaned_req["method"] = req["method"]
if req.get("path"):
cleaned_req["path"] = req["path"]
response_data = req.get("response") or {}
if response_data.get("statusCode"):
cleaned_req["status"] = response_data["statusCode"]
return cleaned_req if cleaned_req else None
def _process_sitemap_response(self, resp: dict[str, Any]) -> dict[str, Any]:
cleaned_resp = {}
if resp.get("statusCode"):
cleaned_resp["status"] = resp["statusCode"]
if resp.get("length"):
cleaned_resp["size"] = resp["length"]
if resp.get("roundtripTime"):
cleaned_resp["time_ms"] = resp["roundtripTime"]
return cleaned_resp
def view_sitemap_entry(self, entry_id: str) -> dict[str, Any]:
try:
query = gql("""
query GetSitemapEntry($id: ID!) {
sitemapEntry(id: $id) {
id kind label hasDescendants
metadata { ... on SitemapEntryMetadataDomain { isTls port } }
request { method path response { statusCode length roundtripTime } }
requests(first: 30, order: {by: CREATED_AT, ordering: DESC}) {
edges { node { method path response { statusCode length } } }
count { value }
}
}
}
""")
result = self.client.execute(query, variable_values={"id": entry_id})
entry = result.get("sitemapEntry")
if not entry:
return {"error": f"Sitemap entry {entry_id} not found"}
cleaned = self._process_sitemap_metadata(entry)
if entry.get("request"):
req = entry["request"]
cleaned_req = {}
if req.get("method"):
cleaned_req["method"] = req["method"]
if req.get("path"):
cleaned_req["path"] = req["path"]
if req.get("response"):
cleaned_req["response"] = self._process_sitemap_response(req["response"])
if cleaned_req:
cleaned["request"] = cleaned_req
requests_data = entry.get("requests", {})
request_nodes = [edge["node"] for edge in requests_data.get("edges", [])]
cleaned_requests = [
req
for req in (self._process_sitemap_request(node) for node in request_nodes)
if req is not None
]
count_data = requests_data.get("count") or {}
cleaned["related_requests"] = {
"requests": cleaned_requests,
"total_count": count_data.get("value", 0),
"showing": f"Latest {len(cleaned_requests)} requests",
}
return {"entry": cleaned} if cleaned else {"error": "Failed to process sitemap entry"} # noqa: TRY300
except (TransportQueryError, ValueError, KeyError) as e:
return {"error": f"Failed to fetch sitemap entry: {e}"}
def close(self) -> None:
pass
_PROXY_MANAGER: ProxyManager | None = None
def get_proxy_manager() -> ProxyManager:
if _PROXY_MANAGER is None:
return ProxyManager()
return _PROXY_MANAGER

View File

@@ -0,0 +1,4 @@
from .python_actions import python_action
__all__ = ["python_action"]

View File

@@ -0,0 +1,47 @@
from typing import Any, Literal
from strix.tools.registry import register_tool
from .python_manager import get_python_session_manager
PythonAction = Literal["new_session", "execute", "close", "list_sessions"]
@register_tool
def python_action(
action: PythonAction,
code: str | None = None,
timeout: int = 30,
session_id: str | None = None,
) -> dict[str, Any]:
def _validate_code(action_name: str, code: str | None) -> None:
if not code:
raise ValueError(f"code parameter is required for {action_name} action")
def _validate_action(action_name: str) -> None:
raise ValueError(f"Unknown action: {action_name}")
manager = get_python_session_manager()
try:
match action:
case "new_session":
return manager.create_session(session_id, code, timeout)
case "execute":
_validate_code(action, code)
assert code is not None
return manager.execute_code(session_id, code, timeout)
case "close":
return manager.close_session(session_id)
case "list_sessions":
return manager.list_sessions()
case _:
_validate_action(action) # type: ignore[unreachable]
except (ValueError, RuntimeError) as e:
return {"stderr": str(e), "session_id": session_id, "stdout": "", "is_running": False}

View File

@@ -0,0 +1,131 @@
<?xml version="1.0" encoding="UTF-8"?>
<tools>
<tool name="python_action">
<description>Perform Python actions using persistent interpreter sessions for cybersecurity tasks.</description>
<details>Common Use Cases:
- Security script development and testing (payload generation, exploit scripts)
- Data analysis of security logs, network traffic, or vulnerability scans
- Cryptographic operations and security tool automation
- Interactive penetration testing workflows and proof-of-concept development
- Processing security data formats (JSON, XML, CSV from security tools)
- HTTP proxy interaction for web security testing (all proxy functions are pre-imported)
Each session instance is PERSISTENT and maintains its own global and local namespaces
until explicitly closed, allowing for multi-step security workflows and stateful computations.
PROXY FUNCTIONS PRE-IMPORTED:
All proxy action functions are automatically imported into every Python session, enabling
seamless HTTP traffic analysis and web security testing
This is particularly useful for:
- Analyzing captured HTTP traffic during web application testing
- Automating request manipulation and replay attacks
- Building custom security testing workflows combining proxy data with Python analysis
- Correlating multiple requests for advanced attack scenarios</details>
<parameters>
<parameter name="action" type="string" required="true">
<description>The Python action to perform: - new_session: Create a new Python interpreter session. This MUST be the first action for each session. - execute: Execute Python code in the specified session. - close: Close the specified session instance. - list_sessions: List all active Python sessions.</description>
</parameter>
<parameter name="code" type="string" required="false">
<description>Required for 'new_session' (as initial code) and 'execute' actions. The Python code to execute.</description>
</parameter>
<parameter name="timeout" type="integer" required="false">
<description>Maximum execution time in seconds for code execution. Applies to both 'new_session' (when initial code is provided) and 'execute' actions. Default is 30 seconds.</description>
</parameter>
<parameter name="session_id" type="string" required="false">
<description>Unique identifier for the Python session. If not provided, uses the default session ID.</description>
</parameter>
</parameters>
<returns type="Dict[str, Any]">
<description>Response containing: - session_id: the ID of the session that was operated on - stdout: captured standard output from code execution (for execute action) - stderr: any error message if execution failed - result: string representation of the last expression result - execution_time: time taken to execute the code - message: status message about the action performed - Various session info depending on the action</description>
</returns>
<notes>
Important usage rules:
1. PERSISTENCE: Session instances remain active and maintain their state (variables,
imports, function definitions) until explicitly closed with the 'close' action.
This allows for multi-step workflows across multiple tool calls.
2. MULTIPLE SESSIONS: You can run multiple Python sessions concurrently by using
different session_id values. Each session operates independently with its own
namespace.
3. Session interaction MUST begin with 'new_session' action for each session instance.
4. Only one action can be performed per call.
5. CODE EXECUTION:
- Both expressions and statements are supported
- Expressions automatically return their result
- Print statements and stdout are captured
- Variables persist between executions in the same session
- Imports, function definitions, etc. persist in the session
- IPython magic commands are fully supported (%pip, %time, %whos, %%writefile, etc.)
- Line magics (%) and cell magics (%%) work as expected
6. CLOSE: Terminates the session completely and frees memory
7. The Python sessions can operate concurrently with other tools. You may invoke
terminal, browser, or other tools while maintaining active Python sessions.
8. Each session has its own isolated namespace - variables in one session don't
affect others.
</notes>
<examples>
# Create new session for security analysis (default session)
<function=python_action>
<parameter=action>new_session</parameter>
<parameter=code>import hashlib
import base64
import json
print("Security analysis session started")</parameter>
</function>
# Analyze security data in the default session
<function=python_action>
<parameter=action>execute</parameter>
<parameter=code>vulnerability_data = {"cve": "CVE-2024-1234", "severity": "high"}
encoded_payload = base64.b64encode(json.dumps(vulnerability_data).encode())
print(f"Encoded: {encoded_payload.decode()}")</parameter>
</function>
# Long running security scan with custom timeout
<function=python_action>
<parameter=action>execute</parameter>
<parameter=code>import time
# Simulate long-running vulnerability scan
time.sleep(45)
print('Security scan completed!')</parameter>
<parameter=timeout>50</parameter>
</function>
# Use IPython magic commands for package management and profiling
<function=python_action>
<parameter=action>execute</parameter>
<parameter=code>%pip install requests
%time response = requests.get('https://httpbin.org/json')
%whos</parameter>
# Analyze requests for potential vulnerabilities
<function=python_action>
<parameter=action>execute</parameter>
<parameter=code># Filter for POST requests that might contain sensitive data
post_requests = list_requests(
httpql_filter="req.method.eq:POST",
page_size=20
)
# Analyze each POST request for potential issues
for req in post_requests.get('requests', []):
request_id = req['id']
# View the request details
request_details = view_request(request_id, part="request")
# Check for potential SQL injection points
body = request_details.get('body', '')
if any(keyword in body.lower() for keyword in ['select', 'union', 'insert', 'update']):
print(f"Potential SQL injection in request {request_id}")
# Repeat the request with a test payload
test_payload = repeat_request(request_id, {
'body': body + "' OR '1'='1"
})
print(f"Test response status: {test_payload.get('status_code')}")
print("Security analysis complete!")</parameter>
</function>
</examples>
</tool>
</tools>

View File

@@ -0,0 +1,172 @@
import io
import signal
import sys
import threading
from typing import Any
from IPython.core.interactiveshell import InteractiveShell
MAX_STDOUT_LENGTH = 10_000
MAX_STDERR_LENGTH = 5_000
class PythonInstance:
def __init__(self, session_id: str) -> None:
self.session_id = session_id
self.is_running = True
self._execution_lock = threading.Lock()
import os
os.chdir("/workspace")
self.shell = InteractiveShell()
self.shell.init_completer()
self.shell.init_history()
self.shell.init_logger()
self._setup_proxy_functions()
def _setup_proxy_functions(self) -> None:
try:
from strix.tools.proxy import proxy_actions
proxy_functions = [
"list_requests",
"list_sitemap",
"repeat_request",
"scope_rules",
"send_request",
"view_request",
"view_sitemap_entry",
]
proxy_dict = {name: getattr(proxy_actions, name) for name in proxy_functions}
self.shell.user_ns.update(proxy_dict)
except ImportError:
pass
def _validate_session(self) -> dict[str, Any] | None:
if not self.is_running:
return {
"session_id": self.session_id,
"stdout": "",
"stderr": "Session is not running",
"result": None,
}
return None
def _setup_execution_environment(self, timeout: int) -> tuple[Any, io.StringIO, io.StringIO]:
stdout_capture = io.StringIO()
stderr_capture = io.StringIO()
def timeout_handler(signum: int, frame: Any) -> None:
raise TimeoutError(f"Code execution timed out after {timeout} seconds")
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(timeout)
sys.stdout = stdout_capture
sys.stderr = stderr_capture
return old_handler, stdout_capture, stderr_capture
def _cleanup_execution_environment(
self, old_handler: Any, old_stdout: Any, old_stderr: Any
) -> None:
signal.signal(signal.SIGALRM, old_handler)
sys.stdout = old_stdout
sys.stderr = old_stderr
def _truncate_output(self, content: str, max_length: int, suffix: str) -> str:
if len(content) > max_length:
return content[:max_length] + suffix
return content
def _format_execution_result(
self, execution_result: Any, stdout_content: str, stderr_content: str
) -> dict[str, Any]:
stdout = self._truncate_output(
stdout_content, MAX_STDOUT_LENGTH, "... [stdout truncated at 10k chars]"
)
if execution_result.result is not None:
if stdout and not stdout.endswith("\n"):
stdout += "\n"
result_repr = repr(execution_result.result)
result_repr = self._truncate_output(
result_repr, MAX_STDOUT_LENGTH, "... [result truncated at 10k chars]"
)
stdout += result_repr
stdout = self._truncate_output(
stdout, MAX_STDOUT_LENGTH, "... [output truncated at 10k chars]"
)
stderr_content = stderr_content if stderr_content else ""
stderr_content = self._truncate_output(
stderr_content, MAX_STDERR_LENGTH, "... [stderr truncated at 5k chars]"
)
if (
execution_result.error_before_exec or execution_result.error_in_exec
) and not stderr_content:
stderr_content = "Execution error occurred"
return {
"session_id": self.session_id,
"stdout": stdout,
"stderr": stderr_content,
"result": repr(execution_result.result)
if execution_result.result is not None
else None,
}
def _handle_execution_error(self, error: BaseException) -> dict[str, Any]:
error_msg = str(error)
error_msg = self._truncate_output(
error_msg, MAX_STDERR_LENGTH, "... [error truncated at 5k chars]"
)
return {
"session_id": self.session_id,
"stdout": "",
"stderr": error_msg,
"result": None,
}
def execute_code(self, code: str, timeout: int = 30) -> dict[str, Any]:
session_error = self._validate_session()
if session_error:
return session_error
with self._execution_lock:
old_stdout, old_stderr = sys.stdout, sys.stderr
try:
old_handler, stdout_capture, stderr_capture = self._setup_execution_environment(
timeout
)
try:
execution_result = self.shell.run_cell(code, silent=False, store_history=True)
signal.alarm(0)
return self._format_execution_result(
execution_result, stdout_capture.getvalue(), stderr_capture.getvalue()
)
except (TimeoutError, KeyboardInterrupt, SystemExit) as e:
signal.alarm(0)
return self._handle_execution_error(e)
finally:
self._cleanup_execution_environment(old_handler, old_stdout, old_stderr)
def close(self) -> None:
self.is_running = False
self.shell.reset(new_session=False)
def is_alive(self) -> bool:
return self.is_running

View File

@@ -0,0 +1,131 @@
import atexit
import contextlib
import signal
import sys
import threading
from typing import Any
from .python_instance import PythonInstance
class PythonSessionManager:
def __init__(self) -> None:
self.sessions: dict[str, PythonInstance] = {}
self._lock = threading.Lock()
self.default_session_id = "default"
self._register_cleanup_handlers()
def create_session(
self, session_id: str | None = None, initial_code: str | None = None, timeout: int = 30
) -> dict[str, Any]:
if session_id is None:
session_id = self.default_session_id
with self._lock:
if session_id in self.sessions:
raise ValueError(f"Python session '{session_id}' already exists")
session = PythonInstance(session_id)
self.sessions[session_id] = session
if initial_code:
result = session.execute_code(initial_code, timeout)
result["message"] = (
f"Python session '{session_id}' created successfully with initial code"
)
else:
result = {
"session_id": session_id,
"message": f"Python session '{session_id}' created successfully",
}
return result
def execute_code(
self, session_id: str | None = None, code: str | None = None, timeout: int = 30
) -> dict[str, Any]:
if session_id is None:
session_id = self.default_session_id
if not code:
raise ValueError("No code provided for execution")
with self._lock:
if session_id not in self.sessions:
raise ValueError(f"Python session '{session_id}' not found")
session = self.sessions[session_id]
result = session.execute_code(code, timeout)
result["message"] = f"Code executed in session '{session_id}'"
return result
def close_session(self, session_id: str | None = None) -> dict[str, Any]:
if session_id is None:
session_id = self.default_session_id
with self._lock:
if session_id not in self.sessions:
raise ValueError(f"Python session '{session_id}' not found")
session = self.sessions.pop(session_id)
session.close()
return {
"session_id": session_id,
"message": f"Python session '{session_id}' closed successfully",
"is_running": False,
}
def list_sessions(self) -> dict[str, Any]:
with self._lock:
session_info = {}
for sid, session in self.sessions.items():
session_info[sid] = {
"is_running": session.is_running,
"is_alive": session.is_alive(),
}
return {"sessions": session_info, "total_count": len(session_info)}
def cleanup_dead_sessions(self) -> None:
with self._lock:
dead_sessions = []
for sid, session in self.sessions.items():
if not session.is_alive():
dead_sessions.append(sid)
for sid in dead_sessions:
session = self.sessions.pop(sid)
with contextlib.suppress(Exception):
session.close()
def close_all_sessions(self) -> None:
with self._lock:
sessions_to_close = list(self.sessions.values())
self.sessions.clear()
for session in sessions_to_close:
with contextlib.suppress(Exception):
session.close()
def _register_cleanup_handlers(self) -> None:
atexit.register(self.close_all_sessions)
signal.signal(signal.SIGTERM, self._signal_handler)
signal.signal(signal.SIGINT, self._signal_handler)
if hasattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, self._signal_handler)
def _signal_handler(self, _signum: int, _frame: Any) -> None:
self.close_all_sessions()
sys.exit(0)
_python_session_manager = PythonSessionManager()
def get_python_session_manager() -> PythonSessionManager:
return _python_session_manager

196
strix/tools/registry.py Normal file
View File

@@ -0,0 +1,196 @@
import inspect
import logging
import os
from collections.abc import Callable
from functools import wraps
from inspect import signature
from pathlib import Path
from typing import Any
tools: list[dict[str, Any]] = []
_tools_by_name: dict[str, Callable[..., Any]] = {}
logger = logging.getLogger(__name__)
class ImplementedInClientSideOnlyError(Exception):
def __init__(
self,
message: str = "This tool is implemented in the client side only",
) -> None:
self.message = message
super().__init__(self.message)
def _process_dynamic_content(content: str) -> str:
if "{{DYNAMIC_MODULES_DESCRIPTION}}" in content:
try:
from strix.prompts import generate_modules_description
modules_description = generate_modules_description()
content = content.replace("{{DYNAMIC_MODULES_DESCRIPTION}}", modules_description)
except ImportError:
logger.warning("Could not import prompts utilities for dynamic schema generation")
content = content.replace(
"{{DYNAMIC_MODULES_DESCRIPTION}}",
"List of prompt modules to load for this agent (max 3). Module discovery failed.",
)
return content
def _load_xml_schema(path: Path) -> Any:
if not path.exists():
return None
try:
content = path.read_text()
content = _process_dynamic_content(content)
start_tag = '<tool name="'
end_tag = "</tool>"
tools_dict = {}
pos = 0
while True:
start_pos = content.find(start_tag, pos)
if start_pos == -1:
break
name_start = start_pos + len(start_tag)
name_end = content.find('"', name_start)
if name_end == -1:
break
tool_name = content[name_start:name_end]
end_pos = content.find(end_tag, name_end)
if end_pos == -1:
break
end_pos += len(end_tag)
tool_element = content[start_pos:end_pos]
tools_dict[tool_name] = tool_element
pos = end_pos
if pos >= len(content):
break
except (IndexError, ValueError, UnicodeError) as e:
logger.warning(f"Error loading schema file {path}: {e}")
return None
else:
return tools_dict
def _get_module_name(func: Callable[..., Any]) -> str:
module = inspect.getmodule(func)
if not module:
return "unknown"
module_name = module.__name__
if ".tools." in module_name:
parts = module_name.split(".tools.")[-1].split(".")
if len(parts) >= 1:
return parts[0]
return "unknown"
def register_tool(
func: Callable[..., Any] | None = None, *, sandbox_execution: bool = True
) -> Callable[..., Any]:
def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
func_dict = {
"name": f.__name__,
"function": f,
"module": _get_module_name(f),
"sandbox_execution": sandbox_execution,
}
sandbox_mode = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
if not sandbox_mode:
try:
module_path = Path(inspect.getfile(f))
schema_file_name = f"{module_path.stem}_schema.xml"
schema_path = module_path.parent / schema_file_name
xml_tools = _load_xml_schema(schema_path)
if xml_tools is not None and f.__name__ in xml_tools:
func_dict["xml_schema"] = xml_tools[f.__name__]
else:
func_dict["xml_schema"] = (
f'<tool name="{f.__name__}">'
"<description>Schema not found for tool.</description>"
"</tool>"
)
except (TypeError, FileNotFoundError) as e:
logger.warning(f"Error loading schema for {f.__name__}: {e}")
func_dict["xml_schema"] = (
f'<tool name="{f.__name__}">'
"<description>Error loading schema.</description>"
"</tool>"
)
tools.append(func_dict)
_tools_by_name[str(func_dict["name"])] = f
@wraps(f)
def wrapper(*args: Any, **kwargs: Any) -> Any:
return f(*args, **kwargs)
return wrapper
if func is None:
return decorator
return decorator(func)
def get_tool_by_name(name: str) -> Callable[..., Any] | None:
return _tools_by_name.get(name)
def get_tool_names() -> list[str]:
return list(_tools_by_name.keys())
def needs_agent_state(tool_name: str) -> bool:
tool_func = get_tool_by_name(tool_name)
if not tool_func:
return False
sig = signature(tool_func)
return "agent_state" in sig.parameters
def should_execute_in_sandbox(tool_name: str) -> bool:
for tool in tools:
if tool.get("name") == tool_name:
return bool(tool.get("sandbox_execution", True))
return True
def get_tools_prompt() -> str:
tools_by_module: dict[str, list[dict[str, Any]]] = {}
for tool in tools:
module = tool.get("module", "unknown")
if module not in tools_by_module:
tools_by_module[module] = []
tools_by_module[module].append(tool)
xml_sections = []
for module, module_tools in sorted(tools_by_module.items()):
tag_name = f"{module}_tools"
section_parts = [f"<{tag_name}>"]
for tool in module_tools:
tool_xml = tool.get("xml_schema", "")
if tool_xml:
indented_tool = "\n".join(f" {line}" for line in tool_xml.split("\n"))
section_parts.append(indented_tool)
section_parts.append(f"</{tag_name}>")
xml_sections.append("\n".join(section_parts))
return "\n\n".join(xml_sections)
def clear_registry() -> None:
tools.clear()
_tools_by_name.clear()

View File

@@ -0,0 +1,6 @@
from .reporting_actions import create_vulnerability_report
__all__ = [
"create_vulnerability_report",
]

View File

@@ -0,0 +1,63 @@
from typing import Any
from strix.tools.registry import register_tool
@register_tool(sandbox_execution=False)
def create_vulnerability_report(
title: str,
content: str,
severity: str,
) -> dict[str, Any]:
validation_error = None
if not title or not title.strip():
validation_error = "Title cannot be empty"
elif not content or not content.strip():
validation_error = "Content cannot be empty"
elif not severity or not severity.strip():
validation_error = "Severity cannot be empty"
else:
valid_severities = ["critical", "high", "medium", "low", "info"]
if severity.lower() not in valid_severities:
validation_error = (
f"Invalid severity '{severity}'. Must be one of: {', '.join(valid_severities)}"
)
if validation_error:
return {"success": False, "message": validation_error}
try:
from strix.cli.tracer import get_global_tracer
tracer = get_global_tracer()
if tracer:
report_id = tracer.add_vulnerability_report(
title=title,
content=content,
severity=severity,
)
return {
"success": True,
"message": f"Vulnerability report '{title}' created successfully",
"report_id": report_id,
"severity": severity.lower(),
}
import logging
logging.warning("Global tracer not available - vulnerability report not stored")
return { # noqa: TRY300
"success": True,
"message": f"Vulnerability report '{title}' created successfully (not persisted)",
"warning": "Report could not be persisted - tracer unavailable",
}
except ImportError:
return {
"success": True,
"message": f"Vulnerability report '{title}' created successfully (not persisted)",
"warning": "Report could not be persisted - tracer module unavailable",
}
except (ValueError, TypeError) as e:
return {"success": False, "message": f"Failed to create vulnerability report: {e!s}"}

View File

@@ -0,0 +1,30 @@
<tools>
<tool name="create_vulnerability_report">
<description>Create a vulnerability report for a discovered security issue.
Use this tool to document a specific verified security vulnerability.
Put ALL details in the content field - affected URLs, parameters, proof of concept, remediation steps, CVE references, CVSS scores, technical details, impact assessment, etc.
DO NOT USE:
- For general security observations without specific vulnerabilities
- When you don't have concrete vulnerability details
- When you don't have a proof of concept, or still not 100% sure if it's a vulnerability
- For tracking multiple vulnerabilities (create separate reports)
- For reporting multiple vulnerabilities at once. Use a separate create_vulnerability_report for each vulnerability.
</description>
<parameters>
<parameter name="title" type="string" required="true">
<description>Clear, concise title of the vulnerability</description>
</parameter>
<parameter name="content" type="string" required="true">
<description>Complete vulnerability details including affected URLs, technical details, impact, proof of concept, remediation steps, and any relevant references. Be comprehensive and include everything relevant.</description>
</parameter>
<parameter name="severity" type="string" required="true">
<description>Severity level: critical, high, medium, low, or info</description>
</parameter>
</parameters>
<returns type="Dict[str, Any]">
<description>Response containing success status and message</description>
</returns>
</tool>
</tools>

View File

@@ -0,0 +1,4 @@
from .terminal_actions import terminal_action
__all__ = ["terminal_action"]

View File

@@ -0,0 +1,53 @@
from typing import Any, Literal
from strix.tools.registry import register_tool
from .terminal_manager import get_terminal_manager
TerminalAction = Literal["new_terminal", "send_input", "wait", "close"]
@register_tool
def terminal_action(
action: TerminalAction,
inputs: list[str] | None = None,
time: float | None = None,
terminal_id: str | None = None,
) -> dict[str, Any]:
def _validate_inputs(action_name: str, inputs: list[str] | None) -> None:
if not inputs:
raise ValueError(f"inputs parameter is required for {action_name} action")
def _validate_time(time_param: float | None) -> None:
if time_param is None:
raise ValueError("time parameter is required for wait action")
def _validate_action(action_name: str) -> None:
raise ValueError(f"Unknown action: {action_name}")
manager = get_terminal_manager()
try:
match action:
case "new_terminal":
return manager.create_terminal(terminal_id, inputs)
case "send_input":
_validate_inputs(action, inputs)
assert inputs is not None
return manager.send_input(terminal_id, inputs)
case "wait":
_validate_time(time)
assert time is not None
return manager.wait_terminal(terminal_id, time)
case "close":
return manager.close_terminal(terminal_id)
case _:
_validate_action(action) # type: ignore[unreachable]
except (ValueError, RuntimeError) as e:
return {"error": str(e), "terminal_id": terminal_id, "snapshot": "", "is_running": False}

View File

@@ -0,0 +1,114 @@
<tools>
<tool name="terminal_action">
<description>Perform terminal actions using a terminal emulator instance. Each terminal instance
is PERSISTENT and remains active until explicitly closed, allowing for multi-step
workflows and long-running processes.</description>
<parameters>
<parameter name="action" type="string" required="true">
<description>The terminal action to perform: - new_terminal: Create a new terminal instance. This MUST be the first action for each terminal tab. - send_input: Send keyboard input to the specified terminal. - wait: Pause execution for specified number of seconds. Can be also used to get the current terminal state (screenshot, output, etc.) after using other tools. - close: Close the specified terminal instance. This MUST be the final action for each terminal tab.</description>
</parameter>
<parameter name="inputs" type="string" required="false">
<description>Required for 'new_terminal' and 'send_input' actions: - List of inputs to send to terminal. Each element in the list MUST be one of the following: - Regular text: "hello", "world", etc. - Literal text (not interpreted as special keys): prefix with "literal:" e.g., "literal:Home", "literal:Escape", "literal:Enter" to send these as text - Enter - Space - Backspace - Escape: "Escape", "^[", "C-[" - Tab: "Tab" - Arrow keys: "Left", "Right", "Up", "Down" - Navigation: "Home", "End", "PageUp", "PageDown" - Function keys: "F1" through "F12" Modifier keys supported with prefixes: - ^ or C- : Control (e.g., "^c", "C-c") - S- : Shift (e.g., "S-F6") - A- : Alt (e.g., "A-Home") - Combined modifiers for arrows: "S-A-Up", "C-S-Left" - Inputs MUST in all cases be sent as a LIST of strings, even if you are only sending one input. - Sending Inputs as a single string will NOT work.</description>
</parameter>
<parameter name="time" type="string" required="false">
<description>Required for 'wait' action. Number of seconds to pause execution. Can be fractional (e.g., 0.5 for half a second).</description>
</parameter>
<parameter name="terminal_id" type="string" required="false">
<description>Identifier for the terminal instance. Required for all actions except the first 'new_terminal' action. Allows managing multiple concurrent terminal tabs. - For 'new_terminal': if not provided, a default terminal is created. If provided, creates a new terminal with that ID. - For other actions: specifies which terminal instance to operate on. - Default terminal ID is "default" if not specified.</description>
</parameter>
</parameters>
<returns type="Dict[str, Any]">
<description>Response containing: - snapshot: raw representation of current terminal state where you can see the output of the command - terminal_id: the ID of the terminal instance that was operated on</description>
</returns>
<notes>
Important usage rules:
1. PERSISTENCE: Terminal instances remain active and maintain their state (environment
variables, current directory, running processes) until explicitly closed with the
'close' action. This allows for multi-step workflows across multiple tool calls.
2. MULTIPLE TERMINALS: You can run multiple terminal instances concurrently by using
different terminal_id values. Each terminal operates independently.
3. Terminal interaction MUST begin with 'new_terminal' action for each terminal instance.
4. Only one action can be performed per call.
5. Input handling:
- Regular text is sent as-is
- Literal text: prefix with "literal:" to send special key names as literal text
- Special keys must match supported key names
- Modifier combinations follow specific syntax
- Control can be specified as ^ or C- prefix
- Shift (S-) works with special keys only
- Alt (A-) works with any character/key
6. Wait action:
- Time is specified in seconds
- Can be used to wait for command completion
- Can be fractional (e.g., 0.5 seconds)
- Snapshot and output are captured after the wait
- You should estimate the time it will take to run the command and set the wait time accordingly.
- It can be from a few seconds to a few minutes, choose wisely depending on the command you are running and the task.
7. The terminal can operate concurrently with other tools. You may invoke
browser, proxy, or other tools (in separate assistant messages) while maintaining
active terminal sessions.
8. You do not need to close terminals after you are done, but you can if you want to
free up resources.
9. You MUST end the inputs list with an "Enter" if you want to run the command, as
it is not sent automatically.
10. AUTOMATIC SPACING BEHAVIOR:
- Consecutive regular text inputs have spaces automatically added between them
- This is helpful for shell commands: ["ls", "-la"] becomes "ls -la"
- This causes problems for compound commands: [":", "w", "q"] becomes ": w q"
- Use "literal:" prefix to bypass spacing: [":", "literal:wq"] becomes ":wq"
- Special keys (Enter, Space, etc.) and literal strings never trigger spacing
11. WHEN TO USE LITERAL PREFIX:
- Vim commands: [":", "literal:wq", "Enter"] instead of [":", "w", "q", "Enter"]
- Any sequence where exact character positioning matters
- When you need multiple characters sent as a single unit
12. Do NOT use terminal actions for file editing or writing. Use the replace_in_file,
write_to_file, or read_file tools instead.
</notes>
<examples>
# Create new terminal with Node.js (default terminal)
<function=terminal_action>
<parameter=action>new_terminal</parameter>
<parameter=inputs>["node", "Enter"]</parameter>
</function>
# Create a second (parallel) terminal instance for Python
<function=terminal_action>
<parameter=action>new_terminal</parameter>
<parameter=terminal_id>python_terminal</parameter>
<parameter=inputs>["python3", "Enter"]</parameter>
</function>
# Send command to the default terminal
<function=terminal_action>
<parameter=action>send_input</parameter>
<parameter=inputs>["require('crypto').randomBytes(1000000).toString('hex')",
"Enter"]</parameter>
</function>
# Wait for previous action on default terminal
<function=terminal_action>
<parameter=action>wait</parameter>
<parameter=time>2.0</parameter>
</function>
# Send multiple inputs with special keys to current terminal
<function=terminal_action>
<parameter=action>send_input</parameter>
<parameter=inputs>["sqlmap -u 'http://example.com/page.php?id=1' --batch", "Enter", "y",
"Enter", "n", "Enter", "n", "Enter"]</parameter>
</function>
# WRONG: Vim command with automatic spacing (becomes ": w q")
<function=terminal_action>
<parameter=action>send_input</parameter>
<parameter=inputs>[":", "w", "q", "Enter"]</parameter>
</function>
# CORRECT: Vim command using literal prefix (becomes ":wq")
<function=terminal_action>
<parameter=action>send_input</parameter>
<parameter=inputs>[":", "literal:wq", "Enter"]</parameter>
</function>
</examples>
</tool>
</tools>

View File

@@ -0,0 +1,231 @@
import contextlib
import os
import pty
import select
import signal
import subprocess
import threading
import time
from typing import Any
import pyte
MAX_TERMINAL_SNAPSHOT_LENGTH = 10_000
class TerminalInstance:
def __init__(self, terminal_id: str, initial_command: str | None = None) -> None:
self.terminal_id = terminal_id
self.process: subprocess.Popen[bytes] | None = None
self.master_fd: int | None = None
self.is_running = False
self._output_lock = threading.Lock()
self._reader_thread: threading.Thread | None = None
self.screen = pyte.HistoryScreen(80, 24, history=1000)
self.stream = pyte.ByteStream()
self.stream.attach(self.screen)
self._start_terminal(initial_command)
def _start_terminal(self, initial_command: str | None = None) -> None:
try:
self.master_fd, slave_fd = pty.openpty()
shell = "/bin/bash"
self.process = subprocess.Popen( # noqa: S603
[shell, "-i"],
stdin=slave_fd,
stdout=slave_fd,
stderr=slave_fd,
cwd="/workspace",
preexec_fn=os.setsid, # noqa: PLW1509 - Required for PTY functionality
)
os.close(slave_fd)
self.is_running = True
self._reader_thread = threading.Thread(target=self._read_output, daemon=True)
self._reader_thread.start()
time.sleep(0.5)
if initial_command:
self._write_to_terminal(initial_command)
except (OSError, ValueError) as e:
raise RuntimeError(f"Failed to start terminal: {e}") from e
def _read_output(self) -> None:
while self.is_running and self.master_fd:
try:
ready, _, _ = select.select([self.master_fd], [], [], 0.1)
if ready:
data = os.read(self.master_fd, 4096)
if data:
with self._output_lock, contextlib.suppress(TypeError):
self.stream.feed(data)
else:
break
except (OSError, ValueError):
break
def _write_to_terminal(self, data: str) -> None:
if self.master_fd and self.is_running:
try:
os.write(self.master_fd, data.encode("utf-8"))
except (OSError, ValueError) as e:
raise RuntimeError("Terminal is no longer available") from e
def send_input(self, inputs: list[str]) -> None:
if not self.is_running:
raise RuntimeError("Terminal is not running")
for i, input_item in enumerate(inputs):
if input_item.startswith("literal:"):
literal_text = input_item[8:]
self._write_to_terminal(literal_text)
else:
key_sequence = self._get_key_sequence(input_item)
if key_sequence:
self._write_to_terminal(key_sequence)
else:
self._write_to_terminal(input_item)
time.sleep(0.05)
if (
i < len(inputs) - 1
and not input_item.startswith("literal:")
and not self._is_special_key(input_item)
and not inputs[i + 1].startswith("literal:")
and not self._is_special_key(inputs[i + 1])
):
self._write_to_terminal(" ")
def get_snapshot(self) -> dict[str, Any]:
with self._output_lock:
history_lines = [
"".join(char.data for char in line_dict.values())
for line_dict in self.screen.history.top
]
current_lines = self.screen.display
all_lines = history_lines + current_lines
rendered_output = "\n".join(all_lines)
if len(rendered_output) > MAX_TERMINAL_SNAPSHOT_LENGTH:
rendered_output = rendered_output[-MAX_TERMINAL_SNAPSHOT_LENGTH:]
truncated = True
else:
truncated = False
return {
"terminal_id": self.terminal_id,
"snapshot": rendered_output,
"is_running": self.is_running,
"process_id": self.process.pid if self.process else None,
"truncated": truncated,
}
def wait(self, duration: float) -> dict[str, Any]:
time.sleep(duration)
return self.get_snapshot()
def close(self) -> None:
self.is_running = False
if self.process:
with contextlib.suppress(OSError, ProcessLookupError):
os.killpg(os.getpgid(self.process.pid), signal.SIGTERM)
try:
self.process.wait(timeout=2)
except subprocess.TimeoutExpired:
os.killpg(os.getpgid(self.process.pid), signal.SIGKILL)
self.process.wait()
if self.master_fd:
with contextlib.suppress(OSError):
os.close(self.master_fd)
self.master_fd = None
if self._reader_thread and self._reader_thread.is_alive():
self._reader_thread.join(timeout=1)
def _is_special_key(self, key: str) -> bool:
special_keys = {
"Enter",
"Space",
"Backspace",
"Tab",
"Escape",
"Up",
"Down",
"Left",
"Right",
"Home",
"End",
"PageUp",
"PageDown",
"Insert",
"Delete",
} | {f"F{i}" for i in range(1, 13)}
if key in special_keys:
return True
return bool(key.startswith(("^", "C-", "S-", "A-")))
def _get_key_sequence(self, key: str) -> str | None:
key_map = {
"Enter": "\r",
"Space": " ",
"Backspace": "\x08",
"Tab": "\t",
"Escape": "\x1b",
"Up": "\x1b[A",
"Down": "\x1b[B",
"Right": "\x1b[C",
"Left": "\x1b[D",
"Home": "\x1b[H",
"End": "\x1b[F",
"PageUp": "\x1b[5~",
"PageDown": "\x1b[6~",
"Insert": "\x1b[2~",
"Delete": "\x1b[3~",
"F1": "\x1b[11~",
"F2": "\x1b[12~",
"F3": "\x1b[13~",
"F4": "\x1b[14~",
"F5": "\x1b[15~",
"F6": "\x1b[17~",
"F7": "\x1b[18~",
"F8": "\x1b[19~",
"F9": "\x1b[20~",
"F10": "\x1b[21~",
"F11": "\x1b[23~",
"F12": "\x1b[24~",
}
if key in key_map:
return key_map[key]
if key.startswith("^") and len(key) == 2:
char = key[1].lower()
return chr(ord(char) - ord("a") + 1) if "a" <= char <= "z" else None
if key.startswith("C-") and len(key) == 3:
char = key[2].lower()
return chr(ord(char) - ord("a") + 1) if "a" <= char <= "z" else None
return None
def is_alive(self) -> bool:
if not self.process:
return False
return self.process.poll() is None

View File

@@ -0,0 +1,191 @@
import atexit
import contextlib
import signal
import sys
import threading
from typing import Any
from .terminal_instance import TerminalInstance
class TerminalManager:
def __init__(self) -> None:
self.terminals: dict[str, TerminalInstance] = {}
self._lock = threading.Lock()
self.default_terminal_id = "default"
self._register_cleanup_handlers()
def create_terminal(
self, terminal_id: str | None = None, inputs: list[str] | None = None
) -> dict[str, Any]:
if terminal_id is None:
terminal_id = self.default_terminal_id
with self._lock:
if terminal_id in self.terminals:
raise ValueError(f"Terminal '{terminal_id}' already exists")
initial_command = None
if inputs:
command_parts: list[str] = []
for input_item in inputs:
if input_item == "Enter":
initial_command = " ".join(command_parts) + "\n"
break
if input_item.startswith("literal:"):
command_parts.append(input_item[8:])
elif input_item not in [
"Space",
"Tab",
"Backspace",
]:
command_parts.append(input_item)
try:
terminal = TerminalInstance(terminal_id, initial_command)
self.terminals[terminal_id] = terminal
if inputs and not initial_command:
terminal.send_input(inputs)
result = terminal.wait(2.0)
else:
result = terminal.wait(1.0)
result["message"] = f"Terminal '{terminal_id}' created successfully"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to create terminal '{terminal_id}': {e}") from e
else:
return result
def send_input(
self, terminal_id: str | None = None, inputs: list[str] | None = None
) -> dict[str, Any]:
if terminal_id is None:
terminal_id = self.default_terminal_id
if not inputs:
raise ValueError("No inputs provided")
with self._lock:
if terminal_id not in self.terminals:
raise ValueError(f"Terminal '{terminal_id}' not found")
terminal = self.terminals[terminal_id]
try:
terminal.send_input(inputs)
result = terminal.wait(2.0)
result["message"] = f"Input sent to terminal '{terminal_id}'"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to send input to terminal '{terminal_id}': {e}") from e
else:
return result
def wait_terminal(
self, terminal_id: str | None = None, duration: float = 1.0
) -> dict[str, Any]:
if terminal_id is None:
terminal_id = self.default_terminal_id
with self._lock:
if terminal_id not in self.terminals:
raise ValueError(f"Terminal '{terminal_id}' not found")
terminal = self.terminals[terminal_id]
try:
result = terminal.wait(duration)
result["message"] = f"Waited {duration}s on terminal '{terminal_id}'"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to wait on terminal '{terminal_id}': {e}") from e
else:
return result
def close_terminal(self, terminal_id: str | None = None) -> dict[str, Any]:
if terminal_id is None:
terminal_id = self.default_terminal_id
with self._lock:
if terminal_id not in self.terminals:
raise ValueError(f"Terminal '{terminal_id}' not found")
terminal = self.terminals.pop(terminal_id)
try:
terminal.close()
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to close terminal '{terminal_id}': {e}") from e
else:
return {
"terminal_id": terminal_id,
"message": f"Terminal '{terminal_id}' closed successfully",
"snapshot": "",
"is_running": False,
}
def get_terminal_snapshot(self, terminal_id: str | None = None) -> dict[str, Any]:
if terminal_id is None:
terminal_id = self.default_terminal_id
with self._lock:
if terminal_id not in self.terminals:
raise ValueError(f"Terminal '{terminal_id}' not found")
terminal = self.terminals[terminal_id]
return terminal.get_snapshot()
def list_terminals(self) -> dict[str, Any]:
with self._lock:
terminal_info = {}
for tid, terminal in self.terminals.items():
terminal_info[tid] = {
"is_running": terminal.is_running,
"is_alive": terminal.is_alive(),
"process_id": terminal.process.pid if terminal.process else None,
}
return {"terminals": terminal_info, "total_count": len(terminal_info)}
def cleanup_dead_terminals(self) -> None:
with self._lock:
dead_terminals = []
for tid, terminal in self.terminals.items():
if not terminal.is_alive():
dead_terminals.append(tid)
for tid in dead_terminals:
terminal = self.terminals.pop(tid)
with contextlib.suppress(Exception):
terminal.close()
def close_all_terminals(self) -> None:
with self._lock:
terminals_to_close = list(self.terminals.values())
self.terminals.clear()
for terminal in terminals_to_close:
with contextlib.suppress(Exception):
terminal.close()
def _register_cleanup_handlers(self) -> None:
atexit.register(self.close_all_terminals)
signal.signal(signal.SIGTERM, self._signal_handler)
signal.signal(signal.SIGINT, self._signal_handler)
if hasattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, self._signal_handler)
def _signal_handler(self, _signum: int, _frame: Any) -> None:
self.close_all_terminals()
sys.exit(0)
_terminal_manager = TerminalManager()
def get_terminal_manager() -> TerminalManager:
return _terminal_manager

View File

@@ -0,0 +1,4 @@
from .thinking_actions import think
__all__ = ["think"]

View File

@@ -0,0 +1,18 @@
from typing import Any
from strix.tools.registry import register_tool
@register_tool(sandbox_execution=False)
def think(thought: str) -> dict[str, Any]:
try:
if not thought or not thought.strip():
return {"success": False, "message": "Thought cannot be empty"}
return {
"success": True,
"message": f"Thought recorded successfully with {len(thought.strip())} characters",
}
except (ValueError, TypeError) as e:
return {"success": False, "message": f"Failed to record thought: {e!s}"}

View File

@@ -0,0 +1,52 @@
<tools>
<tool name="think">
<description>Use the tool to think about something. It will not obtain new information or change the
database. Use it when complex reasoning or some cache memory is needed.</description>
<details>This tool creates dedicated space for structured thinking during complex tasks,
particularly useful for:
- Tool output analysis: When you need to carefully process the output of previous tool calls
- Policy-heavy environments: When you need to follow detailed guidelines and verify compliance
- Sequential decision making: When each action builds on previous ones and mistakes are costly
- Multi-step problem solving: When you need to break down complex problems into manageable steps</details>
<parameters>
<parameter name="thought" type="string" required="true">
<description>The thought or reasoning to record</description>
</parameter>
</parameters>
<returns type="Dict[str, Any]">
<description>Response containing: - success: Whether the thought was recorded successfully - message: Confirmation message with character count or error details</description>
</returns>
<examples>
# Planning and strategy
<function=think>
<parameter=thought>I need to analyze the scan results systematically. First, let me review
the open ports: 22 (SSH), 80 (HTTP), 443 (HTTPS), and 3306 (MySQL). The MySQL port being
externally accessible is a high priority security concern. I should check for default
credentials and version information. For the web services, I need to enumerate
directories and test for common web vulnerabilities.</parameter>
</function>
# Analysis of tool outputs
<function=think>
<parameter=thought>The Nmap scan revealed 15 open ports, but three stand out as concerning:
- Port 3306 (MySQL) - Database should not be exposed externally
- Port 5432 (PostgreSQL) - Another database port that's risky when public
- Port 6379 (Redis) - Often misconfigured and can lead to data exposure
I should prioritize testing these database services for authentication bypass and
information disclosure vulnerabilities.</parameter>
</function>
# Decision making and next steps
<function=think>
<parameter=thought>Based on the vulnerability scan results, I've identified several critical
issues that need immediate attention:
1. SQL injection in the login form (confirmed with sqlmap)
2. Reflected XSS in the search parameter
3. Directory traversal in the file upload function
I should document these findings with proof-of-concept exploits and assign appropriate
CVSS scores. The SQL injection poses the highest risk due to potential data
exfiltration.</parameter>
</function>
</examples>
</tool>
</tools>

View File

@@ -0,0 +1,4 @@
from .web_search_actions import web_search
__all__ = ["web_search"]

View File

@@ -0,0 +1,80 @@
import os
from typing import Any
import requests
from strix.tools.registry import register_tool
SYSTEM_PROMPT = """You are assisting a cybersecurity agent specialized in vulnerability scanning
and security assessment running on Kali Linux. When responding to search queries:
1. Prioritize cybersecurity-relevant information including:
- Vulnerability details (CVEs, CVSS scores, impact)
- Security tools, techniques, and methodologies
- Exploit information and proof-of-concepts
- Security best practices and mitigations
- Penetration testing approaches
- Web application security findings
2. Provide technical depth appropriate for security professionals
3. Include specific versions, configurations, and technical details when available
4. Focus on actionable intelligence for security assessment
5. Cite reliable security sources (NIST, OWASP, CVE databases, security vendors)
6. When providing commands or installation instructions, prioritize Kali Linux compatibility
and use apt package manager or tools pre-installed in Kali
7. Be detailed and specific - avoid general answers. Always include concrete code examples,
command-line instructions, configuration snippets, or practical implementation steps
when applicable
Structure your response to be comprehensive yet concise, emphasizing the most critical
security implications and details."""
@register_tool(sandbox_execution=False)
def web_search(query: str) -> dict[str, Any]:
try:
api_key = os.getenv("PERPLEXITY_API_KEY")
if not api_key:
return {
"success": False,
"message": "PERPLEXITY_API_KEY environment variable not set",
"results": [],
}
url = "https://api.perplexity.ai/chat/completions"
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
payload = {
"model": "sonar-reasoning",
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": query},
],
}
response = requests.post(url, headers=headers, json=payload, timeout=300)
response.raise_for_status()
response_data = response.json()
content = response_data["choices"][0]["message"]["content"]
except requests.exceptions.Timeout:
return {"success": False, "message": "Request timed out", "results": []}
except requests.exceptions.RequestException as e:
return {"success": False, "message": f"API request failed: {e!s}", "results": []}
except KeyError as e:
return {
"success": False,
"message": f"Unexpected API response format: missing {e!s}",
"results": [],
}
except Exception as e: # noqa: BLE001
return {"success": False, "message": f"Web search failed: {e!s}", "results": []}
else:
return {
"success": True,
"query": query,
"content": content,
"message": "Web search completed successfully",
}

View File

@@ -0,0 +1,83 @@
<tools>
<tool name="web_search">
<description>Search the web using Perplexity AI for real-time information and current events.
This is your PRIMARY research tool - use it extensively and liberally for:
- Current vulnerabilities, CVEs, and security advisories
- Latest attack techniques, exploits, and proof-of-concepts
- Technology-specific security research and documentation
- Target reconnaissance and OSINT gathering
- Security tool documentation and usage guides
- Incident response and threat intelligence
- Compliance frameworks and security standards
- Bug bounty reports and security research findings
- Security conference talks and research papers
The tool provides intelligent, contextual responses with current information that may not be in your training data. Use it early and often during security assessments to gather the most up-to-date factual information.</description>
<details>This tool leverages Perplexity AI's sonar-reasoning model to search the web and provide intelligent, contextual responses to queries. It's essential for effective cybersecurity work as it provides access to the latest vulnerabilities, attack vectors, security tools, and defensive techniques. The AI understands security context and can synthesize information from multiple sources.</details>
<parameters>
<parameter name="query" type="string" required="true">
<description>The search query or question you want to research. Be specific and include relevant technical terms, version numbers, or context for better results. Make it as detailed as possible, with the context of the current security assessment.</description>
</parameter>
</parameters>
<returns type="Dict[str, Any]">
<description>Response containing: - success: Whether the search was successful - query: The original search query - content: AI-generated response with current information - message: Status message</description>
</returns>
<examples>
# Found specific service version during reconnaissance
<function=web_search>
<parameter=query>I found OpenSSH 7.4 running on port 22. Are there any known exploits or privilege escalation techniques for this specific version?</parameter>
</function>
# Encountered WAF blocking attempts
<function=web_search>
<parameter=query>Cloudflare is blocking my SQLmap attempts on this login form. What are the latest bypass techniques for Cloudflare WAF in 2024?</parameter>
</function>
# Need to exploit discovered CMS
<function=web_search>
<parameter=query>Target is running WordPress 5.8.3 with WooCommerce 6.1.1. What are the current RCE exploits for this combination?</parameter>
</function>
# Stuck on privilege escalation
<function=web_search>
<parameter=query>I have low-privilege shell on Ubuntu 20.04 with kernel 5.4.0-74-generic. What local privilege escalation exploits work for this exact kernel version?</parameter>
</function>
# Need lateral movement in Active Directory
<function=web_search>
<parameter=query>I compromised a domain user account in Windows Server 2019 AD environment. What are the best techniques to escalate to Domain Admin without triggering EDR?</parameter>
</function>
# Encountered specific error during exploitation
<function=web_search>
<parameter=query>Getting "Access denied" when trying to upload webshell to IIS 10.0. What are alternative file upload bypass techniques for Windows IIS?</parameter>
</function>
# Need to bypass endpoint protection
<function=web_search>
<parameter=query>Target has CrowdStrike Falcon running. What are the latest techniques to bypass this EDR for payload execution and persistence?</parameter>
</function>
# Research target's infrastructure for attack surface
<function=web_search>
<parameter=query>I found target company "AcmeCorp" uses Office 365 and Azure. What are the common misconfigurations and attack vectors for this cloud setup?</parameter>
</function>
# Found interesting subdomain during recon
<function=web_search>
<parameter=query>Discovered staging.target.com running Jenkins 2.401.3. What are the current authentication bypass and RCE exploits for this Jenkins version?</parameter>
</function>
# Need alternative tools when primary fails
<function=web_search>
<parameter=query>Nmap is being detected and blocked by the target's IPS. What are stealthy alternatives for port scanning that evade modern intrusion prevention systems?</parameter>
</function>
# Finding best security tools for specific tasks
<function=web_search>
<parameter=query>What is the best Python pip package in 2025 for JWT security testing and manipulation, including cracking weak secrets and algorithm confusion attacks?</parameter>
</function>
</examples>
</tool>
</tools>