12 Commits

Author SHA1 Message Date
0xallam
e5104eb93a chore(release): bump version to 0.6.1 2026-01-14 21:30:14 -08:00
0xallam
d8a08e9a8c chore(prompt): discourage literal \n in tool params 2026-01-14 21:29:06 -08:00
0xallam
f6475cec07 chore(prompt): enforce single tool call per message and remove stop word usage 2026-01-14 19:51:08 -08:00
0xallam
31baa0dfc0 fix: restore ollama_api_base config fallback for Ollama support 2026-01-14 18:54:45 -08:00
0xallam
56526cbf90 fix(agent): fix agent loop hanging and simplify LLM module
- Fix agent loop getting stuck by adding hard stop mechanism
- Add _force_stop flag for immediate task cancellation across threads
- Use thread-safe loop.call_soon_threadsafe for cross-thread cancellation
- Remove request_queue.py (eliminated threading/queue complexity causing hangs)
- Simplify llm.py: direct acompletion calls, cleaner streaming
- Reduce retry wait times to prevent long hangs during retries
- Make timeouts configurable (llm_max_retries, memory_compressor_timeout, sandbox_execution_timeout)
- Keep essential token tracking (input/output/cached tokens, cost, requests)
- Maintain Anthropic prompt caching for system messages
2026-01-14 18:54:45 -08:00
0xallam
47faeb1ef3 fix(agent): use correct agent name in identity instead of class name 2026-01-14 11:24:24 -08:00
0xallam
435ac82d9e chore: add defusedxml dependency 2026-01-14 10:57:32 -08:00
0xallam
f08014cf51 fix(agent): fix tool schemas not retrieved on pyinstaller binary and validate tool call args 2026-01-14 10:57:32 -08:00
dependabot[bot]
bc8e14f68a chore(deps-dev): bump virtualenv from 20.34.0 to 20.36.1
Bumps [virtualenv](https://github.com/pypa/virtualenv) from 20.34.0 to 20.36.1.
- [Release notes](https://github.com/pypa/virtualenv/releases)
- [Changelog](https://github.com/pypa/virtualenv/blob/main/docs/changelog.rst)
- [Commits](https://github.com/pypa/virtualenv/compare/20.34.0...20.36.1)

---
updated-dependencies:
- dependency-name: virtualenv
  dependency-version: 20.36.1
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-01-13 17:15:58 -08:00
dependabot[bot]
eae2b783c0 chore(deps): bump filelock from 3.20.1 to 3.20.3
Bumps [filelock](https://github.com/tox-dev/py-filelock) from 3.20.1 to 3.20.3.
- [Release notes](https://github.com/tox-dev/py-filelock/releases)
- [Changelog](https://github.com/tox-dev/filelock/blob/main/docs/changelog.rst)
- [Commits](https://github.com/tox-dev/py-filelock/compare/3.20.1...3.20.3)

---
updated-dependencies:
- dependency-name: filelock
  dependency-version: 3.20.3
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-01-13 17:15:43 -08:00
dependabot[bot]
058cf1abdb chore(deps): bump azure-core from 1.35.0 to 1.38.0
Bumps [azure-core](https://github.com/Azure/azure-sdk-for-python) from 1.35.0 to 1.38.0.
- [Release notes](https://github.com/Azure/azure-sdk-for-python/releases)
- [Commits](https://github.com/Azure/azure-sdk-for-python/compare/azure-core_1.35.0...azure-core_1.38.0)

---
updated-dependencies:
- dependency-name: azure-core
  dependency-version: 1.38.0
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-01-13 17:15:22 -08:00
Ahmed Allam
d16bdb277a Update README 2026-01-14 05:00:16 +04:00
19 changed files with 448 additions and 561 deletions

View File

@@ -1,55 +1,61 @@
<p align="center"> <p align="center">
<a href="https://strix.ai/"> <a href="https://strix.ai/">
<img src=".github/logo.png" width="150" alt="Strix Logo"> <img src="https://github.com/usestrix/.github/raw/main/imgs/cover.png" alt="Strix Banner" width="100%">
</a> </a>
</p> </p>
<h1 align="center">Strix</h1>
<h2 align="center">Open-source AI Hackers to secure your Apps</h2>
<div align="center"> <div align="center">
[![Python](https://img.shields.io/pypi/pyversions/strix-agent?color=3776AB)](https://pypi.org/project/strix-agent/) # Strix
[![PyPI](https://img.shields.io/pypi/v/strix-agent?color=10b981)](https://pypi.org/project/strix-agent/)
[![License](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](LICENSE)
[![Docs](https://img.shields.io/badge/Docs-docs.strix.ai-10b981.svg)](https://docs.strix.ai)
[![GitHub Stars](https://img.shields.io/github/stars/usestrix/strix)](https://github.com/usestrix/strix) ### Open-source AI hackers to find and fix your apps vulnerabilities.
[![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?&logo=discord&logoColor=white)](https://discord.gg/YjKFvEZSdZ)
[![Website](https://img.shields.io/badge/Website-strix.ai-2d3748.svg)](https://strix.ai)
<a href="https://trendshift.io/repositories/15362" target="_blank"><img src="https://trendshift.io/api/badge/repositories/15362" alt="usestrix%2Fstrix | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a> <br/>
[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/usestrix/strix) <a href="https://docs.strix.ai"><img src="https://img.shields.io/badge/Docs-docs.strix.ai-2b9246?style=for-the-badge&logo=gitbook&logoColor=white" alt="Docs"></a>
<a href="https://strix.ai"><img src="https://img.shields.io/badge/Website-strix.ai-3b82f6?style=for-the-badge&logoColor=white" alt="Website"></a>
<a href="https://pypi.org/project/strix-agent/"><img src="https://img.shields.io/badge/PyPI-strix--agent-f59e0b?style=for-the-badge&logo=pypi&logoColor=white" alt="PyPI"></a>
<a href="https://deepwiki.com/usestrix/strix"><img src="https://deepwiki.com/badge.svg" alt="Ask DeepWiki"></a>
<a href="https://github.com/usestrix/strix"><img src="https://img.shields.io/github/stars/usestrix/strix?style=flat-square" alt="GitHub Stars"></a>
<a href="LICENSE"><img src="https://img.shields.io/badge/License-Apache%202.0-3b82f6?style=flat-square" alt="License"></a>
<a href="https://pypi.org/project/strix-agent/"><img src="https://img.shields.io/pypi/v/strix-agent?style=flat-square" alt="PyPI Version"></a>
<a href="https://discord.gg/YjKFvEZSdZ"><img src="https://github.com/usestrix/.github/raw/main/imgs/Discord.png" height="40" alt="Join Discord"></a>
<a href="https://x.com/strix_ai"><img src="https://github.com/usestrix/.github/raw/main/imgs/X.png" height="40" alt="Follow on X"></a>
<a href="https://trendshift.io/repositories/15362" target="_blank"><img src="https://trendshift.io/api/badge/repositories/15362" alt="usestrix/strix | Trendshift" width="250" height="55"/></a>
</div> </div>
<br> <br/>
<div align="center"> <div align="center">
<img src=".github/screenshot.png" alt="Strix Demo" width="800" style="border-radius: 16px;"> <img src=".github/screenshot.png" alt="Strix Demo" width="900" style="border-radius: 16px;">
</div> </div>
<br> <br>
> [!TIP] > [!TIP]
> **New!** Strix now integrates seamlessly with GitHub Actions and CI/CD pipelines. Automatically scan for vulnerabilities on every pull request and block insecure code before it reaches production! > **New!** Strix integrates seamlessly with GitHub Actions and CI/CD pipelines. Automatically scan for vulnerabilities on every pull request and block insecure code before it reaches production!
--- ---
## 🦉 Strix Overview
## Strix Overview
Strix are autonomous AI agents that act just like real hackers - they run your code dynamically, find vulnerabilities, and validate them through actual proof-of-concepts. Built for developers and security teams who need fast, accurate security testing without the overhead of manual pentesting or the false positives of static analysis tools. Strix are autonomous AI agents that act just like real hackers - they run your code dynamically, find vulnerabilities, and validate them through actual proof-of-concepts. Built for developers and security teams who need fast, accurate security testing without the overhead of manual pentesting or the false positives of static analysis tools.
**Key Capabilities:** **Key Capabilities:**
- 🔧 **Full hacker toolkit** out of the box - **Full hacker toolkit** out of the box
- 🤝 **Teams of agents** that collaborate and scale - **Teams of agents** that collaborate and scale
- **Real validation** with PoCs, not false positives - **Real validation** with PoCs, not false positives
- 💻 **Developerfirst** CLI with actionable reports - **Developerfirst** CLI with actionable reports
- 🔄 **Autofix & reporting** to accelerate remediation - **Autofix & reporting** to accelerate remediation
## 🎯 Use Cases ## 🎯 Use Cases
@@ -87,7 +93,7 @@ strix --target ./app-directory
> [!NOTE] > [!NOTE]
> First run automatically pulls the sandbox Docker image. Results are saved to `strix_runs/<run-name>` > First run automatically pulls the sandbox Docker image. Results are saved to `strix_runs/<run-name>`
## ☁️ Run Strix in Cloud ## Run Strix in Cloud
Want to skip the local setup, API keys, and unpredictable LLM costs? Run the hosted cloud version of Strix at **[app.strix.ai](https://strix.ai)**. Want to skip the local setup, API keys, and unpredictable LLM costs? Run the hosted cloud version of Strix at **[app.strix.ai](https://strix.ai)**.
@@ -104,7 +110,7 @@ Launch a scan in just a few minutes—no setup or configuration required—and y
## ✨ Features ## ✨ Features
### 🛠️ Agentic Security Tools ### Agentic Security Tools
Strix agents come equipped with a comprehensive security testing toolkit: Strix agents come equipped with a comprehensive security testing toolkit:
@@ -116,7 +122,7 @@ Strix agents come equipped with a comprehensive security testing toolkit:
- **Code Analysis** - Static and dynamic analysis capabilities - **Code Analysis** - Static and dynamic analysis capabilities
- **Knowledge Management** - Structured findings and attack documentation - **Knowledge Management** - Structured findings and attack documentation
### 🎯 Comprehensive Vulnerability Detection ### Comprehensive Vulnerability Detection
Strix can identify and validate a wide range of security vulnerabilities: Strix can identify and validate a wide range of security vulnerabilities:
@@ -128,7 +134,7 @@ Strix can identify and validate a wide range of security vulnerabilities:
- **Authentication** - JWT vulnerabilities, session management - **Authentication** - JWT vulnerabilities, session management
- **Infrastructure** - Misconfigurations, exposed services - **Infrastructure** - Misconfigurations, exposed services
### 🕸️ Graph of Agents ### Graph of Agents
Advanced multi-agent orchestration for comprehensive security testing: Advanced multi-agent orchestration for comprehensive security testing:
@@ -138,7 +144,7 @@ Advanced multi-agent orchestration for comprehensive security testing:
--- ---
## 💻 Usage Examples ## Usage Examples
### Basic Usage ### Basic Usage
@@ -169,7 +175,7 @@ strix --target api.your-app.com --instruction "Focus on business logic flaws and
strix --target api.your-app.com --instruction-file ./instruction.md strix --target api.your-app.com --instruction-file ./instruction.md
``` ```
### 🤖 Headless Mode ### Headless Mode
Run Strix programmatically without interactive UI using the `-n/--non-interactive` flag—perfect for servers and automated jobs. The CLI prints real-time vulnerability findings, and the final report before exiting. Exits with non-zero code when vulnerabilities are found. Run Strix programmatically without interactive UI using the `-n/--non-interactive` flag—perfect for servers and automated jobs. The CLI prints real-time vulnerability findings, and the final report before exiting. Exits with non-zero code when vulnerabilities are found.
@@ -177,7 +183,7 @@ Run Strix programmatically without interactive UI using the `-n/--non-interactiv
strix -n --target https://your-app.com strix -n --target https://your-app.com
``` ```
### 🔄 CI/CD (GitHub Actions) ### CI/CD (GitHub Actions)
Strix can be added to your pipeline to run a security test on pull requests with a lightweight GitHub Actions workflow: Strix can be added to your pipeline to run a security test on pull requests with a lightweight GitHub Actions workflow:
@@ -204,7 +210,7 @@ jobs:
run: strix -n -t ./ --scan-mode quick run: strix -n -t ./ --scan-mode quick
``` ```
### ⚙️ Configuration ### Configuration
```bash ```bash
export STRIX_LLM="openai/gpt-5" export STRIX_LLM="openai/gpt-5"
@@ -227,22 +233,23 @@ export STRIX_REASONING_EFFORT="high" # control thinking effort (default: high,
See the [LLM Providers documentation](https://docs.strix.ai/llm-providers/overview) for all supported providers including Vertex AI, Bedrock, Azure, and local models. See the [LLM Providers documentation](https://docs.strix.ai/llm-providers/overview) for all supported providers including Vertex AI, Bedrock, Azure, and local models.
## 📚 Documentation ## Documentation
Full documentation is available at **[docs.strix.ai](https://docs.strix.ai)** — including detailed guides for usage, CI/CD integrations, skills, and advanced configuration. Full documentation is available at **[docs.strix.ai](https://docs.strix.ai)** — including detailed guides for usage, CI/CD integrations, skills, and advanced configuration.
## 🤝 Contributing ## Contributing
We welcome contributions of code, docs, and new skills - check out our [Contributing Guide](https://docs.strix.ai/contributing) to get started or open a [pull request](https://github.com/usestrix/strix/pulls)/[issue](https://github.com/usestrix/strix/issues). We welcome contributions of code, docs, and new skills - check out our [Contributing Guide](https://docs.strix.ai/contributing) to get started or open a [pull request](https://github.com/usestrix/strix/pulls)/[issue](https://github.com/usestrix/strix/issues).
## 👥 Join Our Community ## Join Our Community
Have questions? Found a bug? Want to contribute? **[Join our Discord!](https://discord.gg/YjKFvEZSdZ)** Have questions? Found a bug? Want to contribute? **[Join our Discord!](https://discord.gg/YjKFvEZSdZ)**
## 🌟 Support the Project ## Support the Project
**Love Strix?** Give us a ⭐ on GitHub! **Love Strix?** Give us a ⭐ on GitHub!
## 🙏 Acknowledgements
## Acknowledgements
Strix builds on the incredible work of open-source projects like [LiteLLM](https://github.com/BerriAI/litellm), [Caido](https://github.com/caido/caido), [ProjectDiscovery](https://github.com/projectdiscovery), [Playwright](https://github.com/microsoft/playwright), and [Textual](https://github.com/Textualize/textual). Huge thanks to their maintainers! Strix builds on the incredible work of open-source projects like [LiteLLM](https://github.com/BerriAI/litellm), [Caido](https://github.com/caido/caido), [ProjectDiscovery](https://github.com/projectdiscovery), [Playwright](https://github.com/microsoft/playwright), and [Textual](https://github.com/Textualize/textual). Huge thanks to their maintainers!

35
poetry.lock generated
View File

@@ -379,19 +379,18 @@ files = [
[[package]] [[package]]
name = "azure-core" name = "azure-core"
version = "1.35.0" version = "1.38.0"
description = "Microsoft Azure Core Library for Python" description = "Microsoft Azure Core Library for Python"
optional = false optional = false
python-versions = ">=3.9" python-versions = ">=3.9"
groups = ["main"] groups = ["main"]
files = [ files = [
{file = "azure_core-1.35.0-py3-none-any.whl", hash = "sha256:8db78c72868a58f3de8991eb4d22c4d368fae226dac1002998d6c50437e7dad1"}, {file = "azure_core-1.38.0-py3-none-any.whl", hash = "sha256:ab0c9b2cd71fecb1842d52c965c95285d3cfb38902f6766e4a471f1cd8905335"},
{file = "azure_core-1.35.0.tar.gz", hash = "sha256:c0be528489485e9ede59b6971eb63c1eaacf83ef53001bfe3904e475e972be5c"}, {file = "azure_core-1.38.0.tar.gz", hash = "sha256:8194d2682245a3e4e3151a667c686464c3786fed7918b394d035bdcd61bb5993"},
] ]
[package.dependencies] [package.dependencies]
requests = ">=2.21.0" requests = ">=2.21.0"
six = ">=1.11.0"
typing-extensions = ">=4.6.0" typing-extensions = ">=4.6.0"
[package.extras] [package.extras]
@@ -1199,6 +1198,18 @@ files = [
{file = "decorator-5.2.1.tar.gz", hash = "sha256:65f266143752f734b0a7cc83c46f4618af75b8c5911b00ccb61d0ac9b6da0360"}, {file = "decorator-5.2.1.tar.gz", hash = "sha256:65f266143752f734b0a7cc83c46f4618af75b8c5911b00ccb61d0ac9b6da0360"},
] ]
[[package]]
name = "defusedxml"
version = "0.7.1"
description = "XML bomb protection for Python stdlib modules"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
groups = ["main"]
files = [
{file = "defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61"},
{file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"},
]
[[package]] [[package]]
name = "dill" name = "dill"
version = "0.4.0" version = "0.4.0"
@@ -1490,14 +1501,14 @@ files = [
[[package]] [[package]]
name = "filelock" name = "filelock"
version = "3.20.1" version = "3.20.3"
description = "A platform independent file lock." description = "A platform independent file lock."
optional = false optional = false
python-versions = ">=3.10" python-versions = ">=3.10"
groups = ["main", "dev"] groups = ["main", "dev"]
files = [ files = [
{file = "filelock-3.20.1-py3-none-any.whl", hash = "sha256:15d9e9a67306188a44baa72f569d2bfd803076269365fdea0934385da4dc361a"}, {file = "filelock-3.20.3-py3-none-any.whl", hash = "sha256:4b0dda527ee31078689fc205ec4f1c1bf7d56cf88b6dc9426c4f230e46c2dce1"},
{file = "filelock-3.20.1.tar.gz", hash = "sha256:b8360948b351b80f420878d8516519a2204b07aefcdcfd24912a5d33127f188c"}, {file = "filelock-3.20.3.tar.gz", hash = "sha256:18c57ee915c7ec61cff0ecf7f0f869936c7c30191bb0cf406f1341778d0834e1"},
] ]
[[package]] [[package]]
@@ -7095,19 +7106,19 @@ test = ["aiohttp (>=3.10.5)", "flake8 (>=5.0,<6.0)", "mypy (>=0.800)", "psutil",
[[package]] [[package]]
name = "virtualenv" name = "virtualenv"
version = "20.34.0" version = "20.36.1"
description = "Virtual Python Environment builder" description = "Virtual Python Environment builder"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
groups = ["dev"] groups = ["dev"]
files = [ files = [
{file = "virtualenv-20.34.0-py3-none-any.whl", hash = "sha256:341f5afa7eee943e4984a9207c025feedd768baff6753cd660c857ceb3e36026"}, {file = "virtualenv-20.36.1-py3-none-any.whl", hash = "sha256:575a8d6b124ef88f6f51d56d656132389f961062a9177016a50e4f507bbcc19f"},
{file = "virtualenv-20.34.0.tar.gz", hash = "sha256:44815b2c9dee7ed86e387b842a84f20b93f7f417f95886ca1996a72a4138eb1a"}, {file = "virtualenv-20.36.1.tar.gz", hash = "sha256:8befb5c81842c641f8ee658481e42641c68b5eab3521d8e092d18320902466ba"},
] ]
[package.dependencies] [package.dependencies]
distlib = ">=0.3.7,<1" distlib = ">=0.3.7,<1"
filelock = ">=3.12.2,<4" filelock = {version = ">=3.20.1,<4", markers = "python_version >= \"3.10\""}
platformdirs = ">=3.9.1,<5" platformdirs = ">=3.9.1,<5"
[package.extras] [package.extras]
@@ -7424,4 +7435,4 @@ vertex = ["google-cloud-aiplatform"]
[metadata] [metadata]
lock-version = "2.1" lock-version = "2.1"
python-versions = "^3.12" python-versions = "^3.12"
content-hash = "91f49e313e5690bbef87e17730441f26d366daeccb16b5020e03e581fbb9d4d5" content-hash = "0424a0e82fe49501f3a80166676e257a9dae97093d9bc730489789195f523735"

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "strix-agent" name = "strix-agent"
version = "0.6.0" version = "0.6.1"
description = "Open-source AI Hackers for your apps" description = "Open-source AI Hackers for your apps"
authors = ["Strix <hi@usestrix.com>"] authors = ["Strix <hi@usestrix.com>"]
readme = "README.md" readme = "README.md"
@@ -69,6 +69,7 @@ gql = { version = "^3.5.3", extras = ["requests"], optional = true }
pyte = { version = "^0.8.1", optional = true } pyte = { version = "^0.8.1", optional = true }
libtmux = { version = "^0.46.2", optional = true } libtmux = { version = "^0.46.2", optional = true }
numpydoc = { version = "^1.8.0", optional = true } numpydoc = { version = "^1.8.0", optional = true }
defusedxml = "^0.7.1"
[tool.poetry.extras] [tool.poetry.extras]
vertex = ["google-cloud-aiplatform"] vertex = ["google-cloud-aiplatform"]

View File

@@ -111,7 +111,6 @@ hiddenimports = [
'strix.llm.llm', 'strix.llm.llm',
'strix.llm.config', 'strix.llm.config',
'strix.llm.utils', 'strix.llm.utils',
'strix.llm.request_queue',
'strix.llm.memory_compressor', 'strix.llm.memory_compressor',
'strix.runtime', 'strix.runtime',
'strix.runtime.runtime', 'strix.runtime.runtime',

View File

@@ -308,17 +308,18 @@ Tool calls use XML format:
CRITICAL RULES: CRITICAL RULES:
0. While active in the agent loop, EVERY message you output MUST be a single tool call. Do not send plain text-only responses. 0. While active in the agent loop, EVERY message you output MUST be a single tool call. Do not send plain text-only responses.
1. One tool call per message 1. Exactly one tool call per message — never include more than one <function>...</function> block in a single LLM message.
2. Tool call must be last in message 2. Tool call must be last in message
3. End response after </function> tag. It's your stop word. Do not continue after it. 3. EVERY tool call MUST end with </function>. This is MANDATORY. Never omit the closing tag. End your response immediately after </function>.
4. Use ONLY the exact XML format shown above. NEVER use JSON/YAML/INI or any other syntax for tools or parameters. 4. Use ONLY the exact XML format shown above. NEVER use JSON/YAML/INI or any other syntax for tools or parameters.
5. Tool names must match exactly the tool "name" defined (no module prefixes, dots, or variants). 5. When sending ANY multi-line content in tool parameters, use real newlines (actual line breaks). Do NOT emit literal "\n" sequences. If you send "\n" instead of real line breaks inside the XML parameter value, tools may fail or behave incorrectly.
6. Tool names must match exactly the tool "name" defined (no module prefixes, dots, or variants).
- Correct: <function=think> ... </function> - Correct: <function=think> ... </function>
- Incorrect: <thinking_tools.think> ... </function> - Incorrect: <thinking_tools.think> ... </function>
- Incorrect: <think> ... </think> - Incorrect: <think> ... </think>
- Incorrect: {"think": {...}} - Incorrect: {"think": {...}}
6. Parameters must use <parameter=param_name>value</parameter> exactly. Do NOT pass parameters as JSON or key:value lines. Do NOT add quotes/braces around values. 7. Parameters must use <parameter=param_name>value</parameter> exactly. Do NOT pass parameters as JSON or key:value lines. Do NOT add quotes/braces around values.
7. Do NOT wrap tool calls in markdown/code fences or add any text before or after the tool block. 8. Do NOT wrap tool calls in markdown/code fences or add any text before or after the tool block.
Example (agent creation tool): Example (agent creation tool):
<function=create_agent> <function=create_agent>
@@ -331,6 +332,8 @@ SPRAYING EXECUTION NOTE:
- When performing large payload sprays or fuzzing, encapsulate the entire spraying loop inside a single python or terminal tool call (e.g., a Python script using asyncio/aiohttp). Do not issue one tool call per payload. - When performing large payload sprays or fuzzing, encapsulate the entire spraying loop inside a single python or terminal tool call (e.g., a Python script using asyncio/aiohttp). Do not issue one tool call per payload.
- Favor batch-mode CLI tools (sqlmap, ffuf, nuclei, zaproxy, arjun) where appropriate and check traffic via the proxy when beneficial - Favor batch-mode CLI tools (sqlmap, ffuf, nuclei, zaproxy, arjun) where appropriate and check traffic via the proxy when beneficial
REMINDER: Always close each tool call with </function> before going into the next. Incomplete tool calls will fail.
{{ get_tools_prompt() }} {{ get_tools_prompt() }}
</tool_usage> </tool_usage>

View File

@@ -1,7 +1,6 @@
import asyncio import asyncio
import contextlib import contextlib
import logging import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
@@ -18,6 +17,7 @@ from strix.llm import LLM, LLMConfig, LLMRequestFailedError
from strix.llm.utils import clean_content from strix.llm.utils import clean_content
from strix.runtime import SandboxInitializationError from strix.runtime import SandboxInitializationError
from strix.tools import process_tool_invocations from strix.tools import process_tool_invocations
from strix.utils.resource_paths import get_strix_resource_path
from .state import AgentState from .state import AgentState
@@ -35,8 +35,7 @@ class AgentMeta(type):
if name == "BaseAgent": if name == "BaseAgent":
return new_cls return new_cls
agents_dir = Path(__file__).parent prompt_dir = get_strix_resource_path("agents", name)
prompt_dir = agents_dir / name
new_cls.agent_name = name new_cls.agent_name = name
new_cls.jinja_env = Environment( new_cls.jinja_env = Environment(
@@ -66,20 +65,21 @@ class BaseAgent(metaclass=AgentMeta):
self.llm_config = config.get("llm_config", self.default_llm_config) self.llm_config = config.get("llm_config", self.default_llm_config)
if self.llm_config is None: if self.llm_config is None:
raise ValueError("llm_config is required but not provided") raise ValueError("llm_config is required but not provided")
self.llm = LLM(self.llm_config, agent_name=self.agent_name)
state_from_config = config.get("state") state_from_config = config.get("state")
if state_from_config is not None: if state_from_config is not None:
self.state = state_from_config self.state = state_from_config
else: else:
self.state = AgentState( self.state = AgentState(
agent_name=self.agent_name, agent_name="Root Agent",
max_iterations=self.max_iterations, max_iterations=self.max_iterations,
) )
self.llm = LLM(self.llm_config, agent_name=self.agent_name)
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
self.llm.set_agent_identity(self.agent_name, self.state.agent_id) self.llm.set_agent_identity(self.state.agent_name, self.state.agent_id)
self._current_task: asyncio.Task[Any] | None = None self._current_task: asyncio.Task[Any] | None = None
self._force_stop = False
from strix.telemetry.tracer import get_global_tracer from strix.telemetry.tracer import get_global_tracer
@@ -157,6 +157,11 @@ class BaseAgent(metaclass=AgentMeta):
return self._handle_sandbox_error(e, tracer) return self._handle_sandbox_error(e, tracer)
while True: while True:
if self._force_stop:
self._force_stop = False
await self._enter_waiting_state(tracer, was_cancelled=True)
continue
self._check_agent_messages(self.state) self._check_agent_messages(self.state)
if self.state.is_waiting_for_input(): if self.state.is_waiting_for_input():
@@ -247,7 +252,8 @@ class BaseAgent(metaclass=AgentMeta):
continue continue
async def _wait_for_input(self) -> None: async def _wait_for_input(self) -> None:
import asyncio if self._force_stop:
return
if self.state.has_waiting_timeout(): if self.state.has_waiting_timeout():
self.state.resume_from_waiting() self.state.resume_from_waiting()
@@ -340,6 +346,7 @@ class BaseAgent(metaclass=AgentMeta):
async def _process_iteration(self, tracer: Optional["Tracer"]) -> bool: async def _process_iteration(self, tracer: Optional["Tracer"]) -> bool:
final_response = None final_response = None
async for response in self.llm.generate(self.state.get_conversation_history()): async for response in self.llm.generate(self.state.get_conversation_history()):
final_response = response final_response = response
if tracer and response.content: if tracer and response.content:
@@ -585,6 +592,11 @@ class BaseAgent(metaclass=AgentMeta):
return True return True
def cancel_current_execution(self) -> None: def cancel_current_execution(self) -> None:
self._force_stop = True
if self._current_task and not self._current_task.done(): if self._current_task and not self._current_task.done():
self._current_task.cancel() try:
loop = self._current_task.get_loop()
loop.call_soon_threadsafe(self._current_task.cancel)
except RuntimeError:
self._current_task.cancel()
self._current_task = None self._current_task = None

View File

@@ -16,9 +16,9 @@ class Config:
litellm_base_url = None litellm_base_url = None
ollama_api_base = None ollama_api_base = None
strix_reasoning_effort = "high" strix_reasoning_effort = "high"
strix_llm_max_retries = "5"
strix_memory_compressor_timeout = "30"
llm_timeout = "300" llm_timeout = "300"
llm_rate_limit_delay = "4.0"
llm_rate_limit_concurrent = "1"
# Tool & Feature Configuration # Tool & Feature Configuration
perplexity_api_key = None perplexity_api_key = None
@@ -27,7 +27,7 @@ class Config:
# Runtime Configuration # Runtime Configuration
strix_image = "ghcr.io/usestrix/strix-sandbox:0.1.10" strix_image = "ghcr.io/usestrix/strix-sandbox:0.1.10"
strix_runtime_backend = "docker" strix_runtime_backend = "docker"
strix_sandbox_execution_timeout = "500" strix_sandbox_execution_timeout = "120"
strix_sandbox_connect_timeout = "10" strix_sandbox_connect_timeout = "10"
# Telemetry # Telemetry

View File

@@ -1,60 +1,29 @@
import asyncio import asyncio
import logging
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any from typing import Any
import litellm import litellm
from jinja2 import ( from jinja2 import Environment, FileSystemLoader, select_autoescape
Environment, from litellm import acompletion, completion_cost, stream_chunk_builder, supports_reasoning
FileSystemLoader,
select_autoescape,
)
from litellm import completion_cost, stream_chunk_builder, supports_reasoning
from litellm.utils import supports_prompt_caching, supports_vision from litellm.utils import supports_prompt_caching, supports_vision
from strix.config import Config from strix.config import Config
from strix.llm.config import LLMConfig from strix.llm.config import LLMConfig
from strix.llm.memory_compressor import MemoryCompressor from strix.llm.memory_compressor import MemoryCompressor
from strix.llm.request_queue import get_global_queue from strix.llm.utils import (
from strix.llm.utils import _truncate_to_first_function, parse_tool_invocations _truncate_to_first_function,
fix_incomplete_tool_call,
parse_tool_invocations,
)
from strix.skills import load_skills from strix.skills import load_skills
from strix.tools import get_tools_prompt from strix.tools import get_tools_prompt
from strix.utils.resource_paths import get_strix_resource_path
MAX_RETRIES = 5
RETRY_MULTIPLIER = 8
RETRY_MIN = 8
RETRY_MAX = 64
def _should_retry(exception: Exception) -> bool:
status_code = None
if hasattr(exception, "status_code"):
status_code = exception.status_code
elif hasattr(exception, "response") and hasattr(exception.response, "status_code"):
status_code = exception.response.status_code
if status_code is not None:
return bool(litellm._should_retry(status_code))
return True
logger = logging.getLogger(__name__)
litellm.drop_params = True litellm.drop_params = True
litellm.modify_params = True litellm.modify_params = True
_LLM_API_KEY = Config.get("llm_api_key")
_LLM_API_BASE = (
Config.get("llm_api_base")
or Config.get("openai_api_base")
or Config.get("litellm_base_url")
or Config.get("ollama_api_base")
)
_STRIX_REASONING_EFFORT = Config.get("strix_reasoning_effort")
class LLMRequestFailedError(Exception): class LLMRequestFailedError(Exception):
def __init__(self, message: str, details: str | None = None): def __init__(self, message: str, details: str | None = None):
@@ -63,20 +32,11 @@ class LLMRequestFailedError(Exception):
self.details = details self.details = details
class StepRole(str, Enum):
AGENT = "agent"
USER = "user"
SYSTEM = "system"
@dataclass @dataclass
class LLMResponse: class LLMResponse:
content: str content: str
tool_invocations: list[dict[str, Any]] | None = None tool_invocations: list[dict[str, Any]] | None = None
scan_id: str | None = None thinking_blocks: list[dict[str, Any]] | None = None
step_number: int = 1
role: StepRole = StepRole.AGENT
thinking_blocks: list[dict[str, Any]] | None = None # For reasoning models.
@dataclass @dataclass
@@ -84,76 +44,63 @@ class RequestStats:
input_tokens: int = 0 input_tokens: int = 0
output_tokens: int = 0 output_tokens: int = 0
cached_tokens: int = 0 cached_tokens: int = 0
cache_creation_tokens: int = 0
cost: float = 0.0 cost: float = 0.0
requests: int = 0 requests: int = 0
failed_requests: int = 0
def to_dict(self) -> dict[str, int | float]: def to_dict(self) -> dict[str, int | float]:
return { return {
"input_tokens": self.input_tokens, "input_tokens": self.input_tokens,
"output_tokens": self.output_tokens, "output_tokens": self.output_tokens,
"cached_tokens": self.cached_tokens, "cached_tokens": self.cached_tokens,
"cache_creation_tokens": self.cache_creation_tokens,
"cost": round(self.cost, 4), "cost": round(self.cost, 4),
"requests": self.requests, "requests": self.requests,
"failed_requests": self.failed_requests,
} }
class LLM: class LLM:
def __init__( def __init__(self, config: LLMConfig, agent_name: str | None = None):
self, config: LLMConfig, agent_name: str | None = None, agent_id: str | None = None
):
self.config = config self.config = config
self.agent_name = agent_name self.agent_name = agent_name
self.agent_id = agent_id self.agent_id: str | None = None
self._total_stats = RequestStats() self._total_stats = RequestStats()
self._last_request_stats = RequestStats() self.memory_compressor = MemoryCompressor(model_name=config.model_name)
self.system_prompt = self._load_system_prompt(agent_name)
if _STRIX_REASONING_EFFORT: reasoning = Config.get("strix_reasoning_effort")
self._reasoning_effort = _STRIX_REASONING_EFFORT if reasoning:
elif self.config.scan_mode == "quick": self._reasoning_effort = reasoning
elif config.scan_mode == "quick":
self._reasoning_effort = "medium" self._reasoning_effort = "medium"
else: else:
self._reasoning_effort = "high" self._reasoning_effort = "high"
self.memory_compressor = MemoryCompressor( def _load_system_prompt(self, agent_name: str | None) -> str:
model_name=self.config.model_name, if not agent_name:
timeout=self.config.timeout, return ""
)
if agent_name: try:
prompt_dir = Path(__file__).parent.parent / "agents" / agent_name prompt_dir = get_strix_resource_path("agents", agent_name)
skills_dir = Path(__file__).parent.parent / "skills" skills_dir = get_strix_resource_path("skills")
env = Environment(
loader = FileSystemLoader([prompt_dir, skills_dir]) loader=FileSystemLoader([prompt_dir, skills_dir]),
self.jinja_env = Environment(
loader=loader,
autoescape=select_autoescape(enabled_extensions=(), default_for_string=False), autoescape=select_autoescape(enabled_extensions=(), default_for_string=False),
) )
try: skills_to_load = [
skills_to_load = list(self.config.skills or []) *list(self.config.skills or []),
skills_to_load.append(f"scan_modes/{self.config.scan_mode}") f"scan_modes/{self.config.scan_mode}",
]
skill_content = load_skills(skills_to_load, env)
env.globals["get_skill"] = lambda name: skill_content.get(name, "")
skill_content = load_skills(skills_to_load, self.jinja_env) result = env.get_template("system_prompt.jinja").render(
get_tools_prompt=get_tools_prompt,
def get_skill(name: str) -> str: loaded_skill_names=list(skill_content.keys()),
return skill_content.get(name, "") **skill_content,
)
self.jinja_env.globals["get_skill"] = get_skill return str(result)
except Exception: # noqa: BLE001
self.system_prompt = self.jinja_env.get_template("system_prompt.jinja").render( return ""
get_tools_prompt=get_tools_prompt,
loaded_skill_names=list(skill_content.keys()),
**skill_content,
)
except (FileNotFoundError, OSError, ValueError) as e:
logger.warning(f"Failed to load system prompt for {agent_name}: {e}")
self.system_prompt = "You are a helpful AI assistant."
else:
self.system_prompt = "You are a helpful AI assistant."
def set_agent_identity(self, agent_name: str | None, agent_id: str | None) -> None: def set_agent_identity(self, agent_name: str | None, agent_id: str | None) -> None:
if agent_name: if agent_name:
@@ -161,328 +108,121 @@ class LLM:
if agent_id: if agent_id:
self.agent_id = agent_id self.agent_id = agent_id
def _build_identity_message(self) -> dict[str, Any] | None: async def generate(
if not (self.agent_name and str(self.agent_name).strip()): self, conversation_history: list[dict[str, Any]]
return None ) -> AsyncIterator[LLMResponse]:
identity_name = self.agent_name messages = self._prepare_messages(conversation_history)
identity_id = self.agent_id max_retries = int(Config.get("strix_llm_max_retries") or "5")
content = (
"\n\n"
"<agent_identity>\n"
"<meta>Internal metadata: do not echo or reference; "
"not part of history or tool calls.</meta>\n"
"<note>You are now assuming the role of this agent. "
"Act strictly as this agent and maintain self-identity for this step. "
"Now go answer the next needed step!</note>\n"
f"<agent_name>{identity_name}</agent_name>\n"
f"<agent_id>{identity_id}</agent_id>\n"
"</agent_identity>\n\n"
)
return {"role": "user", "content": content}
def _add_cache_control_to_content( for attempt in range(max_retries + 1):
self, content: str | list[dict[str, Any]] try:
) -> str | list[dict[str, Any]]: async for response in self._stream(messages):
if isinstance(content, str): yield response
return [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}] return # noqa: TRY300
if isinstance(content, list) and content: except Exception as e: # noqa: BLE001
last_item = content[-1] if attempt >= max_retries or not self._should_retry(e):
if isinstance(last_item, dict) and last_item.get("type") == "text": self._raise_error(e)
return content[:-1] + [{**last_item, "cache_control": {"type": "ephemeral"}}] wait = min(10, 2 * (2**attempt))
return content await asyncio.sleep(wait)
def _is_anthropic_model(self) -> bool: async def _stream(self, messages: list[dict[str, Any]]) -> AsyncIterator[LLMResponse]:
if not self.config.model_name: accumulated = ""
return False chunks: list[Any] = []
model_lower = self.config.model_name.lower()
return any(provider in model_lower for provider in ["anthropic/", "claude"])
def _calculate_cache_interval(self, total_messages: int) -> int: self._total_stats.requests += 1
if total_messages <= 1: response = await acompletion(**self._build_completion_args(messages), stream=True)
return 10
max_cached_messages = 3 async for chunk in response:
non_system_messages = total_messages - 1 chunks.append(chunk)
delta = self._get_chunk_content(chunk)
interval = 10 if delta:
while non_system_messages // interval > max_cached_messages: accumulated += delta
interval += 10 if "</function>" in accumulated:
accumulated = accumulated[
return interval : accumulated.find("</function>") + len("</function>")
]
def _prepare_cached_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: yield LLMResponse(content=accumulated)
if (
not self.config.enable_prompt_caching
or not supports_prompt_caching(self.config.model_name)
or not messages
):
return messages
if not self._is_anthropic_model():
return messages
cached_messages = list(messages)
if cached_messages and cached_messages[0].get("role") == "system":
system_message = cached_messages[0].copy()
system_message["content"] = self._add_cache_control_to_content(
system_message["content"]
)
cached_messages[0] = system_message
total_messages = len(cached_messages)
if total_messages > 1:
interval = self._calculate_cache_interval(total_messages)
cached_count = 0
for i in range(interval, total_messages, interval):
if cached_count >= 3:
break break
yield LLMResponse(content=accumulated)
if i < len(cached_messages): if chunks:
message = cached_messages[i].copy() self._update_usage_stats(stream_chunk_builder(chunks))
message["content"] = self._add_cache_control_to_content(message["content"])
cached_messages[i] = message
cached_count += 1
return cached_messages accumulated = fix_incomplete_tool_call(_truncate_to_first_function(accumulated))
yield LLMResponse(
content=accumulated,
tool_invocations=parse_tool_invocations(accumulated),
thinking_blocks=self._extract_thinking(chunks),
)
def _prepare_messages(self, conversation_history: list[dict[str, Any]]) -> list[dict[str, Any]]: def _prepare_messages(self, conversation_history: list[dict[str, Any]]) -> list[dict[str, Any]]:
messages = [{"role": "system", "content": self.system_prompt}] messages = [{"role": "system", "content": self.system_prompt}]
identity_message = self._build_identity_message() if self.agent_name:
if identity_message: messages.append(
messages.append(identity_message) {
"role": "user",
compressed_history = list(self.memory_compressor.compress_history(conversation_history)) "content": (
f"\n\n<agent_identity>\n"
f"<meta>Internal metadata: do not echo or reference.</meta>\n"
f"<agent_name>{self.agent_name}</agent_name>\n"
f"<agent_id>{self.agent_id}</agent_id>\n"
f"</agent_identity>\n\n"
),
}
)
compressed = list(self.memory_compressor.compress_history(conversation_history))
conversation_history.clear() conversation_history.clear()
conversation_history.extend(compressed_history) conversation_history.extend(compressed)
messages.extend(compressed_history) messages.extend(compressed)
return self._prepare_cached_messages(messages) if self._is_anthropic() and self.config.enable_prompt_caching:
messages = self._add_cache_control(messages)
async def _stream_and_accumulate( return messages
self,
messages: list[dict[str, Any]],
scan_id: str | None,
step_number: int,
) -> AsyncIterator[LLMResponse]:
accumulated_content = ""
chunks: list[Any] = []
async for chunk in self._stream_request(messages): def _build_completion_args(self, messages: list[dict[str, Any]]) -> dict[str, Any]:
chunks.append(chunk) if not self._supports_vision():
delta = self._extract_chunk_delta(chunk) messages = self._strip_images(messages)
if delta:
accumulated_content += delta
if "</function>" in accumulated_content: args: dict[str, Any] = {
function_end = accumulated_content.find("</function>") + len("</function>")
accumulated_content = accumulated_content[:function_end]
yield LLMResponse(
scan_id=scan_id,
step_number=step_number,
role=StepRole.AGENT,
content=accumulated_content,
tool_invocations=None,
)
if chunks:
complete_response = stream_chunk_builder(chunks)
self._update_usage_stats(complete_response)
accumulated_content = _truncate_to_first_function(accumulated_content)
if "</function>" in accumulated_content:
function_end = accumulated_content.find("</function>") + len("</function>")
accumulated_content = accumulated_content[:function_end]
tool_invocations = parse_tool_invocations(accumulated_content)
# Extract thinking blocks from the complete response if available
thinking_blocks = None
if chunks and self._should_include_reasoning_effort():
complete_response = stream_chunk_builder(chunks)
if (
hasattr(complete_response, "choices")
and complete_response.choices
and hasattr(complete_response.choices[0], "message")
):
message = complete_response.choices[0].message
if hasattr(message, "thinking_blocks") and message.thinking_blocks:
thinking_blocks = message.thinking_blocks
yield LLMResponse(
scan_id=scan_id,
step_number=step_number,
role=StepRole.AGENT,
content=accumulated_content,
tool_invocations=tool_invocations if tool_invocations else None,
thinking_blocks=thinking_blocks,
)
def _raise_llm_error(self, e: Exception) -> None:
error_map: list[tuple[type, str]] = [
(litellm.RateLimitError, "Rate limit exceeded"),
(litellm.AuthenticationError, "Invalid API key"),
(litellm.NotFoundError, "Model not found"),
(litellm.ContextWindowExceededError, "Context too long"),
(litellm.ContentPolicyViolationError, "Content policy violation"),
(litellm.ServiceUnavailableError, "Service unavailable"),
(litellm.Timeout, "Request timed out"),
(litellm.UnprocessableEntityError, "Unprocessable entity"),
(litellm.InternalServerError, "Internal server error"),
(litellm.APIConnectionError, "Connection error"),
(litellm.UnsupportedParamsError, "Unsupported parameters"),
(litellm.BudgetExceededError, "Budget exceeded"),
(litellm.APIResponseValidationError, "Response validation error"),
(litellm.JSONSchemaValidationError, "JSON schema validation error"),
(litellm.InvalidRequestError, "Invalid request"),
(litellm.BadRequestError, "Bad request"),
(litellm.APIError, "API error"),
(litellm.OpenAIError, "OpenAI error"),
]
from strix.telemetry import posthog
for error_type, message in error_map:
if isinstance(e, error_type):
posthog.error(f"llm_{error_type.__name__}", message)
raise LLMRequestFailedError(f"LLM request failed: {message}", str(e)) from e
posthog.error("llm_unknown_error", type(e).__name__)
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
async def generate(
self,
conversation_history: list[dict[str, Any]],
scan_id: str | None = None,
step_number: int = 1,
) -> AsyncIterator[LLMResponse]:
messages = self._prepare_messages(conversation_history)
last_error: Exception | None = None
for attempt in range(MAX_RETRIES):
try:
async for response in self._stream_and_accumulate(messages, scan_id, step_number):
yield response
return # noqa: TRY300
except Exception as e: # noqa: BLE001
last_error = e
if not _should_retry(e) or attempt == MAX_RETRIES - 1:
break
wait_time = min(RETRY_MAX, RETRY_MULTIPLIER * (2**attempt))
wait_time = max(RETRY_MIN, wait_time)
await asyncio.sleep(wait_time)
if last_error:
self._raise_llm_error(last_error)
def _extract_chunk_delta(self, chunk: Any) -> str:
if chunk.choices and hasattr(chunk.choices[0], "delta"):
delta = chunk.choices[0].delta
return getattr(delta, "content", "") or ""
return ""
@property
def usage_stats(self) -> dict[str, dict[str, int | float]]:
return {
"total": self._total_stats.to_dict(),
"last_request": self._last_request_stats.to_dict(),
}
def get_cache_config(self) -> dict[str, bool]:
return {
"enabled": self.config.enable_prompt_caching,
"supported": supports_prompt_caching(self.config.model_name),
}
def _should_include_reasoning_effort(self) -> bool:
if not self.config.model_name:
return False
try:
return bool(supports_reasoning(model=self.config.model_name))
except Exception: # noqa: BLE001
return False
def _model_supports_vision(self) -> bool:
if not self.config.model_name:
return False
try:
return bool(supports_vision(model=self.config.model_name))
except Exception: # noqa: BLE001
return False
def _filter_images_from_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
filtered_messages = []
for msg in messages:
content = msg.get("content")
updated_msg = msg
if isinstance(content, list):
filtered_content = []
for item in content:
if isinstance(item, dict):
if item.get("type") == "image_url":
filtered_content.append(
{
"type": "text",
"text": "[Screenshot removed - model does not support "
"vision. Use view_source or execute_js instead.]",
}
)
else:
filtered_content.append(item)
else:
filtered_content.append(item)
if filtered_content:
text_parts = [
item.get("text", "") if isinstance(item, dict) else str(item)
for item in filtered_content
]
all_text = all(
isinstance(item, dict) and item.get("type") == "text"
for item in filtered_content
)
if all_text:
updated_msg = {**msg, "content": "\n".join(text_parts)}
else:
updated_msg = {**msg, "content": filtered_content}
else:
updated_msg = {**msg, "content": ""}
filtered_messages.append(updated_msg)
return filtered_messages
async def _stream_request(
self,
messages: list[dict[str, Any]],
) -> AsyncIterator[Any]:
if not self._model_supports_vision():
messages = self._filter_images_from_messages(messages)
completion_args: dict[str, Any] = {
"model": self.config.model_name, "model": self.config.model_name,
"messages": messages, "messages": messages,
"timeout": self.config.timeout, "timeout": self.config.timeout,
"stream_options": {"include_usage": True}, "stream_options": {"include_usage": True},
} }
if _LLM_API_KEY: if api_key := Config.get("llm_api_key"):
completion_args["api_key"] = _LLM_API_KEY args["api_key"] = api_key
if _LLM_API_BASE: if api_base := (
completion_args["api_base"] = _LLM_API_BASE Config.get("llm_api_base")
or Config.get("openai_api_base")
or Config.get("litellm_base_url")
or Config.get("ollama_api_base")
):
args["api_base"] = api_base
if self._supports_reasoning():
args["reasoning_effort"] = self._reasoning_effort
completion_args["stop"] = ["</function>"] return args
if self._should_include_reasoning_effort(): def _get_chunk_content(self, chunk: Any) -> str:
completion_args["reasoning_effort"] = self._reasoning_effort if chunk.choices and hasattr(chunk.choices[0], "delta"):
return getattr(chunk.choices[0].delta, "content", "") or ""
return ""
queue = get_global_queue() def _extract_thinking(self, chunks: list[Any]) -> list[dict[str, Any]] | None:
self._total_stats.requests += 1 if not chunks or not self._supports_reasoning():
self._last_request_stats = RequestStats(requests=1) return None
try:
async for chunk in queue.stream_request(completion_args): resp = stream_chunk_builder(chunks)
yield chunk if resp.choices and hasattr(resp.choices[0].message, "thinking_blocks"):
blocks: list[dict[str, Any]] = resp.choices[0].message.thinking_blocks
return blocks
except Exception: # noqa: BLE001, S110 # nosec B110
pass
return None
def _update_usage_stats(self, response: Any) -> None: def _update_usage_stats(self, response: Any) -> None:
try: try:
@@ -491,45 +231,88 @@ class LLM:
output_tokens = getattr(response.usage, "completion_tokens", 0) output_tokens = getattr(response.usage, "completion_tokens", 0)
cached_tokens = 0 cached_tokens = 0
cache_creation_tokens = 0
if hasattr(response.usage, "prompt_tokens_details"): if hasattr(response.usage, "prompt_tokens_details"):
prompt_details = response.usage.prompt_tokens_details prompt_details = response.usage.prompt_tokens_details
if hasattr(prompt_details, "cached_tokens"): if hasattr(prompt_details, "cached_tokens"):
cached_tokens = prompt_details.cached_tokens or 0 cached_tokens = prompt_details.cached_tokens or 0
if hasattr(response.usage, "cache_creation_input_tokens"):
cache_creation_tokens = response.usage.cache_creation_input_tokens or 0
else: else:
input_tokens = 0 input_tokens = 0
output_tokens = 0 output_tokens = 0
cached_tokens = 0 cached_tokens = 0
cache_creation_tokens = 0
try: try:
cost = completion_cost(response) or 0.0 cost = completion_cost(response) or 0.0
except Exception as e: # noqa: BLE001 except Exception: # noqa: BLE001
logger.warning(f"Failed to calculate cost: {e}")
cost = 0.0 cost = 0.0
self._total_stats.input_tokens += input_tokens self._total_stats.input_tokens += input_tokens
self._total_stats.output_tokens += output_tokens self._total_stats.output_tokens += output_tokens
self._total_stats.cached_tokens += cached_tokens self._total_stats.cached_tokens += cached_tokens
self._total_stats.cache_creation_tokens += cache_creation_tokens
self._total_stats.cost += cost self._total_stats.cost += cost
self._last_request_stats.input_tokens = input_tokens except Exception: # noqa: BLE001, S110 # nosec B110
self._last_request_stats.output_tokens = output_tokens pass
self._last_request_stats.cached_tokens = cached_tokens
self._last_request_stats.cache_creation_tokens = cache_creation_tokens
self._last_request_stats.cost = cost
if cached_tokens > 0: def _should_retry(self, e: Exception) -> bool:
logger.info(f"Cache hit: {cached_tokens} cached tokens, {input_tokens} new tokens") code = getattr(e, "status_code", None) or getattr(
if cache_creation_tokens > 0: getattr(e, "response", None), "status_code", None
logger.info(f"Cache creation: {cache_creation_tokens} tokens written to cache") )
return code is None or litellm._should_retry(code)
logger.info(f"Usage stats: {self.usage_stats}") def _raise_error(self, e: Exception) -> None:
except Exception as e: # noqa: BLE001 from strix.telemetry import posthog
logger.warning(f"Failed to update usage stats: {e}")
posthog.error("llm_error", type(e).__name__)
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
def _is_anthropic(self) -> bool:
if not self.config.model_name:
return False
return any(p in self.config.model_name.lower() for p in ["anthropic/", "claude"])
def _supports_vision(self) -> bool:
try:
return bool(supports_vision(model=self.config.model_name))
except Exception: # noqa: BLE001
return False
def _supports_reasoning(self) -> bool:
try:
return bool(supports_reasoning(model=self.config.model_name))
except Exception: # noqa: BLE001
return False
def _strip_images(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
result = []
for msg in messages:
content = msg.get("content")
if isinstance(content, list):
text_parts = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
text_parts.append(item.get("text", ""))
elif isinstance(item, dict) and item.get("type") == "image_url":
text_parts.append("[Image removed - model doesn't support vision]")
result.append({**msg, "content": "\n".join(text_parts)})
else:
result.append(msg)
return result
def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
if not messages or not supports_prompt_caching(self.config.model_name):
return messages
result = list(messages)
if result[0].get("role") == "system":
content = result[0]["content"]
result[0] = {
**result[0],
"content": [
{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}
]
if isinstance(content, str)
else content,
}
return result

View File

@@ -86,7 +86,7 @@ def _extract_message_text(msg: dict[str, Any]) -> str:
def _summarize_messages( def _summarize_messages(
messages: list[dict[str, Any]], messages: list[dict[str, Any]],
model: str, model: str,
timeout: int = 600, timeout: int = 30,
) -> dict[str, Any]: ) -> dict[str, Any]:
if not messages: if not messages:
empty_summary = "<context_summary message_count='0'>{text}</context_summary>" empty_summary = "<context_summary message_count='0'>{text}</context_summary>"
@@ -148,11 +148,11 @@ class MemoryCompressor:
self, self,
max_images: int = 3, max_images: int = 3,
model_name: str | None = None, model_name: str | None = None,
timeout: int = 600, timeout: int | None = None,
): ):
self.max_images = max_images self.max_images = max_images
self.model_name = model_name or Config.get("strix_llm") self.model_name = model_name or Config.get("strix_llm")
self.timeout = timeout self.timeout = timeout or int(Config.get("strix_memory_compressor_timeout") or "30")
if not self.model_name: if not self.model_name:
raise ValueError("STRIX_LLM environment variable must be set and not empty") raise ValueError("STRIX_LLM environment variable must be set and not empty")

View File

@@ -1,58 +0,0 @@
import asyncio
import threading
import time
from collections.abc import AsyncIterator
from typing import Any
from litellm import acompletion
from litellm.types.utils import ModelResponseStream
from strix.config import Config
class LLMRequestQueue:
def __init__(self) -> None:
self.delay_between_requests = float(Config.get("llm_rate_limit_delay") or "4.0")
self.max_concurrent = int(Config.get("llm_rate_limit_concurrent") or "1")
self._semaphore = threading.BoundedSemaphore(self.max_concurrent)
self._last_request_time = 0.0
self._lock = threading.Lock()
async def stream_request(
self, completion_args: dict[str, Any]
) -> AsyncIterator[ModelResponseStream]:
try:
while not self._semaphore.acquire(timeout=0.2):
await asyncio.sleep(0.1)
with self._lock:
now = time.time()
time_since_last = now - self._last_request_time
sleep_needed = max(0, self.delay_between_requests - time_since_last)
self._last_request_time = now + sleep_needed
if sleep_needed > 0:
await asyncio.sleep(sleep_needed)
async for chunk in self._stream_request(completion_args):
yield chunk
finally:
self._semaphore.release()
async def _stream_request(
self, completion_args: dict[str, Any]
) -> AsyncIterator[ModelResponseStream]:
response = await acompletion(**completion_args, stream=True)
async for chunk in response:
yield chunk
_global_queue: LLMRequestQueue | None = None
def get_global_queue() -> LLMRequestQueue:
global _global_queue # noqa: PLW0603
if _global_queue is None:
_global_queue = LLMRequestQueue()
return _global_queue

View File

@@ -18,7 +18,7 @@ def _truncate_to_first_function(content: str) -> str:
def parse_tool_invocations(content: str) -> list[dict[str, Any]] | None: def parse_tool_invocations(content: str) -> list[dict[str, Any]] | None:
content = _fix_stopword(content) content = fix_incomplete_tool_call(content)
tool_invocations: list[dict[str, Any]] = [] tool_invocations: list[dict[str, Any]] = []
@@ -46,16 +46,15 @@ def parse_tool_invocations(content: str) -> list[dict[str, Any]] | None:
return tool_invocations if tool_invocations else None return tool_invocations if tool_invocations else None
def _fix_stopword(content: str) -> str: def fix_incomplete_tool_call(content: str) -> str:
"""Fix incomplete tool calls by adding missing </function> tag."""
if ( if (
"<function=" in content "<function=" in content
and content.count("<function=") == 1 and content.count("<function=") == 1
and "</function>" not in content and "</function>" not in content
): ):
if content.endswith("</"): content = content.rstrip()
content = content.rstrip() + "function>" content = content + "function>" if content.endswith("</") else content + "\n</function>"
else:
content = content + "\n</function>"
return content return content
@@ -74,7 +73,7 @@ def clean_content(content: str) -> str:
if not content: if not content:
return "" return ""
content = _fix_stopword(content) content = fix_incomplete_tool_call(content)
tool_pattern = r"<function=[^>]+>.*?</function>" tool_pattern = r"<function=[^>]+>.*?</function>"
cleaned = re.sub(tool_pattern, "", content, flags=re.DOTALL) cleaned = re.sub(tool_pattern, "", content, flags=re.DOTALL)

View File

@@ -1,11 +1,14 @@
from pathlib import Path
from jinja2 import Environment from jinja2 import Environment
from strix.utils.resource_paths import get_strix_resource_path
def get_available_skills() -> dict[str, list[str]]: def get_available_skills() -> dict[str, list[str]]:
skills_dir = Path(__file__).parent skills_dir = get_strix_resource_path("skills")
available_skills = {} available_skills: dict[str, list[str]] = {}
if not skills_dir.exists():
return available_skills
for category_dir in skills_dir.iterdir(): for category_dir in skills_dir.iterdir():
if category_dir.is_dir() and not category_dir.name.startswith("__"): if category_dir.is_dir() and not category_dir.name.startswith("__"):
@@ -72,7 +75,7 @@ def load_skills(skill_names: list[str], jinja_env: Environment) -> dict[str, str
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
skill_content = {} skill_content = {}
skills_dir = Path(__file__).parent skills_dir = get_strix_resource_path("skills")
available_skills = get_available_skills() available_skills = get_available_skills()

View File

@@ -430,10 +430,8 @@ class Tracer:
"input_tokens": 0, "input_tokens": 0,
"output_tokens": 0, "output_tokens": 0,
"cached_tokens": 0, "cached_tokens": 0,
"cache_creation_tokens": 0,
"cost": 0.0, "cost": 0.0,
"requests": 0, "requests": 0,
"failed_requests": 0,
} }
for agent_instance in _agent_instances.values(): for agent_instance in _agent_instances.values():
@@ -442,10 +440,8 @@ class Tracer:
total_stats["input_tokens"] += agent_stats.input_tokens total_stats["input_tokens"] += agent_stats.input_tokens
total_stats["output_tokens"] += agent_stats.output_tokens total_stats["output_tokens"] += agent_stats.output_tokens
total_stats["cached_tokens"] += agent_stats.cached_tokens total_stats["cached_tokens"] += agent_stats.cached_tokens
total_stats["cache_creation_tokens"] += agent_stats.cache_creation_tokens
total_stats["cost"] += agent_stats.cost total_stats["cost"] += agent_stats.cost
total_stats["requests"] += agent_stats.requests total_stats["requests"] += agent_stats.requests
total_stats["failed_requests"] += agent_stats.failed_requests
total_stats["cost"] = round(total_stats["cost"], 4) total_stats["cost"] = round(total_stats["cost"], 4)

View File

@@ -14,12 +14,13 @@ from .argument_parser import convert_arguments
from .registry import ( from .registry import (
get_tool_by_name, get_tool_by_name,
get_tool_names, get_tool_names,
get_tool_param_schema,
needs_agent_state, needs_agent_state,
should_execute_in_sandbox, should_execute_in_sandbox,
) )
SANDBOX_EXECUTION_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout") or "500") SANDBOX_EXECUTION_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout") or "120")
SANDBOX_CONNECT_TIMEOUT = float(Config.get("strix_sandbox_connect_timeout") or "10") SANDBOX_CONNECT_TIMEOUT = float(Config.get("strix_sandbox_connect_timeout") or "10")
@@ -110,14 +111,51 @@ async def _execute_tool_locally(tool_name: str, agent_state: Any | None, **kwarg
def validate_tool_availability(tool_name: str | None) -> tuple[bool, str]: def validate_tool_availability(tool_name: str | None) -> tuple[bool, str]:
if tool_name is None: if tool_name is None:
return False, "Tool name is missing" available = ", ".join(sorted(get_tool_names()))
return False, f"Tool name is missing. Available tools: {available}"
if tool_name not in get_tool_names(): if tool_name not in get_tool_names():
return False, f"Tool '{tool_name}' is not available" available = ", ".join(sorted(get_tool_names()))
return False, f"Tool '{tool_name}' is not available. Available tools: {available}"
return True, "" return True, ""
def _validate_tool_arguments(tool_name: str, kwargs: dict[str, Any]) -> str | None:
param_schema = get_tool_param_schema(tool_name)
if not param_schema or not param_schema.get("has_params"):
return None
allowed_params: set[str] = param_schema.get("params", set())
required_params: set[str] = param_schema.get("required", set())
optional_params = allowed_params - required_params
schema_hint = _format_schema_hint(tool_name, required_params, optional_params)
unknown_params = set(kwargs.keys()) - allowed_params
if unknown_params:
unknown_list = ", ".join(sorted(unknown_params))
return f"Tool '{tool_name}' received unknown parameter(s): {unknown_list}\n{schema_hint}"
missing_required = [
param for param in required_params if param not in kwargs or kwargs.get(param) in (None, "")
]
if missing_required:
missing_list = ", ".join(sorted(missing_required))
return f"Tool '{tool_name}' missing required parameter(s): {missing_list}\n{schema_hint}"
return None
def _format_schema_hint(tool_name: str, required: set[str], optional: set[str]) -> str:
parts = [f"Valid parameters for '{tool_name}':"]
if required:
parts.append(f" Required: {', '.join(sorted(required))}")
if optional:
parts.append(f" Optional: {', '.join(sorted(optional))}")
return "\n".join(parts)
async def execute_tool_with_validation( async def execute_tool_with_validation(
tool_name: str | None, agent_state: Any | None = None, **kwargs: Any tool_name: str | None, agent_state: Any | None = None, **kwargs: Any
) -> Any: ) -> Any:
@@ -127,6 +165,10 @@ async def execute_tool_with_validation(
assert tool_name is not None assert tool_name is not None
arg_error = _validate_tool_arguments(tool_name, kwargs)
if arg_error:
return f"Error: {arg_error}"
try: try:
result = await execute_tool(tool_name, agent_state, **kwargs) result = await execute_tool(tool_name, agent_state, **kwargs)
except Exception as e: # noqa: BLE001 except Exception as e: # noqa: BLE001

View File

@@ -55,6 +55,7 @@
- Print statements and stdout are captured - Print statements and stdout are captured
- Variables persist between executions in the same session - Variables persist between executions in the same session
- Imports, function definitions, etc. persist in the session - Imports, function definitions, etc. persist in the session
- IMPORTANT (multiline): Put real line breaks in <parameter=code>. Do NOT emit literal "\n" sequences.
- IPython magic commands are fully supported (%pip, %time, %whos, %%writefile, etc.) - IPython magic commands are fully supported (%pip, %time, %whos, %%writefile, etc.)
- Line magics (%) and cell magics (%%) work as expected - Line magics (%) and cell magics (%%) work as expected
6. CLOSE: Terminates the session completely and frees memory 6. CLOSE: Terminates the session completely and frees memory
@@ -73,6 +74,14 @@
print("Security analysis session started")</parameter> print("Security analysis session started")</parameter>
</function> </function>
<function=python_action>
<parameter=action>execute</parameter>
<parameter=code>import requests
url = "https://example.com"
resp = requests.get(url, timeout=10)
print(resp.status_code)</parameter>
</function>
# Analyze security data in the default session # Analyze security data in the default session
<function=python_action> <function=python_action>
<parameter=action>execute</parameter> <parameter=action>execute</parameter>

View File

@@ -7,9 +7,14 @@ from inspect import signature
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
import defusedxml.ElementTree as DefusedET
from strix.utils.resource_paths import get_strix_resource_path
tools: list[dict[str, Any]] = [] tools: list[dict[str, Any]] = []
_tools_by_name: dict[str, Callable[..., Any]] = {} _tools_by_name: dict[str, Callable[..., Any]] = {}
_tool_param_schemas: dict[str, dict[str, Any]] = {}
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -82,6 +87,34 @@ def _load_xml_schema(path: Path) -> Any:
return tools_dict return tools_dict
def _parse_param_schema(tool_xml: str) -> dict[str, Any]:
params: set[str] = set()
required: set[str] = set()
params_start = tool_xml.find("<parameters>")
params_end = tool_xml.find("</parameters>")
if params_start == -1 or params_end == -1:
return {"params": set(), "required": set(), "has_params": False}
params_section = tool_xml[params_start : params_end + len("</parameters>")]
try:
root = DefusedET.fromstring(params_section)
except DefusedET.ParseError:
return {"params": set(), "required": set(), "has_params": False}
for param in root.findall(".//parameter"):
name = param.attrib.get("name")
if not name:
continue
params.add(name)
if param.attrib.get("required", "false").lower() == "true":
required.add(name)
return {"params": params, "required": required, "has_params": bool(params or required)}
def _get_module_name(func: Callable[..., Any]) -> str: def _get_module_name(func: Callable[..., Any]) -> str:
module = inspect.getmodule(func) module = inspect.getmodule(func)
if not module: if not module:
@@ -95,6 +128,27 @@ def _get_module_name(func: Callable[..., Any]) -> str:
return "unknown" return "unknown"
def _get_schema_path(func: Callable[..., Any]) -> Path | None:
module = inspect.getmodule(func)
if not module or not module.__name__:
return None
module_name = module.__name__
if ".tools." not in module_name:
return None
parts = module_name.split(".tools.")[-1].split(".")
if len(parts) < 2:
return None
folder = parts[0]
file_stem = parts[1]
schema_file = f"{file_stem}_schema.xml"
return get_strix_resource_path("tools", folder, schema_file)
def register_tool( def register_tool(
func: Callable[..., Any] | None = None, *, sandbox_execution: bool = True func: Callable[..., Any] | None = None, *, sandbox_execution: bool = True
) -> Callable[..., Any]: ) -> Callable[..., Any]:
@@ -109,11 +163,8 @@ def register_tool(
sandbox_mode = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true" sandbox_mode = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
if not sandbox_mode: if not sandbox_mode:
try: try:
module_path = Path(inspect.getfile(f)) schema_path = _get_schema_path(f)
schema_file_name = f"{module_path.stem}_schema.xml" xml_tools = _load_xml_schema(schema_path) if schema_path else None
schema_path = module_path.parent / schema_file_name
xml_tools = _load_xml_schema(schema_path)
if xml_tools is not None and f.__name__ in xml_tools: if xml_tools is not None and f.__name__ in xml_tools:
func_dict["xml_schema"] = xml_tools[f.__name__] func_dict["xml_schema"] = xml_tools[f.__name__]
@@ -131,6 +182,11 @@ def register_tool(
"</tool>" "</tool>"
) )
if not sandbox_mode:
xml_schema = func_dict.get("xml_schema")
param_schema = _parse_param_schema(xml_schema if isinstance(xml_schema, str) else "")
_tool_param_schemas[str(func_dict["name"])] = param_schema
tools.append(func_dict) tools.append(func_dict)
_tools_by_name[str(func_dict["name"])] = f _tools_by_name[str(func_dict["name"])] = f
@@ -153,6 +209,10 @@ def get_tool_names() -> list[str]:
return list(_tools_by_name.keys()) return list(_tools_by_name.keys())
def get_tool_param_schema(name: str) -> dict[str, Any] | None:
return _tool_param_schemas.get(name)
def needs_agent_state(tool_name: str) -> bool: def needs_agent_state(tool_name: str) -> bool:
tool_func = get_tool_by_name(tool_name) tool_func = get_tool_by_name(tool_name)
if not tool_func: if not tool_func:
@@ -194,3 +254,4 @@ def get_tools_prompt() -> str:
def clear_registry() -> None: def clear_registry() -> None:
tools.clear() tools.clear()
_tools_by_name.clear() _tools_by_name.clear()
_tool_param_schemas.clear()

View File

@@ -95,6 +95,12 @@
<parameter=command>ls -la</parameter> <parameter=command>ls -la</parameter>
</function> </function>
<function=terminal_execute>
<parameter=command>cd /workspace
pwd
ls -la</parameter>
</function>
# Run a command with custom timeout # Run a command with custom timeout
<function=terminal_execute> <function=terminal_execute>
<parameter=command>npm install</parameter> <parameter=command>npm install</parameter>

0
strix/utils/__init__.py Normal file
View File

View File

@@ -0,0 +1,13 @@
import sys
from pathlib import Path
def get_strix_resource_path(*parts: str) -> Path:
frozen_base = getattr(sys, "_MEIPASS", None)
if frozen_base:
base = Path(frozen_base) / "strix"
if base.exists():
return base.joinpath(*parts)
base = Path(__file__).resolve().parent.parent
return base.joinpath(*parts)