Open-source release for Alpha version
This commit is contained in:
64
strix/tools/__init__.py
Normal file
64
strix/tools/__init__.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import os
|
||||
|
||||
from .executor import (
|
||||
execute_tool,
|
||||
execute_tool_invocation,
|
||||
execute_tool_with_validation,
|
||||
extract_screenshot_from_result,
|
||||
process_tool_invocations,
|
||||
remove_screenshot_from_result,
|
||||
validate_tool_availability,
|
||||
)
|
||||
from .registry import (
|
||||
ImplementedInClientSideOnlyError,
|
||||
get_tool_by_name,
|
||||
get_tool_names,
|
||||
get_tools_prompt,
|
||||
needs_agent_state,
|
||||
register_tool,
|
||||
tools,
|
||||
)
|
||||
|
||||
|
||||
SANDBOX_MODE = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
|
||||
|
||||
HAS_PERPLEXITY_API = bool(os.getenv("PERPLEXITY_API_KEY"))
|
||||
|
||||
if not SANDBOX_MODE:
|
||||
from .agents_graph import * # noqa: F403
|
||||
from .browser import * # noqa: F403
|
||||
from .file_edit import * # noqa: F403
|
||||
from .finish import * # noqa: F403
|
||||
from .notes import * # noqa: F403
|
||||
from .proxy import * # noqa: F403
|
||||
from .python import * # noqa: F403
|
||||
from .reporting import * # noqa: F403
|
||||
from .terminal import * # noqa: F403
|
||||
from .thinking import * # noqa: F403
|
||||
|
||||
if HAS_PERPLEXITY_API:
|
||||
from .web_search import * # noqa: F403
|
||||
else:
|
||||
from .browser import * # noqa: F403
|
||||
from .file_edit import * # noqa: F403
|
||||
from .notes import * # noqa: F403
|
||||
from .proxy import * # noqa: F403
|
||||
from .python import * # noqa: F403
|
||||
from .terminal import * # noqa: F403
|
||||
|
||||
__all__ = [
|
||||
"ImplementedInClientSideOnlyError",
|
||||
"execute_tool",
|
||||
"execute_tool_invocation",
|
||||
"execute_tool_with_validation",
|
||||
"extract_screenshot_from_result",
|
||||
"get_tool_by_name",
|
||||
"get_tool_names",
|
||||
"get_tools_prompt",
|
||||
"needs_agent_state",
|
||||
"process_tool_invocations",
|
||||
"register_tool",
|
||||
"remove_screenshot_from_result",
|
||||
"tools",
|
||||
"validate_tool_availability",
|
||||
]
|
||||
16
strix/tools/agents_graph/__init__.py
Normal file
16
strix/tools/agents_graph/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from .agents_graph_actions import (
|
||||
agent_finish,
|
||||
create_agent,
|
||||
send_message_to_agent,
|
||||
view_agent_graph,
|
||||
wait_for_message,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"agent_finish",
|
||||
"create_agent",
|
||||
"send_message_to_agent",
|
||||
"view_agent_graph",
|
||||
"wait_for_message",
|
||||
]
|
||||
610
strix/tools/agents_graph/agents_graph_actions.py
Normal file
610
strix/tools/agents_graph/agents_graph_actions.py
Normal file
@@ -0,0 +1,610 @@
|
||||
import threading
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
from strix.tools.registry import register_tool
|
||||
|
||||
|
||||
_agent_graph: dict[str, Any] = {
|
||||
"nodes": {},
|
||||
"edges": [],
|
||||
}
|
||||
|
||||
_root_agent_id: str | None = None
|
||||
|
||||
_agent_messages: dict[str, list[dict[str, Any]]] = {}
|
||||
|
||||
_running_agents: dict[str, threading.Thread] = {}
|
||||
|
||||
_agent_instances: dict[str, Any] = {}
|
||||
|
||||
_agent_states: dict[str, Any] = {}
|
||||
|
||||
|
||||
def _run_agent_in_thread(
|
||||
agent: Any, state: Any, inherited_messages: list[dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
if inherited_messages:
|
||||
state.add_message("user", "<inherited_context_from_parent>")
|
||||
for msg in inherited_messages:
|
||||
state.add_message(msg["role"], msg["content"])
|
||||
state.add_message("user", "</inherited_context_from_parent>")
|
||||
|
||||
parent_info = _agent_graph["nodes"].get(state.parent_id, {})
|
||||
parent_name = parent_info.get("name", "Unknown Parent")
|
||||
|
||||
context_status = (
|
||||
"inherited conversation context from your parent for background understanding"
|
||||
if inherited_messages
|
||||
else "started with a fresh context"
|
||||
)
|
||||
|
||||
task_xml = f"""<agent_delegation>
|
||||
<identity>
|
||||
⚠️ You are NOT your parent agent. You are a NEW, SEPARATE sub-agent (not root).
|
||||
|
||||
Your Info: {state.agent_name} ({state.agent_id})
|
||||
Parent Info: {parent_name} ({state.parent_id})
|
||||
</identity>
|
||||
|
||||
<your_task>{state.task}</your_task>
|
||||
|
||||
<instructions>
|
||||
- You have {context_status}
|
||||
- Inherited context is for BACKGROUND ONLY - don't continue parent's work
|
||||
- Focus EXCLUSIVELY on your delegated task above
|
||||
- Work independently with your own approach
|
||||
- Use agent_finish when complete to report back to parent
|
||||
- You are a SPECIALIST for this specific task
|
||||
</instructions>
|
||||
</agent_delegation>"""
|
||||
|
||||
state.add_message("user", task_xml)
|
||||
|
||||
_agent_states[state.agent_id] = state
|
||||
|
||||
_agent_graph["nodes"][state.agent_id]["state"] = state.model_dump()
|
||||
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
result = loop.run_until_complete(agent.agent_loop(state.task))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
except Exception as e:
|
||||
_agent_graph["nodes"][state.agent_id]["status"] = "error"
|
||||
_agent_graph["nodes"][state.agent_id]["finished_at"] = datetime.now(UTC).isoformat()
|
||||
_agent_graph["nodes"][state.agent_id]["result"] = {"error": str(e)}
|
||||
_running_agents.pop(state.agent_id, None)
|
||||
_agent_instances.pop(state.agent_id, None)
|
||||
raise
|
||||
else:
|
||||
if state.stop_requested:
|
||||
_agent_graph["nodes"][state.agent_id]["status"] = "stopped"
|
||||
else:
|
||||
_agent_graph["nodes"][state.agent_id]["status"] = "completed"
|
||||
_agent_graph["nodes"][state.agent_id]["finished_at"] = datetime.now(UTC).isoformat()
|
||||
_agent_graph["nodes"][state.agent_id]["result"] = result
|
||||
_running_agents.pop(state.agent_id, None)
|
||||
_agent_instances.pop(state.agent_id, None)
|
||||
|
||||
return {"result": result}
|
||||
|
||||
|
||||
@register_tool(sandbox_execution=False)
|
||||
def view_agent_graph(agent_state: Any) -> dict[str, Any]:
|
||||
try:
|
||||
structure_lines = ["=== AGENT GRAPH STRUCTURE ==="]
|
||||
|
||||
def _build_tree(agent_id: str, depth: int = 0) -> None:
|
||||
node = _agent_graph["nodes"][agent_id]
|
||||
indent = " " * depth
|
||||
|
||||
you_indicator = " ← This is you" if agent_id == agent_state.agent_id else ""
|
||||
|
||||
structure_lines.append(f"{indent}* {node['name']} ({agent_id}){you_indicator}")
|
||||
structure_lines.append(f"{indent} Task: {node['task']}")
|
||||
structure_lines.append(f"{indent} Status: {node['status']}")
|
||||
|
||||
children = [
|
||||
edge["to"]
|
||||
for edge in _agent_graph["edges"]
|
||||
if edge["from"] == agent_id and edge["type"] == "delegation"
|
||||
]
|
||||
|
||||
if children:
|
||||
structure_lines.append(f"{indent} Children:")
|
||||
for child_id in children:
|
||||
_build_tree(child_id, depth + 2)
|
||||
|
||||
root_agent_id = _root_agent_id
|
||||
if not root_agent_id and _agent_graph["nodes"]:
|
||||
for agent_id, node in _agent_graph["nodes"].items():
|
||||
if node.get("parent_id") is None:
|
||||
root_agent_id = agent_id
|
||||
break
|
||||
if not root_agent_id:
|
||||
root_agent_id = next(iter(_agent_graph["nodes"].keys()))
|
||||
|
||||
if root_agent_id and root_agent_id in _agent_graph["nodes"]:
|
||||
_build_tree(root_agent_id)
|
||||
else:
|
||||
structure_lines.append("No agents in the graph yet")
|
||||
|
||||
graph_structure = "\n".join(structure_lines)
|
||||
|
||||
total_nodes = len(_agent_graph["nodes"])
|
||||
running_count = sum(
|
||||
1 for node in _agent_graph["nodes"].values() if node["status"] == "running"
|
||||
)
|
||||
waiting_count = sum(
|
||||
1 for node in _agent_graph["nodes"].values() if node["status"] == "waiting"
|
||||
)
|
||||
stopping_count = sum(
|
||||
1 for node in _agent_graph["nodes"].values() if node["status"] == "stopping"
|
||||
)
|
||||
completed_count = sum(
|
||||
1 for node in _agent_graph["nodes"].values() if node["status"] == "completed"
|
||||
)
|
||||
stopped_count = sum(
|
||||
1 for node in _agent_graph["nodes"].values() if node["status"] == "stopped"
|
||||
)
|
||||
failed_count = sum(
|
||||
1 for node in _agent_graph["nodes"].values() if node["status"] in ["failed", "error"]
|
||||
)
|
||||
|
||||
except Exception as e: # noqa: BLE001
|
||||
return {
|
||||
"error": f"Failed to view agent graph: {e}",
|
||||
"graph_structure": "Error retrieving graph structure",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"graph_structure": graph_structure,
|
||||
"summary": {
|
||||
"total_agents": total_nodes,
|
||||
"running": running_count,
|
||||
"waiting": waiting_count,
|
||||
"stopping": stopping_count,
|
||||
"completed": completed_count,
|
||||
"stopped": stopped_count,
|
||||
"failed": failed_count,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@register_tool(sandbox_execution=False)
|
||||
def create_agent(
|
||||
agent_state: Any,
|
||||
task: str,
|
||||
name: str,
|
||||
inherit_context: bool = True,
|
||||
prompt_modules: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
parent_id = agent_state.agent_id
|
||||
|
||||
module_list = []
|
||||
if prompt_modules:
|
||||
module_list = [m.strip() for m in prompt_modules.split(",") if m.strip()]
|
||||
|
||||
if "root_agent" in module_list:
|
||||
return {
|
||||
"success": False,
|
||||
"error": (
|
||||
"The 'root_agent' module is reserved for the main agent "
|
||||
"and cannot be used by sub-agents"
|
||||
),
|
||||
"agent_id": None,
|
||||
}
|
||||
|
||||
if len(module_list) > 3:
|
||||
return {
|
||||
"success": False,
|
||||
"error": (
|
||||
"Cannot specify more than 3 prompt modules for an agent "
|
||||
"(use comma-separated format)"
|
||||
),
|
||||
"agent_id": None,
|
||||
}
|
||||
|
||||
if module_list:
|
||||
from strix.prompts import get_all_module_names, validate_module_names
|
||||
|
||||
validation = validate_module_names(module_list)
|
||||
if validation["invalid"]:
|
||||
available_modules = list(get_all_module_names())
|
||||
return {
|
||||
"success": False,
|
||||
"error": (
|
||||
f"Invalid prompt modules: {validation['invalid']}. "
|
||||
f"Available modules: {', '.join(available_modules)}"
|
||||
),
|
||||
"agent_id": None,
|
||||
}
|
||||
|
||||
from strix.agents import StrixAgent
|
||||
from strix.agents.state import AgentState
|
||||
from strix.llm.config import LLMConfig
|
||||
|
||||
state = AgentState(task=task, agent_name=name, parent_id=parent_id, max_iterations=200)
|
||||
|
||||
llm_config = LLMConfig(prompt_modules=module_list)
|
||||
agent = StrixAgent(
|
||||
{
|
||||
"llm_config": llm_config,
|
||||
"state": state,
|
||||
}
|
||||
)
|
||||
|
||||
inherited_messages = []
|
||||
if inherit_context:
|
||||
inherited_messages = agent_state.get_conversation_history()
|
||||
|
||||
_agent_instances[state.agent_id] = agent
|
||||
|
||||
thread = threading.Thread(
|
||||
target=_run_agent_in_thread,
|
||||
args=(agent, state, inherited_messages),
|
||||
daemon=True,
|
||||
name=f"Agent-{name}-{state.agent_id}",
|
||||
)
|
||||
thread.start()
|
||||
_running_agents[state.agent_id] = thread
|
||||
|
||||
except Exception as e: # noqa: BLE001
|
||||
return {"success": False, "error": f"Failed to create agent: {e}", "agent_id": None}
|
||||
else:
|
||||
return {
|
||||
"success": True,
|
||||
"agent_id": state.agent_id,
|
||||
"message": f"Agent '{name}' created and started asynchronously",
|
||||
"agent_info": {
|
||||
"id": state.agent_id,
|
||||
"name": name,
|
||||
"status": "running",
|
||||
"parent_id": parent_id,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@register_tool(sandbox_execution=False)
|
||||
def send_message_to_agent(
|
||||
agent_state: Any,
|
||||
target_agent_id: str,
|
||||
message: str,
|
||||
message_type: Literal["query", "instruction", "information"] = "information",
|
||||
priority: Literal["low", "normal", "high", "urgent"] = "normal",
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
if target_agent_id not in _agent_graph["nodes"]:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Target agent '{target_agent_id}' not found in graph",
|
||||
"message_id": None,
|
||||
}
|
||||
|
||||
sender_id = agent_state.agent_id
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
message_id = f"msg_{uuid4().hex[:8]}"
|
||||
message_data = {
|
||||
"id": message_id,
|
||||
"from": sender_id,
|
||||
"to": target_agent_id,
|
||||
"content": message,
|
||||
"message_type": message_type,
|
||||
"priority": priority,
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"delivered": False,
|
||||
"read": False,
|
||||
}
|
||||
|
||||
if target_agent_id not in _agent_messages:
|
||||
_agent_messages[target_agent_id] = []
|
||||
|
||||
_agent_messages[target_agent_id].append(message_data)
|
||||
|
||||
_agent_graph["edges"].append(
|
||||
{
|
||||
"from": sender_id,
|
||||
"to": target_agent_id,
|
||||
"type": "message",
|
||||
"message_id": message_id,
|
||||
"message_type": message_type,
|
||||
"priority": priority,
|
||||
"created_at": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
message_data["delivered"] = True
|
||||
|
||||
target_name = _agent_graph["nodes"][target_agent_id]["name"]
|
||||
sender_name = _agent_graph["nodes"][sender_id]["name"]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message_id": message_id,
|
||||
"message": f"Message sent from '{sender_name}' to '{target_name}'",
|
||||
"delivery_status": "delivered",
|
||||
"target_agent": {
|
||||
"id": target_agent_id,
|
||||
"name": target_name,
|
||||
"status": _agent_graph["nodes"][target_agent_id]["status"],
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e: # noqa: BLE001
|
||||
return {"success": False, "error": f"Failed to send message: {e}", "message_id": None}
|
||||
|
||||
|
||||
@register_tool(sandbox_execution=False)
|
||||
def agent_finish(
|
||||
agent_state: Any,
|
||||
result_summary: str,
|
||||
findings: list[str] | None = None,
|
||||
success: bool = True,
|
||||
report_to_parent: bool = True,
|
||||
final_recommendations: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
if not hasattr(agent_state, "parent_id") or agent_state.parent_id is None:
|
||||
return {
|
||||
"agent_completed": False,
|
||||
"error": (
|
||||
"This tool can only be used by subagents. "
|
||||
"Root/main agents must use finish_scan instead."
|
||||
),
|
||||
"parent_notified": False,
|
||||
}
|
||||
|
||||
agent_id = agent_state.agent_id
|
||||
|
||||
if agent_id not in _agent_graph["nodes"]:
|
||||
return {"agent_completed": False, "error": "Current agent not found in graph"}
|
||||
|
||||
agent_node = _agent_graph["nodes"][agent_id]
|
||||
|
||||
agent_node["status"] = "finished" if success else "failed"
|
||||
agent_node["finished_at"] = datetime.now(UTC).isoformat()
|
||||
agent_node["result"] = {
|
||||
"summary": result_summary,
|
||||
"findings": findings or [],
|
||||
"success": success,
|
||||
"recommendations": final_recommendations or [],
|
||||
}
|
||||
|
||||
parent_notified = False
|
||||
|
||||
if report_to_parent and agent_node["parent_id"]:
|
||||
parent_id = agent_node["parent_id"]
|
||||
|
||||
if parent_id in _agent_graph["nodes"]:
|
||||
findings_xml = "\n".join(
|
||||
f" <finding>{finding}</finding>" for finding in (findings or [])
|
||||
)
|
||||
recommendations_xml = "\n".join(
|
||||
f" <recommendation>{rec}</recommendation>"
|
||||
for rec in (final_recommendations or [])
|
||||
)
|
||||
|
||||
report_message = f"""<agent_completion_report>
|
||||
<agent_info>
|
||||
<agent_name>{agent_node["name"]}</agent_name>
|
||||
<agent_id>{agent_id}</agent_id>
|
||||
<task>{agent_node["task"]}</task>
|
||||
<status>{"SUCCESS" if success else "FAILED"}</status>
|
||||
<completion_time>{agent_node["finished_at"]}</completion_time>
|
||||
</agent_info>
|
||||
<results>
|
||||
<summary>{result_summary}</summary>
|
||||
<findings>
|
||||
{findings_xml}
|
||||
</findings>
|
||||
<recommendations>
|
||||
{recommendations_xml}
|
||||
</recommendations>
|
||||
</results>
|
||||
</agent_completion_report>"""
|
||||
|
||||
if parent_id not in _agent_messages:
|
||||
_agent_messages[parent_id] = []
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
_agent_messages[parent_id].append(
|
||||
{
|
||||
"id": f"report_{uuid4().hex[:8]}",
|
||||
"from": agent_id,
|
||||
"to": parent_id,
|
||||
"content": report_message,
|
||||
"message_type": "information",
|
||||
"priority": "high",
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"delivered": True,
|
||||
"read": False,
|
||||
}
|
||||
)
|
||||
|
||||
parent_notified = True
|
||||
|
||||
_running_agents.pop(agent_id, None)
|
||||
|
||||
return {
|
||||
"agent_completed": True,
|
||||
"parent_notified": parent_notified,
|
||||
"completion_summary": {
|
||||
"agent_id": agent_id,
|
||||
"agent_name": agent_node["name"],
|
||||
"task": agent_node["task"],
|
||||
"success": success,
|
||||
"findings_count": len(findings or []),
|
||||
"has_recommendations": bool(final_recommendations),
|
||||
"finished_at": agent_node["finished_at"],
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e: # noqa: BLE001
|
||||
return {
|
||||
"agent_completed": False,
|
||||
"error": f"Failed to complete agent: {e}",
|
||||
"parent_notified": False,
|
||||
}
|
||||
|
||||
|
||||
def stop_agent(agent_id: str) -> dict[str, Any]:
|
||||
try:
|
||||
if agent_id not in _agent_graph["nodes"]:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Agent '{agent_id}' not found in graph",
|
||||
"agent_id": agent_id,
|
||||
}
|
||||
|
||||
agent_node = _agent_graph["nodes"][agent_id]
|
||||
|
||||
if agent_node["status"] in ["completed", "error", "failed", "stopped"]:
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Agent '{agent_node['name']}' was already stopped",
|
||||
"agent_id": agent_id,
|
||||
"previous_status": agent_node["status"],
|
||||
}
|
||||
|
||||
if agent_id in _agent_states:
|
||||
agent_state = _agent_states[agent_id]
|
||||
agent_state.request_stop()
|
||||
|
||||
if agent_id in _agent_instances:
|
||||
agent_instance = _agent_instances[agent_id]
|
||||
if hasattr(agent_instance, "state"):
|
||||
agent_instance.state.request_stop()
|
||||
if hasattr(agent_instance, "cancel_current_execution"):
|
||||
agent_instance.cancel_current_execution()
|
||||
|
||||
agent_node["status"] = "stopping"
|
||||
|
||||
try:
|
||||
from strix.cli.tracer import get_global_tracer
|
||||
|
||||
tracer = get_global_tracer()
|
||||
if tracer:
|
||||
tracer.update_agent_status(agent_id, "stopping")
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
agent_node["result"] = {
|
||||
"summary": "Agent stop requested by user",
|
||||
"success": False,
|
||||
"stopped_by_user": True,
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Stop request sent to agent '{agent_node['name']}'",
|
||||
"agent_id": agent_id,
|
||||
"agent_name": agent_node["name"],
|
||||
"note": "Agent will stop gracefully after current iteration",
|
||||
}
|
||||
|
||||
except Exception as e: # noqa: BLE001
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Failed to stop agent: {e}",
|
||||
"agent_id": agent_id,
|
||||
}
|
||||
|
||||
|
||||
def send_user_message_to_agent(agent_id: str, message: str) -> dict[str, Any]:
|
||||
try:
|
||||
if agent_id not in _agent_graph["nodes"]:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Agent '{agent_id}' not found in graph",
|
||||
"agent_id": agent_id,
|
||||
}
|
||||
|
||||
agent_node = _agent_graph["nodes"][agent_id]
|
||||
|
||||
if agent_id not in _agent_messages:
|
||||
_agent_messages[agent_id] = []
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
message_data = {
|
||||
"id": f"user_msg_{uuid4().hex[:8]}",
|
||||
"from": "user",
|
||||
"to": agent_id,
|
||||
"content": message,
|
||||
"message_type": "instruction",
|
||||
"priority": "high",
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"delivered": True,
|
||||
"read": False,
|
||||
}
|
||||
|
||||
_agent_messages[agent_id].append(message_data)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Message sent to agent '{agent_node['name']}'",
|
||||
"agent_id": agent_id,
|
||||
"agent_name": agent_node["name"],
|
||||
}
|
||||
|
||||
except Exception as e: # noqa: BLE001
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Failed to send message to agent: {e}",
|
||||
"agent_id": agent_id,
|
||||
}
|
||||
|
||||
|
||||
@register_tool(sandbox_execution=False)
|
||||
def wait_for_message(
|
||||
agent_state: Any,
|
||||
reason: str = "Waiting for messages from other agents or user input",
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
agent_id = agent_state.agent_id
|
||||
agent_name = agent_state.agent_name
|
||||
|
||||
agent_state.enter_waiting_state()
|
||||
|
||||
if agent_id in _agent_graph["nodes"]:
|
||||
_agent_graph["nodes"][agent_id]["status"] = "waiting"
|
||||
_agent_graph["nodes"][agent_id]["waiting_reason"] = reason
|
||||
|
||||
try:
|
||||
from strix.cli.tracer import get_global_tracer
|
||||
|
||||
tracer = get_global_tracer()
|
||||
if tracer:
|
||||
tracer.update_agent_status(agent_id, "waiting")
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
except Exception as e: # noqa: BLE001
|
||||
return {"success": False, "error": f"Failed to enter waiting state: {e}", "status": "error"}
|
||||
else:
|
||||
return {
|
||||
"success": True,
|
||||
"status": "waiting",
|
||||
"message": f"Agent '{agent_name}' is now waiting for messages",
|
||||
"reason": reason,
|
||||
"agent_info": {
|
||||
"id": agent_id,
|
||||
"name": agent_name,
|
||||
"status": "waiting",
|
||||
},
|
||||
"resume_conditions": [
|
||||
"Message from another agent",
|
||||
"Message from user",
|
||||
"Direct communication",
|
||||
],
|
||||
}
|
||||
223
strix/tools/agents_graph/agents_graph_actions_schema.xml
Normal file
223
strix/tools/agents_graph/agents_graph_actions_schema.xml
Normal file
@@ -0,0 +1,223 @@
|
||||
<tools>
|
||||
<tool name="agent_finish">
|
||||
<description>Mark a subagent's task as completed and optionally report results to parent agent.
|
||||
|
||||
IMPORTANT: This tool can ONLY be used by subagents (agents with a parent).
|
||||
Root/main agents must use finish_scan instead.
|
||||
|
||||
This tool should be called when a subagent completes its assigned subtask to:
|
||||
- Mark the subagent's task as completed
|
||||
- Report findings back to the parent agent
|
||||
|
||||
Use this tool when:
|
||||
- You are a subagent working on a specific subtask
|
||||
- You have completed your assigned task
|
||||
- You want to report your findings to the parent agent
|
||||
- You are ready to terminate this subagent's execution</description>
|
||||
<details>This replaces the previous finish_scan tool and handles both sub-agent completion
|
||||
and main agent completion. When a sub-agent finishes, it can report its findings
|
||||
back to the parent agent for coordination.</details>
|
||||
<parameters>
|
||||
<parameter name="result_summary" type="string" required="true">
|
||||
<description>Summary of what the agent accomplished and discovered</description>
|
||||
</parameter>
|
||||
<parameter name="findings" type="string" required="false">
|
||||
<description>List of specific findings, vulnerabilities, or discoveries</description>
|
||||
</parameter>
|
||||
<parameter name="success" type="boolean" required="false">
|
||||
<description>Whether the agent's task completed successfully</description>
|
||||
</parameter>
|
||||
<parameter name="report_to_parent" type="boolean" required="false">
|
||||
<description>Whether to send results back to the parent agent</description>
|
||||
</parameter>
|
||||
<parameter name="final_recommendations" type="string" required="false">
|
||||
<description>Recommendations for next steps or follow-up actions</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
<returns type="Dict[str, Any]">
|
||||
<description>Response containing: - agent_completed: Whether the agent was marked as completed - parent_notified: Whether parent was notified (if applicable) - completion_summary: Summary of completion status</description>
|
||||
</returns>
|
||||
<examples>
|
||||
# Sub-agent completing subdomain enumeration task
|
||||
<function=agent_finish>
|
||||
<parameter=result_summary>Completed comprehensive subdomain enumeration for target.com.
|
||||
Discovered 47 subdomains including several interesting ones with admin/dev
|
||||
in the name. Found 3 subdomains with exposed services on non-standard
|
||||
ports.</parameter>
|
||||
<parameter=findings>["admin.target.com - exposed phpMyAdmin",
|
||||
"dev-api.target.com - unauth API endpoints",
|
||||
"staging.target.com - directory listing enabled",
|
||||
"mail.target.com - POP3/IMAP services"]</parameter>
|
||||
<parameter=success>true</parameter>
|
||||
<parameter=report_to_parent>true</parameter>
|
||||
<parameter=final_recommendations>["Prioritize testing admin.target.com for default creds",
|
||||
"Enumerate dev-api.target.com API endpoints",
|
||||
"Check staging.target.com for sensitive files"]</parameter>
|
||||
</function>
|
||||
</examples>
|
||||
</tool>
|
||||
<tool name="create_agent">
|
||||
<description>Create and spawn a new agent to handle a specific subtask.
|
||||
|
||||
MANDATORY REQUIREMENT: You MUST call view_agent_graph FIRST before creating any new agent to check if there is already an agent working on the same or similar task. Only create a new agent if no existing agent is handling the specific task.</description>
|
||||
<details>The new agent inherits the parent's conversation history and context up to the point
|
||||
of creation, then continues with its assigned subtask. This enables decomposition
|
||||
of complex penetration testing tasks into specialized sub-agents.
|
||||
|
||||
The agent runs asynchronously and independently, allowing the parent to continue
|
||||
immediately while the new agent executes its task in the background.
|
||||
|
||||
CRITICAL: Before calling this tool, you MUST first use view_agent_graph to:
|
||||
- Examine all existing agents and their current tasks
|
||||
- Verify no agent is already working on the same or similar objective
|
||||
- Avoid duplication of effort and resource waste
|
||||
- Ensure efficient coordination across the multi-agent system
|
||||
|
||||
If you as a parent agent don't absolutely have anything to do while your subagents are running, you can use wait_for_message tool. The subagent will continue to run in the background, and update you when it's done.
|
||||
</details>
|
||||
<parameters>
|
||||
<parameter name="task" type="string" required="true">
|
||||
<description>The specific task/objective for the new agent to accomplish</description>
|
||||
</parameter>
|
||||
<parameter name="name" type="string" required="true">
|
||||
<description>Human-readable name for the agent (for tracking purposes)</description>
|
||||
</parameter>
|
||||
<parameter name="inherit_context" type="boolean" required="false">
|
||||
<description>Whether the new agent should inherit parent's conversation history and context</description>
|
||||
</parameter>
|
||||
<parameter name="prompt_modules" type="string" required="false">
|
||||
<description>Comma-separated list of prompt modules to use for the agent. Most agents should have at least one module in order to be useful. {{DYNAMIC_MODULES_DESCRIPTION}}</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
<returns type="Dict[str, Any]">
|
||||
<description>Response containing: - agent_id: Unique identifier for the created agent - success: Whether the agent was created successfully - message: Status message - agent_info: Details about the created agent</description>
|
||||
</returns>
|
||||
<examples>
|
||||
# REQUIRED: First check agent graph before creating any new agent
|
||||
<function=view_agent_graph>
|
||||
</function>
|
||||
# REQUIRED: Check agent graph again before creating another agent
|
||||
<function=view_agent_graph>
|
||||
</function>
|
||||
|
||||
# After confirming no SQL testing agent exists, create agent for vulnerability validation
|
||||
<function=create_agent>
|
||||
<parameter=task>Validate and exploit the suspected SQL injection vulnerability found in
|
||||
the login form. Confirm exploitability and document proof of concept.</parameter>
|
||||
<parameter=name>SQLi Validator</parameter>
|
||||
<parameter=prompt_modules>sql_injection</parameter>
|
||||
</function>
|
||||
|
||||
# Create specialized authentication testing agent with multiple modules (comma-separated)
|
||||
<function=create_agent>
|
||||
<parameter=task>Test authentication mechanisms, JWT implementation, and session management
|
||||
for security vulnerabilities and bypass techniques.</parameter>
|
||||
<parameter=name>Auth Specialist</parameter>
|
||||
<parameter=prompt_modules>authentication_jwt, business_logic</parameter>
|
||||
</function>
|
||||
</examples>
|
||||
</tool>
|
||||
<tool name="send_message_to_agent">
|
||||
<description>Send a message to another agent in the graph for coordination and communication.</description>
|
||||
<details>This enables agents to communicate with each other during execution for:
|
||||
- Sharing discovered information or findings
|
||||
- Asking questions or requesting assistance
|
||||
- Providing instructions or coordination
|
||||
- Reporting status or results</details>
|
||||
<parameters>
|
||||
<parameter name="target_agent_id" type="string" required="true">
|
||||
<description>ID of the agent to send the message to</description>
|
||||
</parameter>
|
||||
<parameter name="message" type="string" required="true">
|
||||
<description>The message content to send</description>
|
||||
</parameter>
|
||||
<parameter name="message_type" type="string" required="false">
|
||||
<description>Type of message being sent: - "query": Question requiring a response - "instruction": Command or directive for the target agent - "information": Informational message (findings, status, etc.)</description>
|
||||
</parameter>
|
||||
<parameter name="priority" type="string" required="false">
|
||||
<description>Priority level of the message</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
<returns type="Dict[str, Any]">
|
||||
<description>Response containing: - success: Whether the message was sent successfully - message_id: Unique identifier for the message - delivery_status: Status of message delivery</description>
|
||||
</returns>
|
||||
<examples>
|
||||
# Share discovered vulnerability information
|
||||
<function=send_message_to_agent>
|
||||
<parameter=target_agent_id>agent_abc123</parameter>
|
||||
<parameter=message>Found SQL injection vulnerability in /login.php parameter 'username'.
|
||||
Payload: admin' OR '1'='1' -- successfully bypassed authentication.
|
||||
You should focus your testing on the authenticated areas of the
|
||||
application.</parameter>
|
||||
<parameter=message_type>information</parameter>
|
||||
<parameter=priority>high</parameter>
|
||||
</function>
|
||||
|
||||
# Request assistance from specialist agent
|
||||
<function=send_message_to_agent>
|
||||
<parameter=target_agent_id>agent_def456</parameter>
|
||||
<parameter=message>I've identified what appears to be a custom encryption implementation
|
||||
in the API responses. Can you analyze the cryptographic strength and look
|
||||
for potential weaknesses?</parameter>
|
||||
<parameter=message_type>query</parameter>
|
||||
<parameter=priority>normal</parameter>
|
||||
</function>
|
||||
</examples>
|
||||
</tool>
|
||||
<tool name="view_agent_graph">
|
||||
<description>View the current agent graph showing all agents, their relationships, and status.</description>
|
||||
<details>This provides a comprehensive overview of the multi-agent system including:
|
||||
- All agent nodes with their tasks, status, and metadata
|
||||
- Parent-child relationships between agents
|
||||
- Message communication patterns
|
||||
- Current execution state</details>
|
||||
<returns type="Dict[str, Any]">
|
||||
<description>Response containing: - graph_structure: Human-readable representation of the agent graph - summary: High-level statistics about the graph</description>
|
||||
</returns>
|
||||
</tool>
|
||||
<tool name="wait_for_message">
|
||||
<description>Pause the agent loop indefinitely until receiving a message from another agent or user.
|
||||
|
||||
This tool puts the agent into a waiting state where it remains idle until it receives any form of communication. The agent will automatically resume execution when a message arrives.
|
||||
|
||||
IMPORTANT: This tool causes the agent to stop all activity until a message is received. Use it when you need to:
|
||||
- Wait for subagent completion reports
|
||||
- Coordinate with other agents before proceeding
|
||||
- Pause for user input or decisions
|
||||
- Synchronize multi-agent workflows
|
||||
|
||||
NOTE: If you are waiting for an agent that is NOT your subagent, you first tell it to message you with updates before waiting for it. Otherwise, you will wait forever!
|
||||
</description>
|
||||
<details>When this tool is called, the agent enters a waiting state and will not continue execution until:
|
||||
- Another agent sends it a message via send_message_to_agent
|
||||
- A user sends it a direct message through the CLI
|
||||
- Any other form of inter-agent or user communication occurs
|
||||
|
||||
The agent will automatically resume from where it left off once a message is received.
|
||||
This is particularly useful for parent agents waiting for subagent results or for coordination points in multi-agent workflows.</details>
|
||||
<parameters>
|
||||
<parameter name="reason" type="string" required="false">
|
||||
<description>Explanation for why the agent is waiting (for logging and monitoring purposes)</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
<returns type="Dict[str, Any]">
|
||||
<description>Response containing: - success: Whether the agent successfully entered waiting state - status: Current agent status ("waiting") - reason: The reason for waiting - agent_info: Details about the waiting agent - resume_conditions: List of conditions that will resume the agent</description>
|
||||
</returns>
|
||||
<examples>
|
||||
# Wait for subagents to complete their tasks
|
||||
<function=wait_for_message>
|
||||
<parameter=reason>Waiting for subdomain enumeration and port scanning subagents to complete their tasks and report findings</parameter>
|
||||
</function>
|
||||
|
||||
# Wait for user input on next steps
|
||||
<function=wait_for_message>
|
||||
<parameter=reason>Waiting for user decision on whether to proceed with exploitation of discovered SQL injection vulnerability</parameter>
|
||||
</function>
|
||||
|
||||
# Coordinate with other agents
|
||||
<function=wait_for_message>
|
||||
<parameter=reason>Waiting for vulnerability assessment agent to share discovered attack vectors before proceeding with exploitation phase</parameter>
|
||||
</function>
|
||||
</examples>
|
||||
</tool>
|
||||
</tools>
|
||||
120
strix/tools/argument_parser.py
Normal file
120
strix/tools/argument_parser.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import contextlib
|
||||
import inspect
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Union, get_args, get_origin
|
||||
|
||||
|
||||
class ArgumentConversionError(Exception):
|
||||
def __init__(self, message: str, param_name: str | None = None) -> None:
|
||||
self.param_name = param_name
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def convert_arguments(func: Callable[..., Any], kwargs: dict[str, Any]) -> dict[str, Any]:
|
||||
try:
|
||||
sig = inspect.signature(func)
|
||||
converted = {}
|
||||
|
||||
for param_name, value in kwargs.items():
|
||||
if param_name not in sig.parameters:
|
||||
converted[param_name] = value
|
||||
continue
|
||||
|
||||
param = sig.parameters[param_name]
|
||||
param_type = param.annotation
|
||||
|
||||
if param_type == inspect.Parameter.empty or value is None:
|
||||
converted[param_name] = value
|
||||
continue
|
||||
|
||||
if not isinstance(value, str):
|
||||
converted[param_name] = value
|
||||
continue
|
||||
|
||||
try:
|
||||
converted[param_name] = convert_string_to_type(value, param_type)
|
||||
except (ValueError, TypeError, json.JSONDecodeError) as e:
|
||||
raise ArgumentConversionError(
|
||||
f"Failed to convert argument '{param_name}' to type {param_type}: {e}",
|
||||
param_name=param_name,
|
||||
) from e
|
||||
|
||||
except (ValueError, TypeError, AttributeError) as e:
|
||||
raise ArgumentConversionError(f"Failed to process function arguments: {e}") from e
|
||||
|
||||
return converted
|
||||
|
||||
|
||||
def convert_string_to_type(value: str, param_type: Any) -> Any:
|
||||
origin = get_origin(param_type)
|
||||
if origin is Union or origin is type(str | None):
|
||||
args = get_args(param_type)
|
||||
for arg_type in args:
|
||||
if arg_type is not type(None):
|
||||
with contextlib.suppress(ValueError, TypeError, json.JSONDecodeError):
|
||||
return convert_string_to_type(value, arg_type)
|
||||
return value
|
||||
|
||||
if hasattr(param_type, "__args__"):
|
||||
args = getattr(param_type, "__args__", ())
|
||||
if len(args) == 2 and type(None) in args:
|
||||
non_none_type = args[0] if args[1] is type(None) else args[1]
|
||||
with contextlib.suppress(ValueError, TypeError, json.JSONDecodeError):
|
||||
return convert_string_to_type(value, non_none_type)
|
||||
return value
|
||||
|
||||
return _convert_basic_types(value, param_type, origin)
|
||||
|
||||
|
||||
def _convert_basic_types(value: str, param_type: Any, origin: Any = None) -> Any:
|
||||
basic_type_converters: dict[Any, Callable[[str], Any]] = {
|
||||
int: int,
|
||||
float: float,
|
||||
bool: _convert_to_bool,
|
||||
str: str,
|
||||
}
|
||||
|
||||
if param_type in basic_type_converters:
|
||||
return basic_type_converters[param_type](value)
|
||||
|
||||
if list in (origin, param_type):
|
||||
return _convert_to_list(value)
|
||||
if dict in (origin, param_type):
|
||||
return _convert_to_dict(value)
|
||||
|
||||
with contextlib.suppress(json.JSONDecodeError):
|
||||
return json.loads(value)
|
||||
return value
|
||||
|
||||
|
||||
def _convert_to_bool(value: str) -> bool:
|
||||
if value.lower() in ("true", "1", "yes", "on"):
|
||||
return True
|
||||
if value.lower() in ("false", "0", "no", "off"):
|
||||
return False
|
||||
return bool(value)
|
||||
|
||||
|
||||
def _convert_to_list(value: str) -> list[Any]:
|
||||
try:
|
||||
parsed = json.loads(value)
|
||||
if isinstance(parsed, list):
|
||||
return parsed
|
||||
except json.JSONDecodeError:
|
||||
if "," in value:
|
||||
return [item.strip() for item in value.split(",")]
|
||||
return [value]
|
||||
else:
|
||||
return [parsed]
|
||||
|
||||
|
||||
def _convert_to_dict(value: str) -> dict[str, Any]:
|
||||
try:
|
||||
parsed = json.loads(value)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
else:
|
||||
return {}
|
||||
4
strix/tools/browser/__init__.py
Normal file
4
strix/tools/browser/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .browser_actions import browser_action
|
||||
|
||||
|
||||
__all__ = ["browser_action"]
|
||||
236
strix/tools/browser/browser_actions.py
Normal file
236
strix/tools/browser/browser_actions.py
Normal file
@@ -0,0 +1,236 @@
|
||||
from typing import Any, Literal, NoReturn
|
||||
|
||||
from strix.tools.registry import register_tool
|
||||
|
||||
from .tab_manager import BrowserTabManager, get_browser_tab_manager
|
||||
|
||||
|
||||
BrowserAction = Literal[
|
||||
"launch",
|
||||
"goto",
|
||||
"click",
|
||||
"type",
|
||||
"scroll_down",
|
||||
"scroll_up",
|
||||
"back",
|
||||
"forward",
|
||||
"new_tab",
|
||||
"switch_tab",
|
||||
"close_tab",
|
||||
"wait",
|
||||
"execute_js",
|
||||
"double_click",
|
||||
"hover",
|
||||
"press_key",
|
||||
"save_pdf",
|
||||
"get_console_logs",
|
||||
"view_source",
|
||||
"close",
|
||||
"list_tabs",
|
||||
]
|
||||
|
||||
|
||||
def _validate_url(action_name: str, url: str | None) -> None:
|
||||
if not url:
|
||||
raise ValueError(f"url parameter is required for {action_name} action")
|
||||
|
||||
|
||||
def _validate_coordinate(action_name: str, coordinate: str | None) -> None:
|
||||
if not coordinate:
|
||||
raise ValueError(f"coordinate parameter is required for {action_name} action")
|
||||
|
||||
|
||||
def _validate_text(action_name: str, text: str | None) -> None:
|
||||
if not text:
|
||||
raise ValueError(f"text parameter is required for {action_name} action")
|
||||
|
||||
|
||||
def _validate_tab_id(action_name: str, tab_id: str | None) -> None:
|
||||
if not tab_id:
|
||||
raise ValueError(f"tab_id parameter is required for {action_name} action")
|
||||
|
||||
|
||||
def _validate_js_code(action_name: str, js_code: str | None) -> None:
|
||||
if not js_code:
|
||||
raise ValueError(f"js_code parameter is required for {action_name} action")
|
||||
|
||||
|
||||
def _validate_duration(action_name: str, duration: float | None) -> None:
|
||||
if duration is None:
|
||||
raise ValueError(f"duration parameter is required for {action_name} action")
|
||||
|
||||
|
||||
def _validate_key(action_name: str, key: str | None) -> None:
|
||||
if not key:
|
||||
raise ValueError(f"key parameter is required for {action_name} action")
|
||||
|
||||
|
||||
def _validate_file_path(action_name: str, file_path: str | None) -> None:
|
||||
if not file_path:
|
||||
raise ValueError(f"file_path parameter is required for {action_name} action")
|
||||
|
||||
|
||||
def _handle_navigation_actions(
|
||||
manager: BrowserTabManager,
|
||||
action: str,
|
||||
url: str | None = None,
|
||||
tab_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
if action == "launch":
|
||||
return manager.launch_browser(url)
|
||||
if action == "goto":
|
||||
_validate_url(action, url)
|
||||
assert url is not None
|
||||
return manager.goto_url(url, tab_id)
|
||||
if action == "back":
|
||||
return manager.back(tab_id)
|
||||
if action == "forward":
|
||||
return manager.forward(tab_id)
|
||||
raise ValueError(f"Unknown navigation action: {action}")
|
||||
|
||||
|
||||
def _handle_interaction_actions(
|
||||
manager: BrowserTabManager,
|
||||
action: str,
|
||||
coordinate: str | None = None,
|
||||
text: str | None = None,
|
||||
key: str | None = None,
|
||||
tab_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
if action in {"click", "double_click", "hover"}:
|
||||
_validate_coordinate(action, coordinate)
|
||||
assert coordinate is not None
|
||||
action_map = {
|
||||
"click": manager.click,
|
||||
"double_click": manager.double_click,
|
||||
"hover": manager.hover,
|
||||
}
|
||||
return action_map[action](coordinate, tab_id)
|
||||
|
||||
if action in {"scroll_down", "scroll_up"}:
|
||||
direction = "down" if action == "scroll_down" else "up"
|
||||
return manager.scroll(direction, tab_id)
|
||||
|
||||
if action == "type":
|
||||
_validate_text(action, text)
|
||||
assert text is not None
|
||||
return manager.type_text(text, tab_id)
|
||||
if action == "press_key":
|
||||
_validate_key(action, key)
|
||||
assert key is not None
|
||||
return manager.press_key(key, tab_id)
|
||||
|
||||
raise ValueError(f"Unknown interaction action: {action}")
|
||||
|
||||
|
||||
def _raise_unknown_action(action: str) -> NoReturn:
|
||||
raise ValueError(f"Unknown action: {action}")
|
||||
|
||||
|
||||
def _handle_tab_actions(
|
||||
manager: BrowserTabManager,
|
||||
action: str,
|
||||
url: str | None = None,
|
||||
tab_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
if action == "new_tab":
|
||||
return manager.new_tab(url)
|
||||
if action == "switch_tab":
|
||||
_validate_tab_id(action, tab_id)
|
||||
assert tab_id is not None
|
||||
return manager.switch_tab(tab_id)
|
||||
if action == "close_tab":
|
||||
_validate_tab_id(action, tab_id)
|
||||
assert tab_id is not None
|
||||
return manager.close_tab(tab_id)
|
||||
if action == "list_tabs":
|
||||
return manager.list_tabs()
|
||||
raise ValueError(f"Unknown tab action: {action}")
|
||||
|
||||
|
||||
def _handle_utility_actions(
|
||||
manager: BrowserTabManager,
|
||||
action: str,
|
||||
duration: float | None = None,
|
||||
js_code: str | None = None,
|
||||
file_path: str | None = None,
|
||||
tab_id: str | None = None,
|
||||
clear: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
if action == "wait":
|
||||
_validate_duration(action, duration)
|
||||
assert duration is not None
|
||||
return manager.wait_browser(duration, tab_id)
|
||||
if action == "execute_js":
|
||||
_validate_js_code(action, js_code)
|
||||
assert js_code is not None
|
||||
return manager.execute_js(js_code, tab_id)
|
||||
if action == "save_pdf":
|
||||
_validate_file_path(action, file_path)
|
||||
assert file_path is not None
|
||||
return manager.save_pdf(file_path, tab_id)
|
||||
if action == "get_console_logs":
|
||||
return manager.get_console_logs(tab_id, clear)
|
||||
if action == "view_source":
|
||||
return manager.view_source(tab_id)
|
||||
if action == "close":
|
||||
return manager.close_browser()
|
||||
raise ValueError(f"Unknown utility action: {action}")
|
||||
|
||||
|
||||
@register_tool
|
||||
def browser_action(
|
||||
action: BrowserAction,
|
||||
url: str | None = None,
|
||||
coordinate: str | None = None,
|
||||
text: str | None = None,
|
||||
tab_id: str | None = None,
|
||||
js_code: str | None = None,
|
||||
duration: float | None = None,
|
||||
key: str | None = None,
|
||||
file_path: str | None = None,
|
||||
clear: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
manager = get_browser_tab_manager()
|
||||
|
||||
try:
|
||||
navigation_actions = {"launch", "goto", "back", "forward"}
|
||||
interaction_actions = {
|
||||
"click",
|
||||
"type",
|
||||
"double_click",
|
||||
"hover",
|
||||
"press_key",
|
||||
"scroll_down",
|
||||
"scroll_up",
|
||||
}
|
||||
tab_actions = {"new_tab", "switch_tab", "close_tab", "list_tabs"}
|
||||
utility_actions = {
|
||||
"wait",
|
||||
"execute_js",
|
||||
"save_pdf",
|
||||
"get_console_logs",
|
||||
"view_source",
|
||||
"close",
|
||||
}
|
||||
|
||||
if action in navigation_actions:
|
||||
return _handle_navigation_actions(manager, action, url, tab_id)
|
||||
if action in interaction_actions:
|
||||
return _handle_interaction_actions(manager, action, coordinate, text, key, tab_id)
|
||||
if action in tab_actions:
|
||||
return _handle_tab_actions(manager, action, url, tab_id)
|
||||
if action in utility_actions:
|
||||
return _handle_utility_actions(
|
||||
manager, action, duration, js_code, file_path, tab_id, clear
|
||||
)
|
||||
|
||||
_raise_unknown_action(action)
|
||||
|
||||
except (ValueError, RuntimeError) as e:
|
||||
return {
|
||||
"error": str(e),
|
||||
"tab_id": tab_id,
|
||||
"screenshot": "",
|
||||
"is_running": False,
|
||||
}
|
||||
183
strix/tools/browser/browser_actions_schema.xml
Normal file
183
strix/tools/browser/browser_actions_schema.xml
Normal file
@@ -0,0 +1,183 @@
|
||||
<?xml version="1.0" ?>
|
||||
<tools>
|
||||
<tool name="browser_action">
|
||||
<description>Perform browser actions using a Playwright-controlled browser with multiple tabs.
|
||||
The browser is PERSISTENT and remains active until explicitly closed, allowing for
|
||||
multi-step workflows and long-running processes across multiple tabs.</description>
|
||||
<parameters>
|
||||
<parameter name="action" type="string" required="true">
|
||||
</parameter>
|
||||
<parameter name="url" type="string" required="false">
|
||||
<description>Required for 'launch', 'goto', and optionally for 'new_tab' actions. The URL to launch the browser at, navigate to, or load in new tab. Must include appropriate protocol (e.g., http://, https://, file://).</description>
|
||||
</parameter>
|
||||
<parameter name="coordinate" type="string" required="false">
|
||||
<description>Required for 'click', 'double_click', and 'hover' actions. Format: "x,y" (e.g., "432,321"). Coordinates should target the center of elements (buttons, links, etc.). Must be within the browser viewport resolution. Be very careful to calculate the coordinates correctly based on the previous screenshot.</description>
|
||||
</parameter>
|
||||
<parameter name="text" type="string" required="false">
|
||||
<description>Required for 'type' action. The text to type in the field.</description>
|
||||
</parameter>
|
||||
<parameter name="tab_id" type="string" required="false">
|
||||
<description>Required for 'switch_tab' and 'close_tab' actions. Optional for other actions to specify which tab to operate on. The ID of the tab to operate on. The first tab created during 'launch' has ID "tab_1". If not provided, actions will operate on the currently active tab.</description>
|
||||
</parameter>
|
||||
<parameter name="js_code" type="string" required="false">
|
||||
<description>Required for 'execute_js' action. JavaScript code to execute in the page context. The code runs in the context of the current page and has access to the DOM and all page-defined variables and functions. The last evaluated expression's value is returned in the response.</description>
|
||||
</parameter>
|
||||
<parameter name="duration" type="string" required="false">
|
||||
<description>Required for 'wait' action. Number of seconds to pause execution. Can be fractional (e.g., 0.5 for half a second).</description>
|
||||
</parameter>
|
||||
<parameter name="key" type="string" required="false">
|
||||
<description>Required for 'press_key' action. The key to press. Valid values include: - Single characters: 'a'-'z', 'A'-'Z', '0'-'9' - Special keys: 'Enter', 'Escape', 'ArrowLeft', 'ArrowRight', etc. - Modifier keys: 'Shift', 'Control', 'Alt', 'Meta' - Function keys: 'F1'-'F12'</description>
|
||||
</parameter>
|
||||
<parameter name="file_path" type="string" required="false">
|
||||
<description>Required for 'save_pdf' action. The file path where to save the PDF.</description>
|
||||
</parameter>
|
||||
<parameter name="clear" type="boolean" required="false">
|
||||
<description>For 'get_console_logs' action: whether to clear console logs after retrieving them. Default is False (keep logs).</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
<returns type="Dict[str, Any]">
|
||||
<description>Response containing: - screenshot: Base64 encoded PNG of the current page state - url: Current page URL - title: Current page title - viewport: Current browser viewport dimensions - tab_id: ID of the current active tab - all_tabs: Dict of all open tab IDs and their URLs - message: Status message about the action performed - js_result: Result of JavaScript execution (for execute_js action) - pdf_saved: File path of saved PDF (for save_pdf action) - console_logs: Array of console messages (for get_console_logs action) Limited to 50KB total and 200 most recent logs. Individual messages truncated at 1KB. - page_source: HTML source code (for view_source action) Large pages are truncated to 100KB (keeping beginning and end sections).</description>
|
||||
</returns>
|
||||
<notes>
|
||||
Important usage rules:
|
||||
1. PERSISTENCE: The browser remains active and maintains its state until
|
||||
explicitly closed with the 'close' action. This allows for multi-step workflows
|
||||
across multiple tool calls and tabs.
|
||||
2. Browser interaction MUST start with 'launch' and end with 'close'.
|
||||
3. Only one action can be performed per call.
|
||||
4. To visit a new URL not reachable from current page, either:
|
||||
- Use 'goto' action
|
||||
- Open a new tab with the URL
|
||||
- Close browser and relaunch
|
||||
5. Click coordinates must be derived from the most recent screenshot.
|
||||
6. You MUST click on the center of the element, not the edge. You MUST calculate
|
||||
the coordinates correctly based on the previous screenshot, otherwise the click
|
||||
will fail. After clicking, check the new screenshot to verify the click was
|
||||
successful.
|
||||
7. Tab management:
|
||||
- First tab from 'launch' is "tab_1"
|
||||
- New tabs are numbered sequentially ("tab_2", "tab_3", etc.)
|
||||
- Must have at least one tab open at all times
|
||||
- Actions affect the currently active tab unless tab_id is specified
|
||||
8. JavaScript execution (following Playwright evaluation patterns):
|
||||
- Code runs in the browser page context, not the tool context
|
||||
- Has access to DOM (document, window, etc.) and page variables/functions
|
||||
- The LAST EVALUATED EXPRESSION is automatically returned - no return statement needed
|
||||
- For simple values: document.title (returns the title)
|
||||
- For objects: {title: document.title, url: location.href} (returns the object)
|
||||
- For async operations: Use await and the promise result will be returned
|
||||
- AVOID explicit return statements - they can break evaluation
|
||||
- object literals must be wrapped in paranthesis when they are the final expression
|
||||
- Variables from tool context are NOT available - pass data as parameters if needed
|
||||
- Examples of correct patterns:
|
||||
* Single value: document.querySelectorAll('img').length
|
||||
* Object result: {images: document.images.length, links: document.links.length}
|
||||
* Async operation: await fetch(location.href).then(r => r.status)
|
||||
* DOM manipulation: document.body.style.backgroundColor = 'red'; 'background changed'
|
||||
|
||||
9. Wait action:
|
||||
- Time is specified in seconds
|
||||
- Can be used to wait for page loads, animations, etc.
|
||||
- Can be fractional (e.g., 0.5 seconds)
|
||||
- Screenshot is captured after the wait
|
||||
10. The browser can operate concurrently with other tools. You may invoke
|
||||
terminal, python, or other tools (in separate assistant messages) while maintaining
|
||||
the active browser session, enabling sophisticated multi-tool workflows.
|
||||
11. Keyboard actions:
|
||||
- Use press_key for individual key presses
|
||||
- Use type for typing regular text
|
||||
- Some keys have special names based on Playwright's key documentation
|
||||
12. All code in the js_code parameter is executed as-is - there's no need to
|
||||
escape special characters or worry about formatting. Just write your JavaScript
|
||||
code normally. It can be single line or multi-line.
|
||||
13. For form filling, click on the field first, then use 'type' to enter text.
|
||||
14. The browser runs in headless mode using Chrome engine for security and performance.
|
||||
</notes>
|
||||
<examples>
|
||||
# Launch browser at URL (creates tab_1)
|
||||
<function=browser_action>
|
||||
<parameter=action>launch</parameter>
|
||||
<parameter=url>https://example.com</parameter>
|
||||
</function>
|
||||
|
||||
# Navigate to different URL
|
||||
<function=browser_action>
|
||||
<parameter=action>goto</parameter>
|
||||
<parameter=url>https://github.com</parameter>
|
||||
</function>
|
||||
|
||||
# Open new tab with different URL
|
||||
<function=browser_action>
|
||||
<parameter=action>new_tab</parameter>
|
||||
<parameter=url>https://another-site.com</parameter>
|
||||
</function>
|
||||
|
||||
# Wait for page load
|
||||
<function=browser_action>
|
||||
<parameter=action>wait</parameter>
|
||||
<parameter=duration>2.5</parameter>
|
||||
</function>
|
||||
|
||||
# Click login button at coordinates from screenshot
|
||||
<function=browser_action>
|
||||
<parameter=action>click</parameter>
|
||||
<parameter=coordinate>450,300</parameter>
|
||||
</function>
|
||||
|
||||
# Click username field and type
|
||||
<function=browser_action>
|
||||
<parameter=action>click</parameter>
|
||||
<parameter=coordinate>400,200</parameter>
|
||||
</function>
|
||||
|
||||
<function=browser_action>
|
||||
<parameter=action>type</parameter>
|
||||
<parameter=text>user@example.com</parameter>
|
||||
</function>
|
||||
|
||||
# Click password field and type
|
||||
<function=browser_action>
|
||||
<parameter=action>click</parameter>
|
||||
<parameter=coordinate>400,250</parameter>
|
||||
</function>
|
||||
|
||||
<function=browser_action>
|
||||
<parameter=action>type</parameter>
|
||||
<parameter=text>mypassword123</parameter>
|
||||
</function>
|
||||
|
||||
# Press Enter key
|
||||
<function=browser_action>
|
||||
<parameter=action>press_key</parameter>
|
||||
<parameter=key>Enter</parameter>
|
||||
</function>
|
||||
|
||||
# Execute JavaScript to get page stats (correct pattern - no return statement)
|
||||
<function=browser_action>
|
||||
<parameter=action>execute_js</parameter>
|
||||
<parameter=js_code>const images = document.querySelectorAll('img');
|
||||
const links = document.querySelectorAll('a');
|
||||
{
|
||||
images: images.length,
|
||||
links: links.length,
|
||||
title: document.title
|
||||
}</parameter>
|
||||
</function>
|
||||
|
||||
# Scroll down
|
||||
<function=browser_action>
|
||||
<parameter=action>scroll_down</parameter>
|
||||
</function>
|
||||
|
||||
# Get console logs
|
||||
<function=browser_action>
|
||||
<parameter=action>get_console_logs</parameter>
|
||||
</function>
|
||||
|
||||
# View page source
|
||||
<function=browser_action>
|
||||
<parameter=action>view_source</parameter>
|
||||
</function>
|
||||
</examples>
|
||||
</tool>
|
||||
</tools>
|
||||
533
strix/tools/browser/browser_instance.py
Normal file
533
strix/tools/browser/browser_instance.py
Normal file
@@ -0,0 +1,533 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
from playwright.async_api import Browser, BrowserContext, Page, Playwright, async_playwright
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_PAGE_SOURCE_LENGTH = 20_000
|
||||
MAX_CONSOLE_LOG_LENGTH = 30_000
|
||||
MAX_INDIVIDUAL_LOG_LENGTH = 1_000
|
||||
MAX_CONSOLE_LOGS_COUNT = 200
|
||||
MAX_JS_RESULT_LENGTH = 5_000
|
||||
|
||||
|
||||
class BrowserInstance:
|
||||
def __init__(self) -> None:
|
||||
self.is_running = True
|
||||
self._execution_lock = threading.Lock()
|
||||
|
||||
self.playwright: Playwright | None = None
|
||||
self.browser: Browser | None = None
|
||||
self.context: BrowserContext | None = None
|
||||
self.pages: dict[str, Page] = {}
|
||||
self.current_page_id: str | None = None
|
||||
self._next_tab_id = 1
|
||||
|
||||
self.console_logs: dict[str, list[dict[str, Any]]] = {}
|
||||
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
self._loop_thread: threading.Thread | None = None
|
||||
|
||||
self._start_event_loop()
|
||||
|
||||
def _start_event_loop(self) -> None:
|
||||
def run_loop() -> None:
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
self._loop.run_forever()
|
||||
|
||||
self._loop_thread = threading.Thread(target=run_loop, daemon=True)
|
||||
self._loop_thread.start()
|
||||
|
||||
while self._loop is None:
|
||||
threading.Event().wait(0.01)
|
||||
|
||||
def _run_async(self, coro: Any) -> dict[str, Any]:
|
||||
if not self._loop or not self.is_running:
|
||||
raise RuntimeError("Browser instance is not running")
|
||||
|
||||
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
||||
return cast("dict[str, Any]", future.result(timeout=30)) # 30 second timeout
|
||||
|
||||
async def _setup_console_logging(self, page: Page, tab_id: str) -> None:
|
||||
self.console_logs[tab_id] = []
|
||||
|
||||
def handle_console(msg: Any) -> None:
|
||||
text = msg.text
|
||||
if len(text) > MAX_INDIVIDUAL_LOG_LENGTH:
|
||||
text = text[:MAX_INDIVIDUAL_LOG_LENGTH] + "... [TRUNCATED]"
|
||||
|
||||
log_entry = {
|
||||
"type": msg.type,
|
||||
"text": text,
|
||||
"location": msg.location,
|
||||
"timestamp": asyncio.get_event_loop().time(),
|
||||
}
|
||||
|
||||
self.console_logs[tab_id].append(log_entry)
|
||||
|
||||
if len(self.console_logs[tab_id]) > MAX_CONSOLE_LOGS_COUNT:
|
||||
self.console_logs[tab_id] = self.console_logs[tab_id][-MAX_CONSOLE_LOGS_COUNT:]
|
||||
|
||||
page.on("console", handle_console)
|
||||
|
||||
async def _launch_browser(self, url: str | None = None) -> dict[str, Any]:
|
||||
self.playwright = await async_playwright().start()
|
||||
|
||||
self.browser = await self.playwright.chromium.launch(
|
||||
headless=True,
|
||||
args=[
|
||||
"--no-sandbox",
|
||||
"--disable-dev-shm-usage",
|
||||
"--disable-gpu",
|
||||
"--disable-web-security",
|
||||
"--disable-features=VizDisplayCompositor",
|
||||
],
|
||||
)
|
||||
|
||||
self.context = await self.browser.new_context(
|
||||
viewport={"width": 1280, "height": 720},
|
||||
user_agent=(
|
||||
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 "
|
||||
"(KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
),
|
||||
)
|
||||
|
||||
page = await self.context.new_page()
|
||||
tab_id = f"tab_{self._next_tab_id}"
|
||||
self._next_tab_id += 1
|
||||
self.pages[tab_id] = page
|
||||
self.current_page_id = tab_id
|
||||
|
||||
await self._setup_console_logging(page, tab_id)
|
||||
|
||||
if url:
|
||||
await page.goto(url, wait_until="domcontentloaded")
|
||||
|
||||
return await self._get_page_state(tab_id)
|
||||
|
||||
async def _get_page_state(self, tab_id: str | None = None) -> dict[str, Any]:
|
||||
if not tab_id:
|
||||
tab_id = self.current_page_id
|
||||
|
||||
if not tab_id or tab_id not in self.pages:
|
||||
raise ValueError(f"Tab '{tab_id}' not found")
|
||||
|
||||
page = self.pages[tab_id]
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
screenshot_bytes = await page.screenshot(type="png", full_page=False)
|
||||
screenshot_b64 = base64.b64encode(screenshot_bytes).decode("utf-8")
|
||||
|
||||
url = page.url
|
||||
title = await page.title()
|
||||
viewport = page.viewport_size
|
||||
|
||||
all_tabs = {}
|
||||
for tid, tab_page in self.pages.items():
|
||||
all_tabs[tid] = {
|
||||
"url": tab_page.url,
|
||||
"title": await tab_page.title() if not tab_page.is_closed() else "Closed",
|
||||
}
|
||||
|
||||
return {
|
||||
"screenshot": screenshot_b64,
|
||||
"url": url,
|
||||
"title": title,
|
||||
"viewport": viewport,
|
||||
"tab_id": tab_id,
|
||||
"all_tabs": all_tabs,
|
||||
}
|
||||
|
||||
def launch(self, url: str | None = None) -> dict[str, Any]:
|
||||
with self._execution_lock:
|
||||
if self.browser is not None:
|
||||
raise ValueError("Browser is already launched")
|
||||
|
||||
return self._run_async(self._launch_browser(url))
|
||||
|
||||
def goto(self, url: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._execution_lock:
|
||||
return self._run_async(self._goto(url, tab_id))
|
||||
|
||||
async def _goto(self, url: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
if not tab_id:
|
||||
tab_id = self.current_page_id
|
||||
|
||||
if not tab_id or tab_id not in self.pages:
|
||||
raise ValueError(f"Tab '{tab_id}' not found")
|
||||
|
||||
page = self.pages[tab_id]
|
||||
await page.goto(url, wait_until="domcontentloaded")
|
||||
|
||||
return await self._get_page_state(tab_id)
|
||||
|
||||
def click(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._execution_lock:
|
||||
return self._run_async(self._click(coordinate, tab_id))
|
||||
|
||||
async def _click(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
if not tab_id:
|
||||
tab_id = self.current_page_id
|
||||
|
||||
if not tab_id or tab_id not in self.pages:
|
||||
raise ValueError(f"Tab '{tab_id}' not found")
|
||||
|
||||
try:
|
||||
x, y = map(int, coordinate.split(","))
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid coordinate format: {coordinate}. Use 'x,y'") from e
|
||||
|
||||
page = self.pages[tab_id]
|
||||
await page.mouse.click(x, y)
|
||||
|
||||
return await self._get_page_state(tab_id)
|
||||
|
||||
def type_text(self, text: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._execution_lock:
|
||||
return self._run_async(self._type_text(text, tab_id))
|
||||
|
||||
async def _type_text(self, text: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
if not tab_id:
|
||||
tab_id = self.current_page_id
|
||||
|
||||
if not tab_id or tab_id not in self.pages:
|
||||
raise ValueError(f"Tab '{tab_id}' not found")
|
||||
|
||||
page = self.pages[tab_id]
|
||||
await page.keyboard.type(text)
|
||||
|
||||
return await self._get_page_state(tab_id)
|
||||
|
||||
def scroll(self, direction: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._execution_lock:
|
||||
return self._run_async(self._scroll(direction, tab_id))
|
||||
|
||||
async def _scroll(self, direction: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
if not tab_id:
|
||||
tab_id = self.current_page_id
|
||||
|
||||
if not tab_id or tab_id not in self.pages:
|
||||
raise ValueError(f"Tab '{tab_id}' not found")
|
||||
|
||||
page = self.pages[tab_id]
|
||||
|
||||
if direction == "down":
|
||||
await page.keyboard.press("PageDown")
|
||||
elif direction == "up":
|
||||
await page.keyboard.press("PageUp")
|
||||
else:
|
||||
raise ValueError(f"Invalid scroll direction: {direction}")
|
||||
|
||||
return await self._get_page_state(tab_id)
|
||||
|
||||
def back(self, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._execution_lock:
|
||||
return self._run_async(self._back(tab_id))
|
||||
|
||||
async def _back(self, tab_id: str | None = None) -> dict[str, Any]:
|
||||
if not tab_id:
|
||||
tab_id = self.current_page_id
|
||||
|
||||
if not tab_id or tab_id not in self.pages:
|
||||
raise ValueError(f"Tab '{tab_id}' not found")
|
||||
|
||||
page = self.pages[tab_id]
|
||||
await page.go_back(wait_until="domcontentloaded")
|
||||
|
||||
return await self._get_page_state(tab_id)
|
||||
|
||||
def forward(self, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._execution_lock:
|
||||
return self._run_async(self._forward(tab_id))
|
||||
|
||||
async def _forward(self, tab_id: str | None = None) -> dict[str, Any]:
|
||||
if not tab_id:
|
||||
tab_id = self.current_page_id
|
||||
|
||||
if not tab_id or tab_id not in self.pages:
|
||||
raise ValueError(f"Tab '{tab_id}' not found")
|
||||
|
||||
page = self.pages[tab_id]
|
||||
await page.go_forward(wait_until="domcontentloaded")
|
||||
|
||||
return await self._get_page_state(tab_id)
|
||||
|
||||
def new_tab(self, url: str | None = None) -> dict[str, Any]:
|
||||
with self._execution_lock:
|
||||
return self._run_async(self._new_tab(url))
|
||||
|
||||
async def _new_tab(self, url: str | None = None) -> dict[str, Any]:
|
||||
if not self.context:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
page = await self.context.new_page()
|
||||
tab_id = f"tab_{self._next_tab_id}"
|
||||
self._next_tab_id += 1
|
||||
self.pages[tab_id] = page
|
||||
self.current_page_id = tab_id
|
||||
|
||||
await self._setup_console_logging(page, tab_id)
|
||||
|
||||
if url:
|
||||
await page.goto(url, wait_until="domcontentloaded")
|
||||
|
||||
return await self._get_page_state(tab_id)
|
||||
|
||||
def switch_tab(self, tab_id: str) -> dict[str, Any]:
|
||||
with self._execution_lock:
|
||||
return self._run_async(self._switch_tab(tab_id))
|
||||
|
||||
async def _switch_tab(self, tab_id: str) -> dict[str, Any]:
|
||||
if tab_id not in self.pages:
|
||||
raise ValueError(f"Tab '{tab_id}' not found")
|
||||
|
||||
self.current_page_id = tab_id
|
||||
return await self._get_page_state(tab_id)
|
||||
|
||||
def close_tab(self, tab_id: str) -> dict[str, Any]:
|
||||
with self._execution_lock:
|
||||
return self._run_async(self._close_tab(tab_id))
|
||||
|
||||
async def _close_tab(self, tab_id: str) -> dict[str, Any]:
|
||||
if tab_id not in self.pages:
|
||||
raise ValueError(f"Tab '{tab_id}' not found")
|
||||
|
||||
if len(self.pages) == 1:
|
||||
raise ValueError("Cannot close the last tab")
|
||||
|
||||
page = self.pages.pop(tab_id)
|
||||
await page.close()
|
||||
|
||||
if tab_id in self.console_logs:
|
||||
del self.console_logs[tab_id]
|
||||
|
||||
if self.current_page_id == tab_id:
|
||||
self.current_page_id = next(iter(self.pages.keys()))
|
||||
|
||||
return await self._get_page_state(self.current_page_id)
|
||||
|
||||
def wait(self, duration: float, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._execution_lock:
|
||||
return self._run_async(self._wait(duration, tab_id))
|
||||
|
||||
async def _wait(self, duration: float, tab_id: str | None = None) -> dict[str, Any]:
|
||||
await asyncio.sleep(duration)
|
||||
return await self._get_page_state(tab_id)
|
||||
|
||||
def execute_js(self, js_code: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._execution_lock:
|
||||
return self._run_async(self._execute_js(js_code, tab_id))
|
||||
|
||||
async def _execute_js(self, js_code: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
if not tab_id:
|
||||
tab_id = self.current_page_id
|
||||
|
||||
if not tab_id or tab_id not in self.pages:
|
||||
raise ValueError(f"Tab '{tab_id}' not found")
|
||||
|
||||
page = self.pages[tab_id]
|
||||
|
||||
try:
|
||||
result = await page.evaluate(js_code)
|
||||
except Exception as e: # noqa: BLE001
|
||||
result = {
|
||||
"error": True,
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e),
|
||||
}
|
||||
|
||||
result_str = str(result)
|
||||
if len(result_str) > MAX_JS_RESULT_LENGTH:
|
||||
result = result_str[:MAX_JS_RESULT_LENGTH] + "... [JS result truncated at 5k chars]"
|
||||
|
||||
state = await self._get_page_state(tab_id)
|
||||
state["js_result"] = result
|
||||
return state
|
||||
|
||||
def get_console_logs(self, tab_id: str | None = None, clear: bool = False) -> dict[str, Any]:
|
||||
with self._execution_lock:
|
||||
return self._run_async(self._get_console_logs(tab_id, clear))
|
||||
|
||||
async def _get_console_logs(
|
||||
self, tab_id: str | None = None, clear: bool = False
|
||||
) -> dict[str, Any]:
|
||||
if not tab_id:
|
||||
tab_id = self.current_page_id
|
||||
|
||||
if not tab_id or tab_id not in self.pages:
|
||||
raise ValueError(f"Tab '{tab_id}' not found")
|
||||
|
||||
logs = self.console_logs.get(tab_id, [])
|
||||
|
||||
total_length = sum(len(str(log)) for log in logs)
|
||||
if total_length > MAX_CONSOLE_LOG_LENGTH:
|
||||
truncated_logs: list[dict[str, Any]] = []
|
||||
current_length = 0
|
||||
|
||||
for log in reversed(logs):
|
||||
log_length = len(str(log))
|
||||
if current_length + log_length <= MAX_CONSOLE_LOG_LENGTH:
|
||||
truncated_logs.insert(0, log)
|
||||
current_length += log_length
|
||||
else:
|
||||
break
|
||||
|
||||
if len(truncated_logs) < len(logs):
|
||||
truncation_notice = {
|
||||
"type": "info",
|
||||
"text": (
|
||||
f"[TRUNCATED: {len(logs) - len(truncated_logs)} older logs "
|
||||
f"removed to stay within {MAX_CONSOLE_LOG_LENGTH} character limit]"
|
||||
),
|
||||
"location": {},
|
||||
"timestamp": 0,
|
||||
}
|
||||
truncated_logs.insert(0, truncation_notice)
|
||||
|
||||
logs = truncated_logs
|
||||
|
||||
if clear:
|
||||
self.console_logs[tab_id] = []
|
||||
|
||||
state = await self._get_page_state(tab_id)
|
||||
state["console_logs"] = logs
|
||||
return state
|
||||
|
||||
def view_source(self, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._execution_lock:
|
||||
return self._run_async(self._view_source(tab_id))
|
||||
|
||||
async def _view_source(self, tab_id: str | None = None) -> dict[str, Any]:
|
||||
if not tab_id:
|
||||
tab_id = self.current_page_id
|
||||
|
||||
if not tab_id or tab_id not in self.pages:
|
||||
raise ValueError(f"Tab '{tab_id}' not found")
|
||||
|
||||
page = self.pages[tab_id]
|
||||
source = await page.content()
|
||||
original_length = len(source)
|
||||
|
||||
if original_length > MAX_PAGE_SOURCE_LENGTH:
|
||||
truncation_message = (
|
||||
f"\n\n<!-- [TRUNCATED: {original_length - MAX_PAGE_SOURCE_LENGTH} "
|
||||
"characters removed] -->\n\n"
|
||||
)
|
||||
available_space = MAX_PAGE_SOURCE_LENGTH - len(truncation_message)
|
||||
truncate_point = available_space // 2
|
||||
|
||||
source = source[:truncate_point] + truncation_message + source[-truncate_point:]
|
||||
|
||||
state = await self._get_page_state(tab_id)
|
||||
state["page_source"] = source
|
||||
return state
|
||||
|
||||
def double_click(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._execution_lock:
|
||||
return self._run_async(self._double_click(coordinate, tab_id))
|
||||
|
||||
async def _double_click(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
if not tab_id:
|
||||
tab_id = self.current_page_id
|
||||
|
||||
if not tab_id or tab_id not in self.pages:
|
||||
raise ValueError(f"Tab '{tab_id}' not found")
|
||||
|
||||
try:
|
||||
x, y = map(int, coordinate.split(","))
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid coordinate format: {coordinate}. Use 'x,y'") from e
|
||||
|
||||
page = self.pages[tab_id]
|
||||
await page.mouse.dblclick(x, y)
|
||||
|
||||
return await self._get_page_state(tab_id)
|
||||
|
||||
def hover(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._execution_lock:
|
||||
return self._run_async(self._hover(coordinate, tab_id))
|
||||
|
||||
async def _hover(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
if not tab_id:
|
||||
tab_id = self.current_page_id
|
||||
|
||||
if not tab_id or tab_id not in self.pages:
|
||||
raise ValueError(f"Tab '{tab_id}' not found")
|
||||
|
||||
try:
|
||||
x, y = map(int, coordinate.split(","))
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid coordinate format: {coordinate}. Use 'x,y'") from e
|
||||
|
||||
page = self.pages[tab_id]
|
||||
await page.mouse.move(x, y)
|
||||
|
||||
return await self._get_page_state(tab_id)
|
||||
|
||||
def press_key(self, key: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._execution_lock:
|
||||
return self._run_async(self._press_key(key, tab_id))
|
||||
|
||||
async def _press_key(self, key: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
if not tab_id:
|
||||
tab_id = self.current_page_id
|
||||
|
||||
if not tab_id or tab_id not in self.pages:
|
||||
raise ValueError(f"Tab '{tab_id}' not found")
|
||||
|
||||
page = self.pages[tab_id]
|
||||
await page.keyboard.press(key)
|
||||
|
||||
return await self._get_page_state(tab_id)
|
||||
|
||||
def save_pdf(self, file_path: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._execution_lock:
|
||||
return self._run_async(self._save_pdf(file_path, tab_id))
|
||||
|
||||
async def _save_pdf(self, file_path: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
if not tab_id:
|
||||
tab_id = self.current_page_id
|
||||
|
||||
if not tab_id or tab_id not in self.pages:
|
||||
raise ValueError(f"Tab '{tab_id}' not found")
|
||||
|
||||
if not Path(file_path).is_absolute():
|
||||
file_path = str(Path("/workspace") / file_path)
|
||||
|
||||
page = self.pages[tab_id]
|
||||
await page.pdf(path=file_path)
|
||||
|
||||
state = await self._get_page_state(tab_id)
|
||||
state["pdf_saved"] = file_path
|
||||
return state
|
||||
|
||||
def close(self) -> None:
|
||||
with self._execution_lock:
|
||||
self.is_running = False
|
||||
if self._loop:
|
||||
asyncio.run_coroutine_threadsafe(self._close_browser(), self._loop)
|
||||
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
|
||||
if self._loop_thread:
|
||||
self._loop_thread.join(timeout=5)
|
||||
|
||||
async def _close_browser(self) -> None:
|
||||
try:
|
||||
if self.browser:
|
||||
await self.browser.close()
|
||||
if self.playwright:
|
||||
await self.playwright.stop()
|
||||
except (OSError, RuntimeError) as e:
|
||||
logger.warning(f"Error closing browser: {e}")
|
||||
|
||||
def is_alive(self) -> bool:
|
||||
return self.is_running and self.browser is not None and self.browser.is_connected()
|
||||
342
strix/tools/browser/tab_manager.py
Normal file
342
strix/tools/browser/tab_manager.py
Normal file
@@ -0,0 +1,342 @@
|
||||
import atexit
|
||||
import contextlib
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from .browser_instance import BrowserInstance
|
||||
|
||||
|
||||
class BrowserTabManager:
|
||||
def __init__(self) -> None:
|
||||
self.browser_instance: BrowserInstance | None = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
self._register_cleanup_handlers()
|
||||
|
||||
def launch_browser(self, url: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is not None:
|
||||
raise ValueError("Browser is already launched")
|
||||
|
||||
try:
|
||||
self.browser_instance = BrowserInstance()
|
||||
result = self.browser_instance.launch(url)
|
||||
result["message"] = "Browser launched successfully"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
if self.browser_instance:
|
||||
self.browser_instance = None
|
||||
raise RuntimeError(f"Failed to launch browser: {e}") from e
|
||||
else:
|
||||
return result
|
||||
|
||||
def goto_url(self, url: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.goto(url, tab_id)
|
||||
result["message"] = f"Navigated to {url}"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to navigate to URL: {e}") from e
|
||||
else:
|
||||
return result
|
||||
|
||||
def click(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.click(coordinate, tab_id)
|
||||
result["message"] = f"Clicked at {coordinate}"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to click: {e}") from e
|
||||
else:
|
||||
return result
|
||||
|
||||
def type_text(self, text: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.type_text(text, tab_id)
|
||||
result["message"] = f"Typed text: {text[:50]}{'...' if len(text) > 50 else ''}"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to type text: {e}") from e
|
||||
else:
|
||||
return result
|
||||
|
||||
def scroll(self, direction: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.scroll(direction, tab_id)
|
||||
result["message"] = f"Scrolled {direction}"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to scroll: {e}") from e
|
||||
else:
|
||||
return result
|
||||
|
||||
def back(self, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.back(tab_id)
|
||||
result["message"] = "Navigated back"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to go back: {e}") from e
|
||||
else:
|
||||
return result
|
||||
|
||||
def forward(self, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.forward(tab_id)
|
||||
result["message"] = "Navigated forward"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to go forward: {e}") from e
|
||||
else:
|
||||
return result
|
||||
|
||||
def new_tab(self, url: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.new_tab(url)
|
||||
result["message"] = f"Created new tab {result.get('tab_id', '')}"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to create new tab: {e}") from e
|
||||
else:
|
||||
return result
|
||||
|
||||
def switch_tab(self, tab_id: str) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.switch_tab(tab_id)
|
||||
result["message"] = f"Switched to tab {tab_id}"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to switch tab: {e}") from e
|
||||
else:
|
||||
return result
|
||||
|
||||
def close_tab(self, tab_id: str) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.close_tab(tab_id)
|
||||
result["message"] = f"Closed tab {tab_id}"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to close tab: {e}") from e
|
||||
else:
|
||||
return result
|
||||
|
||||
def wait_browser(self, duration: float, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.wait(duration, tab_id)
|
||||
result["message"] = f"Waited {duration}s"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to wait: {e}") from e
|
||||
else:
|
||||
return result
|
||||
|
||||
def execute_js(self, js_code: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.execute_js(js_code, tab_id)
|
||||
result["message"] = "JavaScript executed successfully"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to execute JavaScript: {e}") from e
|
||||
else:
|
||||
return result
|
||||
|
||||
def double_click(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.double_click(coordinate, tab_id)
|
||||
result["message"] = f"Double clicked at {coordinate}"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to double click: {e}") from e
|
||||
else:
|
||||
return result
|
||||
|
||||
def hover(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.hover(coordinate, tab_id)
|
||||
result["message"] = f"Hovered at {coordinate}"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to hover: {e}") from e
|
||||
else:
|
||||
return result
|
||||
|
||||
def press_key(self, key: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.press_key(key, tab_id)
|
||||
result["message"] = f"Pressed key {key}"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to press key: {e}") from e
|
||||
else:
|
||||
return result
|
||||
|
||||
def save_pdf(self, file_path: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.save_pdf(file_path, tab_id)
|
||||
result["message"] = f"Page saved as PDF: {file_path}"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to save PDF: {e}") from e
|
||||
else:
|
||||
return result
|
||||
|
||||
def get_console_logs(self, tab_id: str | None = None, clear: bool = False) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.get_console_logs(tab_id, clear)
|
||||
action_text = "cleared and retrieved" if clear else "retrieved"
|
||||
|
||||
logs = result.get("console_logs", [])
|
||||
truncated = any(log.get("text", "").startswith("[TRUNCATED:") for log in logs)
|
||||
truncated_text = " (truncated)" if truncated else ""
|
||||
|
||||
result["message"] = (
|
||||
f"Console logs {action_text} for tab "
|
||||
f"{result.get('tab_id', 'current')}{truncated_text}"
|
||||
)
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to get console logs: {e}") from e
|
||||
else:
|
||||
return result
|
||||
|
||||
def view_source(self, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.view_source(tab_id)
|
||||
result["message"] = "Page source retrieved"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to get page source: {e}") from e
|
||||
else:
|
||||
return result
|
||||
|
||||
def list_tabs(self) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
return {"tabs": {}, "total_count": 0, "current_tab": None}
|
||||
|
||||
try:
|
||||
tab_info = {}
|
||||
for tid, tab_page in self.browser_instance.pages.items():
|
||||
try:
|
||||
tab_info[tid] = {
|
||||
"url": tab_page.url,
|
||||
"title": "Unknown" if tab_page.is_closed() else "Active",
|
||||
"is_current": tid == self.browser_instance.current_page_id,
|
||||
}
|
||||
except (AttributeError, RuntimeError):
|
||||
tab_info[tid] = {
|
||||
"url": "Unknown",
|
||||
"title": "Closed",
|
||||
"is_current": False,
|
||||
}
|
||||
|
||||
return {
|
||||
"tabs": tab_info,
|
||||
"total_count": len(tab_info),
|
||||
"current_tab": self.browser_instance.current_page_id,
|
||||
}
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to list tabs: {e}") from e
|
||||
|
||||
def close_browser(self) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
self.browser_instance.close()
|
||||
self.browser_instance = None
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to close browser: {e}") from e
|
||||
else:
|
||||
return {
|
||||
"message": "Browser closed successfully",
|
||||
"screenshot": "",
|
||||
"is_running": False,
|
||||
}
|
||||
|
||||
def cleanup_dead_browser(self) -> None:
|
||||
with self._lock:
|
||||
if self.browser_instance and not self.browser_instance.is_alive():
|
||||
with contextlib.suppress(Exception):
|
||||
self.browser_instance.close()
|
||||
self.browser_instance = None
|
||||
|
||||
def close_all(self) -> None:
|
||||
with self._lock:
|
||||
if self.browser_instance:
|
||||
with contextlib.suppress(Exception):
|
||||
self.browser_instance.close()
|
||||
self.browser_instance = None
|
||||
|
||||
def _register_cleanup_handlers(self) -> None:
|
||||
atexit.register(self.close_all)
|
||||
|
||||
signal.signal(signal.SIGTERM, self._signal_handler)
|
||||
signal.signal(signal.SIGINT, self._signal_handler)
|
||||
|
||||
if hasattr(signal, "SIGHUP"):
|
||||
signal.signal(signal.SIGHUP, self._signal_handler)
|
||||
|
||||
def _signal_handler(self, _signum: int, _frame: Any) -> None:
|
||||
self.close_all()
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
_browser_tab_manager = BrowserTabManager()
|
||||
|
||||
|
||||
def get_browser_tab_manager() -> BrowserTabManager:
|
||||
return _browser_tab_manager
|
||||
302
strix/tools/executor.py
Normal file
302
strix/tools/executor.py
Normal file
@@ -0,0 +1,302 @@
|
||||
import inspect
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
if os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "false":
|
||||
from strix.runtime import get_runtime
|
||||
|
||||
from .argument_parser import convert_arguments
|
||||
from .registry import (
|
||||
get_tool_by_name,
|
||||
get_tool_names,
|
||||
needs_agent_state,
|
||||
should_execute_in_sandbox,
|
||||
)
|
||||
|
||||
|
||||
async def execute_tool(tool_name: str, agent_state: Any | None = None, **kwargs: Any) -> Any:
|
||||
execute_in_sandbox = should_execute_in_sandbox(tool_name)
|
||||
sandbox_mode = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
|
||||
|
||||
if execute_in_sandbox and not sandbox_mode:
|
||||
return await _execute_tool_in_sandbox(tool_name, agent_state, **kwargs)
|
||||
|
||||
return await _execute_tool_locally(tool_name, agent_state, **kwargs)
|
||||
|
||||
|
||||
async def _execute_tool_in_sandbox(tool_name: str, agent_state: Any, **kwargs: Any) -> Any:
|
||||
if not hasattr(agent_state, "sandbox_id") or not agent_state.sandbox_id:
|
||||
raise ValueError("Agent state with a valid sandbox_id is required for sandbox execution.")
|
||||
|
||||
if not hasattr(agent_state, "sandbox_token") or not agent_state.sandbox_token:
|
||||
raise ValueError(
|
||||
"Agent state with a valid sandbox_token is required for sandbox execution."
|
||||
)
|
||||
|
||||
if (
|
||||
not hasattr(agent_state, "sandbox_info")
|
||||
or "tool_server_port" not in agent_state.sandbox_info
|
||||
):
|
||||
raise ValueError(
|
||||
"Agent state with a valid sandbox_info containing tool_server_port is required."
|
||||
)
|
||||
|
||||
runtime = get_runtime()
|
||||
tool_server_port = agent_state.sandbox_info["tool_server_port"]
|
||||
server_url = await runtime.get_sandbox_url(agent_state.sandbox_id, tool_server_port)
|
||||
request_url = f"{server_url}/execute"
|
||||
|
||||
request_data = {
|
||||
"tool_name": tool_name,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {agent_state.sandbox_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
request_url, json=request_data, headers=headers, timeout=None
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_data = response.json()
|
||||
if response_data.get("error"):
|
||||
raise RuntimeError(f"Sandbox execution error: {response_data['error']}")
|
||||
return response_data.get("result")
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise RuntimeError("Authentication failed: Invalid or missing sandbox token") from e
|
||||
raise RuntimeError(f"HTTP error calling tool server: {e.response.status_code}") from e
|
||||
except httpx.RequestError as e:
|
||||
raise RuntimeError(f"Request error calling tool server: {e}") from e
|
||||
|
||||
|
||||
async def _execute_tool_locally(tool_name: str, agent_state: Any | None, **kwargs: Any) -> Any:
|
||||
tool_func = get_tool_by_name(tool_name)
|
||||
if not tool_func:
|
||||
raise ValueError(f"Tool '{tool_name}' not found")
|
||||
|
||||
converted_kwargs = convert_arguments(tool_func, kwargs)
|
||||
|
||||
if needs_agent_state(tool_name):
|
||||
if agent_state is None:
|
||||
raise ValueError(f"Tool '{tool_name}' requires agent_state but none was provided.")
|
||||
result = tool_func(agent_state=agent_state, **converted_kwargs)
|
||||
else:
|
||||
result = tool_func(**converted_kwargs)
|
||||
|
||||
return await result if inspect.isawaitable(result) else result
|
||||
|
||||
|
||||
def validate_tool_availability(tool_name: str | None) -> tuple[bool, str]:
|
||||
if tool_name is None:
|
||||
return False, "Tool name is missing"
|
||||
|
||||
if tool_name not in get_tool_names():
|
||||
return False, f"Tool '{tool_name}' is not available"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
async def execute_tool_with_validation(
|
||||
tool_name: str | None, agent_state: Any | None = None, **kwargs: Any
|
||||
) -> Any:
|
||||
is_valid, error_msg = validate_tool_availability(tool_name)
|
||||
if not is_valid:
|
||||
return f"Error: {error_msg}"
|
||||
|
||||
assert tool_name is not None
|
||||
|
||||
try:
|
||||
result = await execute_tool(tool_name, agent_state, **kwargs)
|
||||
except Exception as e: # noqa: BLE001
|
||||
error_str = str(e)
|
||||
if len(error_str) > 500:
|
||||
error_str = error_str[:500] + "... [truncated]"
|
||||
return f"Error executing {tool_name}: {error_str}"
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
async def execute_tool_invocation(tool_inv: dict[str, Any], agent_state: Any | None = None) -> Any:
|
||||
tool_name = tool_inv.get("toolName")
|
||||
tool_args = tool_inv.get("args", {})
|
||||
|
||||
return await execute_tool_with_validation(tool_name, agent_state, **tool_args)
|
||||
|
||||
|
||||
def _check_error_result(result: Any) -> tuple[bool, Any]:
|
||||
is_error = False
|
||||
error_payload: Any = None
|
||||
|
||||
if (isinstance(result, dict) and "error" in result) or (
|
||||
isinstance(result, str) and result.strip().lower().startswith("error:")
|
||||
):
|
||||
is_error = True
|
||||
error_payload = result
|
||||
|
||||
return is_error, error_payload
|
||||
|
||||
|
||||
def _update_tracer_with_result(
|
||||
tracer: Any, execution_id: Any, is_error: bool, result: Any, error_payload: Any
|
||||
) -> None:
|
||||
if not tracer or not execution_id:
|
||||
return
|
||||
|
||||
try:
|
||||
if is_error:
|
||||
tracer.update_tool_execution(execution_id, "error", error_payload)
|
||||
else:
|
||||
tracer.update_tool_execution(execution_id, "completed", result)
|
||||
except (ConnectionError, RuntimeError) as e:
|
||||
error_msg = str(e)
|
||||
if tracer and execution_id:
|
||||
tracer.update_tool_execution(execution_id, "error", error_msg)
|
||||
raise
|
||||
|
||||
|
||||
def _format_tool_result(tool_name: str, result: Any) -> tuple[str, list[dict[str, Any]]]:
|
||||
images: list[dict[str, Any]] = []
|
||||
|
||||
screenshot_data = extract_screenshot_from_result(result)
|
||||
if screenshot_data:
|
||||
images.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{screenshot_data}"},
|
||||
}
|
||||
)
|
||||
result_str = remove_screenshot_from_result(result)
|
||||
else:
|
||||
result_str = result
|
||||
|
||||
if result_str is None:
|
||||
final_result_str = f"Tool {tool_name} executed successfully"
|
||||
else:
|
||||
final_result_str = str(result_str)
|
||||
if len(final_result_str) > 10000:
|
||||
start_part = final_result_str[:4000]
|
||||
end_part = final_result_str[-4000:]
|
||||
final_result_str = start_part + "\n\n... [middle content truncated] ...\n\n" + end_part
|
||||
|
||||
observation_xml = (
|
||||
f"<tool_result>\n<tool_name>{tool_name}</tool_name>\n"
|
||||
f"<result>{final_result_str}</result>\n</tool_result>"
|
||||
)
|
||||
|
||||
return observation_xml, images
|
||||
|
||||
|
||||
async def _execute_single_tool(
|
||||
tool_inv: dict[str, Any],
|
||||
agent_state: Any | None,
|
||||
tracer: Any | None,
|
||||
agent_id: str,
|
||||
) -> tuple[str, list[dict[str, Any]], bool]:
|
||||
tool_name = tool_inv.get("toolName", "unknown")
|
||||
args = tool_inv.get("args", {})
|
||||
execution_id = None
|
||||
should_agent_finish = False
|
||||
|
||||
if tracer:
|
||||
execution_id = tracer.log_tool_execution_start(agent_id, tool_name, args)
|
||||
|
||||
try:
|
||||
result = await execute_tool_invocation(tool_inv, agent_state)
|
||||
|
||||
is_error, error_payload = _check_error_result(result)
|
||||
|
||||
if (
|
||||
tool_name in ("finish_scan", "agent_finish")
|
||||
and not is_error
|
||||
and isinstance(result, dict)
|
||||
):
|
||||
if tool_name == "finish_scan":
|
||||
should_agent_finish = result.get("scan_completed", False)
|
||||
elif tool_name == "agent_finish":
|
||||
should_agent_finish = result.get("agent_completed", False)
|
||||
|
||||
_update_tracer_with_result(tracer, execution_id, is_error, result, error_payload)
|
||||
|
||||
except (ConnectionError, RuntimeError, ValueError, TypeError, OSError) as e:
|
||||
error_msg = str(e)
|
||||
if tracer and execution_id:
|
||||
tracer.update_tool_execution(execution_id, "error", error_msg)
|
||||
raise
|
||||
|
||||
observation_xml, images = _format_tool_result(tool_name, result)
|
||||
return observation_xml, images, should_agent_finish
|
||||
|
||||
|
||||
def _get_tracer_and_agent_id(agent_state: Any | None) -> tuple[Any | None, str]:
|
||||
try:
|
||||
from strix.cli.tracer import get_global_tracer
|
||||
|
||||
tracer = get_global_tracer()
|
||||
agent_id = agent_state.agent_id if agent_state else "unknown_agent"
|
||||
except (ImportError, AttributeError):
|
||||
tracer = None
|
||||
agent_id = "unknown_agent"
|
||||
|
||||
return tracer, agent_id
|
||||
|
||||
|
||||
async def process_tool_invocations(
|
||||
tool_invocations: list[dict[str, Any]],
|
||||
conversation_history: list[dict[str, Any]],
|
||||
agent_state: Any | None = None,
|
||||
) -> bool:
|
||||
observation_parts: list[str] = []
|
||||
all_images: list[dict[str, Any]] = []
|
||||
should_agent_finish = False
|
||||
|
||||
tracer, agent_id = _get_tracer_and_agent_id(agent_state)
|
||||
|
||||
for tool_inv in tool_invocations:
|
||||
observation_xml, images, tool_should_finish = await _execute_single_tool(
|
||||
tool_inv, agent_state, tracer, agent_id
|
||||
)
|
||||
observation_parts.append(observation_xml)
|
||||
all_images.extend(images)
|
||||
|
||||
if tool_should_finish:
|
||||
should_agent_finish = True
|
||||
|
||||
if all_images:
|
||||
content = [{"type": "text", "text": "Tool Results:\n\n" + "\n\n".join(observation_parts)}]
|
||||
content.extend(all_images)
|
||||
conversation_history.append({"role": "user", "content": content})
|
||||
else:
|
||||
observation_content = "Tool Results:\n\n" + "\n\n".join(observation_parts)
|
||||
conversation_history.append({"role": "user", "content": observation_content})
|
||||
|
||||
return should_agent_finish
|
||||
|
||||
|
||||
def extract_screenshot_from_result(result: Any) -> str | None:
|
||||
if not isinstance(result, dict):
|
||||
return None
|
||||
|
||||
screenshot = result.get("screenshot")
|
||||
if isinstance(screenshot, str) and screenshot:
|
||||
return screenshot
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def remove_screenshot_from_result(result: Any) -> Any:
|
||||
if not isinstance(result, dict):
|
||||
return result
|
||||
|
||||
result_copy = result.copy()
|
||||
if "screenshot" in result_copy:
|
||||
result_copy["screenshot"] = "[Image data extracted - see attached image]"
|
||||
|
||||
return result_copy
|
||||
4
strix/tools/file_edit/__init__.py
Normal file
4
strix/tools/file_edit/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .file_edit_actions import list_files, search_files, str_replace_editor
|
||||
|
||||
|
||||
__all__ = ["list_files", "search_files", "str_replace_editor"]
|
||||
141
strix/tools/file_edit/file_edit_actions.py
Normal file
141
strix/tools/file_edit/file_edit_actions.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
from openhands_aci import file_editor
|
||||
from openhands_aci.utils.shell import run_shell_cmd
|
||||
|
||||
from strix.tools.registry import register_tool
|
||||
|
||||
|
||||
def _parse_file_editor_output(output: str) -> dict[str, Any]:
|
||||
try:
|
||||
pattern = r"<oh_aci_output_[^>]+>\n(.*?)\n</oh_aci_output_[^>]+>"
|
||||
match = re.search(pattern, output, re.DOTALL)
|
||||
|
||||
if match:
|
||||
json_str = match.group(1)
|
||||
data = json.loads(json_str)
|
||||
return cast("dict[str, Any]", data)
|
||||
return {"output": output, "error": None}
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
return {"output": output, "error": None}
|
||||
|
||||
|
||||
@register_tool
|
||||
def str_replace_editor(
|
||||
command: str,
|
||||
path: str,
|
||||
file_text: str | None = None,
|
||||
view_range: list[int] | None = None,
|
||||
old_str: str | None = None,
|
||||
new_str: str | None = None,
|
||||
insert_line: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
path_obj = Path(path)
|
||||
if not path_obj.is_absolute():
|
||||
path = str(Path("/workspace") / path_obj)
|
||||
|
||||
result = file_editor(
|
||||
command=command,
|
||||
path=path,
|
||||
file_text=file_text,
|
||||
view_range=view_range,
|
||||
old_str=old_str,
|
||||
new_str=new_str,
|
||||
insert_line=insert_line,
|
||||
)
|
||||
|
||||
parsed = _parse_file_editor_output(result)
|
||||
|
||||
if parsed.get("error"):
|
||||
return {"error": parsed["error"]}
|
||||
|
||||
return {"content": parsed.get("output", result)}
|
||||
|
||||
except (OSError, ValueError) as e:
|
||||
return {"error": f"Error in {command} operation: {e!s}"}
|
||||
|
||||
|
||||
@register_tool
|
||||
def list_files(
|
||||
path: str,
|
||||
recursive: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
path_obj = Path(path)
|
||||
if not path_obj.is_absolute():
|
||||
path = str(Path("/workspace") / path_obj)
|
||||
path_obj = Path(path)
|
||||
|
||||
if not path_obj.exists():
|
||||
return {"error": f"Directory not found: {path}"}
|
||||
|
||||
if not path_obj.is_dir():
|
||||
return {"error": f"Path is not a directory: {path}"}
|
||||
|
||||
cmd = f"find '{path}' -type f -o -type d | head -500" if recursive else f"ls -1a '{path}'"
|
||||
|
||||
exit_code, stdout, stderr = run_shell_cmd(cmd)
|
||||
|
||||
if exit_code != 0:
|
||||
return {"error": f"Error listing directory: {stderr}"}
|
||||
|
||||
items = stdout.strip().split("\n") if stdout.strip() else []
|
||||
|
||||
files = []
|
||||
dirs = []
|
||||
|
||||
for item in items:
|
||||
item_path = item if recursive else str(Path(path) / item)
|
||||
item_path_obj = Path(item_path)
|
||||
|
||||
if item_path_obj.is_file():
|
||||
files.append(item)
|
||||
elif item_path_obj.is_dir():
|
||||
dirs.append(item)
|
||||
|
||||
return {
|
||||
"files": sorted(files),
|
||||
"directories": sorted(dirs),
|
||||
"total_files": len(files),
|
||||
"total_dirs": len(dirs),
|
||||
"path": path,
|
||||
"recursive": recursive,
|
||||
}
|
||||
|
||||
except (OSError, ValueError) as e:
|
||||
return {"error": f"Error listing directory: {e!s}"}
|
||||
|
||||
|
||||
@register_tool
|
||||
def search_files(
|
||||
path: str,
|
||||
regex: str,
|
||||
file_pattern: str = "*",
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
path_obj = Path(path)
|
||||
if not path_obj.is_absolute():
|
||||
path = str(Path("/workspace") / path_obj)
|
||||
|
||||
if not Path(path).exists():
|
||||
return {"error": f"Directory not found: {path}"}
|
||||
|
||||
escaped_regex = regex.replace("'", "'\"'\"'")
|
||||
|
||||
cmd = f"rg --line-number --glob '{file_pattern}' '{escaped_regex}' '{path}'"
|
||||
|
||||
exit_code, stdout, stderr = run_shell_cmd(cmd)
|
||||
|
||||
if exit_code not in {0, 1}:
|
||||
return {"error": f"Error searching files: {stderr}"}
|
||||
return {"output": stdout if stdout else "No matches found"}
|
||||
|
||||
except (OSError, ValueError) as e:
|
||||
return {"error": f"Error searching files: {e!s}"}
|
||||
|
||||
|
||||
# ruff: noqa: TRY300
|
||||
128
strix/tools/file_edit/file_edit_actions_schema.xml
Normal file
128
strix/tools/file_edit/file_edit_actions_schema.xml
Normal file
@@ -0,0 +1,128 @@
|
||||
<tools>
|
||||
<tool name="list_files">
|
||||
<description>List files and directories within the specified directory.</description>
|
||||
<parameters>
|
||||
<parameter name="path" type="string" required="true">
|
||||
<description>Directory path to list</description>
|
||||
</parameter>
|
||||
<parameter name="recursive" type="boolean" required="false">
|
||||
<description>Whether to list files recursively</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
<returns type="Dict[str, Any]">
|
||||
<description>Response containing: - files: List of files and directories - total_files: Total number of files found - total_dirs: Total number of directories found</description>
|
||||
</returns>
|
||||
<notes>
|
||||
- Lists contents alphabetically
|
||||
- Returns maximum 500 results to avoid overwhelming output
|
||||
</notes>
|
||||
<examples>
|
||||
# List directory contents
|
||||
<function=list_files>
|
||||
<parameter=path>/home/user/project/src</parameter>
|
||||
</function>
|
||||
|
||||
# Recursive listing
|
||||
<function=list_files>
|
||||
<parameter=path>/home/user/project/src</parameter>
|
||||
<parameter=recursive>true</parameter>
|
||||
</function>
|
||||
</examples>
|
||||
</tool>
|
||||
<tool name="search_files">
|
||||
<description>Perform a regex search across files in a directory.</description>
|
||||
<parameters>
|
||||
<parameter name="path" type="string" required="true">
|
||||
<description>Directory path to search</description>
|
||||
</parameter>
|
||||
<parameter name="regex" type="string" required="true">
|
||||
<description>Regular expression pattern to search for</description>
|
||||
</parameter>
|
||||
<parameter name="file_pattern" type="string" required="false">
|
||||
<description>File pattern to filter (e.g., "*.py", "*.js")</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
<returns type="Dict[str, Any]">
|
||||
<description>Response containing: - output: The search results as a string</description>
|
||||
</returns>
|
||||
<notes>
|
||||
- Searches recursively through subdirectories
|
||||
- Uses ripgrep for fast searching
|
||||
</notes>
|
||||
<examples>
|
||||
# Search Python files for a pattern
|
||||
<function=search_files>
|
||||
<parameter=path>/home/user/project/src</parameter>
|
||||
<parameter=regex>def\s+process_data</parameter>
|
||||
<parameter=file_pattern>*.py</parameter>
|
||||
</function>
|
||||
</examples>
|
||||
</tool>
|
||||
<tool name="str_replace_editor">
|
||||
<description>A text editor tool for viewing, creating and editing files.</description>
|
||||
<parameters>
|
||||
<parameter name="command" type="string" required="true">
|
||||
<description>Editor command to execute</description>
|
||||
</parameter>
|
||||
<parameter name="path" type="string" required="true">
|
||||
<description>Path to the file to edit</description>
|
||||
</parameter>
|
||||
<parameter name="file_text" type="string" required="false">
|
||||
<description>Required parameter of create command, with the content of the file to be created</description>
|
||||
</parameter>
|
||||
<parameter name="view_range" type="string" required="false">
|
||||
<description>Optional parameter of view command when path points to a file. If none is given, the full file is shown. If provided, the file will be shown in the indicated line number range, e.g. [11, 12] will show lines 11 and 12. Indexing at 1 to start. Setting [start_line, -1] shows all lines from start_line to the end of the file</description>
|
||||
</parameter>
|
||||
<parameter name="old_str" type="string" required="false">
|
||||
<description>Required parameter of str_replace command containing the string in path to replace</description>
|
||||
</parameter>
|
||||
<parameter name="new_str" type="string" required="false">
|
||||
<description>Optional parameter of str_replace command containing the new string (if not given, no string will be added). Required parameter of insert command containing the string to insert</description>
|
||||
</parameter>
|
||||
<parameter name="insert_line" type="string" required="false">
|
||||
<description>Required parameter of insert command. The new_str will be inserted AFTER the line insert_line of path</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
<returns type="Dict[str, Any]">
|
||||
<description>Response containing the result of the operation</description>
|
||||
</returns>
|
||||
<notes>
|
||||
Command details:
|
||||
- view: Show file contents, optionally with line range
|
||||
- create: Create a new file with given content
|
||||
- str_replace: Replace old_str with new_str in file
|
||||
- insert: Insert new_str after the specified line number
|
||||
- undo_edit: Revert the last edit made to the file
|
||||
</notes>
|
||||
<examples>
|
||||
# View a file
|
||||
<function=str_replace_editor>
|
||||
<parameter=command>view</parameter>
|
||||
<parameter=path>/home/user/project/file.py</parameter>
|
||||
</function>
|
||||
|
||||
# Create a file
|
||||
<function=str_replace_editor>
|
||||
<parameter=command>create</parameter>
|
||||
<parameter=path>/home/user/project/new_file.py</parameter>
|
||||
<parameter=file_text>print("Hello World")</parameter>
|
||||
</function>
|
||||
|
||||
# Replace text in file
|
||||
<function=str_replace_editor>
|
||||
<parameter=command>str_replace</parameter>
|
||||
<parameter=path>/home/user/project/file.py</parameter>
|
||||
<parameter=old_str>old_function()</parameter>
|
||||
<parameter=new_str>new_function()</parameter>
|
||||
</function>
|
||||
|
||||
# Insert text after line 10
|
||||
<function=str_replace_editor>
|
||||
<parameter=command>insert</parameter>
|
||||
<parameter=path>/home/user/project/file.py</parameter>
|
||||
<parameter=insert_line>10</parameter>
|
||||
<parameter=new_str>print("Inserted line")</parameter>
|
||||
</function>
|
||||
</examples>
|
||||
</tool>
|
||||
</tools>
|
||||
4
strix/tools/finish/__init__.py
Normal file
4
strix/tools/finish/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .finish_actions import finish_scan
|
||||
|
||||
|
||||
__all__ = ["finish_scan"]
|
||||
174
strix/tools/finish/finish_actions.py
Normal file
174
strix/tools/finish/finish_actions.py
Normal file
@@ -0,0 +1,174 @@
|
||||
from typing import Any
|
||||
|
||||
from strix.tools.registry import register_tool
|
||||
|
||||
|
||||
def _validate_root_agent(agent_state: Any) -> dict[str, Any] | None:
|
||||
if (
|
||||
agent_state is not None
|
||||
and hasattr(agent_state, "parent_id")
|
||||
and agent_state.parent_id is not None
|
||||
):
|
||||
return {
|
||||
"success": False,
|
||||
"message": (
|
||||
"This tool can only be used by the root/main agent. "
|
||||
"Subagents must use agent_finish instead."
|
||||
),
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def _validate_content(content: str) -> dict[str, Any] | None:
|
||||
if not content or not content.strip():
|
||||
return {"success": False, "message": "Content cannot be empty"}
|
||||
return None
|
||||
|
||||
|
||||
def _check_active_agents(agent_state: Any = None) -> dict[str, Any] | None:
|
||||
try:
|
||||
from strix.tools.agents_graph.agents_graph_actions import _agent_graph
|
||||
|
||||
current_agent_id = None
|
||||
if agent_state and hasattr(agent_state, "agent_id"):
|
||||
current_agent_id = agent_state.agent_id
|
||||
|
||||
running_agents = []
|
||||
stopping_agents = []
|
||||
|
||||
for agent_id, node in _agent_graph.get("nodes", {}).items():
|
||||
if agent_id == current_agent_id:
|
||||
continue
|
||||
|
||||
status = node.get("status", "")
|
||||
if status == "running":
|
||||
running_agents.append(
|
||||
{
|
||||
"id": agent_id,
|
||||
"name": node.get("name", "Unknown"),
|
||||
"task": node.get("task", "No task description"),
|
||||
}
|
||||
)
|
||||
elif status == "stopping":
|
||||
stopping_agents.append(
|
||||
{
|
||||
"id": agent_id,
|
||||
"name": node.get("name", "Unknown"),
|
||||
}
|
||||
)
|
||||
|
||||
if running_agents or stopping_agents:
|
||||
message_parts = ["Cannot finish scan while other agents are still active:"]
|
||||
|
||||
if running_agents:
|
||||
message_parts.append("\n\nRunning agents:")
|
||||
message_parts.extend(
|
||||
[
|
||||
f" - {agent['name']} ({agent['id']}): {agent['task']}"
|
||||
for agent in running_agents
|
||||
]
|
||||
)
|
||||
|
||||
if stopping_agents:
|
||||
message_parts.append("\n\nStopping agents:")
|
||||
message_parts.extend(
|
||||
[f" - {agent['name']} ({agent['id']})" for agent in stopping_agents]
|
||||
)
|
||||
|
||||
message_parts.extend(
|
||||
[
|
||||
"\n\nSuggested actions:",
|
||||
"1. Use wait_for_message to wait for all agents to complete",
|
||||
"2. Send messages to agents asking them to finish if urgent",
|
||||
"3. Use view_agent_graph to monitor agent status",
|
||||
]
|
||||
)
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"message": "\n".join(message_parts),
|
||||
"active_agents": {
|
||||
"running": len(running_agents),
|
||||
"stopping": len(stopping_agents),
|
||||
"details": {
|
||||
"running": running_agents,
|
||||
"stopping": stopping_agents,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
except ImportError:
|
||||
import logging
|
||||
|
||||
logging.warning("Could not check agent graph status - agents_graph module unavailable")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _finalize_with_tracer(content: str, success: bool) -> dict[str, Any]:
|
||||
try:
|
||||
from strix.cli.tracer import get_global_tracer
|
||||
|
||||
tracer = get_global_tracer()
|
||||
if tracer:
|
||||
tracer.set_final_scan_result(
|
||||
content=content.strip(),
|
||||
success=success,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"scan_completed": True,
|
||||
"message": "Scan completed successfully"
|
||||
if success
|
||||
else "Scan completed with errors",
|
||||
"vulnerabilities_found": len(tracer.vulnerability_reports),
|
||||
}
|
||||
|
||||
import logging
|
||||
|
||||
logging.warning("Global tracer not available - final scan result not stored")
|
||||
|
||||
return { # noqa: TRY300
|
||||
"success": True,
|
||||
"scan_completed": True,
|
||||
"message": "Scan completed successfully (not persisted)"
|
||||
if success
|
||||
else "Scan completed with errors (not persisted)",
|
||||
"warning": "Final result could not be persisted - tracer unavailable",
|
||||
}
|
||||
|
||||
except ImportError:
|
||||
return {
|
||||
"success": True,
|
||||
"scan_completed": True,
|
||||
"message": "Scan completed successfully (not persisted)"
|
||||
if success
|
||||
else "Scan completed with errors (not persisted)",
|
||||
"warning": "Final result could not be persisted - tracer module unavailable",
|
||||
}
|
||||
|
||||
|
||||
@register_tool(sandbox_execution=False)
|
||||
def finish_scan(
|
||||
content: str,
|
||||
success: bool = True,
|
||||
agent_state: Any = None,
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
validation_error = _validate_root_agent(agent_state)
|
||||
if validation_error:
|
||||
return validation_error
|
||||
|
||||
validation_error = _validate_content(content)
|
||||
if validation_error:
|
||||
return validation_error
|
||||
|
||||
active_agents_error = _check_active_agents(agent_state)
|
||||
if active_agents_error:
|
||||
return active_agents_error
|
||||
|
||||
return _finalize_with_tracer(content, success)
|
||||
|
||||
except (ValueError, TypeError, KeyError) as e:
|
||||
return {"success": False, "message": f"Failed to complete scan: {e!s}"}
|
||||
45
strix/tools/finish/finish_actions_schema.xml
Normal file
45
strix/tools/finish/finish_actions_schema.xml
Normal file
@@ -0,0 +1,45 @@
|
||||
<tools>
|
||||
<tool name="finish_scan">
|
||||
<description>Complete the main security scan and generate final report.
|
||||
|
||||
IMPORTANT: This tool can ONLY be used by the root/main agent.
|
||||
Subagents must use agent_finish from agents_graph tool instead.
|
||||
|
||||
IMPORTANT: This tool will NOT allow finishing if any agents are still running or stopping.
|
||||
You must wait for all agents to complete before using this tool.
|
||||
|
||||
This tool MUST be called at the very end of the security assessment to:
|
||||
- Verify all agents have completed their tasks
|
||||
- Generate the final comprehensive scan report
|
||||
- Mark the entire scan as completed
|
||||
- Stop the agent from running
|
||||
|
||||
Use this tool when:
|
||||
- You are the main/root agent conducting the security assessment
|
||||
- ALL subagents have completed their tasks (no agents are "running" or "stopping")
|
||||
- You have completed all testing phases
|
||||
- You are ready to conclude the entire security assessment
|
||||
|
||||
IMPORTANT: Calling this tool multiple times will OVERWRITE any previous scan report.
|
||||
Make sure you include ALL findings and details in a single comprehensive report.
|
||||
|
||||
If agents are still running, this tool will:
|
||||
- Show you which agents are still active
|
||||
- Suggest using wait_for_message to wait for completion
|
||||
- Suggest messaging agents if immediate completion is needed
|
||||
|
||||
Put ALL details in the content - methodology, tools used, vulnerability counts, key findings, recommendations,
|
||||
compliance notes, risk assessments, next steps, etc. Be comprehensive and include everything relevant.</description>
|
||||
<parameters>
|
||||
<parameter name="content" type="string" required="true">
|
||||
<description>Complete scan report including executive summary, methodology, findings, vulnerability details, recommendations, compliance notes, risk assessment, and conclusions. Include everything relevant to the assessment.</description>
|
||||
</parameter>
|
||||
<parameter name="success" type="boolean" required="false">
|
||||
<description>Whether the scan completed successfully without critical errors</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
<returns type="Dict[str, Any]">
|
||||
<description>Response containing success status and completion message. If agents are still running, returns details about active agents and suggested actions.</description>
|
||||
</returns>
|
||||
</tool>
|
||||
</tools>
|
||||
14
strix/tools/notes/__init__.py
Normal file
14
strix/tools/notes/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from .notes_actions import (
|
||||
create_note,
|
||||
delete_note,
|
||||
list_notes,
|
||||
update_note,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"create_note",
|
||||
"delete_note",
|
||||
"list_notes",
|
||||
"update_note",
|
||||
]
|
||||
191
strix/tools/notes/notes_actions.py
Normal file
191
strix/tools/notes/notes_actions.py
Normal file
@@ -0,0 +1,191 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from strix.tools.registry import register_tool
|
||||
|
||||
|
||||
_notes_storage: dict[str, dict[str, Any]] = {}
|
||||
|
||||
|
||||
def _filter_notes(
|
||||
category: str | None = None,
|
||||
tags: list[str] | None = None,
|
||||
priority: str | None = None,
|
||||
search_query: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
filtered_notes = []
|
||||
|
||||
for note_id, note in _notes_storage.items():
|
||||
if category and note.get("category") != category:
|
||||
continue
|
||||
|
||||
if priority and note.get("priority") != priority:
|
||||
continue
|
||||
|
||||
if tags:
|
||||
note_tags = note.get("tags", [])
|
||||
if not any(tag in note_tags for tag in tags):
|
||||
continue
|
||||
|
||||
if search_query:
|
||||
search_lower = search_query.lower()
|
||||
title_match = search_lower in note.get("title", "").lower()
|
||||
content_match = search_lower in note.get("content", "").lower()
|
||||
if not (title_match or content_match):
|
||||
continue
|
||||
|
||||
note_with_id = note.copy()
|
||||
note_with_id["note_id"] = note_id
|
||||
filtered_notes.append(note_with_id)
|
||||
|
||||
filtered_notes.sort(key=lambda x: x.get("created_at", ""), reverse=True)
|
||||
return filtered_notes
|
||||
|
||||
|
||||
@register_tool
|
||||
def create_note(
|
||||
title: str,
|
||||
content: str,
|
||||
category: str = "general",
|
||||
tags: list[str] | None = None,
|
||||
priority: str = "normal",
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
if not title or not title.strip():
|
||||
return {"success": False, "error": "Title cannot be empty", "note_id": None}
|
||||
|
||||
if not content or not content.strip():
|
||||
return {"success": False, "error": "Content cannot be empty", "note_id": None}
|
||||
|
||||
valid_categories = ["general", "findings", "methodology", "todo", "questions", "plan"]
|
||||
if category not in valid_categories:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Invalid category. Must be one of: {', '.join(valid_categories)}",
|
||||
"note_id": None,
|
||||
}
|
||||
|
||||
valid_priorities = ["low", "normal", "high", "urgent"]
|
||||
if priority not in valid_priorities:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Invalid priority. Must be one of: {', '.join(valid_priorities)}",
|
||||
"note_id": None,
|
||||
}
|
||||
|
||||
note_id = str(uuid.uuid4())[:5]
|
||||
timestamp = datetime.now(UTC).isoformat()
|
||||
|
||||
note = {
|
||||
"title": title.strip(),
|
||||
"content": content.strip(),
|
||||
"category": category,
|
||||
"tags": tags or [],
|
||||
"priority": priority,
|
||||
"created_at": timestamp,
|
||||
"updated_at": timestamp,
|
||||
}
|
||||
|
||||
_notes_storage[note_id] = note
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
return {"success": False, "error": f"Failed to create note: {e}", "note_id": None}
|
||||
else:
|
||||
return {
|
||||
"success": True,
|
||||
"note_id": note_id,
|
||||
"message": f"Note '{title}' created successfully",
|
||||
}
|
||||
|
||||
|
||||
@register_tool
|
||||
def list_notes(
|
||||
category: str | None = None,
|
||||
tags: list[str] | None = None,
|
||||
priority: str | None = None,
|
||||
search: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
filtered_notes = _filter_notes(
|
||||
category=category, tags=tags, priority=priority, search_query=search
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"notes": filtered_notes,
|
||||
"total_count": len(filtered_notes),
|
||||
}
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Failed to list notes: {e}",
|
||||
"notes": [],
|
||||
"total_count": 0,
|
||||
}
|
||||
|
||||
|
||||
@register_tool
|
||||
def update_note(
|
||||
note_id: str,
|
||||
title: str | None = None,
|
||||
content: str | None = None,
|
||||
tags: list[str] | None = None,
|
||||
priority: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
if note_id not in _notes_storage:
|
||||
return {"success": False, "error": f"Note with ID '{note_id}' not found"}
|
||||
|
||||
note = _notes_storage[note_id]
|
||||
|
||||
if title is not None:
|
||||
if not title.strip():
|
||||
return {"success": False, "error": "Title cannot be empty"}
|
||||
note["title"] = title.strip()
|
||||
|
||||
if content is not None:
|
||||
if not content.strip():
|
||||
return {"success": False, "error": "Content cannot be empty"}
|
||||
note["content"] = content.strip()
|
||||
|
||||
if tags is not None:
|
||||
note["tags"] = tags
|
||||
|
||||
if priority is not None:
|
||||
valid_priorities = ["low", "normal", "high", "urgent"]
|
||||
if priority not in valid_priorities:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Invalid priority. Must be one of: {', '.join(valid_priorities)}",
|
||||
}
|
||||
note["priority"] = priority
|
||||
|
||||
note["updated_at"] = datetime.now(UTC).isoformat()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Note '{note['title']}' updated successfully",
|
||||
}
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
return {"success": False, "error": f"Failed to update note: {e}"}
|
||||
|
||||
|
||||
@register_tool
|
||||
def delete_note(note_id: str) -> dict[str, Any]:
|
||||
try:
|
||||
if note_id not in _notes_storage:
|
||||
return {"success": False, "error": f"Note with ID '{note_id}' not found"}
|
||||
|
||||
note_title = _notes_storage[note_id]["title"]
|
||||
del _notes_storage[note_id]
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
return {"success": False, "error": f"Failed to delete note: {e}"}
|
||||
else:
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Note '{note_title}' deleted successfully",
|
||||
}
|
||||
150
strix/tools/notes/notes_actions_schema.xml
Normal file
150
strix/tools/notes/notes_actions_schema.xml
Normal file
@@ -0,0 +1,150 @@
|
||||
<tools>
|
||||
<tool name="create_note">
|
||||
<description>Create a personal note for TODOs, side notes, plans, and organizational purposes during
|
||||
the scan.</description>
|
||||
<details>Use this tool for quick reminders, action items, planning thoughts, and organizational notes
|
||||
rather than formal vulnerability reports or detailed findings. This is your personal notepad
|
||||
for keeping track of tasks, ideas, and things to remember or follow up on.</details>
|
||||
<parameters>
|
||||
<parameter name="title" type="string" required="true">
|
||||
<description>Title of the note</description>
|
||||
</parameter>
|
||||
<parameter name="content" type="string" required="true">
|
||||
<description>Content of the note</description>
|
||||
</parameter>
|
||||
<parameter name="category" type="string" required="false">
|
||||
<description>Category to organize the note (default: "general", "findings", "methodology", "todo", "questions", "plan")</description>
|
||||
</parameter>
|
||||
<parameter name="tags" type="string" required="false">
|
||||
<description>Tags for categorization</description>
|
||||
</parameter>
|
||||
<parameter name="priority" type="string" required="false">
|
||||
<description>Priority level of the note ("low", "normal", "high", "urgent")</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
<returns type="Dict[str, Any]">
|
||||
<description>Response containing: - note_id: ID of the created note - success: Whether the note was created successfully</description>
|
||||
</returns>
|
||||
<examples>
|
||||
# Create a TODO reminder
|
||||
<function=create_note>
|
||||
<parameter=title>TODO: Check SSL Certificate Details</parameter>
|
||||
<parameter=content>Remember to verify SSL certificate validity and check for weak ciphers
|
||||
on the HTTPS service discovered on port 443. Also check for certificate
|
||||
transparency logs.</parameter>
|
||||
<parameter=category>todo</parameter>
|
||||
<parameter=tags>["ssl", "certificate", "followup"]</parameter>
|
||||
<parameter=priority>normal</parameter>
|
||||
</function>
|
||||
|
||||
# Planning note
|
||||
<function=create_note>
|
||||
<parameter=title>Scan Strategy Planning</parameter>
|
||||
<parameter=content>Plan for next phase: 1) Complete subdomain enumeration 2) Test discovered
|
||||
web apps for OWASP Top 10 3) Check database services for default creds
|
||||
4) Review any custom applications for business logic flaws</parameter>
|
||||
<parameter=category>plan</parameter>
|
||||
<parameter=tags>["planning", "strategy", "next_steps"]</parameter>
|
||||
</function>
|
||||
|
||||
# Side note for later investigation
|
||||
<function=create_note>
|
||||
<parameter=title>Interesting Directory Found</parameter>
|
||||
<parameter=content>Found /backup/ directory that might contain sensitive files. Low priority
|
||||
for now but worth checking if time permits. Directory listing seems
|
||||
disabled.</parameter>
|
||||
<parameter=category>findings</parameter>
|
||||
<parameter=tags>["directory", "backup", "low_priority"]</parameter>
|
||||
<parameter=priority>low</parameter>
|
||||
</function>
|
||||
</examples>
|
||||
</tool>
|
||||
<tool name="delete_note">
|
||||
<description>Delete a note.</description>
|
||||
<parameters>
|
||||
<parameter name="note_id" type="string" required="true">
|
||||
<description>ID of the note to delete</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
<returns type="Dict[str, Any]">
|
||||
<description>Response containing: - success: Whether the note was deleted successfully</description>
|
||||
</returns>
|
||||
<examples>
|
||||
<function=delete_note>
|
||||
<parameter=note_id>note_123</parameter>
|
||||
</function>
|
||||
</examples>
|
||||
</tool>
|
||||
<tool name="list_notes">
|
||||
<description>List existing notes with optional filtering and search.</description>
|
||||
<parameters>
|
||||
<parameter name="category" type="string" required="false">
|
||||
<description>Filter by category</description>
|
||||
</parameter>
|
||||
<parameter name="tags" type="string" required="false">
|
||||
<description>Filter by tags (returns notes with any of these tags)</description>
|
||||
</parameter>
|
||||
<parameter name="priority" type="string" required="false">
|
||||
<description>Filter by priority level</description>
|
||||
</parameter>
|
||||
<parameter name="search" type="string" required="false">
|
||||
<description>Search query to find in note titles and content</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
<returns type="Dict[str, Any]">
|
||||
<description>Response containing: - notes: List of matching notes - total_count: Total number of notes found</description>
|
||||
</returns>
|
||||
<examples>
|
||||
# List all findings
|
||||
<function=list_notes>
|
||||
<parameter=category>findings</parameter>
|
||||
</function>
|
||||
|
||||
# List high priority items
|
||||
<function=list_notes>
|
||||
<parameter=priority>high</parameter>
|
||||
</function>
|
||||
|
||||
# Search for SQL injection related notes
|
||||
<function=list_notes>
|
||||
<parameter=search>SQL injection</parameter>
|
||||
</function>
|
||||
|
||||
# Search within a specific category
|
||||
<function=list_notes>
|
||||
<parameter=search>admin</parameter>
|
||||
<parameter=category>findings</parameter>
|
||||
</function>
|
||||
</examples>
|
||||
</tool>
|
||||
<tool name="update_note">
|
||||
<description>Update an existing note.</description>
|
||||
<parameters>
|
||||
<parameter name="note_id" type="string" required="true">
|
||||
<description>ID of the note to update</description>
|
||||
</parameter>
|
||||
<parameter name="title" type="string" required="false">
|
||||
<description>New title for the note</description>
|
||||
</parameter>
|
||||
<parameter name="content" type="string" required="false">
|
||||
<description>New content for the note</description>
|
||||
</parameter>
|
||||
<parameter name="tags" type="string" required="false">
|
||||
<description>New tags for the note</description>
|
||||
</parameter>
|
||||
<parameter name="priority" type="string" required="false">
|
||||
<description>New priority level</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
<returns type="Dict[str, Any]">
|
||||
<description>Response containing: - success: Whether the note was updated successfully</description>
|
||||
</returns>
|
||||
<examples>
|
||||
<function=update_note>
|
||||
<parameter=note_id>note_123</parameter>
|
||||
<parameter=content>Updated content with new findings...</parameter>
|
||||
<parameter=priority>urgent</parameter>
|
||||
</function>
|
||||
</examples>
|
||||
</tool>
|
||||
</tools>
|
||||
20
strix/tools/proxy/__init__.py
Normal file
20
strix/tools/proxy/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from .proxy_actions import (
|
||||
list_requests,
|
||||
list_sitemap,
|
||||
repeat_request,
|
||||
scope_rules,
|
||||
send_request,
|
||||
view_request,
|
||||
view_sitemap_entry,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"list_requests",
|
||||
"list_sitemap",
|
||||
"repeat_request",
|
||||
"scope_rules",
|
||||
"send_request",
|
||||
"view_request",
|
||||
"view_sitemap_entry",
|
||||
]
|
||||
101
strix/tools/proxy/proxy_actions.py
Normal file
101
strix/tools/proxy/proxy_actions.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
from strix.tools.registry import register_tool
|
||||
|
||||
from .proxy_manager import get_proxy_manager
|
||||
|
||||
|
||||
RequestPart = Literal["request", "response"]
|
||||
|
||||
|
||||
@register_tool
|
||||
def list_requests(
|
||||
httpql_filter: str | None = None,
|
||||
start_page: int = 1,
|
||||
end_page: int = 1,
|
||||
page_size: int = 50,
|
||||
sort_by: Literal[
|
||||
"timestamp",
|
||||
"host",
|
||||
"method",
|
||||
"path",
|
||||
"status_code",
|
||||
"response_time",
|
||||
"response_size",
|
||||
"source",
|
||||
] = "timestamp",
|
||||
sort_order: Literal["asc", "desc"] = "desc",
|
||||
scope_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
manager = get_proxy_manager()
|
||||
return manager.list_requests(
|
||||
httpql_filter, start_page, end_page, page_size, sort_by, sort_order, scope_id
|
||||
)
|
||||
|
||||
|
||||
@register_tool
|
||||
def view_request(
|
||||
request_id: str,
|
||||
part: RequestPart = "request",
|
||||
search_pattern: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> dict[str, Any]:
|
||||
manager = get_proxy_manager()
|
||||
return manager.view_request(request_id, part, search_pattern, page, page_size)
|
||||
|
||||
|
||||
@register_tool
|
||||
def send_request(
|
||||
method: str,
|
||||
url: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
body: str = "",
|
||||
timeout: int = 30,
|
||||
) -> dict[str, Any]:
|
||||
if headers is None:
|
||||
headers = {}
|
||||
manager = get_proxy_manager()
|
||||
return manager.send_simple_request(method, url, headers, body, timeout)
|
||||
|
||||
|
||||
@register_tool
|
||||
def repeat_request(
|
||||
request_id: str,
|
||||
modifications: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
if modifications is None:
|
||||
modifications = {}
|
||||
manager = get_proxy_manager()
|
||||
return manager.repeat_request(request_id, modifications)
|
||||
|
||||
|
||||
@register_tool
|
||||
def scope_rules(
|
||||
action: Literal["get", "list", "create", "update", "delete"],
|
||||
allowlist: list[str] | None = None,
|
||||
denylist: list[str] | None = None,
|
||||
scope_id: str | None = None,
|
||||
scope_name: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
manager = get_proxy_manager()
|
||||
return manager.scope_rules(action, allowlist, denylist, scope_id, scope_name)
|
||||
|
||||
|
||||
@register_tool
|
||||
def list_sitemap(
|
||||
scope_id: str | None = None,
|
||||
parent_id: str | None = None,
|
||||
depth: Literal["DIRECT", "ALL"] = "DIRECT",
|
||||
page: int = 1,
|
||||
) -> dict[str, Any]:
|
||||
manager = get_proxy_manager()
|
||||
return manager.list_sitemap(scope_id, parent_id, depth, page)
|
||||
|
||||
|
||||
@register_tool
|
||||
def view_sitemap_entry(
|
||||
entry_id: str,
|
||||
) -> dict[str, Any]:
|
||||
manager = get_proxy_manager()
|
||||
return manager.view_sitemap_entry(entry_id)
|
||||
267
strix/tools/proxy/proxy_actions_schema.xml
Normal file
267
strix/tools/proxy/proxy_actions_schema.xml
Normal file
@@ -0,0 +1,267 @@
|
||||
<?xml version="1.0" ?>
|
||||
<tools>
|
||||
<tool name="list_requests">
|
||||
<description>List and filter proxy requests using HTTPQL with pagination.</description>
|
||||
<parameters>
|
||||
<parameter name="httpql_filter" type="string" required="false">
|
||||
<description>HTTPQL filter using Caido's syntax:
|
||||
|
||||
Integer fields (port, code, roundtrip, id) - eq, gt, gte, lt, lte, ne:
|
||||
- resp.code.eq:200, resp.code.gte:400, req.port.eq:443
|
||||
|
||||
Text/byte fields (ext, host, method, path, query, raw) - regex:
|
||||
- req.method.regex:"POST", req.path.regex:"/api/.*", req.host.regex:".*.com"
|
||||
|
||||
Date fields (created_at) - gt, lt with ISO formats:
|
||||
- req.created_at.gt:"2024-01-01T00:00:00Z"
|
||||
|
||||
Special: source:intercept, preset:"name"</description>
|
||||
</parameter>
|
||||
<parameter name="start_page" type="integer" required="false">
|
||||
<description>Starting page (1-based)</description>
|
||||
</parameter>
|
||||
<parameter name="end_page" type="integer" required="false">
|
||||
<description>Ending page (1-based, inclusive)</description>
|
||||
</parameter>
|
||||
<parameter name="page_size" type="integer" required="false">
|
||||
<description>Requests per page</description>
|
||||
</parameter>
|
||||
<parameter name="sort_by" type="string" required="false">
|
||||
<description>Sort field from: "timestamp", "host", "status_code", "response_time", "response_size"</description>
|
||||
</parameter>
|
||||
<parameter name="sort_order" type="string" required="false">
|
||||
<description>Sort direction ("asc" or "desc")</description>
|
||||
</parameter>
|
||||
<parameter name="scope_id" type="string" required="false">
|
||||
<description>Scope ID to filter requests (use scope_rules to manage scopes)</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
<returns type="Dict[str, Any]">
|
||||
<description>Response containing:
|
||||
- 'requests': Request objects for page range
|
||||
- 'total_count': Total matching requests
|
||||
- 'start_page', 'end_page', 'page_size': Query parameters
|
||||
- 'returned_count': Requests in response</description>
|
||||
</returns>
|
||||
<examples>
|
||||
# POST requests to API with 200 responses
|
||||
<function=list_requests>
|
||||
<parameter=httpql_filter>req.method.eq:"POST" AND req.path.cont:"/api/"</parameter>
|
||||
<parameter=sort_by>response_time</parameter>
|
||||
<parameter=scope_id>scope123</parameter>
|
||||
</function>
|
||||
|
||||
# Requests within specific scope
|
||||
<function=list_requests>
|
||||
<parameter=scope_id>scope123</parameter>
|
||||
<parameter=sort_by>timestamp</parameter>
|
||||
</function>
|
||||
</examples>
|
||||
</tool>
|
||||
|
||||
<tool name="view_request">
|
||||
<description>View request/response data with search and pagination.</description>
|
||||
<parameters>
|
||||
<parameter name="request_id" type="string" required="true">
|
||||
<description>Request ID</description>
|
||||
</parameter>
|
||||
<parameter name="part" type="string" required="false">
|
||||
<description>Which part to return ("request" or "response")</description>
|
||||
</parameter>
|
||||
<parameter name="search_pattern" type="string" required="false">
|
||||
<description>Regex pattern to search content. Common patterns:
|
||||
- API endpoints: r"/api/[a-zA-Z0-9._/-]+"
|
||||
- URLs: r"https?://[^\\s<>"\']+"
|
||||
- Parameters: r'[?&][a-zA-Z0-9_]+=([^&\\s<>"\']+)'
|
||||
- Reflections: input_value in content</description>
|
||||
</parameter>
|
||||
<parameter name="page" type="integer" required="false">
|
||||
<description>Page number for pagination</description>
|
||||
</parameter>
|
||||
<parameter name="page_size" type="integer" required="false">
|
||||
<description>Lines per page</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
<returns type="Dict[str, Any]">
|
||||
<description>With search_pattern (COMPACT):
|
||||
- 'matches': [{match, before, after, position}] - max 20
|
||||
- 'total_matches': Total found
|
||||
- 'truncated': If limited to 20
|
||||
|
||||
Without search_pattern (PAGINATION):
|
||||
- 'content': Page content
|
||||
- 'page': Current page
|
||||
- 'showing_lines': Range display
|
||||
- 'has_more': More pages available</description>
|
||||
</returns>
|
||||
<examples>
|
||||
# Find API endpoints in response
|
||||
<function=view_request>
|
||||
<parameter=request_id>123</parameter>
|
||||
<parameter=part>response</parameter>
|
||||
<parameter=search_pattern>/api/[a-zA-Z0-9._/-]+</parameter>
|
||||
</function>
|
||||
</examples>
|
||||
</tool>
|
||||
|
||||
<tool name="send_request">
|
||||
<description>Send a simple HTTP request through proxy.</description>
|
||||
<parameters>
|
||||
<parameter name="method" type="string" required="true">
|
||||
<description>HTTP method (GET, POST, etc.)</description>
|
||||
</parameter>
|
||||
<parameter name="url" type="string" required="true">
|
||||
<description>Target URL</description>
|
||||
</parameter>
|
||||
<parameter name="headers" type="dict" required="false">
|
||||
<description>Headers as {"key": "value"}</description>
|
||||
</parameter>
|
||||
<parameter name="body" type="string" required="false">
|
||||
<description>Request body</description>
|
||||
</parameter>
|
||||
<parameter name="timeout" type="integer" required="false">
|
||||
<description>Request timeout</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
</tool>
|
||||
|
||||
<tool name="repeat_request">
|
||||
<description>Repeat an existing proxy request with modifications for pentesting.
|
||||
|
||||
PROPER WORKFLOW:
|
||||
1. Use browser_action to browse the target application
|
||||
2. Use list_requests() to see captured proxy traffic
|
||||
3. Use repeat_request() to modify and test specific requests
|
||||
|
||||
This mirrors real pentesting: browse → capture → modify → test</description>
|
||||
<parameters>
|
||||
<parameter name="request_id" type="string" required="true">
|
||||
<description>ID of the original request to repeat (from list_requests)</description>
|
||||
</parameter>
|
||||
<parameter name="modifications" type="dict" required="false">
|
||||
<description>Changes to apply to the original request:
|
||||
- "url": New URL or modify existing one
|
||||
- "params": Dict to update query parameters
|
||||
- "headers": Dict to add/update headers
|
||||
- "body": New request body (replaces original)
|
||||
- "cookies": Dict to add/update cookies</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
<returns type="Dict[str, Any]">
|
||||
<description>Response data with status, headers, body, timing, and request details</description>
|
||||
</returns>
|
||||
<examples>
|
||||
# Modify POST body payload
|
||||
<function=repeat_request>
|
||||
<parameter=request_id>req_789</parameter>
|
||||
<parameter=modifications>{"body": "{\"username\":\"admin\",\"password\":\"admin\"}"}</parameter>
|
||||
</function>
|
||||
</examples>
|
||||
</tool>
|
||||
|
||||
<tool name="scope_rules">
|
||||
<description>Manage proxy scope patterns for domain/file filtering using Caido's scope system.</description>
|
||||
<parameters>
|
||||
<parameter name="action" type="string" required="true">
|
||||
<description>Scope action:
|
||||
- get: Get specific scope by ID or list all if no ID
|
||||
- update: Update existing scope (requires scope_id and scope_name)
|
||||
- list: List all available scopes
|
||||
- create: Create new scope (requires scope_name)
|
||||
- delete: Delete scope (requires scope_id)</description>
|
||||
</parameter>
|
||||
<parameter name="allowlist" type="list" required="false">
|
||||
<description>Domain patterns to include. Examples: ["*.example.com", "api.test.com"]</description>
|
||||
</parameter>
|
||||
<parameter name="denylist" type="list" required="false">
|
||||
<description>Patterns to exclude. Some common extensions:
|
||||
["*.gif", "*.jpg", "*.png", "*.css", "*.js", "*.ico", "*.svg", "*woff*", "*.ttf"]</description>
|
||||
</parameter>
|
||||
<parameter name="scope_id" type="string" required="false">
|
||||
<description>Specific scope ID to operate on (required for get, update, delete)</description>
|
||||
</parameter>
|
||||
<parameter name="scope_name" type="string" required="false">
|
||||
<description>Name for scope (required for create, update)</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
<returns type="Dict[str, Any]">
|
||||
<description>Depending on action:
|
||||
- get: Single scope object or error
|
||||
- list: {"scopes": [...], "count": N}
|
||||
- create/update: {"scope": {...}, "message": "..."}
|
||||
- delete: {"message": "...", "deletedId": "..."}</description>
|
||||
</returns>
|
||||
<notes>
|
||||
- Empty allowlist = allow all domains
|
||||
- Denylist overrides allowlist
|
||||
- Glob patterns: * (any), ? (single), [abc] (one of), [a-z] (range), [^abc] (none of)
|
||||
- Each scope has unique ID and can be used with list_requests(scopeId=...)
|
||||
</notes>
|
||||
<examples>
|
||||
# Create API-only scope
|
||||
<function=scope_rules>
|
||||
<parameter=action>create</parameter>
|
||||
<parameter=scope_name>API Testing</parameter>
|
||||
<parameter=allowlist>["api.example.com", "*.api.com"]</parameter>
|
||||
<parameter=denylist>["*.gif", "*.jpg", "*.png", "*.css", "*.js"]</parameter>
|
||||
</function>
|
||||
</examples>
|
||||
</tool>
|
||||
|
||||
<tool name="list_sitemap">
|
||||
<description>View hierarchical sitemap of discovered attack surface from proxied traffic.
|
||||
|
||||
Perfect for bug hunters to understand the application structure and identify
|
||||
interesting endpoints, directories, and entry points discovered during testing.</description>
|
||||
<parameters>
|
||||
<parameter name="scope_id" type="string" required="false">
|
||||
<description>Scope ID to filter sitemap entries (use scope_rules to get/create scope IDs)</description>
|
||||
</parameter>
|
||||
<parameter name="parent_id" type="string" required="false">
|
||||
<description>ID of parent entry to expand. If None, returns root domains.</description>
|
||||
</parameter>
|
||||
<parameter name="depth" type="string" required="false">
|
||||
<description>DIRECT: Only immediate children. ALL: All descendants recursively.</description>
|
||||
</parameter>
|
||||
<parameter name="page" type="integer" required="false">
|
||||
<description>Page number for pagination (30 entries per page)</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
<returns type="Dict[str, Any]">
|
||||
<description>Response containing:
|
||||
- 'entries': List of cleaned sitemap entries
|
||||
- 'page', 'total_pages', 'total_count': Pagination info
|
||||
- 'has_more': Whether more pages available
|
||||
- Each entry: id, kind, label, hasDescendants, request (method/path/status only)</description>
|
||||
</returns>
|
||||
<notes>
|
||||
Entry kinds:
|
||||
- DOMAIN: Root domains (example.com)
|
||||
- DIRECTORY: Path directories (/api/, /admin/)
|
||||
- REQUEST: Individual endpoints
|
||||
- REQUEST_BODY: POST/PUT body variations
|
||||
- REQUEST_QUERY: GET parameter variations
|
||||
|
||||
Check hasDescendants=true to identify entries worth expanding.
|
||||
Use parent_id from any entry to drill down into subdirectories.
|
||||
</notes>
|
||||
</tool>
|
||||
|
||||
<tool name="view_sitemap_entry">
|
||||
<description>Get detailed information about a specific sitemap entry and related requests.
|
||||
|
||||
Perfect for understanding what's been discovered under a specific directory
|
||||
or endpoint, including all related requests and response codes.</description>
|
||||
<parameters>
|
||||
<parameter name="entry_id" type="string" required="true">
|
||||
<description>ID of the sitemap entry to examine</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
<returns type="Dict[str, Any]">
|
||||
<description>Response containing:
|
||||
- 'entry': Complete entry details including metadata
|
||||
- Entry contains 'requests' with all related HTTP requests
|
||||
- Shows request methods, paths, response codes, timing</description>
|
||||
</returns>
|
||||
</tool>
|
||||
</tools>
|
||||
785
strix/tools/proxy/proxy_manager.py
Normal file
785
strix/tools/proxy/proxy_manager.py
Normal file
@@ -0,0 +1,785 @@
|
||||
import base64
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
|
||||
import requests
|
||||
from gql import Client, gql
|
||||
from gql.transport.exceptions import TransportQueryError
|
||||
from gql.transport.requests import RequestsHTTPTransport
|
||||
from requests.exceptions import ProxyError, RequestException, Timeout
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
class ProxyManager:
|
||||
def __init__(self, auth_token: str | None = None):
|
||||
host = "127.0.0.1"
|
||||
port = os.getenv("CAIDO_PORT", "56789")
|
||||
self.base_url = f"http://{host}:{port}/graphql"
|
||||
self.proxies = {"http": f"http://{host}:{port}", "https": f"http://{host}:{port}"}
|
||||
self.auth_token = auth_token or os.getenv("CAIDO_API_TOKEN")
|
||||
self.transport = RequestsHTTPTransport(
|
||||
url=self.base_url, headers={"Authorization": f"Bearer {self.auth_token}"}
|
||||
)
|
||||
self.client = Client(transport=self.transport, fetch_schema_from_transport=False)
|
||||
|
||||
def list_requests(
|
||||
self,
|
||||
httpql_filter: str | None = None,
|
||||
start_page: int = 1,
|
||||
end_page: int = 1,
|
||||
page_size: int = 50,
|
||||
sort_by: str = "timestamp",
|
||||
sort_order: str = "desc",
|
||||
scope_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
offset = (start_page - 1) * page_size
|
||||
limit = (end_page - start_page + 1) * page_size
|
||||
|
||||
sort_mapping = {
|
||||
"timestamp": "CREATED_AT",
|
||||
"host": "HOST",
|
||||
"method": "METHOD",
|
||||
"path": "PATH",
|
||||
"status_code": "RESP_STATUS_CODE",
|
||||
"response_time": "RESP_ROUNDTRIP_TIME",
|
||||
"response_size": "RESP_LENGTH",
|
||||
"source": "SOURCE",
|
||||
}
|
||||
|
||||
query = gql("""
|
||||
query GetRequests(
|
||||
$limit: Int, $offset: Int, $filter: HTTPQL,
|
||||
$order: RequestResponseOrderInput, $scopeId: ID
|
||||
) {
|
||||
requestsByOffset(
|
||||
limit: $limit, offset: $offset, filter: $filter,
|
||||
order: $order, scopeId: $scopeId
|
||||
) {
|
||||
edges {
|
||||
node {
|
||||
id method host path query createdAt length isTls port
|
||||
source alteration fileExtension
|
||||
response { id statusCode length roundtripTime createdAt }
|
||||
}
|
||||
}
|
||||
count { value }
|
||||
}
|
||||
}
|
||||
""")
|
||||
|
||||
variables = {
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"filter": httpql_filter,
|
||||
"order": {
|
||||
"by": sort_mapping.get(sort_by, "CREATED_AT"),
|
||||
"ordering": sort_order.upper(),
|
||||
},
|
||||
"scopeId": scope_id,
|
||||
}
|
||||
|
||||
try:
|
||||
result = self.client.execute(query, variable_values=variables)
|
||||
data = result.get("requestsByOffset", {})
|
||||
nodes = [edge["node"] for edge in data.get("edges", [])]
|
||||
|
||||
count_data = data.get("count") or {}
|
||||
return {
|
||||
"requests": nodes,
|
||||
"total_count": count_data.get("value", 0),
|
||||
"start_page": start_page,
|
||||
"end_page": end_page,
|
||||
"page_size": page_size,
|
||||
"offset": offset,
|
||||
"returned_count": len(nodes),
|
||||
"sort_by": sort_by,
|
||||
"sort_order": sort_order,
|
||||
}
|
||||
except (TransportQueryError, ValueError, KeyError) as e:
|
||||
return {"requests": [], "total_count": 0, "error": f"Error fetching requests: {e}"}
|
||||
|
||||
def view_request(
|
||||
self,
|
||||
request_id: str,
|
||||
part: str = "request",
|
||||
search_pattern: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> dict[str, Any]:
|
||||
queries = {
|
||||
"request": """query GetRequest($id: ID!) {
|
||||
request(id: $id) {
|
||||
id method host path query createdAt length isTls port
|
||||
source alteration edited raw
|
||||
}
|
||||
}""",
|
||||
"response": """query GetRequest($id: ID!) {
|
||||
request(id: $id) {
|
||||
id response {
|
||||
id statusCode length roundtripTime createdAt raw
|
||||
}
|
||||
}
|
||||
}""",
|
||||
}
|
||||
|
||||
if part not in queries:
|
||||
return {"error": f"Invalid part '{part}'. Use 'request' or 'response'"}
|
||||
|
||||
try:
|
||||
result = self.client.execute(gql(queries[part]), variable_values={"id": request_id})
|
||||
request_data = result.get("request", {})
|
||||
|
||||
if not request_data:
|
||||
return {"error": f"Request {request_id} not found"}
|
||||
|
||||
if part == "request":
|
||||
raw_content = request_data.get("raw")
|
||||
else:
|
||||
response_data = request_data.get("response") or {}
|
||||
raw_content = response_data.get("raw")
|
||||
|
||||
if not raw_content:
|
||||
return {"error": "No content available"}
|
||||
|
||||
content = base64.b64decode(raw_content).decode("utf-8", errors="replace")
|
||||
|
||||
if part == "response":
|
||||
request_data["response"]["raw"] = content
|
||||
else:
|
||||
request_data["raw"] = content
|
||||
|
||||
return (
|
||||
self._search_content(request_data, content, search_pattern)
|
||||
if search_pattern
|
||||
else self._paginate_content(request_data, content, page, page_size)
|
||||
)
|
||||
|
||||
except (TransportQueryError, ValueError, KeyError, UnicodeDecodeError) as e:
|
||||
return {"error": f"Failed to view request: {e}"}
|
||||
|
||||
def _search_content(
|
||||
self, request_data: dict[str, Any], content: str, pattern: str
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
regex = re.compile(pattern, re.IGNORECASE | re.MULTILINE | re.DOTALL)
|
||||
matches = []
|
||||
|
||||
for match in regex.finditer(content):
|
||||
start, end = match.start(), match.end()
|
||||
context_size = 120
|
||||
|
||||
before = re.sub(r"\s+", " ", content[max(0, start - context_size) : start].strip())[
|
||||
-100:
|
||||
]
|
||||
after = re.sub(r"\s+", " ", content[end : end + context_size].strip())[:100]
|
||||
|
||||
matches.append(
|
||||
{"match": match.group(), "before": before, "after": after, "position": start}
|
||||
)
|
||||
|
||||
if len(matches) >= 20:
|
||||
break
|
||||
|
||||
return {
|
||||
"id": request_data.get("id"),
|
||||
"matches": matches,
|
||||
"total_matches": len(matches),
|
||||
"search_pattern": pattern,
|
||||
"truncated": len(matches) >= 20,
|
||||
}
|
||||
except re.error as e:
|
||||
return {"error": f"Invalid regex: {e}"}
|
||||
|
||||
def _paginate_content(
|
||||
self, request_data: dict[str, Any], content: str, page: int, page_size: int
|
||||
) -> dict[str, Any]:
|
||||
display_lines = []
|
||||
for line in content.split("\n"):
|
||||
if len(line) <= 80:
|
||||
display_lines.append(line)
|
||||
else:
|
||||
display_lines.extend(
|
||||
[
|
||||
line[i : i + 80] + (" \\" if i + 80 < len(line) else "")
|
||||
for i in range(0, len(line), 80)
|
||||
]
|
||||
)
|
||||
|
||||
total_lines = len(display_lines)
|
||||
total_pages = (total_lines + page_size - 1) // page_size
|
||||
page = max(1, min(page, total_pages))
|
||||
|
||||
start_line = (page - 1) * page_size
|
||||
end_line = min(total_lines, start_line + page_size)
|
||||
|
||||
return {
|
||||
"id": request_data.get("id"),
|
||||
"content": "\n".join(display_lines[start_line:end_line]),
|
||||
"page": page,
|
||||
"total_pages": total_pages,
|
||||
"showing_lines": f"{start_line + 1}-{end_line} of {total_lines}",
|
||||
"has_more": page < total_pages,
|
||||
}
|
||||
|
||||
def send_simple_request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
body: str = "",
|
||||
timeout: int = 30,
|
||||
) -> dict[str, Any]:
|
||||
if headers is None:
|
||||
headers = {}
|
||||
try:
|
||||
start_time = time.time()
|
||||
response = requests.request(
|
||||
method=method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
data=body or None,
|
||||
proxies=self.proxies,
|
||||
timeout=timeout,
|
||||
verify=False,
|
||||
)
|
||||
response_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
body_content = response.text
|
||||
if len(body_content) > 10000:
|
||||
body_content = body_content[:10000] + "\n... [truncated]"
|
||||
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"headers": dict(response.headers),
|
||||
"body": body_content,
|
||||
"response_time_ms": response_time,
|
||||
"url": response.url,
|
||||
"message": (
|
||||
"Request sent through proxy - check list_requests() for captured traffic"
|
||||
),
|
||||
}
|
||||
except (RequestException, ProxyError, Timeout) as e:
|
||||
return {"error": f"Request failed: {type(e).__name__}", "details": str(e), "url": url}
|
||||
|
||||
def repeat_request(
|
||||
self, request_id: str, modifications: dict[str, Any] | None = None
|
||||
) -> dict[str, Any]:
|
||||
if modifications is None:
|
||||
modifications = {}
|
||||
|
||||
original = self.view_request(request_id, "request")
|
||||
if "error" in original:
|
||||
return {"error": f"Could not retrieve original request: {original['error']}"}
|
||||
|
||||
raw_content = original.get("content", "")
|
||||
if not raw_content:
|
||||
return {"error": "No raw request content found"}
|
||||
|
||||
request_components = self._parse_http_request(raw_content)
|
||||
if "error" in request_components:
|
||||
return request_components
|
||||
|
||||
full_url = self._build_full_url(request_components, modifications)
|
||||
if "error" in full_url:
|
||||
return full_url
|
||||
|
||||
modified_request = self._apply_modifications(
|
||||
request_components, modifications, full_url["url"]
|
||||
)
|
||||
|
||||
return self._send_modified_request(modified_request, request_id, modifications)
|
||||
|
||||
def _parse_http_request(self, raw_content: str) -> dict[str, Any]:
|
||||
lines = raw_content.split("\n")
|
||||
request_line = lines[0].strip().split(" ")
|
||||
if len(request_line) < 2:
|
||||
return {"error": "Invalid request line format"}
|
||||
|
||||
method, url_path = request_line[0], request_line[1]
|
||||
|
||||
headers = {}
|
||||
body_start = 0
|
||||
for i, line in enumerate(lines[1:], 1):
|
||||
if line.strip() == "":
|
||||
body_start = i + 1
|
||||
break
|
||||
if ":" in line:
|
||||
key, value = line.split(":", 1)
|
||||
headers[key.strip()] = value.strip()
|
||||
|
||||
body = "\n".join(lines[body_start:]).strip() if body_start < len(lines) else ""
|
||||
|
||||
return {"method": method, "url_path": url_path, "headers": headers, "body": body}
|
||||
|
||||
def _build_full_url(
|
||||
self, components: dict[str, Any], modifications: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
headers = components["headers"]
|
||||
host = headers.get("Host", "")
|
||||
if not host:
|
||||
return {"error": "No Host header found"}
|
||||
|
||||
protocol = (
|
||||
"https" if ":443" in host or "https" in headers.get("Referer", "").lower() else "http"
|
||||
)
|
||||
full_url = f"{protocol}://{host}{components['url_path']}"
|
||||
|
||||
if "url" in modifications:
|
||||
full_url = modifications["url"]
|
||||
|
||||
return {"url": full_url}
|
||||
|
||||
def _apply_modifications(
|
||||
self, components: dict[str, Any], modifications: dict[str, Any], full_url: str
|
||||
) -> dict[str, Any]:
|
||||
headers = components["headers"].copy()
|
||||
body = components["body"]
|
||||
final_url = full_url
|
||||
|
||||
if "params" in modifications:
|
||||
parsed = urlparse(final_url)
|
||||
params = {k: v[0] if v else "" for k, v in parse_qs(parsed.query).items()}
|
||||
params.update(modifications["params"])
|
||||
final_url = urlunparse(parsed._replace(query=urlencode(params)))
|
||||
|
||||
if "headers" in modifications:
|
||||
headers.update(modifications["headers"])
|
||||
|
||||
if "body" in modifications:
|
||||
body = modifications["body"]
|
||||
|
||||
if "cookies" in modifications:
|
||||
cookies = {}
|
||||
if headers.get("Cookie"):
|
||||
for cookie in headers["Cookie"].split(";"):
|
||||
if "=" in cookie:
|
||||
k, v = cookie.split("=", 1)
|
||||
cookies[k.strip()] = v.strip()
|
||||
cookies.update(modifications["cookies"])
|
||||
headers["Cookie"] = "; ".join([f"{k}={v}" for k, v in cookies.items()])
|
||||
|
||||
return {
|
||||
"method": components["method"],
|
||||
"url": final_url,
|
||||
"headers": headers,
|
||||
"body": body,
|
||||
}
|
||||
|
||||
def _send_modified_request(
|
||||
self, request_data: dict[str, Any], request_id: str, modifications: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
start_time = time.time()
|
||||
response = requests.request(
|
||||
method=request_data["method"],
|
||||
url=request_data["url"],
|
||||
headers=request_data["headers"],
|
||||
data=request_data["body"] or None,
|
||||
proxies=self.proxies,
|
||||
timeout=30,
|
||||
verify=False,
|
||||
)
|
||||
response_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
response_body = response.text
|
||||
truncated = len(response_body) > 10000
|
||||
if truncated:
|
||||
response_body = response_body[:10000] + "\n... [truncated]"
|
||||
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"status_text": response.reason,
|
||||
"headers": {
|
||||
k: v
|
||||
for k, v in response.headers.items()
|
||||
if k.lower()
|
||||
in ["content-type", "content-length", "server", "set-cookie", "location"]
|
||||
},
|
||||
"body": response_body,
|
||||
"body_truncated": truncated,
|
||||
"body_size": len(response.content),
|
||||
"response_time_ms": response_time,
|
||||
"url": response.url,
|
||||
"original_request_id": request_id,
|
||||
"modifications_applied": modifications,
|
||||
"request": {
|
||||
"method": request_data["method"],
|
||||
"url": request_data["url"],
|
||||
"headers": request_data["headers"],
|
||||
"has_body": bool(request_data["body"]),
|
||||
},
|
||||
}
|
||||
|
||||
except ProxyError as e:
|
||||
return {
|
||||
"error": "Proxy connection failed - is Caido running?",
|
||||
"details": str(e),
|
||||
"original_request_id": request_id,
|
||||
}
|
||||
except (RequestException, Timeout) as e:
|
||||
return {
|
||||
"error": f"Failed to repeat request: {type(e).__name__}",
|
||||
"details": str(e),
|
||||
"original_request_id": request_id,
|
||||
}
|
||||
|
||||
def _handle_scope_list(self) -> dict[str, Any]:
|
||||
result = self.client.execute(gql("query { scopes { id name allowlist denylist indexed } }"))
|
||||
scopes = result.get("scopes", [])
|
||||
return {"scopes": scopes, "count": len(scopes)}
|
||||
|
||||
def _handle_scope_get(self, scope_id: str | None) -> dict[str, Any]:
|
||||
if not scope_id:
|
||||
return self._handle_scope_list()
|
||||
|
||||
result = self.client.execute(
|
||||
gql(
|
||||
"query GetScope($id: ID!) { scope(id: $id) { id name allowlist denylist indexed } }"
|
||||
),
|
||||
variable_values={"id": scope_id},
|
||||
)
|
||||
scope = result.get("scope")
|
||||
if not scope:
|
||||
return {"error": f"Scope {scope_id} not found"}
|
||||
return {"scope": scope}
|
||||
|
||||
def _handle_scope_create(
|
||||
self, scope_name: str, allowlist: list[str] | None, denylist: list[str] | None
|
||||
) -> dict[str, Any]:
|
||||
if not scope_name:
|
||||
return {"error": "scope_name required for create"}
|
||||
|
||||
mutation = gql("""
|
||||
mutation CreateScope($input: CreateScopeInput!) {
|
||||
createScope(input: $input) {
|
||||
scope { id name allowlist denylist indexed }
|
||||
error {
|
||||
... on InvalidGlobTermsUserError { code terms }
|
||||
... on OtherUserError { code }
|
||||
}
|
||||
}
|
||||
}
|
||||
""")
|
||||
|
||||
result = self.client.execute(
|
||||
mutation,
|
||||
variable_values={
|
||||
"input": {
|
||||
"name": scope_name,
|
||||
"allowlist": allowlist or [],
|
||||
"denylist": denylist or [],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
payload = result.get("createScope", {})
|
||||
if payload.get("error"):
|
||||
error = payload["error"]
|
||||
return {"error": f"Invalid glob patterns: {error.get('terms', error.get('code'))}"}
|
||||
|
||||
return {"scope": payload.get("scope"), "message": "Scope created successfully"}
|
||||
|
||||
def _handle_scope_update(
|
||||
self,
|
||||
scope_id: str,
|
||||
scope_name: str,
|
||||
allowlist: list[str] | None,
|
||||
denylist: list[str] | None,
|
||||
) -> dict[str, Any]:
|
||||
if not scope_id or not scope_name:
|
||||
return {"error": "scope_id and scope_name required"}
|
||||
|
||||
mutation = gql("""
|
||||
mutation UpdateScope($id: ID!, $input: UpdateScopeInput!) {
|
||||
updateScope(id: $id, input: $input) {
|
||||
scope { id name allowlist denylist indexed }
|
||||
error {
|
||||
... on InvalidGlobTermsUserError { code terms }
|
||||
... on OtherUserError { code }
|
||||
}
|
||||
}
|
||||
}
|
||||
""")
|
||||
|
||||
result = self.client.execute(
|
||||
mutation,
|
||||
variable_values={
|
||||
"id": scope_id,
|
||||
"input": {
|
||||
"name": scope_name,
|
||||
"allowlist": allowlist or [],
|
||||
"denylist": denylist or [],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
payload = result.get("updateScope", {})
|
||||
if payload.get("error"):
|
||||
error = payload["error"]
|
||||
return {"error": f"Invalid glob patterns: {error.get('terms', error.get('code'))}"}
|
||||
|
||||
return {"scope": payload.get("scope"), "message": "Scope updated successfully"}
|
||||
|
||||
def _handle_scope_delete(self, scope_id: str) -> dict[str, Any]:
|
||||
if not scope_id:
|
||||
return {"error": "scope_id required for delete"}
|
||||
|
||||
result = self.client.execute(
|
||||
gql("mutation DeleteScope($id: ID!) { deleteScope(id: $id) { deletedId } }"),
|
||||
variable_values={"id": scope_id},
|
||||
)
|
||||
|
||||
payload = result.get("deleteScope", {})
|
||||
if not payload.get("deletedId"):
|
||||
return {"error": f"Failed to delete scope {scope_id}"}
|
||||
return {"message": f"Scope {scope_id} deleted", "deletedId": payload["deletedId"]}
|
||||
|
||||
def scope_rules(
|
||||
self,
|
||||
action: str,
|
||||
allowlist: list[str] | None = None,
|
||||
denylist: list[str] | None = None,
|
||||
scope_id: str | None = None,
|
||||
scope_name: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
handlers: dict[str, Callable[[], dict[str, Any]]] = {
|
||||
"list": self._handle_scope_list,
|
||||
"get": lambda: self._handle_scope_get(scope_id),
|
||||
"create": lambda: (
|
||||
{"error": "scope_name required for create"}
|
||||
if not scope_name
|
||||
else self._handle_scope_create(scope_name, allowlist, denylist)
|
||||
),
|
||||
"update": lambda: (
|
||||
{"error": "scope_id and scope_name required"}
|
||||
if not scope_id or not scope_name
|
||||
else self._handle_scope_update(scope_id, scope_name, allowlist, denylist)
|
||||
),
|
||||
"delete": lambda: (
|
||||
{"error": "scope_id required for delete"}
|
||||
if not scope_id
|
||||
else self._handle_scope_delete(scope_id)
|
||||
),
|
||||
}
|
||||
|
||||
handler = handlers.get(action)
|
||||
if not handler:
|
||||
return {
|
||||
"error": f"Unsupported action: {action}. Use 'get', 'list', 'create', "
|
||||
f"'update', or 'delete'"
|
||||
}
|
||||
|
||||
try:
|
||||
result = handler()
|
||||
except (TransportQueryError, ValueError, KeyError) as e:
|
||||
return {"error": f"Scope operation failed: {e}"}
|
||||
else:
|
||||
return result
|
||||
|
||||
def list_sitemap(
|
||||
self,
|
||||
scope_id: str | None = None,
|
||||
parent_id: str | None = None,
|
||||
depth: str = "DIRECT",
|
||||
page: int = 1,
|
||||
page_size: int = 30,
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
skip_count = (page - 1) * page_size
|
||||
|
||||
if parent_id:
|
||||
query = gql("""
|
||||
query GetSitemapDescendants($parentId: ID!, $depth: SitemapDescendantsDepth!) {
|
||||
sitemapDescendantEntries(parentId: $parentId, depth: $depth) {
|
||||
edges {
|
||||
node {
|
||||
id kind label hasDescendants
|
||||
request { method path response { statusCode } }
|
||||
}
|
||||
}
|
||||
count { value }
|
||||
}
|
||||
}
|
||||
""")
|
||||
result = self.client.execute(
|
||||
query, variable_values={"parentId": parent_id, "depth": depth}
|
||||
)
|
||||
data = result.get("sitemapDescendantEntries", {})
|
||||
else:
|
||||
query = gql("""
|
||||
query GetSitemapRoots($scopeId: ID) {
|
||||
sitemapRootEntries(scopeId: $scopeId) {
|
||||
edges { node {
|
||||
id kind label hasDescendants
|
||||
metadata { ... on SitemapEntryMetadataDomain { isTls port } }
|
||||
request { method path response { statusCode } }
|
||||
} }
|
||||
count { value }
|
||||
}
|
||||
}
|
||||
""")
|
||||
result = self.client.execute(query, variable_values={"scopeId": scope_id})
|
||||
data = result.get("sitemapRootEntries", {})
|
||||
|
||||
all_nodes = [edge["node"] for edge in data.get("edges", [])]
|
||||
count_data = data.get("count") or {}
|
||||
total_count = count_data.get("value", 0)
|
||||
|
||||
paginated_nodes = all_nodes[skip_count : skip_count + page_size]
|
||||
cleaned_nodes = []
|
||||
|
||||
for node in paginated_nodes:
|
||||
cleaned = {
|
||||
"id": node["id"],
|
||||
"kind": node["kind"],
|
||||
"label": node["label"],
|
||||
"hasDescendants": node["hasDescendants"],
|
||||
}
|
||||
|
||||
if node.get("metadata") and (
|
||||
node["metadata"].get("isTls") is not None or node["metadata"].get("port")
|
||||
):
|
||||
cleaned["metadata"] = node["metadata"]
|
||||
|
||||
if node.get("request"):
|
||||
req = node["request"]
|
||||
cleaned_req = {}
|
||||
if req.get("method"):
|
||||
cleaned_req["method"] = req["method"]
|
||||
if req.get("path"):
|
||||
cleaned_req["path"] = req["path"]
|
||||
response_data = req.get("response") or {}
|
||||
if response_data.get("statusCode"):
|
||||
cleaned_req["status"] = response_data["statusCode"]
|
||||
if cleaned_req:
|
||||
cleaned["request"] = cleaned_req
|
||||
|
||||
cleaned_nodes.append(cleaned)
|
||||
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
return {
|
||||
"entries": cleaned_nodes,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total_pages": total_pages,
|
||||
"total_count": total_count,
|
||||
"has_more": page < total_pages,
|
||||
"showing": (
|
||||
f"{skip_count + 1}-{min(skip_count + page_size, total_count)} of {total_count}"
|
||||
),
|
||||
}
|
||||
|
||||
except (TransportQueryError, ValueError, KeyError) as e:
|
||||
return {"error": f"Failed to fetch sitemap: {e}"}
|
||||
|
||||
def _process_sitemap_metadata(self, node: dict[str, Any]) -> dict[str, Any]:
|
||||
cleaned = {
|
||||
"id": node["id"],
|
||||
"kind": node["kind"],
|
||||
"label": node["label"],
|
||||
"hasDescendants": node["hasDescendants"],
|
||||
}
|
||||
|
||||
if node.get("metadata") and (
|
||||
node["metadata"].get("isTls") is not None or node["metadata"].get("port")
|
||||
):
|
||||
cleaned["metadata"] = node["metadata"]
|
||||
|
||||
return cleaned
|
||||
|
||||
def _process_sitemap_request(self, req: dict[str, Any]) -> dict[str, Any] | None:
|
||||
cleaned_req = {}
|
||||
if req.get("method"):
|
||||
cleaned_req["method"] = req["method"]
|
||||
if req.get("path"):
|
||||
cleaned_req["path"] = req["path"]
|
||||
response_data = req.get("response") or {}
|
||||
if response_data.get("statusCode"):
|
||||
cleaned_req["status"] = response_data["statusCode"]
|
||||
return cleaned_req if cleaned_req else None
|
||||
|
||||
def _process_sitemap_response(self, resp: dict[str, Any]) -> dict[str, Any]:
|
||||
cleaned_resp = {}
|
||||
if resp.get("statusCode"):
|
||||
cleaned_resp["status"] = resp["statusCode"]
|
||||
if resp.get("length"):
|
||||
cleaned_resp["size"] = resp["length"]
|
||||
if resp.get("roundtripTime"):
|
||||
cleaned_resp["time_ms"] = resp["roundtripTime"]
|
||||
return cleaned_resp
|
||||
|
||||
def view_sitemap_entry(self, entry_id: str) -> dict[str, Any]:
|
||||
try:
|
||||
query = gql("""
|
||||
query GetSitemapEntry($id: ID!) {
|
||||
sitemapEntry(id: $id) {
|
||||
id kind label hasDescendants
|
||||
metadata { ... on SitemapEntryMetadataDomain { isTls port } }
|
||||
request { method path response { statusCode length roundtripTime } }
|
||||
requests(first: 30, order: {by: CREATED_AT, ordering: DESC}) {
|
||||
edges { node { method path response { statusCode length } } }
|
||||
count { value }
|
||||
}
|
||||
}
|
||||
}
|
||||
""")
|
||||
|
||||
result = self.client.execute(query, variable_values={"id": entry_id})
|
||||
entry = result.get("sitemapEntry")
|
||||
|
||||
if not entry:
|
||||
return {"error": f"Sitemap entry {entry_id} not found"}
|
||||
|
||||
cleaned = self._process_sitemap_metadata(entry)
|
||||
|
||||
if entry.get("request"):
|
||||
req = entry["request"]
|
||||
cleaned_req = {}
|
||||
if req.get("method"):
|
||||
cleaned_req["method"] = req["method"]
|
||||
if req.get("path"):
|
||||
cleaned_req["path"] = req["path"]
|
||||
if req.get("response"):
|
||||
cleaned_req["response"] = self._process_sitemap_response(req["response"])
|
||||
if cleaned_req:
|
||||
cleaned["request"] = cleaned_req
|
||||
|
||||
requests_data = entry.get("requests", {})
|
||||
request_nodes = [edge["node"] for edge in requests_data.get("edges", [])]
|
||||
|
||||
cleaned_requests = [
|
||||
req
|
||||
for req in (self._process_sitemap_request(node) for node in request_nodes)
|
||||
if req is not None
|
||||
]
|
||||
|
||||
count_data = requests_data.get("count") or {}
|
||||
cleaned["related_requests"] = {
|
||||
"requests": cleaned_requests,
|
||||
"total_count": count_data.get("value", 0),
|
||||
"showing": f"Latest {len(cleaned_requests)} requests",
|
||||
}
|
||||
|
||||
return {"entry": cleaned} if cleaned else {"error": "Failed to process sitemap entry"} # noqa: TRY300
|
||||
|
||||
except (TransportQueryError, ValueError, KeyError) as e:
|
||||
return {"error": f"Failed to fetch sitemap entry: {e}"}
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
_PROXY_MANAGER: ProxyManager | None = None
|
||||
|
||||
|
||||
def get_proxy_manager() -> ProxyManager:
|
||||
if _PROXY_MANAGER is None:
|
||||
return ProxyManager()
|
||||
return _PROXY_MANAGER
|
||||
4
strix/tools/python/__init__.py
Normal file
4
strix/tools/python/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .python_actions import python_action
|
||||
|
||||
|
||||
__all__ = ["python_action"]
|
||||
47
strix/tools/python/python_actions.py
Normal file
47
strix/tools/python/python_actions.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
from strix.tools.registry import register_tool
|
||||
|
||||
from .python_manager import get_python_session_manager
|
||||
|
||||
|
||||
PythonAction = Literal["new_session", "execute", "close", "list_sessions"]
|
||||
|
||||
|
||||
@register_tool
|
||||
def python_action(
|
||||
action: PythonAction,
|
||||
code: str | None = None,
|
||||
timeout: int = 30,
|
||||
session_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
def _validate_code(action_name: str, code: str | None) -> None:
|
||||
if not code:
|
||||
raise ValueError(f"code parameter is required for {action_name} action")
|
||||
|
||||
def _validate_action(action_name: str) -> None:
|
||||
raise ValueError(f"Unknown action: {action_name}")
|
||||
|
||||
manager = get_python_session_manager()
|
||||
|
||||
try:
|
||||
match action:
|
||||
case "new_session":
|
||||
return manager.create_session(session_id, code, timeout)
|
||||
|
||||
case "execute":
|
||||
_validate_code(action, code)
|
||||
assert code is not None
|
||||
return manager.execute_code(session_id, code, timeout)
|
||||
|
||||
case "close":
|
||||
return manager.close_session(session_id)
|
||||
|
||||
case "list_sessions":
|
||||
return manager.list_sessions()
|
||||
|
||||
case _:
|
||||
_validate_action(action) # type: ignore[unreachable]
|
||||
|
||||
except (ValueError, RuntimeError) as e:
|
||||
return {"stderr": str(e), "session_id": session_id, "stdout": "", "is_running": False}
|
||||
131
strix/tools/python/python_actions_schema.xml
Normal file
131
strix/tools/python/python_actions_schema.xml
Normal file
@@ -0,0 +1,131 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<tools>
|
||||
<tool name="python_action">
|
||||
<description>Perform Python actions using persistent interpreter sessions for cybersecurity tasks.</description>
|
||||
<details>Common Use Cases:
|
||||
- Security script development and testing (payload generation, exploit scripts)
|
||||
- Data analysis of security logs, network traffic, or vulnerability scans
|
||||
- Cryptographic operations and security tool automation
|
||||
- Interactive penetration testing workflows and proof-of-concept development
|
||||
- Processing security data formats (JSON, XML, CSV from security tools)
|
||||
- HTTP proxy interaction for web security testing (all proxy functions are pre-imported)
|
||||
|
||||
Each session instance is PERSISTENT and maintains its own global and local namespaces
|
||||
until explicitly closed, allowing for multi-step security workflows and stateful computations.
|
||||
|
||||
PROXY FUNCTIONS PRE-IMPORTED:
|
||||
All proxy action functions are automatically imported into every Python session, enabling
|
||||
seamless HTTP traffic analysis and web security testing
|
||||
|
||||
This is particularly useful for:
|
||||
- Analyzing captured HTTP traffic during web application testing
|
||||
- Automating request manipulation and replay attacks
|
||||
- Building custom security testing workflows combining proxy data with Python analysis
|
||||
- Correlating multiple requests for advanced attack scenarios</details>
|
||||
<parameters>
|
||||
<parameter name="action" type="string" required="true">
|
||||
<description>The Python action to perform: - new_session: Create a new Python interpreter session. This MUST be the first action for each session. - execute: Execute Python code in the specified session. - close: Close the specified session instance. - list_sessions: List all active Python sessions.</description>
|
||||
</parameter>
|
||||
<parameter name="code" type="string" required="false">
|
||||
<description>Required for 'new_session' (as initial code) and 'execute' actions. The Python code to execute.</description>
|
||||
</parameter>
|
||||
<parameter name="timeout" type="integer" required="false">
|
||||
<description>Maximum execution time in seconds for code execution. Applies to both 'new_session' (when initial code is provided) and 'execute' actions. Default is 30 seconds.</description>
|
||||
</parameter>
|
||||
<parameter name="session_id" type="string" required="false">
|
||||
<description>Unique identifier for the Python session. If not provided, uses the default session ID.</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
<returns type="Dict[str, Any]">
|
||||
<description>Response containing: - session_id: the ID of the session that was operated on - stdout: captured standard output from code execution (for execute action) - stderr: any error message if execution failed - result: string representation of the last expression result - execution_time: time taken to execute the code - message: status message about the action performed - Various session info depending on the action</description>
|
||||
</returns>
|
||||
<notes>
|
||||
Important usage rules:
|
||||
1. PERSISTENCE: Session instances remain active and maintain their state (variables,
|
||||
imports, function definitions) until explicitly closed with the 'close' action.
|
||||
This allows for multi-step workflows across multiple tool calls.
|
||||
2. MULTIPLE SESSIONS: You can run multiple Python sessions concurrently by using
|
||||
different session_id values. Each session operates independently with its own
|
||||
namespace.
|
||||
3. Session interaction MUST begin with 'new_session' action for each session instance.
|
||||
4. Only one action can be performed per call.
|
||||
5. CODE EXECUTION:
|
||||
- Both expressions and statements are supported
|
||||
- Expressions automatically return their result
|
||||
- Print statements and stdout are captured
|
||||
- Variables persist between executions in the same session
|
||||
- Imports, function definitions, etc. persist in the session
|
||||
- IPython magic commands are fully supported (%pip, %time, %whos, %%writefile, etc.)
|
||||
- Line magics (%) and cell magics (%%) work as expected
|
||||
6. CLOSE: Terminates the session completely and frees memory
|
||||
7. The Python sessions can operate concurrently with other tools. You may invoke
|
||||
terminal, browser, or other tools while maintaining active Python sessions.
|
||||
8. Each session has its own isolated namespace - variables in one session don't
|
||||
affect others.
|
||||
</notes>
|
||||
<examples>
|
||||
# Create new session for security analysis (default session)
|
||||
<function=python_action>
|
||||
<parameter=action>new_session</parameter>
|
||||
<parameter=code>import hashlib
|
||||
import base64
|
||||
import json
|
||||
print("Security analysis session started")</parameter>
|
||||
</function>
|
||||
|
||||
# Analyze security data in the default session
|
||||
<function=python_action>
|
||||
<parameter=action>execute</parameter>
|
||||
<parameter=code>vulnerability_data = {"cve": "CVE-2024-1234", "severity": "high"}
|
||||
encoded_payload = base64.b64encode(json.dumps(vulnerability_data).encode())
|
||||
print(f"Encoded: {encoded_payload.decode()}")</parameter>
|
||||
</function>
|
||||
|
||||
# Long running security scan with custom timeout
|
||||
<function=python_action>
|
||||
<parameter=action>execute</parameter>
|
||||
<parameter=code>import time
|
||||
# Simulate long-running vulnerability scan
|
||||
time.sleep(45)
|
||||
print('Security scan completed!')</parameter>
|
||||
<parameter=timeout>50</parameter>
|
||||
</function>
|
||||
|
||||
# Use IPython magic commands for package management and profiling
|
||||
<function=python_action>
|
||||
<parameter=action>execute</parameter>
|
||||
<parameter=code>%pip install requests
|
||||
%time response = requests.get('https://httpbin.org/json')
|
||||
%whos</parameter>
|
||||
|
||||
# Analyze requests for potential vulnerabilities
|
||||
<function=python_action>
|
||||
<parameter=action>execute</parameter>
|
||||
<parameter=code># Filter for POST requests that might contain sensitive data
|
||||
post_requests = list_requests(
|
||||
httpql_filter="req.method.eq:POST",
|
||||
page_size=20
|
||||
)
|
||||
|
||||
# Analyze each POST request for potential issues
|
||||
for req in post_requests.get('requests', []):
|
||||
request_id = req['id']
|
||||
# View the request details
|
||||
request_details = view_request(request_id, part="request")
|
||||
|
||||
# Check for potential SQL injection points
|
||||
body = request_details.get('body', '')
|
||||
if any(keyword in body.lower() for keyword in ['select', 'union', 'insert', 'update']):
|
||||
print(f"Potential SQL injection in request {request_id}")
|
||||
|
||||
# Repeat the request with a test payload
|
||||
test_payload = repeat_request(request_id, {
|
||||
'body': body + "' OR '1'='1"
|
||||
})
|
||||
print(f"Test response status: {test_payload.get('status_code')}")
|
||||
|
||||
print("Security analysis complete!")</parameter>
|
||||
</function>
|
||||
</examples>
|
||||
</tool>
|
||||
</tools>
|
||||
172
strix/tools/python/python_instance.py
Normal file
172
strix/tools/python/python_instance.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import io
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from IPython.core.interactiveshell import InteractiveShell
|
||||
|
||||
|
||||
MAX_STDOUT_LENGTH = 10_000
|
||||
MAX_STDERR_LENGTH = 5_000
|
||||
|
||||
|
||||
class PythonInstance:
|
||||
def __init__(self, session_id: str) -> None:
|
||||
self.session_id = session_id
|
||||
self.is_running = True
|
||||
self._execution_lock = threading.Lock()
|
||||
|
||||
import os
|
||||
|
||||
os.chdir("/workspace")
|
||||
|
||||
self.shell = InteractiveShell()
|
||||
self.shell.init_completer()
|
||||
self.shell.init_history()
|
||||
self.shell.init_logger()
|
||||
|
||||
self._setup_proxy_functions()
|
||||
|
||||
def _setup_proxy_functions(self) -> None:
|
||||
try:
|
||||
from strix.tools.proxy import proxy_actions
|
||||
|
||||
proxy_functions = [
|
||||
"list_requests",
|
||||
"list_sitemap",
|
||||
"repeat_request",
|
||||
"scope_rules",
|
||||
"send_request",
|
||||
"view_request",
|
||||
"view_sitemap_entry",
|
||||
]
|
||||
|
||||
proxy_dict = {name: getattr(proxy_actions, name) for name in proxy_functions}
|
||||
self.shell.user_ns.update(proxy_dict)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
def _validate_session(self) -> dict[str, Any] | None:
|
||||
if not self.is_running:
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"stdout": "",
|
||||
"stderr": "Session is not running",
|
||||
"result": None,
|
||||
}
|
||||
return None
|
||||
|
||||
def _setup_execution_environment(self, timeout: int) -> tuple[Any, io.StringIO, io.StringIO]:
|
||||
stdout_capture = io.StringIO()
|
||||
stderr_capture = io.StringIO()
|
||||
|
||||
def timeout_handler(signum: int, frame: Any) -> None:
|
||||
raise TimeoutError(f"Code execution timed out after {timeout} seconds")
|
||||
|
||||
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
|
||||
signal.alarm(timeout)
|
||||
|
||||
sys.stdout = stdout_capture
|
||||
sys.stderr = stderr_capture
|
||||
|
||||
return old_handler, stdout_capture, stderr_capture
|
||||
|
||||
def _cleanup_execution_environment(
|
||||
self, old_handler: Any, old_stdout: Any, old_stderr: Any
|
||||
) -> None:
|
||||
signal.signal(signal.SIGALRM, old_handler)
|
||||
sys.stdout = old_stdout
|
||||
sys.stderr = old_stderr
|
||||
|
||||
def _truncate_output(self, content: str, max_length: int, suffix: str) -> str:
|
||||
if len(content) > max_length:
|
||||
return content[:max_length] + suffix
|
||||
return content
|
||||
|
||||
def _format_execution_result(
|
||||
self, execution_result: Any, stdout_content: str, stderr_content: str
|
||||
) -> dict[str, Any]:
|
||||
stdout = self._truncate_output(
|
||||
stdout_content, MAX_STDOUT_LENGTH, "... [stdout truncated at 10k chars]"
|
||||
)
|
||||
|
||||
if execution_result.result is not None:
|
||||
if stdout and not stdout.endswith("\n"):
|
||||
stdout += "\n"
|
||||
result_repr = repr(execution_result.result)
|
||||
result_repr = self._truncate_output(
|
||||
result_repr, MAX_STDOUT_LENGTH, "... [result truncated at 10k chars]"
|
||||
)
|
||||
stdout += result_repr
|
||||
|
||||
stdout = self._truncate_output(
|
||||
stdout, MAX_STDOUT_LENGTH, "... [output truncated at 10k chars]"
|
||||
)
|
||||
|
||||
stderr_content = stderr_content if stderr_content else ""
|
||||
stderr_content = self._truncate_output(
|
||||
stderr_content, MAX_STDERR_LENGTH, "... [stderr truncated at 5k chars]"
|
||||
)
|
||||
|
||||
if (
|
||||
execution_result.error_before_exec or execution_result.error_in_exec
|
||||
) and not stderr_content:
|
||||
stderr_content = "Execution error occurred"
|
||||
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"stdout": stdout,
|
||||
"stderr": stderr_content,
|
||||
"result": repr(execution_result.result)
|
||||
if execution_result.result is not None
|
||||
else None,
|
||||
}
|
||||
|
||||
def _handle_execution_error(self, error: BaseException) -> dict[str, Any]:
|
||||
error_msg = str(error)
|
||||
error_msg = self._truncate_output(
|
||||
error_msg, MAX_STDERR_LENGTH, "... [error truncated at 5k chars]"
|
||||
)
|
||||
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"stdout": "",
|
||||
"stderr": error_msg,
|
||||
"result": None,
|
||||
}
|
||||
|
||||
def execute_code(self, code: str, timeout: int = 30) -> dict[str, Any]:
|
||||
session_error = self._validate_session()
|
||||
if session_error:
|
||||
return session_error
|
||||
|
||||
with self._execution_lock:
|
||||
old_stdout, old_stderr = sys.stdout, sys.stderr
|
||||
|
||||
try:
|
||||
old_handler, stdout_capture, stderr_capture = self._setup_execution_environment(
|
||||
timeout
|
||||
)
|
||||
|
||||
try:
|
||||
execution_result = self.shell.run_cell(code, silent=False, store_history=True)
|
||||
signal.alarm(0)
|
||||
|
||||
return self._format_execution_result(
|
||||
execution_result, stdout_capture.getvalue(), stderr_capture.getvalue()
|
||||
)
|
||||
|
||||
except (TimeoutError, KeyboardInterrupt, SystemExit) as e:
|
||||
signal.alarm(0)
|
||||
return self._handle_execution_error(e)
|
||||
|
||||
finally:
|
||||
self._cleanup_execution_environment(old_handler, old_stdout, old_stderr)
|
||||
|
||||
def close(self) -> None:
|
||||
self.is_running = False
|
||||
self.shell.reset(new_session=False)
|
||||
|
||||
def is_alive(self) -> bool:
|
||||
return self.is_running
|
||||
131
strix/tools/python/python_manager.py
Normal file
131
strix/tools/python/python_manager.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import atexit
|
||||
import contextlib
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from .python_instance import PythonInstance
|
||||
|
||||
|
||||
class PythonSessionManager:
|
||||
def __init__(self) -> None:
|
||||
self.sessions: dict[str, PythonInstance] = {}
|
||||
self._lock = threading.Lock()
|
||||
self.default_session_id = "default"
|
||||
|
||||
self._register_cleanup_handlers()
|
||||
|
||||
def create_session(
|
||||
self, session_id: str | None = None, initial_code: str | None = None, timeout: int = 30
|
||||
) -> dict[str, Any]:
|
||||
if session_id is None:
|
||||
session_id = self.default_session_id
|
||||
|
||||
with self._lock:
|
||||
if session_id in self.sessions:
|
||||
raise ValueError(f"Python session '{session_id}' already exists")
|
||||
|
||||
session = PythonInstance(session_id)
|
||||
self.sessions[session_id] = session
|
||||
|
||||
if initial_code:
|
||||
result = session.execute_code(initial_code, timeout)
|
||||
result["message"] = (
|
||||
f"Python session '{session_id}' created successfully with initial code"
|
||||
)
|
||||
else:
|
||||
result = {
|
||||
"session_id": session_id,
|
||||
"message": f"Python session '{session_id}' created successfully",
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def execute_code(
|
||||
self, session_id: str | None = None, code: str | None = None, timeout: int = 30
|
||||
) -> dict[str, Any]:
|
||||
if session_id is None:
|
||||
session_id = self.default_session_id
|
||||
|
||||
if not code:
|
||||
raise ValueError("No code provided for execution")
|
||||
|
||||
with self._lock:
|
||||
if session_id not in self.sessions:
|
||||
raise ValueError(f"Python session '{session_id}' not found")
|
||||
|
||||
session = self.sessions[session_id]
|
||||
|
||||
result = session.execute_code(code, timeout)
|
||||
result["message"] = f"Code executed in session '{session_id}'"
|
||||
return result
|
||||
|
||||
def close_session(self, session_id: str | None = None) -> dict[str, Any]:
|
||||
if session_id is None:
|
||||
session_id = self.default_session_id
|
||||
|
||||
with self._lock:
|
||||
if session_id not in self.sessions:
|
||||
raise ValueError(f"Python session '{session_id}' not found")
|
||||
|
||||
session = self.sessions.pop(session_id)
|
||||
|
||||
session.close()
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"message": f"Python session '{session_id}' closed successfully",
|
||||
"is_running": False,
|
||||
}
|
||||
|
||||
def list_sessions(self) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
session_info = {}
|
||||
for sid, session in self.sessions.items():
|
||||
session_info[sid] = {
|
||||
"is_running": session.is_running,
|
||||
"is_alive": session.is_alive(),
|
||||
}
|
||||
|
||||
return {"sessions": session_info, "total_count": len(session_info)}
|
||||
|
||||
def cleanup_dead_sessions(self) -> None:
|
||||
with self._lock:
|
||||
dead_sessions = []
|
||||
for sid, session in self.sessions.items():
|
||||
if not session.is_alive():
|
||||
dead_sessions.append(sid)
|
||||
|
||||
for sid in dead_sessions:
|
||||
session = self.sessions.pop(sid)
|
||||
with contextlib.suppress(Exception):
|
||||
session.close()
|
||||
|
||||
def close_all_sessions(self) -> None:
|
||||
with self._lock:
|
||||
sessions_to_close = list(self.sessions.values())
|
||||
self.sessions.clear()
|
||||
|
||||
for session in sessions_to_close:
|
||||
with contextlib.suppress(Exception):
|
||||
session.close()
|
||||
|
||||
def _register_cleanup_handlers(self) -> None:
|
||||
atexit.register(self.close_all_sessions)
|
||||
|
||||
signal.signal(signal.SIGTERM, self._signal_handler)
|
||||
signal.signal(signal.SIGINT, self._signal_handler)
|
||||
|
||||
if hasattr(signal, "SIGHUP"):
|
||||
signal.signal(signal.SIGHUP, self._signal_handler)
|
||||
|
||||
def _signal_handler(self, _signum: int, _frame: Any) -> None:
|
||||
self.close_all_sessions()
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
_python_session_manager = PythonSessionManager()
|
||||
|
||||
|
||||
def get_python_session_manager() -> PythonSessionManager:
|
||||
return _python_session_manager
|
||||
196
strix/tools/registry.py
Normal file
196
strix/tools/registry.py
Normal file
@@ -0,0 +1,196 @@
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from inspect import signature
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
tools: list[dict[str, Any]] = []
|
||||
_tools_by_name: dict[str, Callable[..., Any]] = {}
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImplementedInClientSideOnlyError(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "This tool is implemented in the client side only",
|
||||
) -> None:
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
def _process_dynamic_content(content: str) -> str:
|
||||
if "{{DYNAMIC_MODULES_DESCRIPTION}}" in content:
|
||||
try:
|
||||
from strix.prompts import generate_modules_description
|
||||
|
||||
modules_description = generate_modules_description()
|
||||
content = content.replace("{{DYNAMIC_MODULES_DESCRIPTION}}", modules_description)
|
||||
except ImportError:
|
||||
logger.warning("Could not import prompts utilities for dynamic schema generation")
|
||||
content = content.replace(
|
||||
"{{DYNAMIC_MODULES_DESCRIPTION}}",
|
||||
"List of prompt modules to load for this agent (max 3). Module discovery failed.",
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def _load_xml_schema(path: Path) -> Any:
|
||||
if not path.exists():
|
||||
return None
|
||||
try:
|
||||
content = path.read_text()
|
||||
|
||||
content = _process_dynamic_content(content)
|
||||
|
||||
start_tag = '<tool name="'
|
||||
end_tag = "</tool>"
|
||||
tools_dict = {}
|
||||
|
||||
pos = 0
|
||||
while True:
|
||||
start_pos = content.find(start_tag, pos)
|
||||
if start_pos == -1:
|
||||
break
|
||||
|
||||
name_start = start_pos + len(start_tag)
|
||||
name_end = content.find('"', name_start)
|
||||
if name_end == -1:
|
||||
break
|
||||
tool_name = content[name_start:name_end]
|
||||
|
||||
end_pos = content.find(end_tag, name_end)
|
||||
if end_pos == -1:
|
||||
break
|
||||
end_pos += len(end_tag)
|
||||
|
||||
tool_element = content[start_pos:end_pos]
|
||||
tools_dict[tool_name] = tool_element
|
||||
|
||||
pos = end_pos
|
||||
|
||||
if pos >= len(content):
|
||||
break
|
||||
except (IndexError, ValueError, UnicodeError) as e:
|
||||
logger.warning(f"Error loading schema file {path}: {e}")
|
||||
return None
|
||||
else:
|
||||
return tools_dict
|
||||
|
||||
|
||||
def _get_module_name(func: Callable[..., Any]) -> str:
|
||||
module = inspect.getmodule(func)
|
||||
if not module:
|
||||
return "unknown"
|
||||
|
||||
module_name = module.__name__
|
||||
if ".tools." in module_name:
|
||||
parts = module_name.split(".tools.")[-1].split(".")
|
||||
if len(parts) >= 1:
|
||||
return parts[0]
|
||||
return "unknown"
|
||||
|
||||
|
||||
def register_tool(
|
||||
func: Callable[..., Any] | None = None, *, sandbox_execution: bool = True
|
||||
) -> Callable[..., Any]:
|
||||
def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
|
||||
func_dict = {
|
||||
"name": f.__name__,
|
||||
"function": f,
|
||||
"module": _get_module_name(f),
|
||||
"sandbox_execution": sandbox_execution,
|
||||
}
|
||||
|
||||
sandbox_mode = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
|
||||
if not sandbox_mode:
|
||||
try:
|
||||
module_path = Path(inspect.getfile(f))
|
||||
schema_file_name = f"{module_path.stem}_schema.xml"
|
||||
schema_path = module_path.parent / schema_file_name
|
||||
|
||||
xml_tools = _load_xml_schema(schema_path)
|
||||
|
||||
if xml_tools is not None and f.__name__ in xml_tools:
|
||||
func_dict["xml_schema"] = xml_tools[f.__name__]
|
||||
else:
|
||||
func_dict["xml_schema"] = (
|
||||
f'<tool name="{f.__name__}">'
|
||||
"<description>Schema not found for tool.</description>"
|
||||
"</tool>"
|
||||
)
|
||||
except (TypeError, FileNotFoundError) as e:
|
||||
logger.warning(f"Error loading schema for {f.__name__}: {e}")
|
||||
func_dict["xml_schema"] = (
|
||||
f'<tool name="{f.__name__}">'
|
||||
"<description>Error loading schema.</description>"
|
||||
"</tool>"
|
||||
)
|
||||
|
||||
tools.append(func_dict)
|
||||
_tools_by_name[str(func_dict["name"])] = f
|
||||
|
||||
@wraps(f)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
if func is None:
|
||||
return decorator
|
||||
return decorator(func)
|
||||
|
||||
|
||||
def get_tool_by_name(name: str) -> Callable[..., Any] | None:
|
||||
return _tools_by_name.get(name)
|
||||
|
||||
|
||||
def get_tool_names() -> list[str]:
|
||||
return list(_tools_by_name.keys())
|
||||
|
||||
|
||||
def needs_agent_state(tool_name: str) -> bool:
|
||||
tool_func = get_tool_by_name(tool_name)
|
||||
if not tool_func:
|
||||
return False
|
||||
sig = signature(tool_func)
|
||||
return "agent_state" in sig.parameters
|
||||
|
||||
|
||||
def should_execute_in_sandbox(tool_name: str) -> bool:
|
||||
for tool in tools:
|
||||
if tool.get("name") == tool_name:
|
||||
return bool(tool.get("sandbox_execution", True))
|
||||
return True
|
||||
|
||||
|
||||
def get_tools_prompt() -> str:
|
||||
tools_by_module: dict[str, list[dict[str, Any]]] = {}
|
||||
for tool in tools:
|
||||
module = tool.get("module", "unknown")
|
||||
if module not in tools_by_module:
|
||||
tools_by_module[module] = []
|
||||
tools_by_module[module].append(tool)
|
||||
|
||||
xml_sections = []
|
||||
for module, module_tools in sorted(tools_by_module.items()):
|
||||
tag_name = f"{module}_tools"
|
||||
section_parts = [f"<{tag_name}>"]
|
||||
for tool in module_tools:
|
||||
tool_xml = tool.get("xml_schema", "")
|
||||
if tool_xml:
|
||||
indented_tool = "\n".join(f" {line}" for line in tool_xml.split("\n"))
|
||||
section_parts.append(indented_tool)
|
||||
section_parts.append(f"</{tag_name}>")
|
||||
xml_sections.append("\n".join(section_parts))
|
||||
|
||||
return "\n\n".join(xml_sections)
|
||||
|
||||
|
||||
def clear_registry() -> None:
|
||||
tools.clear()
|
||||
_tools_by_name.clear()
|
||||
6
strix/tools/reporting/__init__.py
Normal file
6
strix/tools/reporting/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .reporting_actions import create_vulnerability_report
|
||||
|
||||
|
||||
__all__ = [
|
||||
"create_vulnerability_report",
|
||||
]
|
||||
63
strix/tools/reporting/reporting_actions.py
Normal file
63
strix/tools/reporting/reporting_actions.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from typing import Any
|
||||
|
||||
from strix.tools.registry import register_tool
|
||||
|
||||
|
||||
@register_tool(sandbox_execution=False)
|
||||
def create_vulnerability_report(
|
||||
title: str,
|
||||
content: str,
|
||||
severity: str,
|
||||
) -> dict[str, Any]:
|
||||
validation_error = None
|
||||
if not title or not title.strip():
|
||||
validation_error = "Title cannot be empty"
|
||||
elif not content or not content.strip():
|
||||
validation_error = "Content cannot be empty"
|
||||
elif not severity or not severity.strip():
|
||||
validation_error = "Severity cannot be empty"
|
||||
else:
|
||||
valid_severities = ["critical", "high", "medium", "low", "info"]
|
||||
if severity.lower() not in valid_severities:
|
||||
validation_error = (
|
||||
f"Invalid severity '{severity}'. Must be one of: {', '.join(valid_severities)}"
|
||||
)
|
||||
|
||||
if validation_error:
|
||||
return {"success": False, "message": validation_error}
|
||||
|
||||
try:
|
||||
from strix.cli.tracer import get_global_tracer
|
||||
|
||||
tracer = get_global_tracer()
|
||||
if tracer:
|
||||
report_id = tracer.add_vulnerability_report(
|
||||
title=title,
|
||||
content=content,
|
||||
severity=severity,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Vulnerability report '{title}' created successfully",
|
||||
"report_id": report_id,
|
||||
"severity": severity.lower(),
|
||||
}
|
||||
import logging
|
||||
|
||||
logging.warning("Global tracer not available - vulnerability report not stored")
|
||||
|
||||
return { # noqa: TRY300
|
||||
"success": True,
|
||||
"message": f"Vulnerability report '{title}' created successfully (not persisted)",
|
||||
"warning": "Report could not be persisted - tracer unavailable",
|
||||
}
|
||||
|
||||
except ImportError:
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Vulnerability report '{title}' created successfully (not persisted)",
|
||||
"warning": "Report could not be persisted - tracer module unavailable",
|
||||
}
|
||||
except (ValueError, TypeError) as e:
|
||||
return {"success": False, "message": f"Failed to create vulnerability report: {e!s}"}
|
||||
30
strix/tools/reporting/reporting_actions_schema.xml
Normal file
30
strix/tools/reporting/reporting_actions_schema.xml
Normal file
@@ -0,0 +1,30 @@
|
||||
<tools>
|
||||
<tool name="create_vulnerability_report">
|
||||
<description>Create a vulnerability report for a discovered security issue.
|
||||
|
||||
Use this tool to document a specific verified security vulnerability.
|
||||
Put ALL details in the content field - affected URLs, parameters, proof of concept, remediation steps, CVE references, CVSS scores, technical details, impact assessment, etc.
|
||||
|
||||
DO NOT USE:
|
||||
- For general security observations without specific vulnerabilities
|
||||
- When you don't have concrete vulnerability details
|
||||
- When you don't have a proof of concept, or still not 100% sure if it's a vulnerability
|
||||
- For tracking multiple vulnerabilities (create separate reports)
|
||||
- For reporting multiple vulnerabilities at once. Use a separate create_vulnerability_report for each vulnerability.
|
||||
</description>
|
||||
<parameters>
|
||||
<parameter name="title" type="string" required="true">
|
||||
<description>Clear, concise title of the vulnerability</description>
|
||||
</parameter>
|
||||
<parameter name="content" type="string" required="true">
|
||||
<description>Complete vulnerability details including affected URLs, technical details, impact, proof of concept, remediation steps, and any relevant references. Be comprehensive and include everything relevant.</description>
|
||||
</parameter>
|
||||
<parameter name="severity" type="string" required="true">
|
||||
<description>Severity level: critical, high, medium, low, or info</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
<returns type="Dict[str, Any]">
|
||||
<description>Response containing success status and message</description>
|
||||
</returns>
|
||||
</tool>
|
||||
</tools>
|
||||
4
strix/tools/terminal/__init__.py
Normal file
4
strix/tools/terminal/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .terminal_actions import terminal_action
|
||||
|
||||
|
||||
__all__ = ["terminal_action"]
|
||||
53
strix/tools/terminal/terminal_actions.py
Normal file
53
strix/tools/terminal/terminal_actions.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
from strix.tools.registry import register_tool
|
||||
|
||||
from .terminal_manager import get_terminal_manager
|
||||
|
||||
|
||||
TerminalAction = Literal["new_terminal", "send_input", "wait", "close"]
|
||||
|
||||
|
||||
@register_tool
|
||||
def terminal_action(
|
||||
action: TerminalAction,
|
||||
inputs: list[str] | None = None,
|
||||
time: float | None = None,
|
||||
terminal_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
def _validate_inputs(action_name: str, inputs: list[str] | None) -> None:
|
||||
if not inputs:
|
||||
raise ValueError(f"inputs parameter is required for {action_name} action")
|
||||
|
||||
def _validate_time(time_param: float | None) -> None:
|
||||
if time_param is None:
|
||||
raise ValueError("time parameter is required for wait action")
|
||||
|
||||
def _validate_action(action_name: str) -> None:
|
||||
raise ValueError(f"Unknown action: {action_name}")
|
||||
|
||||
manager = get_terminal_manager()
|
||||
|
||||
try:
|
||||
match action:
|
||||
case "new_terminal":
|
||||
return manager.create_terminal(terminal_id, inputs)
|
||||
|
||||
case "send_input":
|
||||
_validate_inputs(action, inputs)
|
||||
assert inputs is not None
|
||||
return manager.send_input(terminal_id, inputs)
|
||||
|
||||
case "wait":
|
||||
_validate_time(time)
|
||||
assert time is not None
|
||||
return manager.wait_terminal(terminal_id, time)
|
||||
|
||||
case "close":
|
||||
return manager.close_terminal(terminal_id)
|
||||
|
||||
case _:
|
||||
_validate_action(action) # type: ignore[unreachable]
|
||||
|
||||
except (ValueError, RuntimeError) as e:
|
||||
return {"error": str(e), "terminal_id": terminal_id, "snapshot": "", "is_running": False}
|
||||
114
strix/tools/terminal/terminal_actions_schema.xml
Normal file
114
strix/tools/terminal/terminal_actions_schema.xml
Normal file
@@ -0,0 +1,114 @@
|
||||
<tools>
|
||||
<tool name="terminal_action">
|
||||
<description>Perform terminal actions using a terminal emulator instance. Each terminal instance
|
||||
is PERSISTENT and remains active until explicitly closed, allowing for multi-step
|
||||
workflows and long-running processes.</description>
|
||||
<parameters>
|
||||
<parameter name="action" type="string" required="true">
|
||||
<description>The terminal action to perform: - new_terminal: Create a new terminal instance. This MUST be the first action for each terminal tab. - send_input: Send keyboard input to the specified terminal. - wait: Pause execution for specified number of seconds. Can be also used to get the current terminal state (screenshot, output, etc.) after using other tools. - close: Close the specified terminal instance. This MUST be the final action for each terminal tab.</description>
|
||||
</parameter>
|
||||
<parameter name="inputs" type="string" required="false">
|
||||
<description>Required for 'new_terminal' and 'send_input' actions: - List of inputs to send to terminal. Each element in the list MUST be one of the following: - Regular text: "hello", "world", etc. - Literal text (not interpreted as special keys): prefix with "literal:" e.g., "literal:Home", "literal:Escape", "literal:Enter" to send these as text - Enter - Space - Backspace - Escape: "Escape", "^[", "C-[" - Tab: "Tab" - Arrow keys: "Left", "Right", "Up", "Down" - Navigation: "Home", "End", "PageUp", "PageDown" - Function keys: "F1" through "F12" Modifier keys supported with prefixes: - ^ or C- : Control (e.g., "^c", "C-c") - S- : Shift (e.g., "S-F6") - A- : Alt (e.g., "A-Home") - Combined modifiers for arrows: "S-A-Up", "C-S-Left" - Inputs MUST in all cases be sent as a LIST of strings, even if you are only sending one input. - Sending Inputs as a single string will NOT work.</description>
|
||||
</parameter>
|
||||
<parameter name="time" type="string" required="false">
|
||||
<description>Required for 'wait' action. Number of seconds to pause execution. Can be fractional (e.g., 0.5 for half a second).</description>
|
||||
</parameter>
|
||||
<parameter name="terminal_id" type="string" required="false">
|
||||
<description>Identifier for the terminal instance. Required for all actions except the first 'new_terminal' action. Allows managing multiple concurrent terminal tabs. - For 'new_terminal': if not provided, a default terminal is created. If provided, creates a new terminal with that ID. - For other actions: specifies which terminal instance to operate on. - Default terminal ID is "default" if not specified.</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
<returns type="Dict[str, Any]">
|
||||
<description>Response containing: - snapshot: raw representation of current terminal state where you can see the output of the command - terminal_id: the ID of the terminal instance that was operated on</description>
|
||||
</returns>
|
||||
<notes>
|
||||
Important usage rules:
|
||||
1. PERSISTENCE: Terminal instances remain active and maintain their state (environment
|
||||
variables, current directory, running processes) until explicitly closed with the
|
||||
'close' action. This allows for multi-step workflows across multiple tool calls.
|
||||
2. MULTIPLE TERMINALS: You can run multiple terminal instances concurrently by using
|
||||
different terminal_id values. Each terminal operates independently.
|
||||
3. Terminal interaction MUST begin with 'new_terminal' action for each terminal instance.
|
||||
4. Only one action can be performed per call.
|
||||
5. Input handling:
|
||||
- Regular text is sent as-is
|
||||
- Literal text: prefix with "literal:" to send special key names as literal text
|
||||
- Special keys must match supported key names
|
||||
- Modifier combinations follow specific syntax
|
||||
- Control can be specified as ^ or C- prefix
|
||||
- Shift (S-) works with special keys only
|
||||
- Alt (A-) works with any character/key
|
||||
6. Wait action:
|
||||
- Time is specified in seconds
|
||||
- Can be used to wait for command completion
|
||||
- Can be fractional (e.g., 0.5 seconds)
|
||||
- Snapshot and output are captured after the wait
|
||||
- You should estimate the time it will take to run the command and set the wait time accordingly.
|
||||
- It can be from a few seconds to a few minutes, choose wisely depending on the command you are running and the task.
|
||||
7. The terminal can operate concurrently with other tools. You may invoke
|
||||
browser, proxy, or other tools (in separate assistant messages) while maintaining
|
||||
active terminal sessions.
|
||||
8. You do not need to close terminals after you are done, but you can if you want to
|
||||
free up resources.
|
||||
9. You MUST end the inputs list with an "Enter" if you want to run the command, as
|
||||
it is not sent automatically.
|
||||
10. AUTOMATIC SPACING BEHAVIOR:
|
||||
- Consecutive regular text inputs have spaces automatically added between them
|
||||
- This is helpful for shell commands: ["ls", "-la"] becomes "ls -la"
|
||||
- This causes problems for compound commands: [":", "w", "q"] becomes ": w q"
|
||||
- Use "literal:" prefix to bypass spacing: [":", "literal:wq"] becomes ":wq"
|
||||
- Special keys (Enter, Space, etc.) and literal strings never trigger spacing
|
||||
11. WHEN TO USE LITERAL PREFIX:
|
||||
- Vim commands: [":", "literal:wq", "Enter"] instead of [":", "w", "q", "Enter"]
|
||||
- Any sequence where exact character positioning matters
|
||||
- When you need multiple characters sent as a single unit
|
||||
12. Do NOT use terminal actions for file editing or writing. Use the replace_in_file,
|
||||
write_to_file, or read_file tools instead.
|
||||
</notes>
|
||||
<examples>
|
||||
# Create new terminal with Node.js (default terminal)
|
||||
<function=terminal_action>
|
||||
<parameter=action>new_terminal</parameter>
|
||||
<parameter=inputs>["node", "Enter"]</parameter>
|
||||
</function>
|
||||
|
||||
# Create a second (parallel) terminal instance for Python
|
||||
<function=terminal_action>
|
||||
<parameter=action>new_terminal</parameter>
|
||||
<parameter=terminal_id>python_terminal</parameter>
|
||||
<parameter=inputs>["python3", "Enter"]</parameter>
|
||||
</function>
|
||||
|
||||
# Send command to the default terminal
|
||||
<function=terminal_action>
|
||||
<parameter=action>send_input</parameter>
|
||||
<parameter=inputs>["require('crypto').randomBytes(1000000).toString('hex')",
|
||||
"Enter"]</parameter>
|
||||
</function>
|
||||
|
||||
# Wait for previous action on default terminal
|
||||
<function=terminal_action>
|
||||
<parameter=action>wait</parameter>
|
||||
<parameter=time>2.0</parameter>
|
||||
</function>
|
||||
|
||||
# Send multiple inputs with special keys to current terminal
|
||||
<function=terminal_action>
|
||||
<parameter=action>send_input</parameter>
|
||||
<parameter=inputs>["sqlmap -u 'http://example.com/page.php?id=1' --batch", "Enter", "y",
|
||||
"Enter", "n", "Enter", "n", "Enter"]</parameter>
|
||||
</function>
|
||||
|
||||
# WRONG: Vim command with automatic spacing (becomes ": w q")
|
||||
<function=terminal_action>
|
||||
<parameter=action>send_input</parameter>
|
||||
<parameter=inputs>[":", "w", "q", "Enter"]</parameter>
|
||||
</function>
|
||||
|
||||
# CORRECT: Vim command using literal prefix (becomes ":wq")
|
||||
<function=terminal_action>
|
||||
<parameter=action>send_input</parameter>
|
||||
<parameter=inputs>[":", "literal:wq", "Enter"]</parameter>
|
||||
</function>
|
||||
</examples>
|
||||
</tool>
|
||||
</tools>
|
||||
231
strix/tools/terminal/terminal_instance.py
Normal file
231
strix/tools/terminal/terminal_instance.py
Normal file
@@ -0,0 +1,231 @@
|
||||
import contextlib
|
||||
import os
|
||||
import pty
|
||||
import select
|
||||
import signal
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import pyte
|
||||
|
||||
|
||||
MAX_TERMINAL_SNAPSHOT_LENGTH = 10_000
|
||||
|
||||
|
||||
class TerminalInstance:
|
||||
def __init__(self, terminal_id: str, initial_command: str | None = None) -> None:
|
||||
self.terminal_id = terminal_id
|
||||
self.process: subprocess.Popen[bytes] | None = None
|
||||
self.master_fd: int | None = None
|
||||
self.is_running = False
|
||||
self._output_lock = threading.Lock()
|
||||
self._reader_thread: threading.Thread | None = None
|
||||
|
||||
self.screen = pyte.HistoryScreen(80, 24, history=1000)
|
||||
self.stream = pyte.ByteStream()
|
||||
self.stream.attach(self.screen)
|
||||
|
||||
self._start_terminal(initial_command)
|
||||
|
||||
def _start_terminal(self, initial_command: str | None = None) -> None:
|
||||
try:
|
||||
self.master_fd, slave_fd = pty.openpty()
|
||||
|
||||
shell = "/bin/bash"
|
||||
|
||||
self.process = subprocess.Popen( # noqa: S603
|
||||
[shell, "-i"],
|
||||
stdin=slave_fd,
|
||||
stdout=slave_fd,
|
||||
stderr=slave_fd,
|
||||
cwd="/workspace",
|
||||
preexec_fn=os.setsid, # noqa: PLW1509 - Required for PTY functionality
|
||||
)
|
||||
|
||||
os.close(slave_fd)
|
||||
|
||||
self.is_running = True
|
||||
|
||||
self._reader_thread = threading.Thread(target=self._read_output, daemon=True)
|
||||
self._reader_thread.start()
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
if initial_command:
|
||||
self._write_to_terminal(initial_command)
|
||||
|
||||
except (OSError, ValueError) as e:
|
||||
raise RuntimeError(f"Failed to start terminal: {e}") from e
|
||||
|
||||
def _read_output(self) -> None:
|
||||
while self.is_running and self.master_fd:
|
||||
try:
|
||||
ready, _, _ = select.select([self.master_fd], [], [], 0.1)
|
||||
if ready:
|
||||
data = os.read(self.master_fd, 4096)
|
||||
if data:
|
||||
with self._output_lock, contextlib.suppress(TypeError):
|
||||
self.stream.feed(data)
|
||||
else:
|
||||
break
|
||||
except (OSError, ValueError):
|
||||
break
|
||||
|
||||
def _write_to_terminal(self, data: str) -> None:
|
||||
if self.master_fd and self.is_running:
|
||||
try:
|
||||
os.write(self.master_fd, data.encode("utf-8"))
|
||||
except (OSError, ValueError) as e:
|
||||
raise RuntimeError("Terminal is no longer available") from e
|
||||
|
||||
def send_input(self, inputs: list[str]) -> None:
|
||||
if not self.is_running:
|
||||
raise RuntimeError("Terminal is not running")
|
||||
|
||||
for i, input_item in enumerate(inputs):
|
||||
if input_item.startswith("literal:"):
|
||||
literal_text = input_item[8:]
|
||||
self._write_to_terminal(literal_text)
|
||||
else:
|
||||
key_sequence = self._get_key_sequence(input_item)
|
||||
if key_sequence:
|
||||
self._write_to_terminal(key_sequence)
|
||||
else:
|
||||
self._write_to_terminal(input_item)
|
||||
|
||||
time.sleep(0.05)
|
||||
|
||||
if (
|
||||
i < len(inputs) - 1
|
||||
and not input_item.startswith("literal:")
|
||||
and not self._is_special_key(input_item)
|
||||
and not inputs[i + 1].startswith("literal:")
|
||||
and not self._is_special_key(inputs[i + 1])
|
||||
):
|
||||
self._write_to_terminal(" ")
|
||||
|
||||
def get_snapshot(self) -> dict[str, Any]:
|
||||
with self._output_lock:
|
||||
history_lines = [
|
||||
"".join(char.data for char in line_dict.values())
|
||||
for line_dict in self.screen.history.top
|
||||
]
|
||||
|
||||
current_lines = self.screen.display
|
||||
|
||||
all_lines = history_lines + current_lines
|
||||
rendered_output = "\n".join(all_lines)
|
||||
|
||||
if len(rendered_output) > MAX_TERMINAL_SNAPSHOT_LENGTH:
|
||||
rendered_output = rendered_output[-MAX_TERMINAL_SNAPSHOT_LENGTH:]
|
||||
truncated = True
|
||||
else:
|
||||
truncated = False
|
||||
|
||||
return {
|
||||
"terminal_id": self.terminal_id,
|
||||
"snapshot": rendered_output,
|
||||
"is_running": self.is_running,
|
||||
"process_id": self.process.pid if self.process else None,
|
||||
"truncated": truncated,
|
||||
}
|
||||
|
||||
def wait(self, duration: float) -> dict[str, Any]:
|
||||
time.sleep(duration)
|
||||
return self.get_snapshot()
|
||||
|
||||
def close(self) -> None:
|
||||
self.is_running = False
|
||||
|
||||
if self.process:
|
||||
with contextlib.suppress(OSError, ProcessLookupError):
|
||||
os.killpg(os.getpgid(self.process.pid), signal.SIGTERM)
|
||||
|
||||
try:
|
||||
self.process.wait(timeout=2)
|
||||
except subprocess.TimeoutExpired:
|
||||
os.killpg(os.getpgid(self.process.pid), signal.SIGKILL)
|
||||
self.process.wait()
|
||||
|
||||
if self.master_fd:
|
||||
with contextlib.suppress(OSError):
|
||||
os.close(self.master_fd)
|
||||
self.master_fd = None
|
||||
|
||||
if self._reader_thread and self._reader_thread.is_alive():
|
||||
self._reader_thread.join(timeout=1)
|
||||
|
||||
def _is_special_key(self, key: str) -> bool:
|
||||
special_keys = {
|
||||
"Enter",
|
||||
"Space",
|
||||
"Backspace",
|
||||
"Tab",
|
||||
"Escape",
|
||||
"Up",
|
||||
"Down",
|
||||
"Left",
|
||||
"Right",
|
||||
"Home",
|
||||
"End",
|
||||
"PageUp",
|
||||
"PageDown",
|
||||
"Insert",
|
||||
"Delete",
|
||||
} | {f"F{i}" for i in range(1, 13)}
|
||||
|
||||
if key in special_keys:
|
||||
return True
|
||||
|
||||
return bool(key.startswith(("^", "C-", "S-", "A-")))
|
||||
|
||||
def _get_key_sequence(self, key: str) -> str | None:
|
||||
key_map = {
|
||||
"Enter": "\r",
|
||||
"Space": " ",
|
||||
"Backspace": "\x08",
|
||||
"Tab": "\t",
|
||||
"Escape": "\x1b",
|
||||
"Up": "\x1b[A",
|
||||
"Down": "\x1b[B",
|
||||
"Right": "\x1b[C",
|
||||
"Left": "\x1b[D",
|
||||
"Home": "\x1b[H",
|
||||
"End": "\x1b[F",
|
||||
"PageUp": "\x1b[5~",
|
||||
"PageDown": "\x1b[6~",
|
||||
"Insert": "\x1b[2~",
|
||||
"Delete": "\x1b[3~",
|
||||
"F1": "\x1b[11~",
|
||||
"F2": "\x1b[12~",
|
||||
"F3": "\x1b[13~",
|
||||
"F4": "\x1b[14~",
|
||||
"F5": "\x1b[15~",
|
||||
"F6": "\x1b[17~",
|
||||
"F7": "\x1b[18~",
|
||||
"F8": "\x1b[19~",
|
||||
"F9": "\x1b[20~",
|
||||
"F10": "\x1b[21~",
|
||||
"F11": "\x1b[23~",
|
||||
"F12": "\x1b[24~",
|
||||
}
|
||||
|
||||
if key in key_map:
|
||||
return key_map[key]
|
||||
|
||||
if key.startswith("^") and len(key) == 2:
|
||||
char = key[1].lower()
|
||||
return chr(ord(char) - ord("a") + 1) if "a" <= char <= "z" else None
|
||||
|
||||
if key.startswith("C-") and len(key) == 3:
|
||||
char = key[2].lower()
|
||||
return chr(ord(char) - ord("a") + 1) if "a" <= char <= "z" else None
|
||||
|
||||
return None
|
||||
|
||||
def is_alive(self) -> bool:
|
||||
if not self.process:
|
||||
return False
|
||||
return self.process.poll() is None
|
||||
191
strix/tools/terminal/terminal_manager.py
Normal file
191
strix/tools/terminal/terminal_manager.py
Normal file
@@ -0,0 +1,191 @@
|
||||
import atexit
|
||||
import contextlib
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from .terminal_instance import TerminalInstance
|
||||
|
||||
|
||||
class TerminalManager:
|
||||
def __init__(self) -> None:
|
||||
self.terminals: dict[str, TerminalInstance] = {}
|
||||
self._lock = threading.Lock()
|
||||
self.default_terminal_id = "default"
|
||||
|
||||
self._register_cleanup_handlers()
|
||||
|
||||
def create_terminal(
|
||||
self, terminal_id: str | None = None, inputs: list[str] | None = None
|
||||
) -> dict[str, Any]:
|
||||
if terminal_id is None:
|
||||
terminal_id = self.default_terminal_id
|
||||
|
||||
with self._lock:
|
||||
if terminal_id in self.terminals:
|
||||
raise ValueError(f"Terminal '{terminal_id}' already exists")
|
||||
|
||||
initial_command = None
|
||||
if inputs:
|
||||
command_parts: list[str] = []
|
||||
for input_item in inputs:
|
||||
if input_item == "Enter":
|
||||
initial_command = " ".join(command_parts) + "\n"
|
||||
break
|
||||
if input_item.startswith("literal:"):
|
||||
command_parts.append(input_item[8:])
|
||||
elif input_item not in [
|
||||
"Space",
|
||||
"Tab",
|
||||
"Backspace",
|
||||
]:
|
||||
command_parts.append(input_item)
|
||||
|
||||
try:
|
||||
terminal = TerminalInstance(terminal_id, initial_command)
|
||||
self.terminals[terminal_id] = terminal
|
||||
|
||||
if inputs and not initial_command:
|
||||
terminal.send_input(inputs)
|
||||
result = terminal.wait(2.0)
|
||||
else:
|
||||
result = terminal.wait(1.0)
|
||||
|
||||
result["message"] = f"Terminal '{terminal_id}' created successfully"
|
||||
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to create terminal '{terminal_id}': {e}") from e
|
||||
else:
|
||||
return result
|
||||
|
||||
def send_input(
|
||||
self, terminal_id: str | None = None, inputs: list[str] | None = None
|
||||
) -> dict[str, Any]:
|
||||
if terminal_id is None:
|
||||
terminal_id = self.default_terminal_id
|
||||
|
||||
if not inputs:
|
||||
raise ValueError("No inputs provided")
|
||||
|
||||
with self._lock:
|
||||
if terminal_id not in self.terminals:
|
||||
raise ValueError(f"Terminal '{terminal_id}' not found")
|
||||
|
||||
terminal = self.terminals[terminal_id]
|
||||
|
||||
try:
|
||||
terminal.send_input(inputs)
|
||||
result = terminal.wait(2.0)
|
||||
result["message"] = f"Input sent to terminal '{terminal_id}'"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to send input to terminal '{terminal_id}': {e}") from e
|
||||
else:
|
||||
return result
|
||||
|
||||
def wait_terminal(
|
||||
self, terminal_id: str | None = None, duration: float = 1.0
|
||||
) -> dict[str, Any]:
|
||||
if terminal_id is None:
|
||||
terminal_id = self.default_terminal_id
|
||||
|
||||
with self._lock:
|
||||
if terminal_id not in self.terminals:
|
||||
raise ValueError(f"Terminal '{terminal_id}' not found")
|
||||
|
||||
terminal = self.terminals[terminal_id]
|
||||
|
||||
try:
|
||||
result = terminal.wait(duration)
|
||||
result["message"] = f"Waited {duration}s on terminal '{terminal_id}'"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to wait on terminal '{terminal_id}': {e}") from e
|
||||
else:
|
||||
return result
|
||||
|
||||
def close_terminal(self, terminal_id: str | None = None) -> dict[str, Any]:
|
||||
if terminal_id is None:
|
||||
terminal_id = self.default_terminal_id
|
||||
|
||||
with self._lock:
|
||||
if terminal_id not in self.terminals:
|
||||
raise ValueError(f"Terminal '{terminal_id}' not found")
|
||||
|
||||
terminal = self.terminals.pop(terminal_id)
|
||||
|
||||
try:
|
||||
terminal.close()
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to close terminal '{terminal_id}': {e}") from e
|
||||
else:
|
||||
return {
|
||||
"terminal_id": terminal_id,
|
||||
"message": f"Terminal '{terminal_id}' closed successfully",
|
||||
"snapshot": "",
|
||||
"is_running": False,
|
||||
}
|
||||
|
||||
def get_terminal_snapshot(self, terminal_id: str | None = None) -> dict[str, Any]:
|
||||
if terminal_id is None:
|
||||
terminal_id = self.default_terminal_id
|
||||
|
||||
with self._lock:
|
||||
if terminal_id not in self.terminals:
|
||||
raise ValueError(f"Terminal '{terminal_id}' not found")
|
||||
|
||||
terminal = self.terminals[terminal_id]
|
||||
|
||||
return terminal.get_snapshot()
|
||||
|
||||
def list_terminals(self) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
terminal_info = {}
|
||||
for tid, terminal in self.terminals.items():
|
||||
terminal_info[tid] = {
|
||||
"is_running": terminal.is_running,
|
||||
"is_alive": terminal.is_alive(),
|
||||
"process_id": terminal.process.pid if terminal.process else None,
|
||||
}
|
||||
|
||||
return {"terminals": terminal_info, "total_count": len(terminal_info)}
|
||||
|
||||
def cleanup_dead_terminals(self) -> None:
|
||||
with self._lock:
|
||||
dead_terminals = []
|
||||
for tid, terminal in self.terminals.items():
|
||||
if not terminal.is_alive():
|
||||
dead_terminals.append(tid)
|
||||
|
||||
for tid in dead_terminals:
|
||||
terminal = self.terminals.pop(tid)
|
||||
with contextlib.suppress(Exception):
|
||||
terminal.close()
|
||||
|
||||
def close_all_terminals(self) -> None:
|
||||
with self._lock:
|
||||
terminals_to_close = list(self.terminals.values())
|
||||
self.terminals.clear()
|
||||
|
||||
for terminal in terminals_to_close:
|
||||
with contextlib.suppress(Exception):
|
||||
terminal.close()
|
||||
|
||||
def _register_cleanup_handlers(self) -> None:
|
||||
atexit.register(self.close_all_terminals)
|
||||
|
||||
signal.signal(signal.SIGTERM, self._signal_handler)
|
||||
signal.signal(signal.SIGINT, self._signal_handler)
|
||||
|
||||
if hasattr(signal, "SIGHUP"):
|
||||
signal.signal(signal.SIGHUP, self._signal_handler)
|
||||
|
||||
def _signal_handler(self, _signum: int, _frame: Any) -> None:
|
||||
self.close_all_terminals()
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
_terminal_manager = TerminalManager()
|
||||
|
||||
|
||||
def get_terminal_manager() -> TerminalManager:
|
||||
return _terminal_manager
|
||||
4
strix/tools/thinking/__init__.py
Normal file
4
strix/tools/thinking/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .thinking_actions import think
|
||||
|
||||
|
||||
__all__ = ["think"]
|
||||
18
strix/tools/thinking/thinking_actions.py
Normal file
18
strix/tools/thinking/thinking_actions.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from typing import Any
|
||||
|
||||
from strix.tools.registry import register_tool
|
||||
|
||||
|
||||
@register_tool(sandbox_execution=False)
|
||||
def think(thought: str) -> dict[str, Any]:
|
||||
try:
|
||||
if not thought or not thought.strip():
|
||||
return {"success": False, "message": "Thought cannot be empty"}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Thought recorded successfully with {len(thought.strip())} characters",
|
||||
}
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
return {"success": False, "message": f"Failed to record thought: {e!s}"}
|
||||
52
strix/tools/thinking/thinking_actions_schema.xml
Normal file
52
strix/tools/thinking/thinking_actions_schema.xml
Normal file
@@ -0,0 +1,52 @@
|
||||
<tools>
|
||||
<tool name="think">
|
||||
<description>Use the tool to think about something. It will not obtain new information or change the
|
||||
database. Use it when complex reasoning or some cache memory is needed.</description>
|
||||
<details>This tool creates dedicated space for structured thinking during complex tasks,
|
||||
particularly useful for:
|
||||
- Tool output analysis: When you need to carefully process the output of previous tool calls
|
||||
- Policy-heavy environments: When you need to follow detailed guidelines and verify compliance
|
||||
- Sequential decision making: When each action builds on previous ones and mistakes are costly
|
||||
- Multi-step problem solving: When you need to break down complex problems into manageable steps</details>
|
||||
<parameters>
|
||||
<parameter name="thought" type="string" required="true">
|
||||
<description>The thought or reasoning to record</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
<returns type="Dict[str, Any]">
|
||||
<description>Response containing: - success: Whether the thought was recorded successfully - message: Confirmation message with character count or error details</description>
|
||||
</returns>
|
||||
<examples>
|
||||
# Planning and strategy
|
||||
<function=think>
|
||||
<parameter=thought>I need to analyze the scan results systematically. First, let me review
|
||||
the open ports: 22 (SSH), 80 (HTTP), 443 (HTTPS), and 3306 (MySQL). The MySQL port being
|
||||
externally accessible is a high priority security concern. I should check for default
|
||||
credentials and version information. For the web services, I need to enumerate
|
||||
directories and test for common web vulnerabilities.</parameter>
|
||||
</function>
|
||||
|
||||
# Analysis of tool outputs
|
||||
<function=think>
|
||||
<parameter=thought>The Nmap scan revealed 15 open ports, but three stand out as concerning:
|
||||
- Port 3306 (MySQL) - Database should not be exposed externally
|
||||
- Port 5432 (PostgreSQL) - Another database port that's risky when public
|
||||
- Port 6379 (Redis) - Often misconfigured and can lead to data exposure
|
||||
I should prioritize testing these database services for authentication bypass and
|
||||
information disclosure vulnerabilities.</parameter>
|
||||
</function>
|
||||
|
||||
# Decision making and next steps
|
||||
<function=think>
|
||||
<parameter=thought>Based on the vulnerability scan results, I've identified several critical
|
||||
issues that need immediate attention:
|
||||
1. SQL injection in the login form (confirmed with sqlmap)
|
||||
2. Reflected XSS in the search parameter
|
||||
3. Directory traversal in the file upload function
|
||||
I should document these findings with proof-of-concept exploits and assign appropriate
|
||||
CVSS scores. The SQL injection poses the highest risk due to potential data
|
||||
exfiltration.</parameter>
|
||||
</function>
|
||||
</examples>
|
||||
</tool>
|
||||
</tools>
|
||||
4
strix/tools/web_search/__init__.py
Normal file
4
strix/tools/web_search/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .web_search_actions import web_search
|
||||
|
||||
|
||||
__all__ = ["web_search"]
|
||||
80
strix/tools/web_search/web_search_actions.py
Normal file
80
strix/tools/web_search/web_search_actions.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
from strix.tools.registry import register_tool
|
||||
|
||||
|
||||
SYSTEM_PROMPT = """You are assisting a cybersecurity agent specialized in vulnerability scanning
|
||||
and security assessment running on Kali Linux. When responding to search queries:
|
||||
|
||||
1. Prioritize cybersecurity-relevant information including:
|
||||
- Vulnerability details (CVEs, CVSS scores, impact)
|
||||
- Security tools, techniques, and methodologies
|
||||
- Exploit information and proof-of-concepts
|
||||
- Security best practices and mitigations
|
||||
- Penetration testing approaches
|
||||
- Web application security findings
|
||||
|
||||
2. Provide technical depth appropriate for security professionals
|
||||
3. Include specific versions, configurations, and technical details when available
|
||||
4. Focus on actionable intelligence for security assessment
|
||||
5. Cite reliable security sources (NIST, OWASP, CVE databases, security vendors)
|
||||
6. When providing commands or installation instructions, prioritize Kali Linux compatibility
|
||||
and use apt package manager or tools pre-installed in Kali
|
||||
7. Be detailed and specific - avoid general answers. Always include concrete code examples,
|
||||
command-line instructions, configuration snippets, or practical implementation steps
|
||||
when applicable
|
||||
|
||||
Structure your response to be comprehensive yet concise, emphasizing the most critical
|
||||
security implications and details."""
|
||||
|
||||
|
||||
@register_tool(sandbox_execution=False)
|
||||
def web_search(query: str) -> dict[str, Any]:
|
||||
try:
|
||||
api_key = os.getenv("PERPLEXITY_API_KEY")
|
||||
if not api_key:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "PERPLEXITY_API_KEY environment variable not set",
|
||||
"results": [],
|
||||
}
|
||||
|
||||
url = "https://api.perplexity.ai/chat/completions"
|
||||
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
||||
|
||||
payload = {
|
||||
"model": "sonar-reasoning",
|
||||
"messages": [
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": query},
|
||||
],
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=payload, timeout=300)
|
||||
response.raise_for_status()
|
||||
|
||||
response_data = response.json()
|
||||
content = response_data["choices"][0]["message"]["content"]
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
return {"success": False, "message": "Request timed out", "results": []}
|
||||
except requests.exceptions.RequestException as e:
|
||||
return {"success": False, "message": f"API request failed: {e!s}", "results": []}
|
||||
except KeyError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"Unexpected API response format: missing {e!s}",
|
||||
"results": [],
|
||||
}
|
||||
except Exception as e: # noqa: BLE001
|
||||
return {"success": False, "message": f"Web search failed: {e!s}", "results": []}
|
||||
else:
|
||||
return {
|
||||
"success": True,
|
||||
"query": query,
|
||||
"content": content,
|
||||
"message": "Web search completed successfully",
|
||||
}
|
||||
83
strix/tools/web_search/web_search_actions_schema.xml
Normal file
83
strix/tools/web_search/web_search_actions_schema.xml
Normal file
@@ -0,0 +1,83 @@
|
||||
<tools>
|
||||
<tool name="web_search">
|
||||
<description>Search the web using Perplexity AI for real-time information and current events.
|
||||
|
||||
This is your PRIMARY research tool - use it extensively and liberally for:
|
||||
- Current vulnerabilities, CVEs, and security advisories
|
||||
- Latest attack techniques, exploits, and proof-of-concepts
|
||||
- Technology-specific security research and documentation
|
||||
- Target reconnaissance and OSINT gathering
|
||||
- Security tool documentation and usage guides
|
||||
- Incident response and threat intelligence
|
||||
- Compliance frameworks and security standards
|
||||
- Bug bounty reports and security research findings
|
||||
- Security conference talks and research papers
|
||||
|
||||
The tool provides intelligent, contextual responses with current information that may not be in your training data. Use it early and often during security assessments to gather the most up-to-date factual information.</description>
|
||||
<details>This tool leverages Perplexity AI's sonar-reasoning model to search the web and provide intelligent, contextual responses to queries. It's essential for effective cybersecurity work as it provides access to the latest vulnerabilities, attack vectors, security tools, and defensive techniques. The AI understands security context and can synthesize information from multiple sources.</details>
|
||||
<parameters>
|
||||
<parameter name="query" type="string" required="true">
|
||||
<description>The search query or question you want to research. Be specific and include relevant technical terms, version numbers, or context for better results. Make it as detailed as possible, with the context of the current security assessment.</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
<returns type="Dict[str, Any]">
|
||||
<description>Response containing: - success: Whether the search was successful - query: The original search query - content: AI-generated response with current information - message: Status message</description>
|
||||
</returns>
|
||||
<examples>
|
||||
# Found specific service version during reconnaissance
|
||||
<function=web_search>
|
||||
<parameter=query>I found OpenSSH 7.4 running on port 22. Are there any known exploits or privilege escalation techniques for this specific version?</parameter>
|
||||
</function>
|
||||
|
||||
# Encountered WAF blocking attempts
|
||||
<function=web_search>
|
||||
<parameter=query>Cloudflare is blocking my SQLmap attempts on this login form. What are the latest bypass techniques for Cloudflare WAF in 2024?</parameter>
|
||||
</function>
|
||||
|
||||
# Need to exploit discovered CMS
|
||||
<function=web_search>
|
||||
<parameter=query>Target is running WordPress 5.8.3 with WooCommerce 6.1.1. What are the current RCE exploits for this combination?</parameter>
|
||||
</function>
|
||||
|
||||
# Stuck on privilege escalation
|
||||
<function=web_search>
|
||||
<parameter=query>I have low-privilege shell on Ubuntu 20.04 with kernel 5.4.0-74-generic. What local privilege escalation exploits work for this exact kernel version?</parameter>
|
||||
</function>
|
||||
|
||||
# Need lateral movement in Active Directory
|
||||
<function=web_search>
|
||||
<parameter=query>I compromised a domain user account in Windows Server 2019 AD environment. What are the best techniques to escalate to Domain Admin without triggering EDR?</parameter>
|
||||
</function>
|
||||
|
||||
# Encountered specific error during exploitation
|
||||
<function=web_search>
|
||||
<parameter=query>Getting "Access denied" when trying to upload webshell to IIS 10.0. What are alternative file upload bypass techniques for Windows IIS?</parameter>
|
||||
</function>
|
||||
|
||||
# Need to bypass endpoint protection
|
||||
<function=web_search>
|
||||
<parameter=query>Target has CrowdStrike Falcon running. What are the latest techniques to bypass this EDR for payload execution and persistence?</parameter>
|
||||
</function>
|
||||
|
||||
# Research target's infrastructure for attack surface
|
||||
<function=web_search>
|
||||
<parameter=query>I found target company "AcmeCorp" uses Office 365 and Azure. What are the common misconfigurations and attack vectors for this cloud setup?</parameter>
|
||||
</function>
|
||||
|
||||
# Found interesting subdomain during recon
|
||||
<function=web_search>
|
||||
<parameter=query>Discovered staging.target.com running Jenkins 2.401.3. What are the current authentication bypass and RCE exploits for this Jenkins version?</parameter>
|
||||
</function>
|
||||
|
||||
# Need alternative tools when primary fails
|
||||
<function=web_search>
|
||||
<parameter=query>Nmap is being detected and blocked by the target's IPS. What are stealthy alternatives for port scanning that evade modern intrusion prevention systems?</parameter>
|
||||
</function>
|
||||
|
||||
# Finding best security tools for specific tasks
|
||||
<function=web_search>
|
||||
<parameter=query>What is the best Python pip package in 2025 for JWT security testing and manipulation, including cracking weak secrets and algorithm confusion attacks?</parameter>
|
||||
</function>
|
||||
</examples>
|
||||
</tool>
|
||||
</tools>
|
||||
Reference in New Issue
Block a user