12 Commits

Author SHA1 Message Date
0xallam
e5104eb93a chore(release): bump version to 0.6.1 2026-01-14 21:30:14 -08:00
0xallam
d8a08e9a8c chore(prompt): discourage literal \n in tool params 2026-01-14 21:29:06 -08:00
0xallam
f6475cec07 chore(prompt): enforce single tool call per message and remove stop word usage 2026-01-14 19:51:08 -08:00
0xallam
31baa0dfc0 fix: restore ollama_api_base config fallback for Ollama support 2026-01-14 18:54:45 -08:00
0xallam
56526cbf90 fix(agent): fix agent loop hanging and simplify LLM module
- Fix agent loop getting stuck by adding hard stop mechanism
- Add _force_stop flag for immediate task cancellation across threads
- Use thread-safe loop.call_soon_threadsafe for cross-thread cancellation
- Remove request_queue.py (eliminated threading/queue complexity causing hangs)
- Simplify llm.py: direct acompletion calls, cleaner streaming
- Reduce retry wait times to prevent long hangs during retries
- Make timeouts configurable (llm_max_retries, memory_compressor_timeout, sandbox_execution_timeout)
- Keep essential token tracking (input/output/cached tokens, cost, requests)
- Maintain Anthropic prompt caching for system messages
2026-01-14 18:54:45 -08:00
0xallam
47faeb1ef3 fix(agent): use correct agent name in identity instead of class name 2026-01-14 11:24:24 -08:00
0xallam
435ac82d9e chore: add defusedxml dependency 2026-01-14 10:57:32 -08:00
0xallam
f08014cf51 fix(agent): fix tool schemas not retrieved on pyinstaller binary and validate tool call args 2026-01-14 10:57:32 -08:00
dependabot[bot]
bc8e14f68a chore(deps-dev): bump virtualenv from 20.34.0 to 20.36.1
Bumps [virtualenv](https://github.com/pypa/virtualenv) from 20.34.0 to 20.36.1.
- [Release notes](https://github.com/pypa/virtualenv/releases)
- [Changelog](https://github.com/pypa/virtualenv/blob/main/docs/changelog.rst)
- [Commits](https://github.com/pypa/virtualenv/compare/20.34.0...20.36.1)

---
updated-dependencies:
- dependency-name: virtualenv
  dependency-version: 20.36.1
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-01-13 17:15:58 -08:00
dependabot[bot]
eae2b783c0 chore(deps): bump filelock from 3.20.1 to 3.20.3
Bumps [filelock](https://github.com/tox-dev/py-filelock) from 3.20.1 to 3.20.3.
- [Release notes](https://github.com/tox-dev/py-filelock/releases)
- [Changelog](https://github.com/tox-dev/filelock/blob/main/docs/changelog.rst)
- [Commits](https://github.com/tox-dev/py-filelock/compare/3.20.1...3.20.3)

---
updated-dependencies:
- dependency-name: filelock
  dependency-version: 3.20.3
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-01-13 17:15:43 -08:00
dependabot[bot]
058cf1abdb chore(deps): bump azure-core from 1.35.0 to 1.38.0
Bumps [azure-core](https://github.com/Azure/azure-sdk-for-python) from 1.35.0 to 1.38.0.
- [Release notes](https://github.com/Azure/azure-sdk-for-python/releases)
- [Commits](https://github.com/Azure/azure-sdk-for-python/compare/azure-core_1.35.0...azure-core_1.38.0)

---
updated-dependencies:
- dependency-name: azure-core
  dependency-version: 1.38.0
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-01-13 17:15:22 -08:00
Ahmed Allam
d16bdb277a Update README 2026-01-14 05:00:16 +04:00
19 changed files with 448 additions and 561 deletions

View File

@@ -1,55 +1,61 @@
<p align="center">
<a href="https://strix.ai/">
<img src=".github/logo.png" width="150" alt="Strix Logo">
<img src="https://github.com/usestrix/.github/raw/main/imgs/cover.png" alt="Strix Banner" width="100%">
</a>
</p>
<h1 align="center">Strix</h1>
<h2 align="center">Open-source AI Hackers to secure your Apps</h2>
<div align="center">
[![Python](https://img.shields.io/pypi/pyversions/strix-agent?color=3776AB)](https://pypi.org/project/strix-agent/)
[![PyPI](https://img.shields.io/pypi/v/strix-agent?color=10b981)](https://pypi.org/project/strix-agent/)
[![License](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](LICENSE)
[![Docs](https://img.shields.io/badge/Docs-docs.strix.ai-10b981.svg)](https://docs.strix.ai)
# Strix
[![GitHub Stars](https://img.shields.io/github/stars/usestrix/strix)](https://github.com/usestrix/strix)
[![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?&logo=discord&logoColor=white)](https://discord.gg/YjKFvEZSdZ)
[![Website](https://img.shields.io/badge/Website-strix.ai-2d3748.svg)](https://strix.ai)
### Open-source AI hackers to find and fix your apps vulnerabilities.
<a href="https://trendshift.io/repositories/15362" target="_blank"><img src="https://trendshift.io/api/badge/repositories/15362" alt="usestrix%2Fstrix | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
<br/>
[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/usestrix/strix)
<a href="https://docs.strix.ai"><img src="https://img.shields.io/badge/Docs-docs.strix.ai-2b9246?style=for-the-badge&logo=gitbook&logoColor=white" alt="Docs"></a>
<a href="https://strix.ai"><img src="https://img.shields.io/badge/Website-strix.ai-3b82f6?style=for-the-badge&logoColor=white" alt="Website"></a>
<a href="https://pypi.org/project/strix-agent/"><img src="https://img.shields.io/badge/PyPI-strix--agent-f59e0b?style=for-the-badge&logo=pypi&logoColor=white" alt="PyPI"></a>
<a href="https://deepwiki.com/usestrix/strix"><img src="https://deepwiki.com/badge.svg" alt="Ask DeepWiki"></a>
<a href="https://github.com/usestrix/strix"><img src="https://img.shields.io/github/stars/usestrix/strix?style=flat-square" alt="GitHub Stars"></a>
<a href="LICENSE"><img src="https://img.shields.io/badge/License-Apache%202.0-3b82f6?style=flat-square" alt="License"></a>
<a href="https://pypi.org/project/strix-agent/"><img src="https://img.shields.io/pypi/v/strix-agent?style=flat-square" alt="PyPI Version"></a>
<a href="https://discord.gg/YjKFvEZSdZ"><img src="https://github.com/usestrix/.github/raw/main/imgs/Discord.png" height="40" alt="Join Discord"></a>
<a href="https://x.com/strix_ai"><img src="https://github.com/usestrix/.github/raw/main/imgs/X.png" height="40" alt="Follow on X"></a>
<a href="https://trendshift.io/repositories/15362" target="_blank"><img src="https://trendshift.io/api/badge/repositories/15362" alt="usestrix/strix | Trendshift" width="250" height="55"/></a>
</div>
<br>
<br/>
<div align="center">
<img src=".github/screenshot.png" alt="Strix Demo" width="800" style="border-radius: 16px;">
<img src=".github/screenshot.png" alt="Strix Demo" width="900" style="border-radius: 16px;">
</div>
<br>
> [!TIP]
> **New!** Strix now integrates seamlessly with GitHub Actions and CI/CD pipelines. Automatically scan for vulnerabilities on every pull request and block insecure code before it reaches production!
> **New!** Strix integrates seamlessly with GitHub Actions and CI/CD pipelines. Automatically scan for vulnerabilities on every pull request and block insecure code before it reaches production!
---
## 🦉 Strix Overview
## Strix Overview
Strix are autonomous AI agents that act just like real hackers - they run your code dynamically, find vulnerabilities, and validate them through actual proof-of-concepts. Built for developers and security teams who need fast, accurate security testing without the overhead of manual pentesting or the false positives of static analysis tools.
**Key Capabilities:**
- 🔧 **Full hacker toolkit** out of the box
- 🤝 **Teams of agents** that collaborate and scale
- **Real validation** with PoCs, not false positives
- 💻 **Developerfirst** CLI with actionable reports
- 🔄 **Autofix & reporting** to accelerate remediation
- **Full hacker toolkit** out of the box
- **Teams of agents** that collaborate and scale
- **Real validation** with PoCs, not false positives
- **Developerfirst** CLI with actionable reports
- **Autofix & reporting** to accelerate remediation
## 🎯 Use Cases
@@ -87,7 +93,7 @@ strix --target ./app-directory
> [!NOTE]
> First run automatically pulls the sandbox Docker image. Results are saved to `strix_runs/<run-name>`
## ☁️ Run Strix in Cloud
## Run Strix in Cloud
Want to skip the local setup, API keys, and unpredictable LLM costs? Run the hosted cloud version of Strix at **[app.strix.ai](https://strix.ai)**.
@@ -104,7 +110,7 @@ Launch a scan in just a few minutes—no setup or configuration required—and y
## ✨ Features
### 🛠️ Agentic Security Tools
### Agentic Security Tools
Strix agents come equipped with a comprehensive security testing toolkit:
@@ -116,7 +122,7 @@ Strix agents come equipped with a comprehensive security testing toolkit:
- **Code Analysis** - Static and dynamic analysis capabilities
- **Knowledge Management** - Structured findings and attack documentation
### 🎯 Comprehensive Vulnerability Detection
### Comprehensive Vulnerability Detection
Strix can identify and validate a wide range of security vulnerabilities:
@@ -128,7 +134,7 @@ Strix can identify and validate a wide range of security vulnerabilities:
- **Authentication** - JWT vulnerabilities, session management
- **Infrastructure** - Misconfigurations, exposed services
### 🕸️ Graph of Agents
### Graph of Agents
Advanced multi-agent orchestration for comprehensive security testing:
@@ -138,7 +144,7 @@ Advanced multi-agent orchestration for comprehensive security testing:
---
## 💻 Usage Examples
## Usage Examples
### Basic Usage
@@ -169,7 +175,7 @@ strix --target api.your-app.com --instruction "Focus on business logic flaws and
strix --target api.your-app.com --instruction-file ./instruction.md
```
### 🤖 Headless Mode
### Headless Mode
Run Strix programmatically without interactive UI using the `-n/--non-interactive` flag—perfect for servers and automated jobs. The CLI prints real-time vulnerability findings, and the final report before exiting. Exits with non-zero code when vulnerabilities are found.
@@ -177,7 +183,7 @@ Run Strix programmatically without interactive UI using the `-n/--non-interactiv
strix -n --target https://your-app.com
```
### 🔄 CI/CD (GitHub Actions)
### CI/CD (GitHub Actions)
Strix can be added to your pipeline to run a security test on pull requests with a lightweight GitHub Actions workflow:
@@ -204,7 +210,7 @@ jobs:
run: strix -n -t ./ --scan-mode quick
```
### ⚙️ Configuration
### Configuration
```bash
export STRIX_LLM="openai/gpt-5"
@@ -227,22 +233,23 @@ export STRIX_REASONING_EFFORT="high" # control thinking effort (default: high,
See the [LLM Providers documentation](https://docs.strix.ai/llm-providers/overview) for all supported providers including Vertex AI, Bedrock, Azure, and local models.
## 📚 Documentation
## Documentation
Full documentation is available at **[docs.strix.ai](https://docs.strix.ai)** — including detailed guides for usage, CI/CD integrations, skills, and advanced configuration.
## 🤝 Contributing
## Contributing
We welcome contributions of code, docs, and new skills - check out our [Contributing Guide](https://docs.strix.ai/contributing) to get started or open a [pull request](https://github.com/usestrix/strix/pulls)/[issue](https://github.com/usestrix/strix/issues).
## 👥 Join Our Community
## Join Our Community
Have questions? Found a bug? Want to contribute? **[Join our Discord!](https://discord.gg/YjKFvEZSdZ)**
## 🌟 Support the Project
## Support the Project
**Love Strix?** Give us a ⭐ on GitHub!
## 🙏 Acknowledgements
## Acknowledgements
Strix builds on the incredible work of open-source projects like [LiteLLM](https://github.com/BerriAI/litellm), [Caido](https://github.com/caido/caido), [ProjectDiscovery](https://github.com/projectdiscovery), [Playwright](https://github.com/microsoft/playwright), and [Textual](https://github.com/Textualize/textual). Huge thanks to their maintainers!

35
poetry.lock generated
View File

@@ -379,19 +379,18 @@ files = [
[[package]]
name = "azure-core"
version = "1.35.0"
version = "1.38.0"
description = "Microsoft Azure Core Library for Python"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "azure_core-1.35.0-py3-none-any.whl", hash = "sha256:8db78c72868a58f3de8991eb4d22c4d368fae226dac1002998d6c50437e7dad1"},
{file = "azure_core-1.35.0.tar.gz", hash = "sha256:c0be528489485e9ede59b6971eb63c1eaacf83ef53001bfe3904e475e972be5c"},
{file = "azure_core-1.38.0-py3-none-any.whl", hash = "sha256:ab0c9b2cd71fecb1842d52c965c95285d3cfb38902f6766e4a471f1cd8905335"},
{file = "azure_core-1.38.0.tar.gz", hash = "sha256:8194d2682245a3e4e3151a667c686464c3786fed7918b394d035bdcd61bb5993"},
]
[package.dependencies]
requests = ">=2.21.0"
six = ">=1.11.0"
typing-extensions = ">=4.6.0"
[package.extras]
@@ -1199,6 +1198,18 @@ files = [
{file = "decorator-5.2.1.tar.gz", hash = "sha256:65f266143752f734b0a7cc83c46f4618af75b8c5911b00ccb61d0ac9b6da0360"},
]
[[package]]
name = "defusedxml"
version = "0.7.1"
description = "XML bomb protection for Python stdlib modules"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
groups = ["main"]
files = [
{file = "defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61"},
{file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"},
]
[[package]]
name = "dill"
version = "0.4.0"
@@ -1490,14 +1501,14 @@ files = [
[[package]]
name = "filelock"
version = "3.20.1"
version = "3.20.3"
description = "A platform independent file lock."
optional = false
python-versions = ">=3.10"
groups = ["main", "dev"]
files = [
{file = "filelock-3.20.1-py3-none-any.whl", hash = "sha256:15d9e9a67306188a44baa72f569d2bfd803076269365fdea0934385da4dc361a"},
{file = "filelock-3.20.1.tar.gz", hash = "sha256:b8360948b351b80f420878d8516519a2204b07aefcdcfd24912a5d33127f188c"},
{file = "filelock-3.20.3-py3-none-any.whl", hash = "sha256:4b0dda527ee31078689fc205ec4f1c1bf7d56cf88b6dc9426c4f230e46c2dce1"},
{file = "filelock-3.20.3.tar.gz", hash = "sha256:18c57ee915c7ec61cff0ecf7f0f869936c7c30191bb0cf406f1341778d0834e1"},
]
[[package]]
@@ -7095,19 +7106,19 @@ test = ["aiohttp (>=3.10.5)", "flake8 (>=5.0,<6.0)", "mypy (>=0.800)", "psutil",
[[package]]
name = "virtualenv"
version = "20.34.0"
version = "20.36.1"
description = "Virtual Python Environment builder"
optional = false
python-versions = ">=3.8"
groups = ["dev"]
files = [
{file = "virtualenv-20.34.0-py3-none-any.whl", hash = "sha256:341f5afa7eee943e4984a9207c025feedd768baff6753cd660c857ceb3e36026"},
{file = "virtualenv-20.34.0.tar.gz", hash = "sha256:44815b2c9dee7ed86e387b842a84f20b93f7f417f95886ca1996a72a4138eb1a"},
{file = "virtualenv-20.36.1-py3-none-any.whl", hash = "sha256:575a8d6b124ef88f6f51d56d656132389f961062a9177016a50e4f507bbcc19f"},
{file = "virtualenv-20.36.1.tar.gz", hash = "sha256:8befb5c81842c641f8ee658481e42641c68b5eab3521d8e092d18320902466ba"},
]
[package.dependencies]
distlib = ">=0.3.7,<1"
filelock = ">=3.12.2,<4"
filelock = {version = ">=3.20.1,<4", markers = "python_version >= \"3.10\""}
platformdirs = ">=3.9.1,<5"
[package.extras]
@@ -7424,4 +7435,4 @@ vertex = ["google-cloud-aiplatform"]
[metadata]
lock-version = "2.1"
python-versions = "^3.12"
content-hash = "91f49e313e5690bbef87e17730441f26d366daeccb16b5020e03e581fbb9d4d5"
content-hash = "0424a0e82fe49501f3a80166676e257a9dae97093d9bc730489789195f523735"

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "strix-agent"
version = "0.6.0"
version = "0.6.1"
description = "Open-source AI Hackers for your apps"
authors = ["Strix <hi@usestrix.com>"]
readme = "README.md"
@@ -69,6 +69,7 @@ gql = { version = "^3.5.3", extras = ["requests"], optional = true }
pyte = { version = "^0.8.1", optional = true }
libtmux = { version = "^0.46.2", optional = true }
numpydoc = { version = "^1.8.0", optional = true }
defusedxml = "^0.7.1"
[tool.poetry.extras]
vertex = ["google-cloud-aiplatform"]

View File

@@ -111,7 +111,6 @@ hiddenimports = [
'strix.llm.llm',
'strix.llm.config',
'strix.llm.utils',
'strix.llm.request_queue',
'strix.llm.memory_compressor',
'strix.runtime',
'strix.runtime.runtime',

View File

@@ -308,17 +308,18 @@ Tool calls use XML format:
CRITICAL RULES:
0. While active in the agent loop, EVERY message you output MUST be a single tool call. Do not send plain text-only responses.
1. One tool call per message
1. Exactly one tool call per message — never include more than one <function>...</function> block in a single LLM message.
2. Tool call must be last in message
3. End response after </function> tag. It's your stop word. Do not continue after it.
3. EVERY tool call MUST end with </function>. This is MANDATORY. Never omit the closing tag. End your response immediately after </function>.
4. Use ONLY the exact XML format shown above. NEVER use JSON/YAML/INI or any other syntax for tools or parameters.
5. Tool names must match exactly the tool "name" defined (no module prefixes, dots, or variants).
5. When sending ANY multi-line content in tool parameters, use real newlines (actual line breaks). Do NOT emit literal "\n" sequences. If you send "\n" instead of real line breaks inside the XML parameter value, tools may fail or behave incorrectly.
6. Tool names must match exactly the tool "name" defined (no module prefixes, dots, or variants).
- Correct: <function=think> ... </function>
- Incorrect: <thinking_tools.think> ... </function>
- Incorrect: <think> ... </think>
- Incorrect: {"think": {...}}
6. Parameters must use <parameter=param_name>value</parameter> exactly. Do NOT pass parameters as JSON or key:value lines. Do NOT add quotes/braces around values.
7. Do NOT wrap tool calls in markdown/code fences or add any text before or after the tool block.
7. Parameters must use <parameter=param_name>value</parameter> exactly. Do NOT pass parameters as JSON or key:value lines. Do NOT add quotes/braces around values.
8. Do NOT wrap tool calls in markdown/code fences or add any text before or after the tool block.
Example (agent creation tool):
<function=create_agent>
@@ -331,6 +332,8 @@ SPRAYING EXECUTION NOTE:
- When performing large payload sprays or fuzzing, encapsulate the entire spraying loop inside a single python or terminal tool call (e.g., a Python script using asyncio/aiohttp). Do not issue one tool call per payload.
- Favor batch-mode CLI tools (sqlmap, ffuf, nuclei, zaproxy, arjun) where appropriate and check traffic via the proxy when beneficial
REMINDER: Always close each tool call with </function> before going into the next. Incomplete tool calls will fail.
{{ get_tools_prompt() }}
</tool_usage>

View File

@@ -1,7 +1,6 @@
import asyncio
import contextlib
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional
@@ -18,6 +17,7 @@ from strix.llm import LLM, LLMConfig, LLMRequestFailedError
from strix.llm.utils import clean_content
from strix.runtime import SandboxInitializationError
from strix.tools import process_tool_invocations
from strix.utils.resource_paths import get_strix_resource_path
from .state import AgentState
@@ -35,8 +35,7 @@ class AgentMeta(type):
if name == "BaseAgent":
return new_cls
agents_dir = Path(__file__).parent
prompt_dir = agents_dir / name
prompt_dir = get_strix_resource_path("agents", name)
new_cls.agent_name = name
new_cls.jinja_env = Environment(
@@ -66,20 +65,21 @@ class BaseAgent(metaclass=AgentMeta):
self.llm_config = config.get("llm_config", self.default_llm_config)
if self.llm_config is None:
raise ValueError("llm_config is required but not provided")
self.llm = LLM(self.llm_config, agent_name=self.agent_name)
state_from_config = config.get("state")
if state_from_config is not None:
self.state = state_from_config
else:
self.state = AgentState(
agent_name=self.agent_name,
agent_name="Root Agent",
max_iterations=self.max_iterations,
)
self.llm = LLM(self.llm_config, agent_name=self.agent_name)
with contextlib.suppress(Exception):
self.llm.set_agent_identity(self.agent_name, self.state.agent_id)
self.llm.set_agent_identity(self.state.agent_name, self.state.agent_id)
self._current_task: asyncio.Task[Any] | None = None
self._force_stop = False
from strix.telemetry.tracer import get_global_tracer
@@ -157,6 +157,11 @@ class BaseAgent(metaclass=AgentMeta):
return self._handle_sandbox_error(e, tracer)
while True:
if self._force_stop:
self._force_stop = False
await self._enter_waiting_state(tracer, was_cancelled=True)
continue
self._check_agent_messages(self.state)
if self.state.is_waiting_for_input():
@@ -247,7 +252,8 @@ class BaseAgent(metaclass=AgentMeta):
continue
async def _wait_for_input(self) -> None:
import asyncio
if self._force_stop:
return
if self.state.has_waiting_timeout():
self.state.resume_from_waiting()
@@ -340,6 +346,7 @@ class BaseAgent(metaclass=AgentMeta):
async def _process_iteration(self, tracer: Optional["Tracer"]) -> bool:
final_response = None
async for response in self.llm.generate(self.state.get_conversation_history()):
final_response = response
if tracer and response.content:
@@ -585,6 +592,11 @@ class BaseAgent(metaclass=AgentMeta):
return True
def cancel_current_execution(self) -> None:
self._force_stop = True
if self._current_task and not self._current_task.done():
self._current_task.cancel()
try:
loop = self._current_task.get_loop()
loop.call_soon_threadsafe(self._current_task.cancel)
except RuntimeError:
self._current_task.cancel()
self._current_task = None

View File

@@ -16,9 +16,9 @@ class Config:
litellm_base_url = None
ollama_api_base = None
strix_reasoning_effort = "high"
strix_llm_max_retries = "5"
strix_memory_compressor_timeout = "30"
llm_timeout = "300"
llm_rate_limit_delay = "4.0"
llm_rate_limit_concurrent = "1"
# Tool & Feature Configuration
perplexity_api_key = None
@@ -27,7 +27,7 @@ class Config:
# Runtime Configuration
strix_image = "ghcr.io/usestrix/strix-sandbox:0.1.10"
strix_runtime_backend = "docker"
strix_sandbox_execution_timeout = "500"
strix_sandbox_execution_timeout = "120"
strix_sandbox_connect_timeout = "10"
# Telemetry

View File

@@ -1,60 +1,29 @@
import asyncio
import logging
from collections.abc import AsyncIterator
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any
import litellm
from jinja2 import (
Environment,
FileSystemLoader,
select_autoescape,
)
from litellm import completion_cost, stream_chunk_builder, supports_reasoning
from jinja2 import Environment, FileSystemLoader, select_autoescape
from litellm import acompletion, completion_cost, stream_chunk_builder, supports_reasoning
from litellm.utils import supports_prompt_caching, supports_vision
from strix.config import Config
from strix.llm.config import LLMConfig
from strix.llm.memory_compressor import MemoryCompressor
from strix.llm.request_queue import get_global_queue
from strix.llm.utils import _truncate_to_first_function, parse_tool_invocations
from strix.llm.utils import (
_truncate_to_first_function,
fix_incomplete_tool_call,
parse_tool_invocations,
)
from strix.skills import load_skills
from strix.tools import get_tools_prompt
from strix.utils.resource_paths import get_strix_resource_path
MAX_RETRIES = 5
RETRY_MULTIPLIER = 8
RETRY_MIN = 8
RETRY_MAX = 64
def _should_retry(exception: Exception) -> bool:
status_code = None
if hasattr(exception, "status_code"):
status_code = exception.status_code
elif hasattr(exception, "response") and hasattr(exception.response, "status_code"):
status_code = exception.response.status_code
if status_code is not None:
return bool(litellm._should_retry(status_code))
return True
logger = logging.getLogger(__name__)
litellm.drop_params = True
litellm.modify_params = True
_LLM_API_KEY = Config.get("llm_api_key")
_LLM_API_BASE = (
Config.get("llm_api_base")
or Config.get("openai_api_base")
or Config.get("litellm_base_url")
or Config.get("ollama_api_base")
)
_STRIX_REASONING_EFFORT = Config.get("strix_reasoning_effort")
class LLMRequestFailedError(Exception):
def __init__(self, message: str, details: str | None = None):
@@ -63,20 +32,11 @@ class LLMRequestFailedError(Exception):
self.details = details
class StepRole(str, Enum):
AGENT = "agent"
USER = "user"
SYSTEM = "system"
@dataclass
class LLMResponse:
content: str
tool_invocations: list[dict[str, Any]] | None = None
scan_id: str | None = None
step_number: int = 1
role: StepRole = StepRole.AGENT
thinking_blocks: list[dict[str, Any]] | None = None # For reasoning models.
thinking_blocks: list[dict[str, Any]] | None = None
@dataclass
@@ -84,76 +44,63 @@ class RequestStats:
input_tokens: int = 0
output_tokens: int = 0
cached_tokens: int = 0
cache_creation_tokens: int = 0
cost: float = 0.0
requests: int = 0
failed_requests: int = 0
def to_dict(self) -> dict[str, int | float]:
return {
"input_tokens": self.input_tokens,
"output_tokens": self.output_tokens,
"cached_tokens": self.cached_tokens,
"cache_creation_tokens": self.cache_creation_tokens,
"cost": round(self.cost, 4),
"requests": self.requests,
"failed_requests": self.failed_requests,
}
class LLM:
def __init__(
self, config: LLMConfig, agent_name: str | None = None, agent_id: str | None = None
):
def __init__(self, config: LLMConfig, agent_name: str | None = None):
self.config = config
self.agent_name = agent_name
self.agent_id = agent_id
self.agent_id: str | None = None
self._total_stats = RequestStats()
self._last_request_stats = RequestStats()
self.memory_compressor = MemoryCompressor(model_name=config.model_name)
self.system_prompt = self._load_system_prompt(agent_name)
if _STRIX_REASONING_EFFORT:
self._reasoning_effort = _STRIX_REASONING_EFFORT
elif self.config.scan_mode == "quick":
reasoning = Config.get("strix_reasoning_effort")
if reasoning:
self._reasoning_effort = reasoning
elif config.scan_mode == "quick":
self._reasoning_effort = "medium"
else:
self._reasoning_effort = "high"
self.memory_compressor = MemoryCompressor(
model_name=self.config.model_name,
timeout=self.config.timeout,
)
def _load_system_prompt(self, agent_name: str | None) -> str:
if not agent_name:
return ""
if agent_name:
prompt_dir = Path(__file__).parent.parent / "agents" / agent_name
skills_dir = Path(__file__).parent.parent / "skills"
loader = FileSystemLoader([prompt_dir, skills_dir])
self.jinja_env = Environment(
loader=loader,
try:
prompt_dir = get_strix_resource_path("agents", agent_name)
skills_dir = get_strix_resource_path("skills")
env = Environment(
loader=FileSystemLoader([prompt_dir, skills_dir]),
autoescape=select_autoescape(enabled_extensions=(), default_for_string=False),
)
try:
skills_to_load = list(self.config.skills or [])
skills_to_load.append(f"scan_modes/{self.config.scan_mode}")
skills_to_load = [
*list(self.config.skills or []),
f"scan_modes/{self.config.scan_mode}",
]
skill_content = load_skills(skills_to_load, env)
env.globals["get_skill"] = lambda name: skill_content.get(name, "")
skill_content = load_skills(skills_to_load, self.jinja_env)
def get_skill(name: str) -> str:
return skill_content.get(name, "")
self.jinja_env.globals["get_skill"] = get_skill
self.system_prompt = self.jinja_env.get_template("system_prompt.jinja").render(
get_tools_prompt=get_tools_prompt,
loaded_skill_names=list(skill_content.keys()),
**skill_content,
)
except (FileNotFoundError, OSError, ValueError) as e:
logger.warning(f"Failed to load system prompt for {agent_name}: {e}")
self.system_prompt = "You are a helpful AI assistant."
else:
self.system_prompt = "You are a helpful AI assistant."
result = env.get_template("system_prompt.jinja").render(
get_tools_prompt=get_tools_prompt,
loaded_skill_names=list(skill_content.keys()),
**skill_content,
)
return str(result)
except Exception: # noqa: BLE001
return ""
def set_agent_identity(self, agent_name: str | None, agent_id: str | None) -> None:
if agent_name:
@@ -161,328 +108,121 @@ class LLM:
if agent_id:
self.agent_id = agent_id
def _build_identity_message(self) -> dict[str, Any] | None:
if not (self.agent_name and str(self.agent_name).strip()):
return None
identity_name = self.agent_name
identity_id = self.agent_id
content = (
"\n\n"
"<agent_identity>\n"
"<meta>Internal metadata: do not echo or reference; "
"not part of history or tool calls.</meta>\n"
"<note>You are now assuming the role of this agent. "
"Act strictly as this agent and maintain self-identity for this step. "
"Now go answer the next needed step!</note>\n"
f"<agent_name>{identity_name}</agent_name>\n"
f"<agent_id>{identity_id}</agent_id>\n"
"</agent_identity>\n\n"
)
return {"role": "user", "content": content}
async def generate(
self, conversation_history: list[dict[str, Any]]
) -> AsyncIterator[LLMResponse]:
messages = self._prepare_messages(conversation_history)
max_retries = int(Config.get("strix_llm_max_retries") or "5")
def _add_cache_control_to_content(
self, content: str | list[dict[str, Any]]
) -> str | list[dict[str, Any]]:
if isinstance(content, str):
return [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}]
if isinstance(content, list) and content:
last_item = content[-1]
if isinstance(last_item, dict) and last_item.get("type") == "text":
return content[:-1] + [{**last_item, "cache_control": {"type": "ephemeral"}}]
return content
for attempt in range(max_retries + 1):
try:
async for response in self._stream(messages):
yield response
return # noqa: TRY300
except Exception as e: # noqa: BLE001
if attempt >= max_retries or not self._should_retry(e):
self._raise_error(e)
wait = min(10, 2 * (2**attempt))
await asyncio.sleep(wait)
def _is_anthropic_model(self) -> bool:
if not self.config.model_name:
return False
model_lower = self.config.model_name.lower()
return any(provider in model_lower for provider in ["anthropic/", "claude"])
async def _stream(self, messages: list[dict[str, Any]]) -> AsyncIterator[LLMResponse]:
accumulated = ""
chunks: list[Any] = []
def _calculate_cache_interval(self, total_messages: int) -> int:
if total_messages <= 1:
return 10
self._total_stats.requests += 1
response = await acompletion(**self._build_completion_args(messages), stream=True)
max_cached_messages = 3
non_system_messages = total_messages - 1
interval = 10
while non_system_messages // interval > max_cached_messages:
interval += 10
return interval
def _prepare_cached_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
if (
not self.config.enable_prompt_caching
or not supports_prompt_caching(self.config.model_name)
or not messages
):
return messages
if not self._is_anthropic_model():
return messages
cached_messages = list(messages)
if cached_messages and cached_messages[0].get("role") == "system":
system_message = cached_messages[0].copy()
system_message["content"] = self._add_cache_control_to_content(
system_message["content"]
)
cached_messages[0] = system_message
total_messages = len(cached_messages)
if total_messages > 1:
interval = self._calculate_cache_interval(total_messages)
cached_count = 0
for i in range(interval, total_messages, interval):
if cached_count >= 3:
async for chunk in response:
chunks.append(chunk)
delta = self._get_chunk_content(chunk)
if delta:
accumulated += delta
if "</function>" in accumulated:
accumulated = accumulated[
: accumulated.find("</function>") + len("</function>")
]
yield LLMResponse(content=accumulated)
break
yield LLMResponse(content=accumulated)
if i < len(cached_messages):
message = cached_messages[i].copy()
message["content"] = self._add_cache_control_to_content(message["content"])
cached_messages[i] = message
cached_count += 1
if chunks:
self._update_usage_stats(stream_chunk_builder(chunks))
return cached_messages
accumulated = fix_incomplete_tool_call(_truncate_to_first_function(accumulated))
yield LLMResponse(
content=accumulated,
tool_invocations=parse_tool_invocations(accumulated),
thinking_blocks=self._extract_thinking(chunks),
)
def _prepare_messages(self, conversation_history: list[dict[str, Any]]) -> list[dict[str, Any]]:
messages = [{"role": "system", "content": self.system_prompt}]
identity_message = self._build_identity_message()
if identity_message:
messages.append(identity_message)
compressed_history = list(self.memory_compressor.compress_history(conversation_history))
if self.agent_name:
messages.append(
{
"role": "user",
"content": (
f"\n\n<agent_identity>\n"
f"<meta>Internal metadata: do not echo or reference.</meta>\n"
f"<agent_name>{self.agent_name}</agent_name>\n"
f"<agent_id>{self.agent_id}</agent_id>\n"
f"</agent_identity>\n\n"
),
}
)
compressed = list(self.memory_compressor.compress_history(conversation_history))
conversation_history.clear()
conversation_history.extend(compressed_history)
messages.extend(compressed_history)
conversation_history.extend(compressed)
messages.extend(compressed)
return self._prepare_cached_messages(messages)
if self._is_anthropic() and self.config.enable_prompt_caching:
messages = self._add_cache_control(messages)
async def _stream_and_accumulate(
self,
messages: list[dict[str, Any]],
scan_id: str | None,
step_number: int,
) -> AsyncIterator[LLMResponse]:
accumulated_content = ""
chunks: list[Any] = []
return messages
async for chunk in self._stream_request(messages):
chunks.append(chunk)
delta = self._extract_chunk_delta(chunk)
if delta:
accumulated_content += delta
def _build_completion_args(self, messages: list[dict[str, Any]]) -> dict[str, Any]:
if not self._supports_vision():
messages = self._strip_images(messages)
if "</function>" in accumulated_content:
function_end = accumulated_content.find("</function>") + len("</function>")
accumulated_content = accumulated_content[:function_end]
yield LLMResponse(
scan_id=scan_id,
step_number=step_number,
role=StepRole.AGENT,
content=accumulated_content,
tool_invocations=None,
)
if chunks:
complete_response = stream_chunk_builder(chunks)
self._update_usage_stats(complete_response)
accumulated_content = _truncate_to_first_function(accumulated_content)
if "</function>" in accumulated_content:
function_end = accumulated_content.find("</function>") + len("</function>")
accumulated_content = accumulated_content[:function_end]
tool_invocations = parse_tool_invocations(accumulated_content)
# Extract thinking blocks from the complete response if available
thinking_blocks = None
if chunks and self._should_include_reasoning_effort():
complete_response = stream_chunk_builder(chunks)
if (
hasattr(complete_response, "choices")
and complete_response.choices
and hasattr(complete_response.choices[0], "message")
):
message = complete_response.choices[0].message
if hasattr(message, "thinking_blocks") and message.thinking_blocks:
thinking_blocks = message.thinking_blocks
yield LLMResponse(
scan_id=scan_id,
step_number=step_number,
role=StepRole.AGENT,
content=accumulated_content,
tool_invocations=tool_invocations if tool_invocations else None,
thinking_blocks=thinking_blocks,
)
def _raise_llm_error(self, e: Exception) -> None:
error_map: list[tuple[type, str]] = [
(litellm.RateLimitError, "Rate limit exceeded"),
(litellm.AuthenticationError, "Invalid API key"),
(litellm.NotFoundError, "Model not found"),
(litellm.ContextWindowExceededError, "Context too long"),
(litellm.ContentPolicyViolationError, "Content policy violation"),
(litellm.ServiceUnavailableError, "Service unavailable"),
(litellm.Timeout, "Request timed out"),
(litellm.UnprocessableEntityError, "Unprocessable entity"),
(litellm.InternalServerError, "Internal server error"),
(litellm.APIConnectionError, "Connection error"),
(litellm.UnsupportedParamsError, "Unsupported parameters"),
(litellm.BudgetExceededError, "Budget exceeded"),
(litellm.APIResponseValidationError, "Response validation error"),
(litellm.JSONSchemaValidationError, "JSON schema validation error"),
(litellm.InvalidRequestError, "Invalid request"),
(litellm.BadRequestError, "Bad request"),
(litellm.APIError, "API error"),
(litellm.OpenAIError, "OpenAI error"),
]
from strix.telemetry import posthog
for error_type, message in error_map:
if isinstance(e, error_type):
posthog.error(f"llm_{error_type.__name__}", message)
raise LLMRequestFailedError(f"LLM request failed: {message}", str(e)) from e
posthog.error("llm_unknown_error", type(e).__name__)
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
async def generate(
self,
conversation_history: list[dict[str, Any]],
scan_id: str | None = None,
step_number: int = 1,
) -> AsyncIterator[LLMResponse]:
messages = self._prepare_messages(conversation_history)
last_error: Exception | None = None
for attempt in range(MAX_RETRIES):
try:
async for response in self._stream_and_accumulate(messages, scan_id, step_number):
yield response
return # noqa: TRY300
except Exception as e: # noqa: BLE001
last_error = e
if not _should_retry(e) or attempt == MAX_RETRIES - 1:
break
wait_time = min(RETRY_MAX, RETRY_MULTIPLIER * (2**attempt))
wait_time = max(RETRY_MIN, wait_time)
await asyncio.sleep(wait_time)
if last_error:
self._raise_llm_error(last_error)
def _extract_chunk_delta(self, chunk: Any) -> str:
if chunk.choices and hasattr(chunk.choices[0], "delta"):
delta = chunk.choices[0].delta
return getattr(delta, "content", "") or ""
return ""
@property
def usage_stats(self) -> dict[str, dict[str, int | float]]:
return {
"total": self._total_stats.to_dict(),
"last_request": self._last_request_stats.to_dict(),
}
def get_cache_config(self) -> dict[str, bool]:
return {
"enabled": self.config.enable_prompt_caching,
"supported": supports_prompt_caching(self.config.model_name),
}
def _should_include_reasoning_effort(self) -> bool:
if not self.config.model_name:
return False
try:
return bool(supports_reasoning(model=self.config.model_name))
except Exception: # noqa: BLE001
return False
def _model_supports_vision(self) -> bool:
if not self.config.model_name:
return False
try:
return bool(supports_vision(model=self.config.model_name))
except Exception: # noqa: BLE001
return False
def _filter_images_from_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
filtered_messages = []
for msg in messages:
content = msg.get("content")
updated_msg = msg
if isinstance(content, list):
filtered_content = []
for item in content:
if isinstance(item, dict):
if item.get("type") == "image_url":
filtered_content.append(
{
"type": "text",
"text": "[Screenshot removed - model does not support "
"vision. Use view_source or execute_js instead.]",
}
)
else:
filtered_content.append(item)
else:
filtered_content.append(item)
if filtered_content:
text_parts = [
item.get("text", "") if isinstance(item, dict) else str(item)
for item in filtered_content
]
all_text = all(
isinstance(item, dict) and item.get("type") == "text"
for item in filtered_content
)
if all_text:
updated_msg = {**msg, "content": "\n".join(text_parts)}
else:
updated_msg = {**msg, "content": filtered_content}
else:
updated_msg = {**msg, "content": ""}
filtered_messages.append(updated_msg)
return filtered_messages
async def _stream_request(
self,
messages: list[dict[str, Any]],
) -> AsyncIterator[Any]:
if not self._model_supports_vision():
messages = self._filter_images_from_messages(messages)
completion_args: dict[str, Any] = {
args: dict[str, Any] = {
"model": self.config.model_name,
"messages": messages,
"timeout": self.config.timeout,
"stream_options": {"include_usage": True},
}
if _LLM_API_KEY:
completion_args["api_key"] = _LLM_API_KEY
if _LLM_API_BASE:
completion_args["api_base"] = _LLM_API_BASE
if api_key := Config.get("llm_api_key"):
args["api_key"] = api_key
if api_base := (
Config.get("llm_api_base")
or Config.get("openai_api_base")
or Config.get("litellm_base_url")
or Config.get("ollama_api_base")
):
args["api_base"] = api_base
if self._supports_reasoning():
args["reasoning_effort"] = self._reasoning_effort
completion_args["stop"] = ["</function>"]
return args
if self._should_include_reasoning_effort():
completion_args["reasoning_effort"] = self._reasoning_effort
def _get_chunk_content(self, chunk: Any) -> str:
if chunk.choices and hasattr(chunk.choices[0], "delta"):
return getattr(chunk.choices[0].delta, "content", "") or ""
return ""
queue = get_global_queue()
self._total_stats.requests += 1
self._last_request_stats = RequestStats(requests=1)
async for chunk in queue.stream_request(completion_args):
yield chunk
def _extract_thinking(self, chunks: list[Any]) -> list[dict[str, Any]] | None:
if not chunks or not self._supports_reasoning():
return None
try:
resp = stream_chunk_builder(chunks)
if resp.choices and hasattr(resp.choices[0].message, "thinking_blocks"):
blocks: list[dict[str, Any]] = resp.choices[0].message.thinking_blocks
return blocks
except Exception: # noqa: BLE001, S110 # nosec B110
pass
return None
def _update_usage_stats(self, response: Any) -> None:
try:
@@ -491,45 +231,88 @@ class LLM:
output_tokens = getattr(response.usage, "completion_tokens", 0)
cached_tokens = 0
cache_creation_tokens = 0
if hasattr(response.usage, "prompt_tokens_details"):
prompt_details = response.usage.prompt_tokens_details
if hasattr(prompt_details, "cached_tokens"):
cached_tokens = prompt_details.cached_tokens or 0
if hasattr(response.usage, "cache_creation_input_tokens"):
cache_creation_tokens = response.usage.cache_creation_input_tokens or 0
else:
input_tokens = 0
output_tokens = 0
cached_tokens = 0
cache_creation_tokens = 0
try:
cost = completion_cost(response) or 0.0
except Exception as e: # noqa: BLE001
logger.warning(f"Failed to calculate cost: {e}")
except Exception: # noqa: BLE001
cost = 0.0
self._total_stats.input_tokens += input_tokens
self._total_stats.output_tokens += output_tokens
self._total_stats.cached_tokens += cached_tokens
self._total_stats.cache_creation_tokens += cache_creation_tokens
self._total_stats.cost += cost
self._last_request_stats.input_tokens = input_tokens
self._last_request_stats.output_tokens = output_tokens
self._last_request_stats.cached_tokens = cached_tokens
self._last_request_stats.cache_creation_tokens = cache_creation_tokens
self._last_request_stats.cost = cost
except Exception: # noqa: BLE001, S110 # nosec B110
pass
if cached_tokens > 0:
logger.info(f"Cache hit: {cached_tokens} cached tokens, {input_tokens} new tokens")
if cache_creation_tokens > 0:
logger.info(f"Cache creation: {cache_creation_tokens} tokens written to cache")
def _should_retry(self, e: Exception) -> bool:
code = getattr(e, "status_code", None) or getattr(
getattr(e, "response", None), "status_code", None
)
return code is None or litellm._should_retry(code)
logger.info(f"Usage stats: {self.usage_stats}")
except Exception as e: # noqa: BLE001
logger.warning(f"Failed to update usage stats: {e}")
def _raise_error(self, e: Exception) -> None:
from strix.telemetry import posthog
posthog.error("llm_error", type(e).__name__)
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
def _is_anthropic(self) -> bool:
if not self.config.model_name:
return False
return any(p in self.config.model_name.lower() for p in ["anthropic/", "claude"])
def _supports_vision(self) -> bool:
try:
return bool(supports_vision(model=self.config.model_name))
except Exception: # noqa: BLE001
return False
def _supports_reasoning(self) -> bool:
try:
return bool(supports_reasoning(model=self.config.model_name))
except Exception: # noqa: BLE001
return False
def _strip_images(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
result = []
for msg in messages:
content = msg.get("content")
if isinstance(content, list):
text_parts = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
text_parts.append(item.get("text", ""))
elif isinstance(item, dict) and item.get("type") == "image_url":
text_parts.append("[Image removed - model doesn't support vision]")
result.append({**msg, "content": "\n".join(text_parts)})
else:
result.append(msg)
return result
def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
if not messages or not supports_prompt_caching(self.config.model_name):
return messages
result = list(messages)
if result[0].get("role") == "system":
content = result[0]["content"]
result[0] = {
**result[0],
"content": [
{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}
]
if isinstance(content, str)
else content,
}
return result

View File

@@ -86,7 +86,7 @@ def _extract_message_text(msg: dict[str, Any]) -> str:
def _summarize_messages(
messages: list[dict[str, Any]],
model: str,
timeout: int = 600,
timeout: int = 30,
) -> dict[str, Any]:
if not messages:
empty_summary = "<context_summary message_count='0'>{text}</context_summary>"
@@ -148,11 +148,11 @@ class MemoryCompressor:
self,
max_images: int = 3,
model_name: str | None = None,
timeout: int = 600,
timeout: int | None = None,
):
self.max_images = max_images
self.model_name = model_name or Config.get("strix_llm")
self.timeout = timeout
self.timeout = timeout or int(Config.get("strix_memory_compressor_timeout") or "30")
if not self.model_name:
raise ValueError("STRIX_LLM environment variable must be set and not empty")

View File

@@ -1,58 +0,0 @@
import asyncio
import threading
import time
from collections.abc import AsyncIterator
from typing import Any
from litellm import acompletion
from litellm.types.utils import ModelResponseStream
from strix.config import Config
class LLMRequestQueue:
def __init__(self) -> None:
self.delay_between_requests = float(Config.get("llm_rate_limit_delay") or "4.0")
self.max_concurrent = int(Config.get("llm_rate_limit_concurrent") or "1")
self._semaphore = threading.BoundedSemaphore(self.max_concurrent)
self._last_request_time = 0.0
self._lock = threading.Lock()
async def stream_request(
self, completion_args: dict[str, Any]
) -> AsyncIterator[ModelResponseStream]:
try:
while not self._semaphore.acquire(timeout=0.2):
await asyncio.sleep(0.1)
with self._lock:
now = time.time()
time_since_last = now - self._last_request_time
sleep_needed = max(0, self.delay_between_requests - time_since_last)
self._last_request_time = now + sleep_needed
if sleep_needed > 0:
await asyncio.sleep(sleep_needed)
async for chunk in self._stream_request(completion_args):
yield chunk
finally:
self._semaphore.release()
async def _stream_request(
self, completion_args: dict[str, Any]
) -> AsyncIterator[ModelResponseStream]:
response = await acompletion(**completion_args, stream=True)
async for chunk in response:
yield chunk
_global_queue: LLMRequestQueue | None = None
def get_global_queue() -> LLMRequestQueue:
global _global_queue # noqa: PLW0603
if _global_queue is None:
_global_queue = LLMRequestQueue()
return _global_queue

View File

@@ -18,7 +18,7 @@ def _truncate_to_first_function(content: str) -> str:
def parse_tool_invocations(content: str) -> list[dict[str, Any]] | None:
content = _fix_stopword(content)
content = fix_incomplete_tool_call(content)
tool_invocations: list[dict[str, Any]] = []
@@ -46,16 +46,15 @@ def parse_tool_invocations(content: str) -> list[dict[str, Any]] | None:
return tool_invocations if tool_invocations else None
def _fix_stopword(content: str) -> str:
def fix_incomplete_tool_call(content: str) -> str:
"""Fix incomplete tool calls by adding missing </function> tag."""
if (
"<function=" in content
and content.count("<function=") == 1
and "</function>" not in content
):
if content.endswith("</"):
content = content.rstrip() + "function>"
else:
content = content + "\n</function>"
content = content.rstrip()
content = content + "function>" if content.endswith("</") else content + "\n</function>"
return content
@@ -74,7 +73,7 @@ def clean_content(content: str) -> str:
if not content:
return ""
content = _fix_stopword(content)
content = fix_incomplete_tool_call(content)
tool_pattern = r"<function=[^>]+>.*?</function>"
cleaned = re.sub(tool_pattern, "", content, flags=re.DOTALL)

View File

@@ -1,11 +1,14 @@
from pathlib import Path
from jinja2 import Environment
from strix.utils.resource_paths import get_strix_resource_path
def get_available_skills() -> dict[str, list[str]]:
skills_dir = Path(__file__).parent
available_skills = {}
skills_dir = get_strix_resource_path("skills")
available_skills: dict[str, list[str]] = {}
if not skills_dir.exists():
return available_skills
for category_dir in skills_dir.iterdir():
if category_dir.is_dir() and not category_dir.name.startswith("__"):
@@ -72,7 +75,7 @@ def load_skills(skill_names: list[str], jinja_env: Environment) -> dict[str, str
logger = logging.getLogger(__name__)
skill_content = {}
skills_dir = Path(__file__).parent
skills_dir = get_strix_resource_path("skills")
available_skills = get_available_skills()

View File

@@ -430,10 +430,8 @@ class Tracer:
"input_tokens": 0,
"output_tokens": 0,
"cached_tokens": 0,
"cache_creation_tokens": 0,
"cost": 0.0,
"requests": 0,
"failed_requests": 0,
}
for agent_instance in _agent_instances.values():
@@ -442,10 +440,8 @@ class Tracer:
total_stats["input_tokens"] += agent_stats.input_tokens
total_stats["output_tokens"] += agent_stats.output_tokens
total_stats["cached_tokens"] += agent_stats.cached_tokens
total_stats["cache_creation_tokens"] += agent_stats.cache_creation_tokens
total_stats["cost"] += agent_stats.cost
total_stats["requests"] += agent_stats.requests
total_stats["failed_requests"] += agent_stats.failed_requests
total_stats["cost"] = round(total_stats["cost"], 4)

View File

@@ -14,12 +14,13 @@ from .argument_parser import convert_arguments
from .registry import (
get_tool_by_name,
get_tool_names,
get_tool_param_schema,
needs_agent_state,
should_execute_in_sandbox,
)
SANDBOX_EXECUTION_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout") or "500")
SANDBOX_EXECUTION_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout") or "120")
SANDBOX_CONNECT_TIMEOUT = float(Config.get("strix_sandbox_connect_timeout") or "10")
@@ -110,14 +111,51 @@ async def _execute_tool_locally(tool_name: str, agent_state: Any | None, **kwarg
def validate_tool_availability(tool_name: str | None) -> tuple[bool, str]:
if tool_name is None:
return False, "Tool name is missing"
available = ", ".join(sorted(get_tool_names()))
return False, f"Tool name is missing. Available tools: {available}"
if tool_name not in get_tool_names():
return False, f"Tool '{tool_name}' is not available"
available = ", ".join(sorted(get_tool_names()))
return False, f"Tool '{tool_name}' is not available. Available tools: {available}"
return True, ""
def _validate_tool_arguments(tool_name: str, kwargs: dict[str, Any]) -> str | None:
param_schema = get_tool_param_schema(tool_name)
if not param_schema or not param_schema.get("has_params"):
return None
allowed_params: set[str] = param_schema.get("params", set())
required_params: set[str] = param_schema.get("required", set())
optional_params = allowed_params - required_params
schema_hint = _format_schema_hint(tool_name, required_params, optional_params)
unknown_params = set(kwargs.keys()) - allowed_params
if unknown_params:
unknown_list = ", ".join(sorted(unknown_params))
return f"Tool '{tool_name}' received unknown parameter(s): {unknown_list}\n{schema_hint}"
missing_required = [
param for param in required_params if param not in kwargs or kwargs.get(param) in (None, "")
]
if missing_required:
missing_list = ", ".join(sorted(missing_required))
return f"Tool '{tool_name}' missing required parameter(s): {missing_list}\n{schema_hint}"
return None
def _format_schema_hint(tool_name: str, required: set[str], optional: set[str]) -> str:
parts = [f"Valid parameters for '{tool_name}':"]
if required:
parts.append(f" Required: {', '.join(sorted(required))}")
if optional:
parts.append(f" Optional: {', '.join(sorted(optional))}")
return "\n".join(parts)
async def execute_tool_with_validation(
tool_name: str | None, agent_state: Any | None = None, **kwargs: Any
) -> Any:
@@ -127,6 +165,10 @@ async def execute_tool_with_validation(
assert tool_name is not None
arg_error = _validate_tool_arguments(tool_name, kwargs)
if arg_error:
return f"Error: {arg_error}"
try:
result = await execute_tool(tool_name, agent_state, **kwargs)
except Exception as e: # noqa: BLE001

View File

@@ -55,6 +55,7 @@
- Print statements and stdout are captured
- Variables persist between executions in the same session
- Imports, function definitions, etc. persist in the session
- IMPORTANT (multiline): Put real line breaks in <parameter=code>. Do NOT emit literal "\n" sequences.
- IPython magic commands are fully supported (%pip, %time, %whos, %%writefile, etc.)
- Line magics (%) and cell magics (%%) work as expected
6. CLOSE: Terminates the session completely and frees memory
@@ -73,6 +74,14 @@
print("Security analysis session started")</parameter>
</function>
<function=python_action>
<parameter=action>execute</parameter>
<parameter=code>import requests
url = "https://example.com"
resp = requests.get(url, timeout=10)
print(resp.status_code)</parameter>
</function>
# Analyze security data in the default session
<function=python_action>
<parameter=action>execute</parameter>

View File

@@ -7,9 +7,14 @@ from inspect import signature
from pathlib import Path
from typing import Any
import defusedxml.ElementTree as DefusedET
from strix.utils.resource_paths import get_strix_resource_path
tools: list[dict[str, Any]] = []
_tools_by_name: dict[str, Callable[..., Any]] = {}
_tool_param_schemas: dict[str, dict[str, Any]] = {}
logger = logging.getLogger(__name__)
@@ -82,6 +87,34 @@ def _load_xml_schema(path: Path) -> Any:
return tools_dict
def _parse_param_schema(tool_xml: str) -> dict[str, Any]:
params: set[str] = set()
required: set[str] = set()
params_start = tool_xml.find("<parameters>")
params_end = tool_xml.find("</parameters>")
if params_start == -1 or params_end == -1:
return {"params": set(), "required": set(), "has_params": False}
params_section = tool_xml[params_start : params_end + len("</parameters>")]
try:
root = DefusedET.fromstring(params_section)
except DefusedET.ParseError:
return {"params": set(), "required": set(), "has_params": False}
for param in root.findall(".//parameter"):
name = param.attrib.get("name")
if not name:
continue
params.add(name)
if param.attrib.get("required", "false").lower() == "true":
required.add(name)
return {"params": params, "required": required, "has_params": bool(params or required)}
def _get_module_name(func: Callable[..., Any]) -> str:
module = inspect.getmodule(func)
if not module:
@@ -95,6 +128,27 @@ def _get_module_name(func: Callable[..., Any]) -> str:
return "unknown"
def _get_schema_path(func: Callable[..., Any]) -> Path | None:
module = inspect.getmodule(func)
if not module or not module.__name__:
return None
module_name = module.__name__
if ".tools." not in module_name:
return None
parts = module_name.split(".tools.")[-1].split(".")
if len(parts) < 2:
return None
folder = parts[0]
file_stem = parts[1]
schema_file = f"{file_stem}_schema.xml"
return get_strix_resource_path("tools", folder, schema_file)
def register_tool(
func: Callable[..., Any] | None = None, *, sandbox_execution: bool = True
) -> Callable[..., Any]:
@@ -109,11 +163,8 @@ def register_tool(
sandbox_mode = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
if not sandbox_mode:
try:
module_path = Path(inspect.getfile(f))
schema_file_name = f"{module_path.stem}_schema.xml"
schema_path = module_path.parent / schema_file_name
xml_tools = _load_xml_schema(schema_path)
schema_path = _get_schema_path(f)
xml_tools = _load_xml_schema(schema_path) if schema_path else None
if xml_tools is not None and f.__name__ in xml_tools:
func_dict["xml_schema"] = xml_tools[f.__name__]
@@ -131,6 +182,11 @@ def register_tool(
"</tool>"
)
if not sandbox_mode:
xml_schema = func_dict.get("xml_schema")
param_schema = _parse_param_schema(xml_schema if isinstance(xml_schema, str) else "")
_tool_param_schemas[str(func_dict["name"])] = param_schema
tools.append(func_dict)
_tools_by_name[str(func_dict["name"])] = f
@@ -153,6 +209,10 @@ def get_tool_names() -> list[str]:
return list(_tools_by_name.keys())
def get_tool_param_schema(name: str) -> dict[str, Any] | None:
return _tool_param_schemas.get(name)
def needs_agent_state(tool_name: str) -> bool:
tool_func = get_tool_by_name(tool_name)
if not tool_func:
@@ -194,3 +254,4 @@ def get_tools_prompt() -> str:
def clear_registry() -> None:
tools.clear()
_tools_by_name.clear()
_tool_param_schemas.clear()

View File

@@ -95,6 +95,12 @@
<parameter=command>ls -la</parameter>
</function>
<function=terminal_execute>
<parameter=command>cd /workspace
pwd
ls -la</parameter>
</function>
# Run a command with custom timeout
<function=terminal_execute>
<parameter=command>npm install</parameter>

0
strix/utils/__init__.py Normal file
View File

View File

@@ -0,0 +1,13 @@
import sys
from pathlib import Path
def get_strix_resource_path(*parts: str) -> Path:
frozen_base = getattr(sys, "_MEIPASS", None)
if frozen_base:
base = Path(frozen_base) / "strix"
if base.exists():
return base.joinpath(*parts)
base = Path(__file__).resolve().parent.parent
return base.joinpath(*parts)