Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e5104eb93a | ||
|
|
d8a08e9a8c | ||
|
|
f6475cec07 | ||
|
|
31baa0dfc0 | ||
|
|
56526cbf90 | ||
|
|
47faeb1ef3 | ||
|
|
435ac82d9e | ||
|
|
f08014cf51 | ||
|
|
bc8e14f68a | ||
|
|
eae2b783c0 | ||
|
|
058cf1abdb | ||
|
|
d16bdb277a |
79
README.md
79
README.md
@@ -1,55 +1,61 @@
|
|||||||
<p align="center">
|
<p align="center">
|
||||||
<a href="https://strix.ai/">
|
<a href="https://strix.ai/">
|
||||||
<img src=".github/logo.png" width="150" alt="Strix Logo">
|
<img src="https://github.com/usestrix/.github/raw/main/imgs/cover.png" alt="Strix Banner" width="100%">
|
||||||
</a>
|
</a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<h1 align="center">Strix</h1>
|
|
||||||
|
|
||||||
<h2 align="center">Open-source AI Hackers to secure your Apps</h2>
|
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
|
|
||||||
[](https://pypi.org/project/strix-agent/)
|
# Strix
|
||||||
[](https://pypi.org/project/strix-agent/)
|
|
||||||
[](LICENSE)
|
|
||||||
[](https://docs.strix.ai)
|
|
||||||
|
|
||||||
[](https://github.com/usestrix/strix)
|
### Open-source AI hackers to find and fix your app’s vulnerabilities.
|
||||||
[](https://discord.gg/YjKFvEZSdZ)
|
|
||||||
[](https://strix.ai)
|
|
||||||
|
|
||||||
<a href="https://trendshift.io/repositories/15362" target="_blank"><img src="https://trendshift.io/api/badge/repositories/15362" alt="usestrix%2Fstrix | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
<br/>
|
||||||
|
|
||||||
|
|
||||||
[](https://deepwiki.com/usestrix/strix)
|
<a href="https://docs.strix.ai"><img src="https://img.shields.io/badge/Docs-docs.strix.ai-2b9246?style=for-the-badge&logo=gitbook&logoColor=white" alt="Docs"></a>
|
||||||
|
<a href="https://strix.ai"><img src="https://img.shields.io/badge/Website-strix.ai-3b82f6?style=for-the-badge&logoColor=white" alt="Website"></a>
|
||||||
|
<a href="https://pypi.org/project/strix-agent/"><img src="https://img.shields.io/badge/PyPI-strix--agent-f59e0b?style=for-the-badge&logo=pypi&logoColor=white" alt="PyPI"></a>
|
||||||
|
|
||||||
|
<a href="https://deepwiki.com/usestrix/strix"><img src="https://deepwiki.com/badge.svg" alt="Ask DeepWiki"></a>
|
||||||
|
<a href="https://github.com/usestrix/strix"><img src="https://img.shields.io/github/stars/usestrix/strix?style=flat-square" alt="GitHub Stars"></a>
|
||||||
|
<a href="LICENSE"><img src="https://img.shields.io/badge/License-Apache%202.0-3b82f6?style=flat-square" alt="License"></a>
|
||||||
|
<a href="https://pypi.org/project/strix-agent/"><img src="https://img.shields.io/pypi/v/strix-agent?style=flat-square" alt="PyPI Version"></a>
|
||||||
|
|
||||||
|
|
||||||
|
<a href="https://discord.gg/YjKFvEZSdZ"><img src="https://github.com/usestrix/.github/raw/main/imgs/Discord.png" height="40" alt="Join Discord"></a>
|
||||||
|
<a href="https://x.com/strix_ai"><img src="https://github.com/usestrix/.github/raw/main/imgs/X.png" height="40" alt="Follow on X"></a>
|
||||||
|
|
||||||
|
|
||||||
|
<a href="https://trendshift.io/repositories/15362" target="_blank"><img src="https://trendshift.io/api/badge/repositories/15362" alt="usestrix/strix | Trendshift" width="250" height="55"/></a>
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<br>
|
<br/>
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
<img src=".github/screenshot.png" alt="Strix Demo" width="800" style="border-radius: 16px;">
|
<img src=".github/screenshot.png" alt="Strix Demo" width="900" style="border-radius: 16px;">
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<br>
|
<br>
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> **New!** Strix now integrates seamlessly with GitHub Actions and CI/CD pipelines. Automatically scan for vulnerabilities on every pull request and block insecure code before it reaches production!
|
> **New!** Strix integrates seamlessly with GitHub Actions and CI/CD pipelines. Automatically scan for vulnerabilities on every pull request and block insecure code before it reaches production!
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 🦉 Strix Overview
|
|
||||||
|
## Strix Overview
|
||||||
|
|
||||||
Strix are autonomous AI agents that act just like real hackers - they run your code dynamically, find vulnerabilities, and validate them through actual proof-of-concepts. Built for developers and security teams who need fast, accurate security testing without the overhead of manual pentesting or the false positives of static analysis tools.
|
Strix are autonomous AI agents that act just like real hackers - they run your code dynamically, find vulnerabilities, and validate them through actual proof-of-concepts. Built for developers and security teams who need fast, accurate security testing without the overhead of manual pentesting or the false positives of static analysis tools.
|
||||||
|
|
||||||
**Key Capabilities:**
|
**Key Capabilities:**
|
||||||
|
|
||||||
- 🔧 **Full hacker toolkit** out of the box
|
- **Full hacker toolkit** out of the box
|
||||||
- 🤝 **Teams of agents** that collaborate and scale
|
- **Teams of agents** that collaborate and scale
|
||||||
- ✅ **Real validation** with PoCs, not false positives
|
- **Real validation** with PoCs, not false positives
|
||||||
- 💻 **Developer‑first** CLI with actionable reports
|
- **Developer‑first** CLI with actionable reports
|
||||||
- 🔄 **Auto‑fix & reporting** to accelerate remediation
|
- **Auto‑fix & reporting** to accelerate remediation
|
||||||
|
|
||||||
|
|
||||||
## 🎯 Use Cases
|
## 🎯 Use Cases
|
||||||
@@ -87,7 +93,7 @@ strix --target ./app-directory
|
|||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> First run automatically pulls the sandbox Docker image. Results are saved to `strix_runs/<run-name>`
|
> First run automatically pulls the sandbox Docker image. Results are saved to `strix_runs/<run-name>`
|
||||||
|
|
||||||
## ☁️ Run Strix in Cloud
|
## Run Strix in Cloud
|
||||||
|
|
||||||
Want to skip the local setup, API keys, and unpredictable LLM costs? Run the hosted cloud version of Strix at **[app.strix.ai](https://strix.ai)**.
|
Want to skip the local setup, API keys, and unpredictable LLM costs? Run the hosted cloud version of Strix at **[app.strix.ai](https://strix.ai)**.
|
||||||
|
|
||||||
@@ -104,7 +110,7 @@ Launch a scan in just a few minutes—no setup or configuration required—and y
|
|||||||
|
|
||||||
## ✨ Features
|
## ✨ Features
|
||||||
|
|
||||||
### 🛠️ Agentic Security Tools
|
### Agentic Security Tools
|
||||||
|
|
||||||
Strix agents come equipped with a comprehensive security testing toolkit:
|
Strix agents come equipped with a comprehensive security testing toolkit:
|
||||||
|
|
||||||
@@ -116,7 +122,7 @@ Strix agents come equipped with a comprehensive security testing toolkit:
|
|||||||
- **Code Analysis** - Static and dynamic analysis capabilities
|
- **Code Analysis** - Static and dynamic analysis capabilities
|
||||||
- **Knowledge Management** - Structured findings and attack documentation
|
- **Knowledge Management** - Structured findings and attack documentation
|
||||||
|
|
||||||
### 🎯 Comprehensive Vulnerability Detection
|
### Comprehensive Vulnerability Detection
|
||||||
|
|
||||||
Strix can identify and validate a wide range of security vulnerabilities:
|
Strix can identify and validate a wide range of security vulnerabilities:
|
||||||
|
|
||||||
@@ -128,7 +134,7 @@ Strix can identify and validate a wide range of security vulnerabilities:
|
|||||||
- **Authentication** - JWT vulnerabilities, session management
|
- **Authentication** - JWT vulnerabilities, session management
|
||||||
- **Infrastructure** - Misconfigurations, exposed services
|
- **Infrastructure** - Misconfigurations, exposed services
|
||||||
|
|
||||||
### 🕸️ Graph of Agents
|
### Graph of Agents
|
||||||
|
|
||||||
Advanced multi-agent orchestration for comprehensive security testing:
|
Advanced multi-agent orchestration for comprehensive security testing:
|
||||||
|
|
||||||
@@ -138,7 +144,7 @@ Advanced multi-agent orchestration for comprehensive security testing:
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 💻 Usage Examples
|
## Usage Examples
|
||||||
|
|
||||||
### Basic Usage
|
### Basic Usage
|
||||||
|
|
||||||
@@ -169,7 +175,7 @@ strix --target api.your-app.com --instruction "Focus on business logic flaws and
|
|||||||
strix --target api.your-app.com --instruction-file ./instruction.md
|
strix --target api.your-app.com --instruction-file ./instruction.md
|
||||||
```
|
```
|
||||||
|
|
||||||
### 🤖 Headless Mode
|
### Headless Mode
|
||||||
|
|
||||||
Run Strix programmatically without interactive UI using the `-n/--non-interactive` flag—perfect for servers and automated jobs. The CLI prints real-time vulnerability findings, and the final report before exiting. Exits with non-zero code when vulnerabilities are found.
|
Run Strix programmatically without interactive UI using the `-n/--non-interactive` flag—perfect for servers and automated jobs. The CLI prints real-time vulnerability findings, and the final report before exiting. Exits with non-zero code when vulnerabilities are found.
|
||||||
|
|
||||||
@@ -177,7 +183,7 @@ Run Strix programmatically without interactive UI using the `-n/--non-interactiv
|
|||||||
strix -n --target https://your-app.com
|
strix -n --target https://your-app.com
|
||||||
```
|
```
|
||||||
|
|
||||||
### 🔄 CI/CD (GitHub Actions)
|
### CI/CD (GitHub Actions)
|
||||||
|
|
||||||
Strix can be added to your pipeline to run a security test on pull requests with a lightweight GitHub Actions workflow:
|
Strix can be added to your pipeline to run a security test on pull requests with a lightweight GitHub Actions workflow:
|
||||||
|
|
||||||
@@ -204,7 +210,7 @@ jobs:
|
|||||||
run: strix -n -t ./ --scan-mode quick
|
run: strix -n -t ./ --scan-mode quick
|
||||||
```
|
```
|
||||||
|
|
||||||
### ⚙️ Configuration
|
### Configuration
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export STRIX_LLM="openai/gpt-5"
|
export STRIX_LLM="openai/gpt-5"
|
||||||
@@ -227,22 +233,23 @@ export STRIX_REASONING_EFFORT="high" # control thinking effort (default: high,
|
|||||||
|
|
||||||
See the [LLM Providers documentation](https://docs.strix.ai/llm-providers/overview) for all supported providers including Vertex AI, Bedrock, Azure, and local models.
|
See the [LLM Providers documentation](https://docs.strix.ai/llm-providers/overview) for all supported providers including Vertex AI, Bedrock, Azure, and local models.
|
||||||
|
|
||||||
## 📚 Documentation
|
## Documentation
|
||||||
|
|
||||||
Full documentation is available at **[docs.strix.ai](https://docs.strix.ai)** — including detailed guides for usage, CI/CD integrations, skills, and advanced configuration.
|
Full documentation is available at **[docs.strix.ai](https://docs.strix.ai)** — including detailed guides for usage, CI/CD integrations, skills, and advanced configuration.
|
||||||
|
|
||||||
## 🤝 Contributing
|
## Contributing
|
||||||
|
|
||||||
We welcome contributions of code, docs, and new skills - check out our [Contributing Guide](https://docs.strix.ai/contributing) to get started or open a [pull request](https://github.com/usestrix/strix/pulls)/[issue](https://github.com/usestrix/strix/issues).
|
We welcome contributions of code, docs, and new skills - check out our [Contributing Guide](https://docs.strix.ai/contributing) to get started or open a [pull request](https://github.com/usestrix/strix/pulls)/[issue](https://github.com/usestrix/strix/issues).
|
||||||
|
|
||||||
## 👥 Join Our Community
|
## Join Our Community
|
||||||
|
|
||||||
Have questions? Found a bug? Want to contribute? **[Join our Discord!](https://discord.gg/YjKFvEZSdZ)**
|
Have questions? Found a bug? Want to contribute? **[Join our Discord!](https://discord.gg/YjKFvEZSdZ)**
|
||||||
|
|
||||||
## 🌟 Support the Project
|
## Support the Project
|
||||||
|
|
||||||
**Love Strix?** Give us a ⭐ on GitHub!
|
**Love Strix?** Give us a ⭐ on GitHub!
|
||||||
## 🙏 Acknowledgements
|
|
||||||
|
## Acknowledgements
|
||||||
|
|
||||||
Strix builds on the incredible work of open-source projects like [LiteLLM](https://github.com/BerriAI/litellm), [Caido](https://github.com/caido/caido), [ProjectDiscovery](https://github.com/projectdiscovery), [Playwright](https://github.com/microsoft/playwright), and [Textual](https://github.com/Textualize/textual). Huge thanks to their maintainers!
|
Strix builds on the incredible work of open-source projects like [LiteLLM](https://github.com/BerriAI/litellm), [Caido](https://github.com/caido/caido), [ProjectDiscovery](https://github.com/projectdiscovery), [Playwright](https://github.com/microsoft/playwright), and [Textual](https://github.com/Textualize/textual). Huge thanks to their maintainers!
|
||||||
|
|
||||||
|
|||||||
35
poetry.lock
generated
35
poetry.lock
generated
@@ -379,19 +379,18 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "azure-core"
|
name = "azure-core"
|
||||||
version = "1.35.0"
|
version = "1.38.0"
|
||||||
description = "Microsoft Azure Core Library for Python"
|
description = "Microsoft Azure Core Library for Python"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "azure_core-1.35.0-py3-none-any.whl", hash = "sha256:8db78c72868a58f3de8991eb4d22c4d368fae226dac1002998d6c50437e7dad1"},
|
{file = "azure_core-1.38.0-py3-none-any.whl", hash = "sha256:ab0c9b2cd71fecb1842d52c965c95285d3cfb38902f6766e4a471f1cd8905335"},
|
||||||
{file = "azure_core-1.35.0.tar.gz", hash = "sha256:c0be528489485e9ede59b6971eb63c1eaacf83ef53001bfe3904e475e972be5c"},
|
{file = "azure_core-1.38.0.tar.gz", hash = "sha256:8194d2682245a3e4e3151a667c686464c3786fed7918b394d035bdcd61bb5993"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
requests = ">=2.21.0"
|
requests = ">=2.21.0"
|
||||||
six = ">=1.11.0"
|
|
||||||
typing-extensions = ">=4.6.0"
|
typing-extensions = ">=4.6.0"
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
@@ -1199,6 +1198,18 @@ files = [
|
|||||||
{file = "decorator-5.2.1.tar.gz", hash = "sha256:65f266143752f734b0a7cc83c46f4618af75b8c5911b00ccb61d0ac9b6da0360"},
|
{file = "decorator-5.2.1.tar.gz", hash = "sha256:65f266143752f734b0a7cc83c46f4618af75b8c5911b00ccb61d0ac9b6da0360"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "defusedxml"
|
||||||
|
version = "0.7.1"
|
||||||
|
description = "XML bomb protection for Python stdlib modules"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
||||||
|
groups = ["main"]
|
||||||
|
files = [
|
||||||
|
{file = "defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61"},
|
||||||
|
{file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "dill"
|
name = "dill"
|
||||||
version = "0.4.0"
|
version = "0.4.0"
|
||||||
@@ -1490,14 +1501,14 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "filelock"
|
name = "filelock"
|
||||||
version = "3.20.1"
|
version = "3.20.3"
|
||||||
description = "A platform independent file lock."
|
description = "A platform independent file lock."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.10"
|
python-versions = ">=3.10"
|
||||||
groups = ["main", "dev"]
|
groups = ["main", "dev"]
|
||||||
files = [
|
files = [
|
||||||
{file = "filelock-3.20.1-py3-none-any.whl", hash = "sha256:15d9e9a67306188a44baa72f569d2bfd803076269365fdea0934385da4dc361a"},
|
{file = "filelock-3.20.3-py3-none-any.whl", hash = "sha256:4b0dda527ee31078689fc205ec4f1c1bf7d56cf88b6dc9426c4f230e46c2dce1"},
|
||||||
{file = "filelock-3.20.1.tar.gz", hash = "sha256:b8360948b351b80f420878d8516519a2204b07aefcdcfd24912a5d33127f188c"},
|
{file = "filelock-3.20.3.tar.gz", hash = "sha256:18c57ee915c7ec61cff0ecf7f0f869936c7c30191bb0cf406f1341778d0834e1"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -7095,19 +7106,19 @@ test = ["aiohttp (>=3.10.5)", "flake8 (>=5.0,<6.0)", "mypy (>=0.800)", "psutil",
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "virtualenv"
|
name = "virtualenv"
|
||||||
version = "20.34.0"
|
version = "20.36.1"
|
||||||
description = "Virtual Python Environment builder"
|
description = "Virtual Python Environment builder"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
groups = ["dev"]
|
groups = ["dev"]
|
||||||
files = [
|
files = [
|
||||||
{file = "virtualenv-20.34.0-py3-none-any.whl", hash = "sha256:341f5afa7eee943e4984a9207c025feedd768baff6753cd660c857ceb3e36026"},
|
{file = "virtualenv-20.36.1-py3-none-any.whl", hash = "sha256:575a8d6b124ef88f6f51d56d656132389f961062a9177016a50e4f507bbcc19f"},
|
||||||
{file = "virtualenv-20.34.0.tar.gz", hash = "sha256:44815b2c9dee7ed86e387b842a84f20b93f7f417f95886ca1996a72a4138eb1a"},
|
{file = "virtualenv-20.36.1.tar.gz", hash = "sha256:8befb5c81842c641f8ee658481e42641c68b5eab3521d8e092d18320902466ba"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
distlib = ">=0.3.7,<1"
|
distlib = ">=0.3.7,<1"
|
||||||
filelock = ">=3.12.2,<4"
|
filelock = {version = ">=3.20.1,<4", markers = "python_version >= \"3.10\""}
|
||||||
platformdirs = ">=3.9.1,<5"
|
platformdirs = ">=3.9.1,<5"
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
@@ -7424,4 +7435,4 @@ vertex = ["google-cloud-aiplatform"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.1"
|
lock-version = "2.1"
|
||||||
python-versions = "^3.12"
|
python-versions = "^3.12"
|
||||||
content-hash = "91f49e313e5690bbef87e17730441f26d366daeccb16b5020e03e581fbb9d4d5"
|
content-hash = "0424a0e82fe49501f3a80166676e257a9dae97093d9bc730489789195f523735"
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "strix-agent"
|
name = "strix-agent"
|
||||||
version = "0.6.0"
|
version = "0.6.1"
|
||||||
description = "Open-source AI Hackers for your apps"
|
description = "Open-source AI Hackers for your apps"
|
||||||
authors = ["Strix <hi@usestrix.com>"]
|
authors = ["Strix <hi@usestrix.com>"]
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
@@ -69,6 +69,7 @@ gql = { version = "^3.5.3", extras = ["requests"], optional = true }
|
|||||||
pyte = { version = "^0.8.1", optional = true }
|
pyte = { version = "^0.8.1", optional = true }
|
||||||
libtmux = { version = "^0.46.2", optional = true }
|
libtmux = { version = "^0.46.2", optional = true }
|
||||||
numpydoc = { version = "^1.8.0", optional = true }
|
numpydoc = { version = "^1.8.0", optional = true }
|
||||||
|
defusedxml = "^0.7.1"
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
vertex = ["google-cloud-aiplatform"]
|
vertex = ["google-cloud-aiplatform"]
|
||||||
|
|||||||
@@ -111,7 +111,6 @@ hiddenimports = [
|
|||||||
'strix.llm.llm',
|
'strix.llm.llm',
|
||||||
'strix.llm.config',
|
'strix.llm.config',
|
||||||
'strix.llm.utils',
|
'strix.llm.utils',
|
||||||
'strix.llm.request_queue',
|
|
||||||
'strix.llm.memory_compressor',
|
'strix.llm.memory_compressor',
|
||||||
'strix.runtime',
|
'strix.runtime',
|
||||||
'strix.runtime.runtime',
|
'strix.runtime.runtime',
|
||||||
|
|||||||
@@ -308,17 +308,18 @@ Tool calls use XML format:
|
|||||||
|
|
||||||
CRITICAL RULES:
|
CRITICAL RULES:
|
||||||
0. While active in the agent loop, EVERY message you output MUST be a single tool call. Do not send plain text-only responses.
|
0. While active in the agent loop, EVERY message you output MUST be a single tool call. Do not send plain text-only responses.
|
||||||
1. One tool call per message
|
1. Exactly one tool call per message — never include more than one <function>...</function> block in a single LLM message.
|
||||||
2. Tool call must be last in message
|
2. Tool call must be last in message
|
||||||
3. End response after </function> tag. It's your stop word. Do not continue after it.
|
3. EVERY tool call MUST end with </function>. This is MANDATORY. Never omit the closing tag. End your response immediately after </function>.
|
||||||
4. Use ONLY the exact XML format shown above. NEVER use JSON/YAML/INI or any other syntax for tools or parameters.
|
4. Use ONLY the exact XML format shown above. NEVER use JSON/YAML/INI or any other syntax for tools or parameters.
|
||||||
5. Tool names must match exactly the tool "name" defined (no module prefixes, dots, or variants).
|
5. When sending ANY multi-line content in tool parameters, use real newlines (actual line breaks). Do NOT emit literal "\n" sequences. If you send "\n" instead of real line breaks inside the XML parameter value, tools may fail or behave incorrectly.
|
||||||
|
6. Tool names must match exactly the tool "name" defined (no module prefixes, dots, or variants).
|
||||||
- Correct: <function=think> ... </function>
|
- Correct: <function=think> ... </function>
|
||||||
- Incorrect: <thinking_tools.think> ... </function>
|
- Incorrect: <thinking_tools.think> ... </function>
|
||||||
- Incorrect: <think> ... </think>
|
- Incorrect: <think> ... </think>
|
||||||
- Incorrect: {"think": {...}}
|
- Incorrect: {"think": {...}}
|
||||||
6. Parameters must use <parameter=param_name>value</parameter> exactly. Do NOT pass parameters as JSON or key:value lines. Do NOT add quotes/braces around values.
|
7. Parameters must use <parameter=param_name>value</parameter> exactly. Do NOT pass parameters as JSON or key:value lines. Do NOT add quotes/braces around values.
|
||||||
7. Do NOT wrap tool calls in markdown/code fences or add any text before or after the tool block.
|
8. Do NOT wrap tool calls in markdown/code fences or add any text before or after the tool block.
|
||||||
|
|
||||||
Example (agent creation tool):
|
Example (agent creation tool):
|
||||||
<function=create_agent>
|
<function=create_agent>
|
||||||
@@ -331,6 +332,8 @@ SPRAYING EXECUTION NOTE:
|
|||||||
- When performing large payload sprays or fuzzing, encapsulate the entire spraying loop inside a single python or terminal tool call (e.g., a Python script using asyncio/aiohttp). Do not issue one tool call per payload.
|
- When performing large payload sprays or fuzzing, encapsulate the entire spraying loop inside a single python or terminal tool call (e.g., a Python script using asyncio/aiohttp). Do not issue one tool call per payload.
|
||||||
- Favor batch-mode CLI tools (sqlmap, ffuf, nuclei, zaproxy, arjun) where appropriate and check traffic via the proxy when beneficial
|
- Favor batch-mode CLI tools (sqlmap, ffuf, nuclei, zaproxy, arjun) where appropriate and check traffic via the proxy when beneficial
|
||||||
|
|
||||||
|
REMINDER: Always close each tool call with </function> before going into the next. Incomplete tool calls will fail.
|
||||||
|
|
||||||
{{ get_tools_prompt() }}
|
{{ get_tools_prompt() }}
|
||||||
</tool_usage>
|
</tool_usage>
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
|
|
||||||
@@ -18,6 +17,7 @@ from strix.llm import LLM, LLMConfig, LLMRequestFailedError
|
|||||||
from strix.llm.utils import clean_content
|
from strix.llm.utils import clean_content
|
||||||
from strix.runtime import SandboxInitializationError
|
from strix.runtime import SandboxInitializationError
|
||||||
from strix.tools import process_tool_invocations
|
from strix.tools import process_tool_invocations
|
||||||
|
from strix.utils.resource_paths import get_strix_resource_path
|
||||||
|
|
||||||
from .state import AgentState
|
from .state import AgentState
|
||||||
|
|
||||||
@@ -35,8 +35,7 @@ class AgentMeta(type):
|
|||||||
if name == "BaseAgent":
|
if name == "BaseAgent":
|
||||||
return new_cls
|
return new_cls
|
||||||
|
|
||||||
agents_dir = Path(__file__).parent
|
prompt_dir = get_strix_resource_path("agents", name)
|
||||||
prompt_dir = agents_dir / name
|
|
||||||
|
|
||||||
new_cls.agent_name = name
|
new_cls.agent_name = name
|
||||||
new_cls.jinja_env = Environment(
|
new_cls.jinja_env = Environment(
|
||||||
@@ -66,20 +65,21 @@ class BaseAgent(metaclass=AgentMeta):
|
|||||||
self.llm_config = config.get("llm_config", self.default_llm_config)
|
self.llm_config = config.get("llm_config", self.default_llm_config)
|
||||||
if self.llm_config is None:
|
if self.llm_config is None:
|
||||||
raise ValueError("llm_config is required but not provided")
|
raise ValueError("llm_config is required but not provided")
|
||||||
self.llm = LLM(self.llm_config, agent_name=self.agent_name)
|
|
||||||
|
|
||||||
state_from_config = config.get("state")
|
state_from_config = config.get("state")
|
||||||
if state_from_config is not None:
|
if state_from_config is not None:
|
||||||
self.state = state_from_config
|
self.state = state_from_config
|
||||||
else:
|
else:
|
||||||
self.state = AgentState(
|
self.state = AgentState(
|
||||||
agent_name=self.agent_name,
|
agent_name="Root Agent",
|
||||||
max_iterations=self.max_iterations,
|
max_iterations=self.max_iterations,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.llm = LLM(self.llm_config, agent_name=self.agent_name)
|
||||||
|
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
self.llm.set_agent_identity(self.agent_name, self.state.agent_id)
|
self.llm.set_agent_identity(self.state.agent_name, self.state.agent_id)
|
||||||
self._current_task: asyncio.Task[Any] | None = None
|
self._current_task: asyncio.Task[Any] | None = None
|
||||||
|
self._force_stop = False
|
||||||
|
|
||||||
from strix.telemetry.tracer import get_global_tracer
|
from strix.telemetry.tracer import get_global_tracer
|
||||||
|
|
||||||
@@ -157,6 +157,11 @@ class BaseAgent(metaclass=AgentMeta):
|
|||||||
return self._handle_sandbox_error(e, tracer)
|
return self._handle_sandbox_error(e, tracer)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
if self._force_stop:
|
||||||
|
self._force_stop = False
|
||||||
|
await self._enter_waiting_state(tracer, was_cancelled=True)
|
||||||
|
continue
|
||||||
|
|
||||||
self._check_agent_messages(self.state)
|
self._check_agent_messages(self.state)
|
||||||
|
|
||||||
if self.state.is_waiting_for_input():
|
if self.state.is_waiting_for_input():
|
||||||
@@ -247,7 +252,8 @@ class BaseAgent(metaclass=AgentMeta):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
async def _wait_for_input(self) -> None:
|
async def _wait_for_input(self) -> None:
|
||||||
import asyncio
|
if self._force_stop:
|
||||||
|
return
|
||||||
|
|
||||||
if self.state.has_waiting_timeout():
|
if self.state.has_waiting_timeout():
|
||||||
self.state.resume_from_waiting()
|
self.state.resume_from_waiting()
|
||||||
@@ -340,6 +346,7 @@ class BaseAgent(metaclass=AgentMeta):
|
|||||||
|
|
||||||
async def _process_iteration(self, tracer: Optional["Tracer"]) -> bool:
|
async def _process_iteration(self, tracer: Optional["Tracer"]) -> bool:
|
||||||
final_response = None
|
final_response = None
|
||||||
|
|
||||||
async for response in self.llm.generate(self.state.get_conversation_history()):
|
async for response in self.llm.generate(self.state.get_conversation_history()):
|
||||||
final_response = response
|
final_response = response
|
||||||
if tracer and response.content:
|
if tracer and response.content:
|
||||||
@@ -585,6 +592,11 @@ class BaseAgent(metaclass=AgentMeta):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def cancel_current_execution(self) -> None:
|
def cancel_current_execution(self) -> None:
|
||||||
|
self._force_stop = True
|
||||||
if self._current_task and not self._current_task.done():
|
if self._current_task and not self._current_task.done():
|
||||||
self._current_task.cancel()
|
try:
|
||||||
|
loop = self._current_task.get_loop()
|
||||||
|
loop.call_soon_threadsafe(self._current_task.cancel)
|
||||||
|
except RuntimeError:
|
||||||
|
self._current_task.cancel()
|
||||||
self._current_task = None
|
self._current_task = None
|
||||||
|
|||||||
@@ -16,9 +16,9 @@ class Config:
|
|||||||
litellm_base_url = None
|
litellm_base_url = None
|
||||||
ollama_api_base = None
|
ollama_api_base = None
|
||||||
strix_reasoning_effort = "high"
|
strix_reasoning_effort = "high"
|
||||||
|
strix_llm_max_retries = "5"
|
||||||
|
strix_memory_compressor_timeout = "30"
|
||||||
llm_timeout = "300"
|
llm_timeout = "300"
|
||||||
llm_rate_limit_delay = "4.0"
|
|
||||||
llm_rate_limit_concurrent = "1"
|
|
||||||
|
|
||||||
# Tool & Feature Configuration
|
# Tool & Feature Configuration
|
||||||
perplexity_api_key = None
|
perplexity_api_key = None
|
||||||
@@ -27,7 +27,7 @@ class Config:
|
|||||||
# Runtime Configuration
|
# Runtime Configuration
|
||||||
strix_image = "ghcr.io/usestrix/strix-sandbox:0.1.10"
|
strix_image = "ghcr.io/usestrix/strix-sandbox:0.1.10"
|
||||||
strix_runtime_backend = "docker"
|
strix_runtime_backend = "docker"
|
||||||
strix_sandbox_execution_timeout = "500"
|
strix_sandbox_execution_timeout = "120"
|
||||||
strix_sandbox_connect_timeout = "10"
|
strix_sandbox_connect_timeout = "10"
|
||||||
|
|
||||||
# Telemetry
|
# Telemetry
|
||||||
|
|||||||
601
strix/llm/llm.py
601
strix/llm/llm.py
@@ -1,60 +1,29 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from jinja2 import (
|
from jinja2 import Environment, FileSystemLoader, select_autoescape
|
||||||
Environment,
|
from litellm import acompletion, completion_cost, stream_chunk_builder, supports_reasoning
|
||||||
FileSystemLoader,
|
|
||||||
select_autoescape,
|
|
||||||
)
|
|
||||||
from litellm import completion_cost, stream_chunk_builder, supports_reasoning
|
|
||||||
from litellm.utils import supports_prompt_caching, supports_vision
|
from litellm.utils import supports_prompt_caching, supports_vision
|
||||||
|
|
||||||
from strix.config import Config
|
from strix.config import Config
|
||||||
from strix.llm.config import LLMConfig
|
from strix.llm.config import LLMConfig
|
||||||
from strix.llm.memory_compressor import MemoryCompressor
|
from strix.llm.memory_compressor import MemoryCompressor
|
||||||
from strix.llm.request_queue import get_global_queue
|
from strix.llm.utils import (
|
||||||
from strix.llm.utils import _truncate_to_first_function, parse_tool_invocations
|
_truncate_to_first_function,
|
||||||
|
fix_incomplete_tool_call,
|
||||||
|
parse_tool_invocations,
|
||||||
|
)
|
||||||
from strix.skills import load_skills
|
from strix.skills import load_skills
|
||||||
from strix.tools import get_tools_prompt
|
from strix.tools import get_tools_prompt
|
||||||
|
from strix.utils.resource_paths import get_strix_resource_path
|
||||||
|
|
||||||
|
|
||||||
MAX_RETRIES = 5
|
|
||||||
RETRY_MULTIPLIER = 8
|
|
||||||
RETRY_MIN = 8
|
|
||||||
RETRY_MAX = 64
|
|
||||||
|
|
||||||
|
|
||||||
def _should_retry(exception: Exception) -> bool:
|
|
||||||
status_code = None
|
|
||||||
if hasattr(exception, "status_code"):
|
|
||||||
status_code = exception.status_code
|
|
||||||
elif hasattr(exception, "response") and hasattr(exception.response, "status_code"):
|
|
||||||
status_code = exception.response.status_code
|
|
||||||
if status_code is not None:
|
|
||||||
return bool(litellm._should_retry(status_code))
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
litellm.modify_params = True
|
litellm.modify_params = True
|
||||||
|
|
||||||
_LLM_API_KEY = Config.get("llm_api_key")
|
|
||||||
_LLM_API_BASE = (
|
|
||||||
Config.get("llm_api_base")
|
|
||||||
or Config.get("openai_api_base")
|
|
||||||
or Config.get("litellm_base_url")
|
|
||||||
or Config.get("ollama_api_base")
|
|
||||||
)
|
|
||||||
_STRIX_REASONING_EFFORT = Config.get("strix_reasoning_effort")
|
|
||||||
|
|
||||||
|
|
||||||
class LLMRequestFailedError(Exception):
|
class LLMRequestFailedError(Exception):
|
||||||
def __init__(self, message: str, details: str | None = None):
|
def __init__(self, message: str, details: str | None = None):
|
||||||
@@ -63,20 +32,11 @@ class LLMRequestFailedError(Exception):
|
|||||||
self.details = details
|
self.details = details
|
||||||
|
|
||||||
|
|
||||||
class StepRole(str, Enum):
|
|
||||||
AGENT = "agent"
|
|
||||||
USER = "user"
|
|
||||||
SYSTEM = "system"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LLMResponse:
|
class LLMResponse:
|
||||||
content: str
|
content: str
|
||||||
tool_invocations: list[dict[str, Any]] | None = None
|
tool_invocations: list[dict[str, Any]] | None = None
|
||||||
scan_id: str | None = None
|
thinking_blocks: list[dict[str, Any]] | None = None
|
||||||
step_number: int = 1
|
|
||||||
role: StepRole = StepRole.AGENT
|
|
||||||
thinking_blocks: list[dict[str, Any]] | None = None # For reasoning models.
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -84,76 +44,63 @@ class RequestStats:
|
|||||||
input_tokens: int = 0
|
input_tokens: int = 0
|
||||||
output_tokens: int = 0
|
output_tokens: int = 0
|
||||||
cached_tokens: int = 0
|
cached_tokens: int = 0
|
||||||
cache_creation_tokens: int = 0
|
|
||||||
cost: float = 0.0
|
cost: float = 0.0
|
||||||
requests: int = 0
|
requests: int = 0
|
||||||
failed_requests: int = 0
|
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, int | float]:
|
def to_dict(self) -> dict[str, int | float]:
|
||||||
return {
|
return {
|
||||||
"input_tokens": self.input_tokens,
|
"input_tokens": self.input_tokens,
|
||||||
"output_tokens": self.output_tokens,
|
"output_tokens": self.output_tokens,
|
||||||
"cached_tokens": self.cached_tokens,
|
"cached_tokens": self.cached_tokens,
|
||||||
"cache_creation_tokens": self.cache_creation_tokens,
|
|
||||||
"cost": round(self.cost, 4),
|
"cost": round(self.cost, 4),
|
||||||
"requests": self.requests,
|
"requests": self.requests,
|
||||||
"failed_requests": self.failed_requests,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class LLM:
|
class LLM:
|
||||||
def __init__(
|
def __init__(self, config: LLMConfig, agent_name: str | None = None):
|
||||||
self, config: LLMConfig, agent_name: str | None = None, agent_id: str | None = None
|
|
||||||
):
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.agent_name = agent_name
|
self.agent_name = agent_name
|
||||||
self.agent_id = agent_id
|
self.agent_id: str | None = None
|
||||||
self._total_stats = RequestStats()
|
self._total_stats = RequestStats()
|
||||||
self._last_request_stats = RequestStats()
|
self.memory_compressor = MemoryCompressor(model_name=config.model_name)
|
||||||
|
self.system_prompt = self._load_system_prompt(agent_name)
|
||||||
|
|
||||||
if _STRIX_REASONING_EFFORT:
|
reasoning = Config.get("strix_reasoning_effort")
|
||||||
self._reasoning_effort = _STRIX_REASONING_EFFORT
|
if reasoning:
|
||||||
elif self.config.scan_mode == "quick":
|
self._reasoning_effort = reasoning
|
||||||
|
elif config.scan_mode == "quick":
|
||||||
self._reasoning_effort = "medium"
|
self._reasoning_effort = "medium"
|
||||||
else:
|
else:
|
||||||
self._reasoning_effort = "high"
|
self._reasoning_effort = "high"
|
||||||
|
|
||||||
self.memory_compressor = MemoryCompressor(
|
def _load_system_prompt(self, agent_name: str | None) -> str:
|
||||||
model_name=self.config.model_name,
|
if not agent_name:
|
||||||
timeout=self.config.timeout,
|
return ""
|
||||||
)
|
|
||||||
|
|
||||||
if agent_name:
|
try:
|
||||||
prompt_dir = Path(__file__).parent.parent / "agents" / agent_name
|
prompt_dir = get_strix_resource_path("agents", agent_name)
|
||||||
skills_dir = Path(__file__).parent.parent / "skills"
|
skills_dir = get_strix_resource_path("skills")
|
||||||
|
env = Environment(
|
||||||
loader = FileSystemLoader([prompt_dir, skills_dir])
|
loader=FileSystemLoader([prompt_dir, skills_dir]),
|
||||||
self.jinja_env = Environment(
|
|
||||||
loader=loader,
|
|
||||||
autoescape=select_autoescape(enabled_extensions=(), default_for_string=False),
|
autoescape=select_autoescape(enabled_extensions=(), default_for_string=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
skills_to_load = [
|
||||||
skills_to_load = list(self.config.skills or [])
|
*list(self.config.skills or []),
|
||||||
skills_to_load.append(f"scan_modes/{self.config.scan_mode}")
|
f"scan_modes/{self.config.scan_mode}",
|
||||||
|
]
|
||||||
|
skill_content = load_skills(skills_to_load, env)
|
||||||
|
env.globals["get_skill"] = lambda name: skill_content.get(name, "")
|
||||||
|
|
||||||
skill_content = load_skills(skills_to_load, self.jinja_env)
|
result = env.get_template("system_prompt.jinja").render(
|
||||||
|
get_tools_prompt=get_tools_prompt,
|
||||||
def get_skill(name: str) -> str:
|
loaded_skill_names=list(skill_content.keys()),
|
||||||
return skill_content.get(name, "")
|
**skill_content,
|
||||||
|
)
|
||||||
self.jinja_env.globals["get_skill"] = get_skill
|
return str(result)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
self.system_prompt = self.jinja_env.get_template("system_prompt.jinja").render(
|
return ""
|
||||||
get_tools_prompt=get_tools_prompt,
|
|
||||||
loaded_skill_names=list(skill_content.keys()),
|
|
||||||
**skill_content,
|
|
||||||
)
|
|
||||||
except (FileNotFoundError, OSError, ValueError) as e:
|
|
||||||
logger.warning(f"Failed to load system prompt for {agent_name}: {e}")
|
|
||||||
self.system_prompt = "You are a helpful AI assistant."
|
|
||||||
else:
|
|
||||||
self.system_prompt = "You are a helpful AI assistant."
|
|
||||||
|
|
||||||
def set_agent_identity(self, agent_name: str | None, agent_id: str | None) -> None:
|
def set_agent_identity(self, agent_name: str | None, agent_id: str | None) -> None:
|
||||||
if agent_name:
|
if agent_name:
|
||||||
@@ -161,328 +108,121 @@ class LLM:
|
|||||||
if agent_id:
|
if agent_id:
|
||||||
self.agent_id = agent_id
|
self.agent_id = agent_id
|
||||||
|
|
||||||
def _build_identity_message(self) -> dict[str, Any] | None:
|
async def generate(
|
||||||
if not (self.agent_name and str(self.agent_name).strip()):
|
self, conversation_history: list[dict[str, Any]]
|
||||||
return None
|
) -> AsyncIterator[LLMResponse]:
|
||||||
identity_name = self.agent_name
|
messages = self._prepare_messages(conversation_history)
|
||||||
identity_id = self.agent_id
|
max_retries = int(Config.get("strix_llm_max_retries") or "5")
|
||||||
content = (
|
|
||||||
"\n\n"
|
|
||||||
"<agent_identity>\n"
|
|
||||||
"<meta>Internal metadata: do not echo or reference; "
|
|
||||||
"not part of history or tool calls.</meta>\n"
|
|
||||||
"<note>You are now assuming the role of this agent. "
|
|
||||||
"Act strictly as this agent and maintain self-identity for this step. "
|
|
||||||
"Now go answer the next needed step!</note>\n"
|
|
||||||
f"<agent_name>{identity_name}</agent_name>\n"
|
|
||||||
f"<agent_id>{identity_id}</agent_id>\n"
|
|
||||||
"</agent_identity>\n\n"
|
|
||||||
)
|
|
||||||
return {"role": "user", "content": content}
|
|
||||||
|
|
||||||
def _add_cache_control_to_content(
|
for attempt in range(max_retries + 1):
|
||||||
self, content: str | list[dict[str, Any]]
|
try:
|
||||||
) -> str | list[dict[str, Any]]:
|
async for response in self._stream(messages):
|
||||||
if isinstance(content, str):
|
yield response
|
||||||
return [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}]
|
return # noqa: TRY300
|
||||||
if isinstance(content, list) and content:
|
except Exception as e: # noqa: BLE001
|
||||||
last_item = content[-1]
|
if attempt >= max_retries or not self._should_retry(e):
|
||||||
if isinstance(last_item, dict) and last_item.get("type") == "text":
|
self._raise_error(e)
|
||||||
return content[:-1] + [{**last_item, "cache_control": {"type": "ephemeral"}}]
|
wait = min(10, 2 * (2**attempt))
|
||||||
return content
|
await asyncio.sleep(wait)
|
||||||
|
|
||||||
def _is_anthropic_model(self) -> bool:
|
async def _stream(self, messages: list[dict[str, Any]]) -> AsyncIterator[LLMResponse]:
|
||||||
if not self.config.model_name:
|
accumulated = ""
|
||||||
return False
|
chunks: list[Any] = []
|
||||||
model_lower = self.config.model_name.lower()
|
|
||||||
return any(provider in model_lower for provider in ["anthropic/", "claude"])
|
|
||||||
|
|
||||||
def _calculate_cache_interval(self, total_messages: int) -> int:
|
self._total_stats.requests += 1
|
||||||
if total_messages <= 1:
|
response = await acompletion(**self._build_completion_args(messages), stream=True)
|
||||||
return 10
|
|
||||||
|
|
||||||
max_cached_messages = 3
|
async for chunk in response:
|
||||||
non_system_messages = total_messages - 1
|
chunks.append(chunk)
|
||||||
|
delta = self._get_chunk_content(chunk)
|
||||||
interval = 10
|
if delta:
|
||||||
while non_system_messages // interval > max_cached_messages:
|
accumulated += delta
|
||||||
interval += 10
|
if "</function>" in accumulated:
|
||||||
|
accumulated = accumulated[
|
||||||
return interval
|
: accumulated.find("</function>") + len("</function>")
|
||||||
|
]
|
||||||
def _prepare_cached_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
yield LLMResponse(content=accumulated)
|
||||||
if (
|
|
||||||
not self.config.enable_prompt_caching
|
|
||||||
or not supports_prompt_caching(self.config.model_name)
|
|
||||||
or not messages
|
|
||||||
):
|
|
||||||
return messages
|
|
||||||
|
|
||||||
if not self._is_anthropic_model():
|
|
||||||
return messages
|
|
||||||
|
|
||||||
cached_messages = list(messages)
|
|
||||||
|
|
||||||
if cached_messages and cached_messages[0].get("role") == "system":
|
|
||||||
system_message = cached_messages[0].copy()
|
|
||||||
system_message["content"] = self._add_cache_control_to_content(
|
|
||||||
system_message["content"]
|
|
||||||
)
|
|
||||||
cached_messages[0] = system_message
|
|
||||||
|
|
||||||
total_messages = len(cached_messages)
|
|
||||||
if total_messages > 1:
|
|
||||||
interval = self._calculate_cache_interval(total_messages)
|
|
||||||
|
|
||||||
cached_count = 0
|
|
||||||
for i in range(interval, total_messages, interval):
|
|
||||||
if cached_count >= 3:
|
|
||||||
break
|
break
|
||||||
|
yield LLMResponse(content=accumulated)
|
||||||
|
|
||||||
if i < len(cached_messages):
|
if chunks:
|
||||||
message = cached_messages[i].copy()
|
self._update_usage_stats(stream_chunk_builder(chunks))
|
||||||
message["content"] = self._add_cache_control_to_content(message["content"])
|
|
||||||
cached_messages[i] = message
|
|
||||||
cached_count += 1
|
|
||||||
|
|
||||||
return cached_messages
|
accumulated = fix_incomplete_tool_call(_truncate_to_first_function(accumulated))
|
||||||
|
yield LLMResponse(
|
||||||
|
content=accumulated,
|
||||||
|
tool_invocations=parse_tool_invocations(accumulated),
|
||||||
|
thinking_blocks=self._extract_thinking(chunks),
|
||||||
|
)
|
||||||
|
|
||||||
def _prepare_messages(self, conversation_history: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
def _prepare_messages(self, conversation_history: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
messages = [{"role": "system", "content": self.system_prompt}]
|
messages = [{"role": "system", "content": self.system_prompt}]
|
||||||
|
|
||||||
identity_message = self._build_identity_message()
|
if self.agent_name:
|
||||||
if identity_message:
|
messages.append(
|
||||||
messages.append(identity_message)
|
{
|
||||||
|
"role": "user",
|
||||||
compressed_history = list(self.memory_compressor.compress_history(conversation_history))
|
"content": (
|
||||||
|
f"\n\n<agent_identity>\n"
|
||||||
|
f"<meta>Internal metadata: do not echo or reference.</meta>\n"
|
||||||
|
f"<agent_name>{self.agent_name}</agent_name>\n"
|
||||||
|
f"<agent_id>{self.agent_id}</agent_id>\n"
|
||||||
|
f"</agent_identity>\n\n"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
compressed = list(self.memory_compressor.compress_history(conversation_history))
|
||||||
conversation_history.clear()
|
conversation_history.clear()
|
||||||
conversation_history.extend(compressed_history)
|
conversation_history.extend(compressed)
|
||||||
messages.extend(compressed_history)
|
messages.extend(compressed)
|
||||||
|
|
||||||
return self._prepare_cached_messages(messages)
|
if self._is_anthropic() and self.config.enable_prompt_caching:
|
||||||
|
messages = self._add_cache_control(messages)
|
||||||
|
|
||||||
async def _stream_and_accumulate(
|
return messages
|
||||||
self,
|
|
||||||
messages: list[dict[str, Any]],
|
|
||||||
scan_id: str | None,
|
|
||||||
step_number: int,
|
|
||||||
) -> AsyncIterator[LLMResponse]:
|
|
||||||
accumulated_content = ""
|
|
||||||
chunks: list[Any] = []
|
|
||||||
|
|
||||||
async for chunk in self._stream_request(messages):
|
def _build_completion_args(self, messages: list[dict[str, Any]]) -> dict[str, Any]:
|
||||||
chunks.append(chunk)
|
if not self._supports_vision():
|
||||||
delta = self._extract_chunk_delta(chunk)
|
messages = self._strip_images(messages)
|
||||||
if delta:
|
|
||||||
accumulated_content += delta
|
|
||||||
|
|
||||||
if "</function>" in accumulated_content:
|
args: dict[str, Any] = {
|
||||||
function_end = accumulated_content.find("</function>") + len("</function>")
|
|
||||||
accumulated_content = accumulated_content[:function_end]
|
|
||||||
|
|
||||||
yield LLMResponse(
|
|
||||||
scan_id=scan_id,
|
|
||||||
step_number=step_number,
|
|
||||||
role=StepRole.AGENT,
|
|
||||||
content=accumulated_content,
|
|
||||||
tool_invocations=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if chunks:
|
|
||||||
complete_response = stream_chunk_builder(chunks)
|
|
||||||
self._update_usage_stats(complete_response)
|
|
||||||
|
|
||||||
accumulated_content = _truncate_to_first_function(accumulated_content)
|
|
||||||
if "</function>" in accumulated_content:
|
|
||||||
function_end = accumulated_content.find("</function>") + len("</function>")
|
|
||||||
accumulated_content = accumulated_content[:function_end]
|
|
||||||
|
|
||||||
tool_invocations = parse_tool_invocations(accumulated_content)
|
|
||||||
|
|
||||||
# Extract thinking blocks from the complete response if available
|
|
||||||
thinking_blocks = None
|
|
||||||
if chunks and self._should_include_reasoning_effort():
|
|
||||||
complete_response = stream_chunk_builder(chunks)
|
|
||||||
if (
|
|
||||||
hasattr(complete_response, "choices")
|
|
||||||
and complete_response.choices
|
|
||||||
and hasattr(complete_response.choices[0], "message")
|
|
||||||
):
|
|
||||||
message = complete_response.choices[0].message
|
|
||||||
if hasattr(message, "thinking_blocks") and message.thinking_blocks:
|
|
||||||
thinking_blocks = message.thinking_blocks
|
|
||||||
|
|
||||||
yield LLMResponse(
|
|
||||||
scan_id=scan_id,
|
|
||||||
step_number=step_number,
|
|
||||||
role=StepRole.AGENT,
|
|
||||||
content=accumulated_content,
|
|
||||||
tool_invocations=tool_invocations if tool_invocations else None,
|
|
||||||
thinking_blocks=thinking_blocks,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _raise_llm_error(self, e: Exception) -> None:
|
|
||||||
error_map: list[tuple[type, str]] = [
|
|
||||||
(litellm.RateLimitError, "Rate limit exceeded"),
|
|
||||||
(litellm.AuthenticationError, "Invalid API key"),
|
|
||||||
(litellm.NotFoundError, "Model not found"),
|
|
||||||
(litellm.ContextWindowExceededError, "Context too long"),
|
|
||||||
(litellm.ContentPolicyViolationError, "Content policy violation"),
|
|
||||||
(litellm.ServiceUnavailableError, "Service unavailable"),
|
|
||||||
(litellm.Timeout, "Request timed out"),
|
|
||||||
(litellm.UnprocessableEntityError, "Unprocessable entity"),
|
|
||||||
(litellm.InternalServerError, "Internal server error"),
|
|
||||||
(litellm.APIConnectionError, "Connection error"),
|
|
||||||
(litellm.UnsupportedParamsError, "Unsupported parameters"),
|
|
||||||
(litellm.BudgetExceededError, "Budget exceeded"),
|
|
||||||
(litellm.APIResponseValidationError, "Response validation error"),
|
|
||||||
(litellm.JSONSchemaValidationError, "JSON schema validation error"),
|
|
||||||
(litellm.InvalidRequestError, "Invalid request"),
|
|
||||||
(litellm.BadRequestError, "Bad request"),
|
|
||||||
(litellm.APIError, "API error"),
|
|
||||||
(litellm.OpenAIError, "OpenAI error"),
|
|
||||||
]
|
|
||||||
|
|
||||||
from strix.telemetry import posthog
|
|
||||||
|
|
||||||
for error_type, message in error_map:
|
|
||||||
if isinstance(e, error_type):
|
|
||||||
posthog.error(f"llm_{error_type.__name__}", message)
|
|
||||||
raise LLMRequestFailedError(f"LLM request failed: {message}", str(e)) from e
|
|
||||||
|
|
||||||
posthog.error("llm_unknown_error", type(e).__name__)
|
|
||||||
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
|
|
||||||
|
|
||||||
async def generate(
|
|
||||||
self,
|
|
||||||
conversation_history: list[dict[str, Any]],
|
|
||||||
scan_id: str | None = None,
|
|
||||||
step_number: int = 1,
|
|
||||||
) -> AsyncIterator[LLMResponse]:
|
|
||||||
messages = self._prepare_messages(conversation_history)
|
|
||||||
|
|
||||||
last_error: Exception | None = None
|
|
||||||
for attempt in range(MAX_RETRIES):
|
|
||||||
try:
|
|
||||||
async for response in self._stream_and_accumulate(messages, scan_id, step_number):
|
|
||||||
yield response
|
|
||||||
return # noqa: TRY300
|
|
||||||
except Exception as e: # noqa: BLE001
|
|
||||||
last_error = e
|
|
||||||
if not _should_retry(e) or attempt == MAX_RETRIES - 1:
|
|
||||||
break
|
|
||||||
wait_time = min(RETRY_MAX, RETRY_MULTIPLIER * (2**attempt))
|
|
||||||
wait_time = max(RETRY_MIN, wait_time)
|
|
||||||
await asyncio.sleep(wait_time)
|
|
||||||
|
|
||||||
if last_error:
|
|
||||||
self._raise_llm_error(last_error)
|
|
||||||
|
|
||||||
def _extract_chunk_delta(self, chunk: Any) -> str:
|
|
||||||
if chunk.choices and hasattr(chunk.choices[0], "delta"):
|
|
||||||
delta = chunk.choices[0].delta
|
|
||||||
return getattr(delta, "content", "") or ""
|
|
||||||
return ""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def usage_stats(self) -> dict[str, dict[str, int | float]]:
|
|
||||||
return {
|
|
||||||
"total": self._total_stats.to_dict(),
|
|
||||||
"last_request": self._last_request_stats.to_dict(),
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_cache_config(self) -> dict[str, bool]:
|
|
||||||
return {
|
|
||||||
"enabled": self.config.enable_prompt_caching,
|
|
||||||
"supported": supports_prompt_caching(self.config.model_name),
|
|
||||||
}
|
|
||||||
|
|
||||||
def _should_include_reasoning_effort(self) -> bool:
|
|
||||||
if not self.config.model_name:
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
return bool(supports_reasoning(model=self.config.model_name))
|
|
||||||
except Exception: # noqa: BLE001
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _model_supports_vision(self) -> bool:
|
|
||||||
if not self.config.model_name:
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
return bool(supports_vision(model=self.config.model_name))
|
|
||||||
except Exception: # noqa: BLE001
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _filter_images_from_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
||||||
filtered_messages = []
|
|
||||||
for msg in messages:
|
|
||||||
content = msg.get("content")
|
|
||||||
updated_msg = msg
|
|
||||||
if isinstance(content, list):
|
|
||||||
filtered_content = []
|
|
||||||
for item in content:
|
|
||||||
if isinstance(item, dict):
|
|
||||||
if item.get("type") == "image_url":
|
|
||||||
filtered_content.append(
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "[Screenshot removed - model does not support "
|
|
||||||
"vision. Use view_source or execute_js instead.]",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
filtered_content.append(item)
|
|
||||||
else:
|
|
||||||
filtered_content.append(item)
|
|
||||||
if filtered_content:
|
|
||||||
text_parts = [
|
|
||||||
item.get("text", "") if isinstance(item, dict) else str(item)
|
|
||||||
for item in filtered_content
|
|
||||||
]
|
|
||||||
all_text = all(
|
|
||||||
isinstance(item, dict) and item.get("type") == "text"
|
|
||||||
for item in filtered_content
|
|
||||||
)
|
|
||||||
if all_text:
|
|
||||||
updated_msg = {**msg, "content": "\n".join(text_parts)}
|
|
||||||
else:
|
|
||||||
updated_msg = {**msg, "content": filtered_content}
|
|
||||||
else:
|
|
||||||
updated_msg = {**msg, "content": ""}
|
|
||||||
filtered_messages.append(updated_msg)
|
|
||||||
return filtered_messages
|
|
||||||
|
|
||||||
async def _stream_request(
|
|
||||||
self,
|
|
||||||
messages: list[dict[str, Any]],
|
|
||||||
) -> AsyncIterator[Any]:
|
|
||||||
if not self._model_supports_vision():
|
|
||||||
messages = self._filter_images_from_messages(messages)
|
|
||||||
|
|
||||||
completion_args: dict[str, Any] = {
|
|
||||||
"model": self.config.model_name,
|
"model": self.config.model_name,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"timeout": self.config.timeout,
|
"timeout": self.config.timeout,
|
||||||
"stream_options": {"include_usage": True},
|
"stream_options": {"include_usage": True},
|
||||||
}
|
}
|
||||||
|
|
||||||
if _LLM_API_KEY:
|
if api_key := Config.get("llm_api_key"):
|
||||||
completion_args["api_key"] = _LLM_API_KEY
|
args["api_key"] = api_key
|
||||||
if _LLM_API_BASE:
|
if api_base := (
|
||||||
completion_args["api_base"] = _LLM_API_BASE
|
Config.get("llm_api_base")
|
||||||
|
or Config.get("openai_api_base")
|
||||||
|
or Config.get("litellm_base_url")
|
||||||
|
or Config.get("ollama_api_base")
|
||||||
|
):
|
||||||
|
args["api_base"] = api_base
|
||||||
|
if self._supports_reasoning():
|
||||||
|
args["reasoning_effort"] = self._reasoning_effort
|
||||||
|
|
||||||
completion_args["stop"] = ["</function>"]
|
return args
|
||||||
|
|
||||||
if self._should_include_reasoning_effort():
|
def _get_chunk_content(self, chunk: Any) -> str:
|
||||||
completion_args["reasoning_effort"] = self._reasoning_effort
|
if chunk.choices and hasattr(chunk.choices[0], "delta"):
|
||||||
|
return getattr(chunk.choices[0].delta, "content", "") or ""
|
||||||
|
return ""
|
||||||
|
|
||||||
queue = get_global_queue()
|
def _extract_thinking(self, chunks: list[Any]) -> list[dict[str, Any]] | None:
|
||||||
self._total_stats.requests += 1
|
if not chunks or not self._supports_reasoning():
|
||||||
self._last_request_stats = RequestStats(requests=1)
|
return None
|
||||||
|
try:
|
||||||
async for chunk in queue.stream_request(completion_args):
|
resp = stream_chunk_builder(chunks)
|
||||||
yield chunk
|
if resp.choices and hasattr(resp.choices[0].message, "thinking_blocks"):
|
||||||
|
blocks: list[dict[str, Any]] = resp.choices[0].message.thinking_blocks
|
||||||
|
return blocks
|
||||||
|
except Exception: # noqa: BLE001, S110 # nosec B110
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
def _update_usage_stats(self, response: Any) -> None:
|
def _update_usage_stats(self, response: Any) -> None:
|
||||||
try:
|
try:
|
||||||
@@ -491,45 +231,88 @@ class LLM:
|
|||||||
output_tokens = getattr(response.usage, "completion_tokens", 0)
|
output_tokens = getattr(response.usage, "completion_tokens", 0)
|
||||||
|
|
||||||
cached_tokens = 0
|
cached_tokens = 0
|
||||||
cache_creation_tokens = 0
|
|
||||||
|
|
||||||
if hasattr(response.usage, "prompt_tokens_details"):
|
if hasattr(response.usage, "prompt_tokens_details"):
|
||||||
prompt_details = response.usage.prompt_tokens_details
|
prompt_details = response.usage.prompt_tokens_details
|
||||||
if hasattr(prompt_details, "cached_tokens"):
|
if hasattr(prompt_details, "cached_tokens"):
|
||||||
cached_tokens = prompt_details.cached_tokens or 0
|
cached_tokens = prompt_details.cached_tokens or 0
|
||||||
|
|
||||||
if hasattr(response.usage, "cache_creation_input_tokens"):
|
|
||||||
cache_creation_tokens = response.usage.cache_creation_input_tokens or 0
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
input_tokens = 0
|
input_tokens = 0
|
||||||
output_tokens = 0
|
output_tokens = 0
|
||||||
cached_tokens = 0
|
cached_tokens = 0
|
||||||
cache_creation_tokens = 0
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cost = completion_cost(response) or 0.0
|
cost = completion_cost(response) or 0.0
|
||||||
except Exception as e: # noqa: BLE001
|
except Exception: # noqa: BLE001
|
||||||
logger.warning(f"Failed to calculate cost: {e}")
|
|
||||||
cost = 0.0
|
cost = 0.0
|
||||||
|
|
||||||
self._total_stats.input_tokens += input_tokens
|
self._total_stats.input_tokens += input_tokens
|
||||||
self._total_stats.output_tokens += output_tokens
|
self._total_stats.output_tokens += output_tokens
|
||||||
self._total_stats.cached_tokens += cached_tokens
|
self._total_stats.cached_tokens += cached_tokens
|
||||||
self._total_stats.cache_creation_tokens += cache_creation_tokens
|
|
||||||
self._total_stats.cost += cost
|
self._total_stats.cost += cost
|
||||||
|
|
||||||
self._last_request_stats.input_tokens = input_tokens
|
except Exception: # noqa: BLE001, S110 # nosec B110
|
||||||
self._last_request_stats.output_tokens = output_tokens
|
pass
|
||||||
self._last_request_stats.cached_tokens = cached_tokens
|
|
||||||
self._last_request_stats.cache_creation_tokens = cache_creation_tokens
|
|
||||||
self._last_request_stats.cost = cost
|
|
||||||
|
|
||||||
if cached_tokens > 0:
|
def _should_retry(self, e: Exception) -> bool:
|
||||||
logger.info(f"Cache hit: {cached_tokens} cached tokens, {input_tokens} new tokens")
|
code = getattr(e, "status_code", None) or getattr(
|
||||||
if cache_creation_tokens > 0:
|
getattr(e, "response", None), "status_code", None
|
||||||
logger.info(f"Cache creation: {cache_creation_tokens} tokens written to cache")
|
)
|
||||||
|
return code is None or litellm._should_retry(code)
|
||||||
|
|
||||||
logger.info(f"Usage stats: {self.usage_stats}")
|
def _raise_error(self, e: Exception) -> None:
|
||||||
except Exception as e: # noqa: BLE001
|
from strix.telemetry import posthog
|
||||||
logger.warning(f"Failed to update usage stats: {e}")
|
|
||||||
|
posthog.error("llm_error", type(e).__name__)
|
||||||
|
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
|
||||||
|
|
||||||
|
def _is_anthropic(self) -> bool:
|
||||||
|
if not self.config.model_name:
|
||||||
|
return False
|
||||||
|
return any(p in self.config.model_name.lower() for p in ["anthropic/", "claude"])
|
||||||
|
|
||||||
|
def _supports_vision(self) -> bool:
|
||||||
|
try:
|
||||||
|
return bool(supports_vision(model=self.config.model_name))
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _supports_reasoning(self) -> bool:
|
||||||
|
try:
|
||||||
|
return bool(supports_reasoning(model=self.config.model_name))
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _strip_images(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
result = []
|
||||||
|
for msg in messages:
|
||||||
|
content = msg.get("content")
|
||||||
|
if isinstance(content, list):
|
||||||
|
text_parts = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, dict) and item.get("type") == "text":
|
||||||
|
text_parts.append(item.get("text", ""))
|
||||||
|
elif isinstance(item, dict) and item.get("type") == "image_url":
|
||||||
|
text_parts.append("[Image removed - model doesn't support vision]")
|
||||||
|
result.append({**msg, "content": "\n".join(text_parts)})
|
||||||
|
else:
|
||||||
|
result.append(msg)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
if not messages or not supports_prompt_caching(self.config.model_name):
|
||||||
|
return messages
|
||||||
|
|
||||||
|
result = list(messages)
|
||||||
|
|
||||||
|
if result[0].get("role") == "system":
|
||||||
|
content = result[0]["content"]
|
||||||
|
result[0] = {
|
||||||
|
**result[0],
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}
|
||||||
|
]
|
||||||
|
if isinstance(content, str)
|
||||||
|
else content,
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ def _extract_message_text(msg: dict[str, Any]) -> str:
|
|||||||
def _summarize_messages(
|
def _summarize_messages(
|
||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
model: str,
|
model: str,
|
||||||
timeout: int = 600,
|
timeout: int = 30,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
if not messages:
|
if not messages:
|
||||||
empty_summary = "<context_summary message_count='0'>{text}</context_summary>"
|
empty_summary = "<context_summary message_count='0'>{text}</context_summary>"
|
||||||
@@ -148,11 +148,11 @@ class MemoryCompressor:
|
|||||||
self,
|
self,
|
||||||
max_images: int = 3,
|
max_images: int = 3,
|
||||||
model_name: str | None = None,
|
model_name: str | None = None,
|
||||||
timeout: int = 600,
|
timeout: int | None = None,
|
||||||
):
|
):
|
||||||
self.max_images = max_images
|
self.max_images = max_images
|
||||||
self.model_name = model_name or Config.get("strix_llm")
|
self.model_name = model_name or Config.get("strix_llm")
|
||||||
self.timeout = timeout
|
self.timeout = timeout or int(Config.get("strix_memory_compressor_timeout") or "30")
|
||||||
|
|
||||||
if not self.model_name:
|
if not self.model_name:
|
||||||
raise ValueError("STRIX_LLM environment variable must be set and not empty")
|
raise ValueError("STRIX_LLM environment variable must be set and not empty")
|
||||||
|
|||||||
@@ -1,58 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from collections.abc import AsyncIterator
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from litellm import acompletion
|
|
||||||
from litellm.types.utils import ModelResponseStream
|
|
||||||
|
|
||||||
from strix.config import Config
|
|
||||||
|
|
||||||
|
|
||||||
class LLMRequestQueue:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.delay_between_requests = float(Config.get("llm_rate_limit_delay") or "4.0")
|
|
||||||
self.max_concurrent = int(Config.get("llm_rate_limit_concurrent") or "1")
|
|
||||||
self._semaphore = threading.BoundedSemaphore(self.max_concurrent)
|
|
||||||
self._last_request_time = 0.0
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
|
|
||||||
async def stream_request(
|
|
||||||
self, completion_args: dict[str, Any]
|
|
||||||
) -> AsyncIterator[ModelResponseStream]:
|
|
||||||
try:
|
|
||||||
while not self._semaphore.acquire(timeout=0.2):
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
|
||||||
with self._lock:
|
|
||||||
now = time.time()
|
|
||||||
time_since_last = now - self._last_request_time
|
|
||||||
sleep_needed = max(0, self.delay_between_requests - time_since_last)
|
|
||||||
self._last_request_time = now + sleep_needed
|
|
||||||
|
|
||||||
if sleep_needed > 0:
|
|
||||||
await asyncio.sleep(sleep_needed)
|
|
||||||
|
|
||||||
async for chunk in self._stream_request(completion_args):
|
|
||||||
yield chunk
|
|
||||||
finally:
|
|
||||||
self._semaphore.release()
|
|
||||||
|
|
||||||
async def _stream_request(
|
|
||||||
self, completion_args: dict[str, Any]
|
|
||||||
) -> AsyncIterator[ModelResponseStream]:
|
|
||||||
response = await acompletion(**completion_args, stream=True)
|
|
||||||
|
|
||||||
async for chunk in response:
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
|
|
||||||
_global_queue: LLMRequestQueue | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_global_queue() -> LLMRequestQueue:
|
|
||||||
global _global_queue # noqa: PLW0603
|
|
||||||
if _global_queue is None:
|
|
||||||
_global_queue = LLMRequestQueue()
|
|
||||||
return _global_queue
|
|
||||||
@@ -18,7 +18,7 @@ def _truncate_to_first_function(content: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def parse_tool_invocations(content: str) -> list[dict[str, Any]] | None:
|
def parse_tool_invocations(content: str) -> list[dict[str, Any]] | None:
|
||||||
content = _fix_stopword(content)
|
content = fix_incomplete_tool_call(content)
|
||||||
|
|
||||||
tool_invocations: list[dict[str, Any]] = []
|
tool_invocations: list[dict[str, Any]] = []
|
||||||
|
|
||||||
@@ -46,16 +46,15 @@ def parse_tool_invocations(content: str) -> list[dict[str, Any]] | None:
|
|||||||
return tool_invocations if tool_invocations else None
|
return tool_invocations if tool_invocations else None
|
||||||
|
|
||||||
|
|
||||||
def _fix_stopword(content: str) -> str:
|
def fix_incomplete_tool_call(content: str) -> str:
|
||||||
|
"""Fix incomplete tool calls by adding missing </function> tag."""
|
||||||
if (
|
if (
|
||||||
"<function=" in content
|
"<function=" in content
|
||||||
and content.count("<function=") == 1
|
and content.count("<function=") == 1
|
||||||
and "</function>" not in content
|
and "</function>" not in content
|
||||||
):
|
):
|
||||||
if content.endswith("</"):
|
content = content.rstrip()
|
||||||
content = content.rstrip() + "function>"
|
content = content + "function>" if content.endswith("</") else content + "\n</function>"
|
||||||
else:
|
|
||||||
content = content + "\n</function>"
|
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
@@ -74,7 +73,7 @@ def clean_content(content: str) -> str:
|
|||||||
if not content:
|
if not content:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
content = _fix_stopword(content)
|
content = fix_incomplete_tool_call(content)
|
||||||
|
|
||||||
tool_pattern = r"<function=[^>]+>.*?</function>"
|
tool_pattern = r"<function=[^>]+>.*?</function>"
|
||||||
cleaned = re.sub(tool_pattern, "", content, flags=re.DOTALL)
|
cleaned = re.sub(tool_pattern, "", content, flags=re.DOTALL)
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from jinja2 import Environment
|
from jinja2 import Environment
|
||||||
|
|
||||||
|
from strix.utils.resource_paths import get_strix_resource_path
|
||||||
|
|
||||||
|
|
||||||
def get_available_skills() -> dict[str, list[str]]:
|
def get_available_skills() -> dict[str, list[str]]:
|
||||||
skills_dir = Path(__file__).parent
|
skills_dir = get_strix_resource_path("skills")
|
||||||
available_skills = {}
|
available_skills: dict[str, list[str]] = {}
|
||||||
|
|
||||||
|
if not skills_dir.exists():
|
||||||
|
return available_skills
|
||||||
|
|
||||||
for category_dir in skills_dir.iterdir():
|
for category_dir in skills_dir.iterdir():
|
||||||
if category_dir.is_dir() and not category_dir.name.startswith("__"):
|
if category_dir.is_dir() and not category_dir.name.startswith("__"):
|
||||||
@@ -72,7 +75,7 @@ def load_skills(skill_names: list[str], jinja_env: Environment) -> dict[str, str
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
skill_content = {}
|
skill_content = {}
|
||||||
skills_dir = Path(__file__).parent
|
skills_dir = get_strix_resource_path("skills")
|
||||||
|
|
||||||
available_skills = get_available_skills()
|
available_skills = get_available_skills()
|
||||||
|
|
||||||
|
|||||||
@@ -430,10 +430,8 @@ class Tracer:
|
|||||||
"input_tokens": 0,
|
"input_tokens": 0,
|
||||||
"output_tokens": 0,
|
"output_tokens": 0,
|
||||||
"cached_tokens": 0,
|
"cached_tokens": 0,
|
||||||
"cache_creation_tokens": 0,
|
|
||||||
"cost": 0.0,
|
"cost": 0.0,
|
||||||
"requests": 0,
|
"requests": 0,
|
||||||
"failed_requests": 0,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for agent_instance in _agent_instances.values():
|
for agent_instance in _agent_instances.values():
|
||||||
@@ -442,10 +440,8 @@ class Tracer:
|
|||||||
total_stats["input_tokens"] += agent_stats.input_tokens
|
total_stats["input_tokens"] += agent_stats.input_tokens
|
||||||
total_stats["output_tokens"] += agent_stats.output_tokens
|
total_stats["output_tokens"] += agent_stats.output_tokens
|
||||||
total_stats["cached_tokens"] += agent_stats.cached_tokens
|
total_stats["cached_tokens"] += agent_stats.cached_tokens
|
||||||
total_stats["cache_creation_tokens"] += agent_stats.cache_creation_tokens
|
|
||||||
total_stats["cost"] += agent_stats.cost
|
total_stats["cost"] += agent_stats.cost
|
||||||
total_stats["requests"] += agent_stats.requests
|
total_stats["requests"] += agent_stats.requests
|
||||||
total_stats["failed_requests"] += agent_stats.failed_requests
|
|
||||||
|
|
||||||
total_stats["cost"] = round(total_stats["cost"], 4)
|
total_stats["cost"] = round(total_stats["cost"], 4)
|
||||||
|
|
||||||
|
|||||||
@@ -14,12 +14,13 @@ from .argument_parser import convert_arguments
|
|||||||
from .registry import (
|
from .registry import (
|
||||||
get_tool_by_name,
|
get_tool_by_name,
|
||||||
get_tool_names,
|
get_tool_names,
|
||||||
|
get_tool_param_schema,
|
||||||
needs_agent_state,
|
needs_agent_state,
|
||||||
should_execute_in_sandbox,
|
should_execute_in_sandbox,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
SANDBOX_EXECUTION_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout") or "500")
|
SANDBOX_EXECUTION_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout") or "120")
|
||||||
SANDBOX_CONNECT_TIMEOUT = float(Config.get("strix_sandbox_connect_timeout") or "10")
|
SANDBOX_CONNECT_TIMEOUT = float(Config.get("strix_sandbox_connect_timeout") or "10")
|
||||||
|
|
||||||
|
|
||||||
@@ -110,14 +111,51 @@ async def _execute_tool_locally(tool_name: str, agent_state: Any | None, **kwarg
|
|||||||
|
|
||||||
def validate_tool_availability(tool_name: str | None) -> tuple[bool, str]:
|
def validate_tool_availability(tool_name: str | None) -> tuple[bool, str]:
|
||||||
if tool_name is None:
|
if tool_name is None:
|
||||||
return False, "Tool name is missing"
|
available = ", ".join(sorted(get_tool_names()))
|
||||||
|
return False, f"Tool name is missing. Available tools: {available}"
|
||||||
|
|
||||||
if tool_name not in get_tool_names():
|
if tool_name not in get_tool_names():
|
||||||
return False, f"Tool '{tool_name}' is not available"
|
available = ", ".join(sorted(get_tool_names()))
|
||||||
|
return False, f"Tool '{tool_name}' is not available. Available tools: {available}"
|
||||||
|
|
||||||
return True, ""
|
return True, ""
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_tool_arguments(tool_name: str, kwargs: dict[str, Any]) -> str | None:
|
||||||
|
param_schema = get_tool_param_schema(tool_name)
|
||||||
|
if not param_schema or not param_schema.get("has_params"):
|
||||||
|
return None
|
||||||
|
|
||||||
|
allowed_params: set[str] = param_schema.get("params", set())
|
||||||
|
required_params: set[str] = param_schema.get("required", set())
|
||||||
|
optional_params = allowed_params - required_params
|
||||||
|
|
||||||
|
schema_hint = _format_schema_hint(tool_name, required_params, optional_params)
|
||||||
|
|
||||||
|
unknown_params = set(kwargs.keys()) - allowed_params
|
||||||
|
if unknown_params:
|
||||||
|
unknown_list = ", ".join(sorted(unknown_params))
|
||||||
|
return f"Tool '{tool_name}' received unknown parameter(s): {unknown_list}\n{schema_hint}"
|
||||||
|
|
||||||
|
missing_required = [
|
||||||
|
param for param in required_params if param not in kwargs or kwargs.get(param) in (None, "")
|
||||||
|
]
|
||||||
|
if missing_required:
|
||||||
|
missing_list = ", ".join(sorted(missing_required))
|
||||||
|
return f"Tool '{tool_name}' missing required parameter(s): {missing_list}\n{schema_hint}"
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _format_schema_hint(tool_name: str, required: set[str], optional: set[str]) -> str:
|
||||||
|
parts = [f"Valid parameters for '{tool_name}':"]
|
||||||
|
if required:
|
||||||
|
parts.append(f" Required: {', '.join(sorted(required))}")
|
||||||
|
if optional:
|
||||||
|
parts.append(f" Optional: {', '.join(sorted(optional))}")
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
async def execute_tool_with_validation(
|
async def execute_tool_with_validation(
|
||||||
tool_name: str | None, agent_state: Any | None = None, **kwargs: Any
|
tool_name: str | None, agent_state: Any | None = None, **kwargs: Any
|
||||||
) -> Any:
|
) -> Any:
|
||||||
@@ -127,6 +165,10 @@ async def execute_tool_with_validation(
|
|||||||
|
|
||||||
assert tool_name is not None
|
assert tool_name is not None
|
||||||
|
|
||||||
|
arg_error = _validate_tool_arguments(tool_name, kwargs)
|
||||||
|
if arg_error:
|
||||||
|
return f"Error: {arg_error}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await execute_tool(tool_name, agent_state, **kwargs)
|
result = await execute_tool(tool_name, agent_state, **kwargs)
|
||||||
except Exception as e: # noqa: BLE001
|
except Exception as e: # noqa: BLE001
|
||||||
|
|||||||
@@ -55,6 +55,7 @@
|
|||||||
- Print statements and stdout are captured
|
- Print statements and stdout are captured
|
||||||
- Variables persist between executions in the same session
|
- Variables persist between executions in the same session
|
||||||
- Imports, function definitions, etc. persist in the session
|
- Imports, function definitions, etc. persist in the session
|
||||||
|
- IMPORTANT (multiline): Put real line breaks in <parameter=code>. Do NOT emit literal "\n" sequences.
|
||||||
- IPython magic commands are fully supported (%pip, %time, %whos, %%writefile, etc.)
|
- IPython magic commands are fully supported (%pip, %time, %whos, %%writefile, etc.)
|
||||||
- Line magics (%) and cell magics (%%) work as expected
|
- Line magics (%) and cell magics (%%) work as expected
|
||||||
6. CLOSE: Terminates the session completely and frees memory
|
6. CLOSE: Terminates the session completely and frees memory
|
||||||
@@ -73,6 +74,14 @@
|
|||||||
print("Security analysis session started")</parameter>
|
print("Security analysis session started")</parameter>
|
||||||
</function>
|
</function>
|
||||||
|
|
||||||
|
<function=python_action>
|
||||||
|
<parameter=action>execute</parameter>
|
||||||
|
<parameter=code>import requests
|
||||||
|
url = "https://example.com"
|
||||||
|
resp = requests.get(url, timeout=10)
|
||||||
|
print(resp.status_code)</parameter>
|
||||||
|
</function>
|
||||||
|
|
||||||
# Analyze security data in the default session
|
# Analyze security data in the default session
|
||||||
<function=python_action>
|
<function=python_action>
|
||||||
<parameter=action>execute</parameter>
|
<parameter=action>execute</parameter>
|
||||||
|
|||||||
@@ -7,9 +7,14 @@ from inspect import signature
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import defusedxml.ElementTree as DefusedET
|
||||||
|
|
||||||
|
from strix.utils.resource_paths import get_strix_resource_path
|
||||||
|
|
||||||
|
|
||||||
tools: list[dict[str, Any]] = []
|
tools: list[dict[str, Any]] = []
|
||||||
_tools_by_name: dict[str, Callable[..., Any]] = {}
|
_tools_by_name: dict[str, Callable[..., Any]] = {}
|
||||||
|
_tool_param_schemas: dict[str, dict[str, Any]] = {}
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -82,6 +87,34 @@ def _load_xml_schema(path: Path) -> Any:
|
|||||||
return tools_dict
|
return tools_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_param_schema(tool_xml: str) -> dict[str, Any]:
|
||||||
|
params: set[str] = set()
|
||||||
|
required: set[str] = set()
|
||||||
|
|
||||||
|
params_start = tool_xml.find("<parameters>")
|
||||||
|
params_end = tool_xml.find("</parameters>")
|
||||||
|
|
||||||
|
if params_start == -1 or params_end == -1:
|
||||||
|
return {"params": set(), "required": set(), "has_params": False}
|
||||||
|
|
||||||
|
params_section = tool_xml[params_start : params_end + len("</parameters>")]
|
||||||
|
|
||||||
|
try:
|
||||||
|
root = DefusedET.fromstring(params_section)
|
||||||
|
except DefusedET.ParseError:
|
||||||
|
return {"params": set(), "required": set(), "has_params": False}
|
||||||
|
|
||||||
|
for param in root.findall(".//parameter"):
|
||||||
|
name = param.attrib.get("name")
|
||||||
|
if not name:
|
||||||
|
continue
|
||||||
|
params.add(name)
|
||||||
|
if param.attrib.get("required", "false").lower() == "true":
|
||||||
|
required.add(name)
|
||||||
|
|
||||||
|
return {"params": params, "required": required, "has_params": bool(params or required)}
|
||||||
|
|
||||||
|
|
||||||
def _get_module_name(func: Callable[..., Any]) -> str:
|
def _get_module_name(func: Callable[..., Any]) -> str:
|
||||||
module = inspect.getmodule(func)
|
module = inspect.getmodule(func)
|
||||||
if not module:
|
if not module:
|
||||||
@@ -95,6 +128,27 @@ def _get_module_name(func: Callable[..., Any]) -> str:
|
|||||||
return "unknown"
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_schema_path(func: Callable[..., Any]) -> Path | None:
|
||||||
|
module = inspect.getmodule(func)
|
||||||
|
if not module or not module.__name__:
|
||||||
|
return None
|
||||||
|
|
||||||
|
module_name = module.__name__
|
||||||
|
|
||||||
|
if ".tools." not in module_name:
|
||||||
|
return None
|
||||||
|
|
||||||
|
parts = module_name.split(".tools.")[-1].split(".")
|
||||||
|
if len(parts) < 2:
|
||||||
|
return None
|
||||||
|
|
||||||
|
folder = parts[0]
|
||||||
|
file_stem = parts[1]
|
||||||
|
schema_file = f"{file_stem}_schema.xml"
|
||||||
|
|
||||||
|
return get_strix_resource_path("tools", folder, schema_file)
|
||||||
|
|
||||||
|
|
||||||
def register_tool(
|
def register_tool(
|
||||||
func: Callable[..., Any] | None = None, *, sandbox_execution: bool = True
|
func: Callable[..., Any] | None = None, *, sandbox_execution: bool = True
|
||||||
) -> Callable[..., Any]:
|
) -> Callable[..., Any]:
|
||||||
@@ -109,11 +163,8 @@ def register_tool(
|
|||||||
sandbox_mode = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
|
sandbox_mode = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
|
||||||
if not sandbox_mode:
|
if not sandbox_mode:
|
||||||
try:
|
try:
|
||||||
module_path = Path(inspect.getfile(f))
|
schema_path = _get_schema_path(f)
|
||||||
schema_file_name = f"{module_path.stem}_schema.xml"
|
xml_tools = _load_xml_schema(schema_path) if schema_path else None
|
||||||
schema_path = module_path.parent / schema_file_name
|
|
||||||
|
|
||||||
xml_tools = _load_xml_schema(schema_path)
|
|
||||||
|
|
||||||
if xml_tools is not None and f.__name__ in xml_tools:
|
if xml_tools is not None and f.__name__ in xml_tools:
|
||||||
func_dict["xml_schema"] = xml_tools[f.__name__]
|
func_dict["xml_schema"] = xml_tools[f.__name__]
|
||||||
@@ -131,6 +182,11 @@ def register_tool(
|
|||||||
"</tool>"
|
"</tool>"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not sandbox_mode:
|
||||||
|
xml_schema = func_dict.get("xml_schema")
|
||||||
|
param_schema = _parse_param_schema(xml_schema if isinstance(xml_schema, str) else "")
|
||||||
|
_tool_param_schemas[str(func_dict["name"])] = param_schema
|
||||||
|
|
||||||
tools.append(func_dict)
|
tools.append(func_dict)
|
||||||
_tools_by_name[str(func_dict["name"])] = f
|
_tools_by_name[str(func_dict["name"])] = f
|
||||||
|
|
||||||
@@ -153,6 +209,10 @@ def get_tool_names() -> list[str]:
|
|||||||
return list(_tools_by_name.keys())
|
return list(_tools_by_name.keys())
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool_param_schema(name: str) -> dict[str, Any] | None:
|
||||||
|
return _tool_param_schemas.get(name)
|
||||||
|
|
||||||
|
|
||||||
def needs_agent_state(tool_name: str) -> bool:
|
def needs_agent_state(tool_name: str) -> bool:
|
||||||
tool_func = get_tool_by_name(tool_name)
|
tool_func = get_tool_by_name(tool_name)
|
||||||
if not tool_func:
|
if not tool_func:
|
||||||
@@ -194,3 +254,4 @@ def get_tools_prompt() -> str:
|
|||||||
def clear_registry() -> None:
|
def clear_registry() -> None:
|
||||||
tools.clear()
|
tools.clear()
|
||||||
_tools_by_name.clear()
|
_tools_by_name.clear()
|
||||||
|
_tool_param_schemas.clear()
|
||||||
|
|||||||
@@ -95,6 +95,12 @@
|
|||||||
<parameter=command>ls -la</parameter>
|
<parameter=command>ls -la</parameter>
|
||||||
</function>
|
</function>
|
||||||
|
|
||||||
|
<function=terminal_execute>
|
||||||
|
<parameter=command>cd /workspace
|
||||||
|
pwd
|
||||||
|
ls -la</parameter>
|
||||||
|
</function>
|
||||||
|
|
||||||
# Run a command with custom timeout
|
# Run a command with custom timeout
|
||||||
<function=terminal_execute>
|
<function=terminal_execute>
|
||||||
<parameter=command>npm install</parameter>
|
<parameter=command>npm install</parameter>
|
||||||
|
|||||||
0
strix/utils/__init__.py
Normal file
0
strix/utils/__init__.py
Normal file
13
strix/utils/resource_paths.py
Normal file
13
strix/utils/resource_paths.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def get_strix_resource_path(*parts: str) -> Path:
|
||||||
|
frozen_base = getattr(sys, "_MEIPASS", None)
|
||||||
|
if frozen_base:
|
||||||
|
base = Path(frozen_base) / "strix"
|
||||||
|
if base.exists():
|
||||||
|
return base.joinpath(*parts)
|
||||||
|
|
||||||
|
base = Path(__file__).resolve().parent.parent
|
||||||
|
return base.joinpath(*parts)
|
||||||
Reference in New Issue
Block a user