Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e5104eb93a | ||
|
|
d8a08e9a8c | ||
|
|
f6475cec07 | ||
|
|
31baa0dfc0 | ||
|
|
56526cbf90 | ||
|
|
47faeb1ef3 | ||
|
|
435ac82d9e | ||
|
|
f08014cf51 | ||
|
|
bc8e14f68a | ||
|
|
eae2b783c0 | ||
|
|
058cf1abdb | ||
|
|
d16bdb277a |
79
README.md
79
README.md
@@ -1,55 +1,61 @@
|
||||
<p align="center">
|
||||
<a href="https://strix.ai/">
|
||||
<img src=".github/logo.png" width="150" alt="Strix Logo">
|
||||
<img src="https://github.com/usestrix/.github/raw/main/imgs/cover.png" alt="Strix Banner" width="100%">
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<h1 align="center">Strix</h1>
|
||||
|
||||
<h2 align="center">Open-source AI Hackers to secure your Apps</h2>
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://pypi.org/project/strix-agent/)
|
||||
[](https://pypi.org/project/strix-agent/)
|
||||
[](LICENSE)
|
||||
[](https://docs.strix.ai)
|
||||
# Strix
|
||||
|
||||
[](https://github.com/usestrix/strix)
|
||||
[](https://discord.gg/YjKFvEZSdZ)
|
||||
[](https://strix.ai)
|
||||
### Open-source AI hackers to find and fix your app’s vulnerabilities.
|
||||
|
||||
<a href="https://trendshift.io/repositories/15362" target="_blank"><img src="https://trendshift.io/api/badge/repositories/15362" alt="usestrix%2Fstrix | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
<br/>
|
||||
|
||||
|
||||
[](https://deepwiki.com/usestrix/strix)
|
||||
<a href="https://docs.strix.ai"><img src="https://img.shields.io/badge/Docs-docs.strix.ai-2b9246?style=for-the-badge&logo=gitbook&logoColor=white" alt="Docs"></a>
|
||||
<a href="https://strix.ai"><img src="https://img.shields.io/badge/Website-strix.ai-3b82f6?style=for-the-badge&logoColor=white" alt="Website"></a>
|
||||
<a href="https://pypi.org/project/strix-agent/"><img src="https://img.shields.io/badge/PyPI-strix--agent-f59e0b?style=for-the-badge&logo=pypi&logoColor=white" alt="PyPI"></a>
|
||||
|
||||
<a href="https://deepwiki.com/usestrix/strix"><img src="https://deepwiki.com/badge.svg" alt="Ask DeepWiki"></a>
|
||||
<a href="https://github.com/usestrix/strix"><img src="https://img.shields.io/github/stars/usestrix/strix?style=flat-square" alt="GitHub Stars"></a>
|
||||
<a href="LICENSE"><img src="https://img.shields.io/badge/License-Apache%202.0-3b82f6?style=flat-square" alt="License"></a>
|
||||
<a href="https://pypi.org/project/strix-agent/"><img src="https://img.shields.io/pypi/v/strix-agent?style=flat-square" alt="PyPI Version"></a>
|
||||
|
||||
|
||||
<a href="https://discord.gg/YjKFvEZSdZ"><img src="https://github.com/usestrix/.github/raw/main/imgs/Discord.png" height="40" alt="Join Discord"></a>
|
||||
<a href="https://x.com/strix_ai"><img src="https://github.com/usestrix/.github/raw/main/imgs/X.png" height="40" alt="Follow on X"></a>
|
||||
|
||||
|
||||
<a href="https://trendshift.io/repositories/15362" target="_blank"><img src="https://trendshift.io/api/badge/repositories/15362" alt="usestrix/strix | Trendshift" width="250" height="55"/></a>
|
||||
|
||||
</div>
|
||||
|
||||
<br>
|
||||
<br/>
|
||||
|
||||
<div align="center">
|
||||
<img src=".github/screenshot.png" alt="Strix Demo" width="800" style="border-radius: 16px;">
|
||||
<img src=".github/screenshot.png" alt="Strix Demo" width="900" style="border-radius: 16px;">
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
> [!TIP]
|
||||
> **New!** Strix now integrates seamlessly with GitHub Actions and CI/CD pipelines. Automatically scan for vulnerabilities on every pull request and block insecure code before it reaches production!
|
||||
> **New!** Strix integrates seamlessly with GitHub Actions and CI/CD pipelines. Automatically scan for vulnerabilities on every pull request and block insecure code before it reaches production!
|
||||
|
||||
---
|
||||
|
||||
## 🦉 Strix Overview
|
||||
|
||||
## Strix Overview
|
||||
|
||||
Strix are autonomous AI agents that act just like real hackers - they run your code dynamically, find vulnerabilities, and validate them through actual proof-of-concepts. Built for developers and security teams who need fast, accurate security testing without the overhead of manual pentesting or the false positives of static analysis tools.
|
||||
|
||||
**Key Capabilities:**
|
||||
|
||||
- 🔧 **Full hacker toolkit** out of the box
|
||||
- 🤝 **Teams of agents** that collaborate and scale
|
||||
- ✅ **Real validation** with PoCs, not false positives
|
||||
- 💻 **Developer‑first** CLI with actionable reports
|
||||
- 🔄 **Auto‑fix & reporting** to accelerate remediation
|
||||
- **Full hacker toolkit** out of the box
|
||||
- **Teams of agents** that collaborate and scale
|
||||
- **Real validation** with PoCs, not false positives
|
||||
- **Developer‑first** CLI with actionable reports
|
||||
- **Auto‑fix & reporting** to accelerate remediation
|
||||
|
||||
|
||||
## 🎯 Use Cases
|
||||
@@ -87,7 +93,7 @@ strix --target ./app-directory
|
||||
> [!NOTE]
|
||||
> First run automatically pulls the sandbox Docker image. Results are saved to `strix_runs/<run-name>`
|
||||
|
||||
## ☁️ Run Strix in Cloud
|
||||
## Run Strix in Cloud
|
||||
|
||||
Want to skip the local setup, API keys, and unpredictable LLM costs? Run the hosted cloud version of Strix at **[app.strix.ai](https://strix.ai)**.
|
||||
|
||||
@@ -104,7 +110,7 @@ Launch a scan in just a few minutes—no setup or configuration required—and y
|
||||
|
||||
## ✨ Features
|
||||
|
||||
### 🛠️ Agentic Security Tools
|
||||
### Agentic Security Tools
|
||||
|
||||
Strix agents come equipped with a comprehensive security testing toolkit:
|
||||
|
||||
@@ -116,7 +122,7 @@ Strix agents come equipped with a comprehensive security testing toolkit:
|
||||
- **Code Analysis** - Static and dynamic analysis capabilities
|
||||
- **Knowledge Management** - Structured findings and attack documentation
|
||||
|
||||
### 🎯 Comprehensive Vulnerability Detection
|
||||
### Comprehensive Vulnerability Detection
|
||||
|
||||
Strix can identify and validate a wide range of security vulnerabilities:
|
||||
|
||||
@@ -128,7 +134,7 @@ Strix can identify and validate a wide range of security vulnerabilities:
|
||||
- **Authentication** - JWT vulnerabilities, session management
|
||||
- **Infrastructure** - Misconfigurations, exposed services
|
||||
|
||||
### 🕸️ Graph of Agents
|
||||
### Graph of Agents
|
||||
|
||||
Advanced multi-agent orchestration for comprehensive security testing:
|
||||
|
||||
@@ -138,7 +144,7 @@ Advanced multi-agent orchestration for comprehensive security testing:
|
||||
|
||||
---
|
||||
|
||||
## 💻 Usage Examples
|
||||
## Usage Examples
|
||||
|
||||
### Basic Usage
|
||||
|
||||
@@ -169,7 +175,7 @@ strix --target api.your-app.com --instruction "Focus on business logic flaws and
|
||||
strix --target api.your-app.com --instruction-file ./instruction.md
|
||||
```
|
||||
|
||||
### 🤖 Headless Mode
|
||||
### Headless Mode
|
||||
|
||||
Run Strix programmatically without interactive UI using the `-n/--non-interactive` flag—perfect for servers and automated jobs. The CLI prints real-time vulnerability findings, and the final report before exiting. Exits with non-zero code when vulnerabilities are found.
|
||||
|
||||
@@ -177,7 +183,7 @@ Run Strix programmatically without interactive UI using the `-n/--non-interactiv
|
||||
strix -n --target https://your-app.com
|
||||
```
|
||||
|
||||
### 🔄 CI/CD (GitHub Actions)
|
||||
### CI/CD (GitHub Actions)
|
||||
|
||||
Strix can be added to your pipeline to run a security test on pull requests with a lightweight GitHub Actions workflow:
|
||||
|
||||
@@ -204,7 +210,7 @@ jobs:
|
||||
run: strix -n -t ./ --scan-mode quick
|
||||
```
|
||||
|
||||
### ⚙️ Configuration
|
||||
### Configuration
|
||||
|
||||
```bash
|
||||
export STRIX_LLM="openai/gpt-5"
|
||||
@@ -227,22 +233,23 @@ export STRIX_REASONING_EFFORT="high" # control thinking effort (default: high,
|
||||
|
||||
See the [LLM Providers documentation](https://docs.strix.ai/llm-providers/overview) for all supported providers including Vertex AI, Bedrock, Azure, and local models.
|
||||
|
||||
## 📚 Documentation
|
||||
## Documentation
|
||||
|
||||
Full documentation is available at **[docs.strix.ai](https://docs.strix.ai)** — including detailed guides for usage, CI/CD integrations, skills, and advanced configuration.
|
||||
|
||||
## 🤝 Contributing
|
||||
## Contributing
|
||||
|
||||
We welcome contributions of code, docs, and new skills - check out our [Contributing Guide](https://docs.strix.ai/contributing) to get started or open a [pull request](https://github.com/usestrix/strix/pulls)/[issue](https://github.com/usestrix/strix/issues).
|
||||
|
||||
## 👥 Join Our Community
|
||||
## Join Our Community
|
||||
|
||||
Have questions? Found a bug? Want to contribute? **[Join our Discord!](https://discord.gg/YjKFvEZSdZ)**
|
||||
|
||||
## 🌟 Support the Project
|
||||
## Support the Project
|
||||
|
||||
**Love Strix?** Give us a ⭐ on GitHub!
|
||||
## 🙏 Acknowledgements
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
Strix builds on the incredible work of open-source projects like [LiteLLM](https://github.com/BerriAI/litellm), [Caido](https://github.com/caido/caido), [ProjectDiscovery](https://github.com/projectdiscovery), [Playwright](https://github.com/microsoft/playwright), and [Textual](https://github.com/Textualize/textual). Huge thanks to their maintainers!
|
||||
|
||||
|
||||
35
poetry.lock
generated
35
poetry.lock
generated
@@ -379,19 +379,18 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "azure-core"
|
||||
version = "1.35.0"
|
||||
version = "1.38.0"
|
||||
description = "Microsoft Azure Core Library for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "azure_core-1.35.0-py3-none-any.whl", hash = "sha256:8db78c72868a58f3de8991eb4d22c4d368fae226dac1002998d6c50437e7dad1"},
|
||||
{file = "azure_core-1.35.0.tar.gz", hash = "sha256:c0be528489485e9ede59b6971eb63c1eaacf83ef53001bfe3904e475e972be5c"},
|
||||
{file = "azure_core-1.38.0-py3-none-any.whl", hash = "sha256:ab0c9b2cd71fecb1842d52c965c95285d3cfb38902f6766e4a471f1cd8905335"},
|
||||
{file = "azure_core-1.38.0.tar.gz", hash = "sha256:8194d2682245a3e4e3151a667c686464c3786fed7918b394d035bdcd61bb5993"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
requests = ">=2.21.0"
|
||||
six = ">=1.11.0"
|
||||
typing-extensions = ">=4.6.0"
|
||||
|
||||
[package.extras]
|
||||
@@ -1199,6 +1198,18 @@ files = [
|
||||
{file = "decorator-5.2.1.tar.gz", hash = "sha256:65f266143752f734b0a7cc83c46f4618af75b8c5911b00ccb61d0ac9b6da0360"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "defusedxml"
|
||||
version = "0.7.1"
|
||||
description = "XML bomb protection for Python stdlib modules"
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61"},
|
||||
{file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dill"
|
||||
version = "0.4.0"
|
||||
@@ -1490,14 +1501,14 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "filelock"
|
||||
version = "3.20.1"
|
||||
version = "3.20.3"
|
||||
description = "A platform independent file lock."
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main", "dev"]
|
||||
files = [
|
||||
{file = "filelock-3.20.1-py3-none-any.whl", hash = "sha256:15d9e9a67306188a44baa72f569d2bfd803076269365fdea0934385da4dc361a"},
|
||||
{file = "filelock-3.20.1.tar.gz", hash = "sha256:b8360948b351b80f420878d8516519a2204b07aefcdcfd24912a5d33127f188c"},
|
||||
{file = "filelock-3.20.3-py3-none-any.whl", hash = "sha256:4b0dda527ee31078689fc205ec4f1c1bf7d56cf88b6dc9426c4f230e46c2dce1"},
|
||||
{file = "filelock-3.20.3.tar.gz", hash = "sha256:18c57ee915c7ec61cff0ecf7f0f869936c7c30191bb0cf406f1341778d0834e1"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -7095,19 +7106,19 @@ test = ["aiohttp (>=3.10.5)", "flake8 (>=5.0,<6.0)", "mypy (>=0.800)", "psutil",
|
||||
|
||||
[[package]]
|
||||
name = "virtualenv"
|
||||
version = "20.34.0"
|
||||
version = "20.36.1"
|
||||
description = "Virtual Python Environment builder"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "virtualenv-20.34.0-py3-none-any.whl", hash = "sha256:341f5afa7eee943e4984a9207c025feedd768baff6753cd660c857ceb3e36026"},
|
||||
{file = "virtualenv-20.34.0.tar.gz", hash = "sha256:44815b2c9dee7ed86e387b842a84f20b93f7f417f95886ca1996a72a4138eb1a"},
|
||||
{file = "virtualenv-20.36.1-py3-none-any.whl", hash = "sha256:575a8d6b124ef88f6f51d56d656132389f961062a9177016a50e4f507bbcc19f"},
|
||||
{file = "virtualenv-20.36.1.tar.gz", hash = "sha256:8befb5c81842c641f8ee658481e42641c68b5eab3521d8e092d18320902466ba"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
distlib = ">=0.3.7,<1"
|
||||
filelock = ">=3.12.2,<4"
|
||||
filelock = {version = ">=3.20.1,<4", markers = "python_version >= \"3.10\""}
|
||||
platformdirs = ">=3.9.1,<5"
|
||||
|
||||
[package.extras]
|
||||
@@ -7424,4 +7435,4 @@ vertex = ["google-cloud-aiplatform"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = "^3.12"
|
||||
content-hash = "91f49e313e5690bbef87e17730441f26d366daeccb16b5020e03e581fbb9d4d5"
|
||||
content-hash = "0424a0e82fe49501f3a80166676e257a9dae97093d9bc730489789195f523735"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "strix-agent"
|
||||
version = "0.6.0"
|
||||
version = "0.6.1"
|
||||
description = "Open-source AI Hackers for your apps"
|
||||
authors = ["Strix <hi@usestrix.com>"]
|
||||
readme = "README.md"
|
||||
@@ -69,6 +69,7 @@ gql = { version = "^3.5.3", extras = ["requests"], optional = true }
|
||||
pyte = { version = "^0.8.1", optional = true }
|
||||
libtmux = { version = "^0.46.2", optional = true }
|
||||
numpydoc = { version = "^1.8.0", optional = true }
|
||||
defusedxml = "^0.7.1"
|
||||
|
||||
[tool.poetry.extras]
|
||||
vertex = ["google-cloud-aiplatform"]
|
||||
|
||||
@@ -111,7 +111,6 @@ hiddenimports = [
|
||||
'strix.llm.llm',
|
||||
'strix.llm.config',
|
||||
'strix.llm.utils',
|
||||
'strix.llm.request_queue',
|
||||
'strix.llm.memory_compressor',
|
||||
'strix.runtime',
|
||||
'strix.runtime.runtime',
|
||||
|
||||
@@ -308,17 +308,18 @@ Tool calls use XML format:
|
||||
|
||||
CRITICAL RULES:
|
||||
0. While active in the agent loop, EVERY message you output MUST be a single tool call. Do not send plain text-only responses.
|
||||
1. One tool call per message
|
||||
1. Exactly one tool call per message — never include more than one <function>...</function> block in a single LLM message.
|
||||
2. Tool call must be last in message
|
||||
3. End response after </function> tag. It's your stop word. Do not continue after it.
|
||||
3. EVERY tool call MUST end with </function>. This is MANDATORY. Never omit the closing tag. End your response immediately after </function>.
|
||||
4. Use ONLY the exact XML format shown above. NEVER use JSON/YAML/INI or any other syntax for tools or parameters.
|
||||
5. Tool names must match exactly the tool "name" defined (no module prefixes, dots, or variants).
|
||||
5. When sending ANY multi-line content in tool parameters, use real newlines (actual line breaks). Do NOT emit literal "\n" sequences. If you send "\n" instead of real line breaks inside the XML parameter value, tools may fail or behave incorrectly.
|
||||
6. Tool names must match exactly the tool "name" defined (no module prefixes, dots, or variants).
|
||||
- Correct: <function=think> ... </function>
|
||||
- Incorrect: <thinking_tools.think> ... </function>
|
||||
- Incorrect: <think> ... </think>
|
||||
- Incorrect: {"think": {...}}
|
||||
6. Parameters must use <parameter=param_name>value</parameter> exactly. Do NOT pass parameters as JSON or key:value lines. Do NOT add quotes/braces around values.
|
||||
7. Do NOT wrap tool calls in markdown/code fences or add any text before or after the tool block.
|
||||
7. Parameters must use <parameter=param_name>value</parameter> exactly. Do NOT pass parameters as JSON or key:value lines. Do NOT add quotes/braces around values.
|
||||
8. Do NOT wrap tool calls in markdown/code fences or add any text before or after the tool block.
|
||||
|
||||
Example (agent creation tool):
|
||||
<function=create_agent>
|
||||
@@ -331,6 +332,8 @@ SPRAYING EXECUTION NOTE:
|
||||
- When performing large payload sprays or fuzzing, encapsulate the entire spraying loop inside a single python or terminal tool call (e.g., a Python script using asyncio/aiohttp). Do not issue one tool call per payload.
|
||||
- Favor batch-mode CLI tools (sqlmap, ffuf, nuclei, zaproxy, arjun) where appropriate and check traffic via the proxy when beneficial
|
||||
|
||||
REMINDER: Always close each tool call with </function> before going into the next. Incomplete tool calls will fail.
|
||||
|
||||
{{ get_tools_prompt() }}
|
||||
</tool_usage>
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
|
||||
@@ -18,6 +17,7 @@ from strix.llm import LLM, LLMConfig, LLMRequestFailedError
|
||||
from strix.llm.utils import clean_content
|
||||
from strix.runtime import SandboxInitializationError
|
||||
from strix.tools import process_tool_invocations
|
||||
from strix.utils.resource_paths import get_strix_resource_path
|
||||
|
||||
from .state import AgentState
|
||||
|
||||
@@ -35,8 +35,7 @@ class AgentMeta(type):
|
||||
if name == "BaseAgent":
|
||||
return new_cls
|
||||
|
||||
agents_dir = Path(__file__).parent
|
||||
prompt_dir = agents_dir / name
|
||||
prompt_dir = get_strix_resource_path("agents", name)
|
||||
|
||||
new_cls.agent_name = name
|
||||
new_cls.jinja_env = Environment(
|
||||
@@ -66,20 +65,21 @@ class BaseAgent(metaclass=AgentMeta):
|
||||
self.llm_config = config.get("llm_config", self.default_llm_config)
|
||||
if self.llm_config is None:
|
||||
raise ValueError("llm_config is required but not provided")
|
||||
self.llm = LLM(self.llm_config, agent_name=self.agent_name)
|
||||
|
||||
state_from_config = config.get("state")
|
||||
if state_from_config is not None:
|
||||
self.state = state_from_config
|
||||
else:
|
||||
self.state = AgentState(
|
||||
agent_name=self.agent_name,
|
||||
agent_name="Root Agent",
|
||||
max_iterations=self.max_iterations,
|
||||
)
|
||||
|
||||
self.llm = LLM(self.llm_config, agent_name=self.agent_name)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
self.llm.set_agent_identity(self.agent_name, self.state.agent_id)
|
||||
self.llm.set_agent_identity(self.state.agent_name, self.state.agent_id)
|
||||
self._current_task: asyncio.Task[Any] | None = None
|
||||
self._force_stop = False
|
||||
|
||||
from strix.telemetry.tracer import get_global_tracer
|
||||
|
||||
@@ -157,6 +157,11 @@ class BaseAgent(metaclass=AgentMeta):
|
||||
return self._handle_sandbox_error(e, tracer)
|
||||
|
||||
while True:
|
||||
if self._force_stop:
|
||||
self._force_stop = False
|
||||
await self._enter_waiting_state(tracer, was_cancelled=True)
|
||||
continue
|
||||
|
||||
self._check_agent_messages(self.state)
|
||||
|
||||
if self.state.is_waiting_for_input():
|
||||
@@ -247,7 +252,8 @@ class BaseAgent(metaclass=AgentMeta):
|
||||
continue
|
||||
|
||||
async def _wait_for_input(self) -> None:
|
||||
import asyncio
|
||||
if self._force_stop:
|
||||
return
|
||||
|
||||
if self.state.has_waiting_timeout():
|
||||
self.state.resume_from_waiting()
|
||||
@@ -340,6 +346,7 @@ class BaseAgent(metaclass=AgentMeta):
|
||||
|
||||
async def _process_iteration(self, tracer: Optional["Tracer"]) -> bool:
|
||||
final_response = None
|
||||
|
||||
async for response in self.llm.generate(self.state.get_conversation_history()):
|
||||
final_response = response
|
||||
if tracer and response.content:
|
||||
@@ -585,6 +592,11 @@ class BaseAgent(metaclass=AgentMeta):
|
||||
return True
|
||||
|
||||
def cancel_current_execution(self) -> None:
|
||||
self._force_stop = True
|
||||
if self._current_task and not self._current_task.done():
|
||||
self._current_task.cancel()
|
||||
try:
|
||||
loop = self._current_task.get_loop()
|
||||
loop.call_soon_threadsafe(self._current_task.cancel)
|
||||
except RuntimeError:
|
||||
self._current_task.cancel()
|
||||
self._current_task = None
|
||||
|
||||
@@ -16,9 +16,9 @@ class Config:
|
||||
litellm_base_url = None
|
||||
ollama_api_base = None
|
||||
strix_reasoning_effort = "high"
|
||||
strix_llm_max_retries = "5"
|
||||
strix_memory_compressor_timeout = "30"
|
||||
llm_timeout = "300"
|
||||
llm_rate_limit_delay = "4.0"
|
||||
llm_rate_limit_concurrent = "1"
|
||||
|
||||
# Tool & Feature Configuration
|
||||
perplexity_api_key = None
|
||||
@@ -27,7 +27,7 @@ class Config:
|
||||
# Runtime Configuration
|
||||
strix_image = "ghcr.io/usestrix/strix-sandbox:0.1.10"
|
||||
strix_runtime_backend = "docker"
|
||||
strix_sandbox_execution_timeout = "500"
|
||||
strix_sandbox_execution_timeout = "120"
|
||||
strix_sandbox_connect_timeout = "10"
|
||||
|
||||
# Telemetry
|
||||
|
||||
601
strix/llm/llm.py
601
strix/llm/llm.py
@@ -1,60 +1,29 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
from jinja2 import (
|
||||
Environment,
|
||||
FileSystemLoader,
|
||||
select_autoescape,
|
||||
)
|
||||
from litellm import completion_cost, stream_chunk_builder, supports_reasoning
|
||||
from jinja2 import Environment, FileSystemLoader, select_autoescape
|
||||
from litellm import acompletion, completion_cost, stream_chunk_builder, supports_reasoning
|
||||
from litellm.utils import supports_prompt_caching, supports_vision
|
||||
|
||||
from strix.config import Config
|
||||
from strix.llm.config import LLMConfig
|
||||
from strix.llm.memory_compressor import MemoryCompressor
|
||||
from strix.llm.request_queue import get_global_queue
|
||||
from strix.llm.utils import _truncate_to_first_function, parse_tool_invocations
|
||||
from strix.llm.utils import (
|
||||
_truncate_to_first_function,
|
||||
fix_incomplete_tool_call,
|
||||
parse_tool_invocations,
|
||||
)
|
||||
from strix.skills import load_skills
|
||||
from strix.tools import get_tools_prompt
|
||||
from strix.utils.resource_paths import get_strix_resource_path
|
||||
|
||||
|
||||
MAX_RETRIES = 5
|
||||
RETRY_MULTIPLIER = 8
|
||||
RETRY_MIN = 8
|
||||
RETRY_MAX = 64
|
||||
|
||||
|
||||
def _should_retry(exception: Exception) -> bool:
|
||||
status_code = None
|
||||
if hasattr(exception, "status_code"):
|
||||
status_code = exception.status_code
|
||||
elif hasattr(exception, "response") and hasattr(exception.response, "status_code"):
|
||||
status_code = exception.response.status_code
|
||||
if status_code is not None:
|
||||
return bool(litellm._should_retry(status_code))
|
||||
return True
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
litellm.drop_params = True
|
||||
litellm.modify_params = True
|
||||
|
||||
_LLM_API_KEY = Config.get("llm_api_key")
|
||||
_LLM_API_BASE = (
|
||||
Config.get("llm_api_base")
|
||||
or Config.get("openai_api_base")
|
||||
or Config.get("litellm_base_url")
|
||||
or Config.get("ollama_api_base")
|
||||
)
|
||||
_STRIX_REASONING_EFFORT = Config.get("strix_reasoning_effort")
|
||||
|
||||
|
||||
class LLMRequestFailedError(Exception):
|
||||
def __init__(self, message: str, details: str | None = None):
|
||||
@@ -63,20 +32,11 @@ class LLMRequestFailedError(Exception):
|
||||
self.details = details
|
||||
|
||||
|
||||
class StepRole(str, Enum):
|
||||
AGENT = "agent"
|
||||
USER = "user"
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
content: str
|
||||
tool_invocations: list[dict[str, Any]] | None = None
|
||||
scan_id: str | None = None
|
||||
step_number: int = 1
|
||||
role: StepRole = StepRole.AGENT
|
||||
thinking_blocks: list[dict[str, Any]] | None = None # For reasoning models.
|
||||
thinking_blocks: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -84,76 +44,63 @@ class RequestStats:
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cached_tokens: int = 0
|
||||
cache_creation_tokens: int = 0
|
||||
cost: float = 0.0
|
||||
requests: int = 0
|
||||
failed_requests: int = 0
|
||||
|
||||
def to_dict(self) -> dict[str, int | float]:
|
||||
return {
|
||||
"input_tokens": self.input_tokens,
|
||||
"output_tokens": self.output_tokens,
|
||||
"cached_tokens": self.cached_tokens,
|
||||
"cache_creation_tokens": self.cache_creation_tokens,
|
||||
"cost": round(self.cost, 4),
|
||||
"requests": self.requests,
|
||||
"failed_requests": self.failed_requests,
|
||||
}
|
||||
|
||||
|
||||
class LLM:
|
||||
def __init__(
|
||||
self, config: LLMConfig, agent_name: str | None = None, agent_id: str | None = None
|
||||
):
|
||||
def __init__(self, config: LLMConfig, agent_name: str | None = None):
|
||||
self.config = config
|
||||
self.agent_name = agent_name
|
||||
self.agent_id = agent_id
|
||||
self.agent_id: str | None = None
|
||||
self._total_stats = RequestStats()
|
||||
self._last_request_stats = RequestStats()
|
||||
self.memory_compressor = MemoryCompressor(model_name=config.model_name)
|
||||
self.system_prompt = self._load_system_prompt(agent_name)
|
||||
|
||||
if _STRIX_REASONING_EFFORT:
|
||||
self._reasoning_effort = _STRIX_REASONING_EFFORT
|
||||
elif self.config.scan_mode == "quick":
|
||||
reasoning = Config.get("strix_reasoning_effort")
|
||||
if reasoning:
|
||||
self._reasoning_effort = reasoning
|
||||
elif config.scan_mode == "quick":
|
||||
self._reasoning_effort = "medium"
|
||||
else:
|
||||
self._reasoning_effort = "high"
|
||||
|
||||
self.memory_compressor = MemoryCompressor(
|
||||
model_name=self.config.model_name,
|
||||
timeout=self.config.timeout,
|
||||
)
|
||||
def _load_system_prompt(self, agent_name: str | None) -> str:
|
||||
if not agent_name:
|
||||
return ""
|
||||
|
||||
if agent_name:
|
||||
prompt_dir = Path(__file__).parent.parent / "agents" / agent_name
|
||||
skills_dir = Path(__file__).parent.parent / "skills"
|
||||
|
||||
loader = FileSystemLoader([prompt_dir, skills_dir])
|
||||
self.jinja_env = Environment(
|
||||
loader=loader,
|
||||
try:
|
||||
prompt_dir = get_strix_resource_path("agents", agent_name)
|
||||
skills_dir = get_strix_resource_path("skills")
|
||||
env = Environment(
|
||||
loader=FileSystemLoader([prompt_dir, skills_dir]),
|
||||
autoescape=select_autoescape(enabled_extensions=(), default_for_string=False),
|
||||
)
|
||||
|
||||
try:
|
||||
skills_to_load = list(self.config.skills or [])
|
||||
skills_to_load.append(f"scan_modes/{self.config.scan_mode}")
|
||||
skills_to_load = [
|
||||
*list(self.config.skills or []),
|
||||
f"scan_modes/{self.config.scan_mode}",
|
||||
]
|
||||
skill_content = load_skills(skills_to_load, env)
|
||||
env.globals["get_skill"] = lambda name: skill_content.get(name, "")
|
||||
|
||||
skill_content = load_skills(skills_to_load, self.jinja_env)
|
||||
|
||||
def get_skill(name: str) -> str:
|
||||
return skill_content.get(name, "")
|
||||
|
||||
self.jinja_env.globals["get_skill"] = get_skill
|
||||
|
||||
self.system_prompt = self.jinja_env.get_template("system_prompt.jinja").render(
|
||||
get_tools_prompt=get_tools_prompt,
|
||||
loaded_skill_names=list(skill_content.keys()),
|
||||
**skill_content,
|
||||
)
|
||||
except (FileNotFoundError, OSError, ValueError) as e:
|
||||
logger.warning(f"Failed to load system prompt for {agent_name}: {e}")
|
||||
self.system_prompt = "You are a helpful AI assistant."
|
||||
else:
|
||||
self.system_prompt = "You are a helpful AI assistant."
|
||||
result = env.get_template("system_prompt.jinja").render(
|
||||
get_tools_prompt=get_tools_prompt,
|
||||
loaded_skill_names=list(skill_content.keys()),
|
||||
**skill_content,
|
||||
)
|
||||
return str(result)
|
||||
except Exception: # noqa: BLE001
|
||||
return ""
|
||||
|
||||
def set_agent_identity(self, agent_name: str | None, agent_id: str | None) -> None:
|
||||
if agent_name:
|
||||
@@ -161,328 +108,121 @@ class LLM:
|
||||
if agent_id:
|
||||
self.agent_id = agent_id
|
||||
|
||||
def _build_identity_message(self) -> dict[str, Any] | None:
|
||||
if not (self.agent_name and str(self.agent_name).strip()):
|
||||
return None
|
||||
identity_name = self.agent_name
|
||||
identity_id = self.agent_id
|
||||
content = (
|
||||
"\n\n"
|
||||
"<agent_identity>\n"
|
||||
"<meta>Internal metadata: do not echo or reference; "
|
||||
"not part of history or tool calls.</meta>\n"
|
||||
"<note>You are now assuming the role of this agent. "
|
||||
"Act strictly as this agent and maintain self-identity for this step. "
|
||||
"Now go answer the next needed step!</note>\n"
|
||||
f"<agent_name>{identity_name}</agent_name>\n"
|
||||
f"<agent_id>{identity_id}</agent_id>\n"
|
||||
"</agent_identity>\n\n"
|
||||
)
|
||||
return {"role": "user", "content": content}
|
||||
async def generate(
|
||||
self, conversation_history: list[dict[str, Any]]
|
||||
) -> AsyncIterator[LLMResponse]:
|
||||
messages = self._prepare_messages(conversation_history)
|
||||
max_retries = int(Config.get("strix_llm_max_retries") or "5")
|
||||
|
||||
def _add_cache_control_to_content(
|
||||
self, content: str | list[dict[str, Any]]
|
||||
) -> str | list[dict[str, Any]]:
|
||||
if isinstance(content, str):
|
||||
return [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}]
|
||||
if isinstance(content, list) and content:
|
||||
last_item = content[-1]
|
||||
if isinstance(last_item, dict) and last_item.get("type") == "text":
|
||||
return content[:-1] + [{**last_item, "cache_control": {"type": "ephemeral"}}]
|
||||
return content
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
async for response in self._stream(messages):
|
||||
yield response
|
||||
return # noqa: TRY300
|
||||
except Exception as e: # noqa: BLE001
|
||||
if attempt >= max_retries or not self._should_retry(e):
|
||||
self._raise_error(e)
|
||||
wait = min(10, 2 * (2**attempt))
|
||||
await asyncio.sleep(wait)
|
||||
|
||||
def _is_anthropic_model(self) -> bool:
|
||||
if not self.config.model_name:
|
||||
return False
|
||||
model_lower = self.config.model_name.lower()
|
||||
return any(provider in model_lower for provider in ["anthropic/", "claude"])
|
||||
async def _stream(self, messages: list[dict[str, Any]]) -> AsyncIterator[LLMResponse]:
|
||||
accumulated = ""
|
||||
chunks: list[Any] = []
|
||||
|
||||
def _calculate_cache_interval(self, total_messages: int) -> int:
|
||||
if total_messages <= 1:
|
||||
return 10
|
||||
self._total_stats.requests += 1
|
||||
response = await acompletion(**self._build_completion_args(messages), stream=True)
|
||||
|
||||
max_cached_messages = 3
|
||||
non_system_messages = total_messages - 1
|
||||
|
||||
interval = 10
|
||||
while non_system_messages // interval > max_cached_messages:
|
||||
interval += 10
|
||||
|
||||
return interval
|
||||
|
||||
def _prepare_cached_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
if (
|
||||
not self.config.enable_prompt_caching
|
||||
or not supports_prompt_caching(self.config.model_name)
|
||||
or not messages
|
||||
):
|
||||
return messages
|
||||
|
||||
if not self._is_anthropic_model():
|
||||
return messages
|
||||
|
||||
cached_messages = list(messages)
|
||||
|
||||
if cached_messages and cached_messages[0].get("role") == "system":
|
||||
system_message = cached_messages[0].copy()
|
||||
system_message["content"] = self._add_cache_control_to_content(
|
||||
system_message["content"]
|
||||
)
|
||||
cached_messages[0] = system_message
|
||||
|
||||
total_messages = len(cached_messages)
|
||||
if total_messages > 1:
|
||||
interval = self._calculate_cache_interval(total_messages)
|
||||
|
||||
cached_count = 0
|
||||
for i in range(interval, total_messages, interval):
|
||||
if cached_count >= 3:
|
||||
async for chunk in response:
|
||||
chunks.append(chunk)
|
||||
delta = self._get_chunk_content(chunk)
|
||||
if delta:
|
||||
accumulated += delta
|
||||
if "</function>" in accumulated:
|
||||
accumulated = accumulated[
|
||||
: accumulated.find("</function>") + len("</function>")
|
||||
]
|
||||
yield LLMResponse(content=accumulated)
|
||||
break
|
||||
yield LLMResponse(content=accumulated)
|
||||
|
||||
if i < len(cached_messages):
|
||||
message = cached_messages[i].copy()
|
||||
message["content"] = self._add_cache_control_to_content(message["content"])
|
||||
cached_messages[i] = message
|
||||
cached_count += 1
|
||||
if chunks:
|
||||
self._update_usage_stats(stream_chunk_builder(chunks))
|
||||
|
||||
return cached_messages
|
||||
accumulated = fix_incomplete_tool_call(_truncate_to_first_function(accumulated))
|
||||
yield LLMResponse(
|
||||
content=accumulated,
|
||||
tool_invocations=parse_tool_invocations(accumulated),
|
||||
thinking_blocks=self._extract_thinking(chunks),
|
||||
)
|
||||
|
||||
def _prepare_messages(self, conversation_history: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
messages = [{"role": "system", "content": self.system_prompt}]
|
||||
|
||||
identity_message = self._build_identity_message()
|
||||
if identity_message:
|
||||
messages.append(identity_message)
|
||||
|
||||
compressed_history = list(self.memory_compressor.compress_history(conversation_history))
|
||||
if self.agent_name:
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"\n\n<agent_identity>\n"
|
||||
f"<meta>Internal metadata: do not echo or reference.</meta>\n"
|
||||
f"<agent_name>{self.agent_name}</agent_name>\n"
|
||||
f"<agent_id>{self.agent_id}</agent_id>\n"
|
||||
f"</agent_identity>\n\n"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
compressed = list(self.memory_compressor.compress_history(conversation_history))
|
||||
conversation_history.clear()
|
||||
conversation_history.extend(compressed_history)
|
||||
messages.extend(compressed_history)
|
||||
conversation_history.extend(compressed)
|
||||
messages.extend(compressed)
|
||||
|
||||
return self._prepare_cached_messages(messages)
|
||||
if self._is_anthropic() and self.config.enable_prompt_caching:
|
||||
messages = self._add_cache_control(messages)
|
||||
|
||||
async def _stream_and_accumulate(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
scan_id: str | None,
|
||||
step_number: int,
|
||||
) -> AsyncIterator[LLMResponse]:
|
||||
accumulated_content = ""
|
||||
chunks: list[Any] = []
|
||||
return messages
|
||||
|
||||
async for chunk in self._stream_request(messages):
|
||||
chunks.append(chunk)
|
||||
delta = self._extract_chunk_delta(chunk)
|
||||
if delta:
|
||||
accumulated_content += delta
|
||||
def _build_completion_args(self, messages: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
if not self._supports_vision():
|
||||
messages = self._strip_images(messages)
|
||||
|
||||
if "</function>" in accumulated_content:
|
||||
function_end = accumulated_content.find("</function>") + len("</function>")
|
||||
accumulated_content = accumulated_content[:function_end]
|
||||
|
||||
yield LLMResponse(
|
||||
scan_id=scan_id,
|
||||
step_number=step_number,
|
||||
role=StepRole.AGENT,
|
||||
content=accumulated_content,
|
||||
tool_invocations=None,
|
||||
)
|
||||
|
||||
if chunks:
|
||||
complete_response = stream_chunk_builder(chunks)
|
||||
self._update_usage_stats(complete_response)
|
||||
|
||||
accumulated_content = _truncate_to_first_function(accumulated_content)
|
||||
if "</function>" in accumulated_content:
|
||||
function_end = accumulated_content.find("</function>") + len("</function>")
|
||||
accumulated_content = accumulated_content[:function_end]
|
||||
|
||||
tool_invocations = parse_tool_invocations(accumulated_content)
|
||||
|
||||
# Extract thinking blocks from the complete response if available
|
||||
thinking_blocks = None
|
||||
if chunks and self._should_include_reasoning_effort():
|
||||
complete_response = stream_chunk_builder(chunks)
|
||||
if (
|
||||
hasattr(complete_response, "choices")
|
||||
and complete_response.choices
|
||||
and hasattr(complete_response.choices[0], "message")
|
||||
):
|
||||
message = complete_response.choices[0].message
|
||||
if hasattr(message, "thinking_blocks") and message.thinking_blocks:
|
||||
thinking_blocks = message.thinking_blocks
|
||||
|
||||
yield LLMResponse(
|
||||
scan_id=scan_id,
|
||||
step_number=step_number,
|
||||
role=StepRole.AGENT,
|
||||
content=accumulated_content,
|
||||
tool_invocations=tool_invocations if tool_invocations else None,
|
||||
thinking_blocks=thinking_blocks,
|
||||
)
|
||||
|
||||
def _raise_llm_error(self, e: Exception) -> None:
|
||||
error_map: list[tuple[type, str]] = [
|
||||
(litellm.RateLimitError, "Rate limit exceeded"),
|
||||
(litellm.AuthenticationError, "Invalid API key"),
|
||||
(litellm.NotFoundError, "Model not found"),
|
||||
(litellm.ContextWindowExceededError, "Context too long"),
|
||||
(litellm.ContentPolicyViolationError, "Content policy violation"),
|
||||
(litellm.ServiceUnavailableError, "Service unavailable"),
|
||||
(litellm.Timeout, "Request timed out"),
|
||||
(litellm.UnprocessableEntityError, "Unprocessable entity"),
|
||||
(litellm.InternalServerError, "Internal server error"),
|
||||
(litellm.APIConnectionError, "Connection error"),
|
||||
(litellm.UnsupportedParamsError, "Unsupported parameters"),
|
||||
(litellm.BudgetExceededError, "Budget exceeded"),
|
||||
(litellm.APIResponseValidationError, "Response validation error"),
|
||||
(litellm.JSONSchemaValidationError, "JSON schema validation error"),
|
||||
(litellm.InvalidRequestError, "Invalid request"),
|
||||
(litellm.BadRequestError, "Bad request"),
|
||||
(litellm.APIError, "API error"),
|
||||
(litellm.OpenAIError, "OpenAI error"),
|
||||
]
|
||||
|
||||
from strix.telemetry import posthog
|
||||
|
||||
for error_type, message in error_map:
|
||||
if isinstance(e, error_type):
|
||||
posthog.error(f"llm_{error_type.__name__}", message)
|
||||
raise LLMRequestFailedError(f"LLM request failed: {message}", str(e)) from e
|
||||
|
||||
posthog.error("llm_unknown_error", type(e).__name__)
|
||||
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
conversation_history: list[dict[str, Any]],
|
||||
scan_id: str | None = None,
|
||||
step_number: int = 1,
|
||||
) -> AsyncIterator[LLMResponse]:
|
||||
messages = self._prepare_messages(conversation_history)
|
||||
|
||||
last_error: Exception | None = None
|
||||
for attempt in range(MAX_RETRIES):
|
||||
try:
|
||||
async for response in self._stream_and_accumulate(messages, scan_id, step_number):
|
||||
yield response
|
||||
return # noqa: TRY300
|
||||
except Exception as e: # noqa: BLE001
|
||||
last_error = e
|
||||
if not _should_retry(e) or attempt == MAX_RETRIES - 1:
|
||||
break
|
||||
wait_time = min(RETRY_MAX, RETRY_MULTIPLIER * (2**attempt))
|
||||
wait_time = max(RETRY_MIN, wait_time)
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
if last_error:
|
||||
self._raise_llm_error(last_error)
|
||||
|
||||
def _extract_chunk_delta(self, chunk: Any) -> str:
|
||||
if chunk.choices and hasattr(chunk.choices[0], "delta"):
|
||||
delta = chunk.choices[0].delta
|
||||
return getattr(delta, "content", "") or ""
|
||||
return ""
|
||||
|
||||
@property
|
||||
def usage_stats(self) -> dict[str, dict[str, int | float]]:
|
||||
return {
|
||||
"total": self._total_stats.to_dict(),
|
||||
"last_request": self._last_request_stats.to_dict(),
|
||||
}
|
||||
|
||||
def get_cache_config(self) -> dict[str, bool]:
|
||||
return {
|
||||
"enabled": self.config.enable_prompt_caching,
|
||||
"supported": supports_prompt_caching(self.config.model_name),
|
||||
}
|
||||
|
||||
def _should_include_reasoning_effort(self) -> bool:
|
||||
if not self.config.model_name:
|
||||
return False
|
||||
try:
|
||||
return bool(supports_reasoning(model=self.config.model_name))
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
def _model_supports_vision(self) -> bool:
|
||||
if not self.config.model_name:
|
||||
return False
|
||||
try:
|
||||
return bool(supports_vision(model=self.config.model_name))
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
def _filter_images_from_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
filtered_messages = []
|
||||
for msg in messages:
|
||||
content = msg.get("content")
|
||||
updated_msg = msg
|
||||
if isinstance(content, list):
|
||||
filtered_content = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "image_url":
|
||||
filtered_content.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": "[Screenshot removed - model does not support "
|
||||
"vision. Use view_source or execute_js instead.]",
|
||||
}
|
||||
)
|
||||
else:
|
||||
filtered_content.append(item)
|
||||
else:
|
||||
filtered_content.append(item)
|
||||
if filtered_content:
|
||||
text_parts = [
|
||||
item.get("text", "") if isinstance(item, dict) else str(item)
|
||||
for item in filtered_content
|
||||
]
|
||||
all_text = all(
|
||||
isinstance(item, dict) and item.get("type") == "text"
|
||||
for item in filtered_content
|
||||
)
|
||||
if all_text:
|
||||
updated_msg = {**msg, "content": "\n".join(text_parts)}
|
||||
else:
|
||||
updated_msg = {**msg, "content": filtered_content}
|
||||
else:
|
||||
updated_msg = {**msg, "content": ""}
|
||||
filtered_messages.append(updated_msg)
|
||||
return filtered_messages
|
||||
|
||||
async def _stream_request(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
) -> AsyncIterator[Any]:
|
||||
if not self._model_supports_vision():
|
||||
messages = self._filter_images_from_messages(messages)
|
||||
|
||||
completion_args: dict[str, Any] = {
|
||||
args: dict[str, Any] = {
|
||||
"model": self.config.model_name,
|
||||
"messages": messages,
|
||||
"timeout": self.config.timeout,
|
||||
"stream_options": {"include_usage": True},
|
||||
}
|
||||
|
||||
if _LLM_API_KEY:
|
||||
completion_args["api_key"] = _LLM_API_KEY
|
||||
if _LLM_API_BASE:
|
||||
completion_args["api_base"] = _LLM_API_BASE
|
||||
if api_key := Config.get("llm_api_key"):
|
||||
args["api_key"] = api_key
|
||||
if api_base := (
|
||||
Config.get("llm_api_base")
|
||||
or Config.get("openai_api_base")
|
||||
or Config.get("litellm_base_url")
|
||||
or Config.get("ollama_api_base")
|
||||
):
|
||||
args["api_base"] = api_base
|
||||
if self._supports_reasoning():
|
||||
args["reasoning_effort"] = self._reasoning_effort
|
||||
|
||||
completion_args["stop"] = ["</function>"]
|
||||
return args
|
||||
|
||||
if self._should_include_reasoning_effort():
|
||||
completion_args["reasoning_effort"] = self._reasoning_effort
|
||||
def _get_chunk_content(self, chunk: Any) -> str:
|
||||
if chunk.choices and hasattr(chunk.choices[0], "delta"):
|
||||
return getattr(chunk.choices[0].delta, "content", "") or ""
|
||||
return ""
|
||||
|
||||
queue = get_global_queue()
|
||||
self._total_stats.requests += 1
|
||||
self._last_request_stats = RequestStats(requests=1)
|
||||
|
||||
async for chunk in queue.stream_request(completion_args):
|
||||
yield chunk
|
||||
def _extract_thinking(self, chunks: list[Any]) -> list[dict[str, Any]] | None:
|
||||
if not chunks or not self._supports_reasoning():
|
||||
return None
|
||||
try:
|
||||
resp = stream_chunk_builder(chunks)
|
||||
if resp.choices and hasattr(resp.choices[0].message, "thinking_blocks"):
|
||||
blocks: list[dict[str, Any]] = resp.choices[0].message.thinking_blocks
|
||||
return blocks
|
||||
except Exception: # noqa: BLE001, S110 # nosec B110
|
||||
pass
|
||||
return None
|
||||
|
||||
def _update_usage_stats(self, response: Any) -> None:
|
||||
try:
|
||||
@@ -491,45 +231,88 @@ class LLM:
|
||||
output_tokens = getattr(response.usage, "completion_tokens", 0)
|
||||
|
||||
cached_tokens = 0
|
||||
cache_creation_tokens = 0
|
||||
|
||||
if hasattr(response.usage, "prompt_tokens_details"):
|
||||
prompt_details = response.usage.prompt_tokens_details
|
||||
if hasattr(prompt_details, "cached_tokens"):
|
||||
cached_tokens = prompt_details.cached_tokens or 0
|
||||
|
||||
if hasattr(response.usage, "cache_creation_input_tokens"):
|
||||
cache_creation_tokens = response.usage.cache_creation_input_tokens or 0
|
||||
|
||||
else:
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
cached_tokens = 0
|
||||
cache_creation_tokens = 0
|
||||
|
||||
try:
|
||||
cost = completion_cost(response) or 0.0
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning(f"Failed to calculate cost: {e}")
|
||||
except Exception: # noqa: BLE001
|
||||
cost = 0.0
|
||||
|
||||
self._total_stats.input_tokens += input_tokens
|
||||
self._total_stats.output_tokens += output_tokens
|
||||
self._total_stats.cached_tokens += cached_tokens
|
||||
self._total_stats.cache_creation_tokens += cache_creation_tokens
|
||||
self._total_stats.cost += cost
|
||||
|
||||
self._last_request_stats.input_tokens = input_tokens
|
||||
self._last_request_stats.output_tokens = output_tokens
|
||||
self._last_request_stats.cached_tokens = cached_tokens
|
||||
self._last_request_stats.cache_creation_tokens = cache_creation_tokens
|
||||
self._last_request_stats.cost = cost
|
||||
except Exception: # noqa: BLE001, S110 # nosec B110
|
||||
pass
|
||||
|
||||
if cached_tokens > 0:
|
||||
logger.info(f"Cache hit: {cached_tokens} cached tokens, {input_tokens} new tokens")
|
||||
if cache_creation_tokens > 0:
|
||||
logger.info(f"Cache creation: {cache_creation_tokens} tokens written to cache")
|
||||
def _should_retry(self, e: Exception) -> bool:
|
||||
code = getattr(e, "status_code", None) or getattr(
|
||||
getattr(e, "response", None), "status_code", None
|
||||
)
|
||||
return code is None or litellm._should_retry(code)
|
||||
|
||||
logger.info(f"Usage stats: {self.usage_stats}")
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning(f"Failed to update usage stats: {e}")
|
||||
def _raise_error(self, e: Exception) -> None:
|
||||
from strix.telemetry import posthog
|
||||
|
||||
posthog.error("llm_error", type(e).__name__)
|
||||
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
|
||||
|
||||
def _is_anthropic(self) -> bool:
|
||||
if not self.config.model_name:
|
||||
return False
|
||||
return any(p in self.config.model_name.lower() for p in ["anthropic/", "claude"])
|
||||
|
||||
def _supports_vision(self) -> bool:
|
||||
try:
|
||||
return bool(supports_vision(model=self.config.model_name))
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
def _supports_reasoning(self) -> bool:
|
||||
try:
|
||||
return bool(supports_reasoning(model=self.config.model_name))
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
def _strip_images(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
result = []
|
||||
for msg in messages:
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "text":
|
||||
text_parts.append(item.get("text", ""))
|
||||
elif isinstance(item, dict) and item.get("type") == "image_url":
|
||||
text_parts.append("[Image removed - model doesn't support vision]")
|
||||
result.append({**msg, "content": "\n".join(text_parts)})
|
||||
else:
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
if not messages or not supports_prompt_caching(self.config.model_name):
|
||||
return messages
|
||||
|
||||
result = list(messages)
|
||||
|
||||
if result[0].get("role") == "system":
|
||||
content = result[0]["content"]
|
||||
result[0] = {
|
||||
**result[0],
|
||||
"content": [
|
||||
{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}
|
||||
]
|
||||
if isinstance(content, str)
|
||||
else content,
|
||||
}
|
||||
return result
|
||||
|
||||
@@ -86,7 +86,7 @@ def _extract_message_text(msg: dict[str, Any]) -> str:
|
||||
def _summarize_messages(
|
||||
messages: list[dict[str, Any]],
|
||||
model: str,
|
||||
timeout: int = 600,
|
||||
timeout: int = 30,
|
||||
) -> dict[str, Any]:
|
||||
if not messages:
|
||||
empty_summary = "<context_summary message_count='0'>{text}</context_summary>"
|
||||
@@ -148,11 +148,11 @@ class MemoryCompressor:
|
||||
self,
|
||||
max_images: int = 3,
|
||||
model_name: str | None = None,
|
||||
timeout: int = 600,
|
||||
timeout: int | None = None,
|
||||
):
|
||||
self.max_images = max_images
|
||||
self.model_name = model_name or Config.get("strix_llm")
|
||||
self.timeout = timeout
|
||||
self.timeout = timeout or int(Config.get("strix_memory_compressor_timeout") or "30")
|
||||
|
||||
if not self.model_name:
|
||||
raise ValueError("STRIX_LLM environment variable must be set and not empty")
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from litellm import acompletion
|
||||
from litellm.types.utils import ModelResponseStream
|
||||
|
||||
from strix.config import Config
|
||||
|
||||
|
||||
class LLMRequestQueue:
|
||||
def __init__(self) -> None:
|
||||
self.delay_between_requests = float(Config.get("llm_rate_limit_delay") or "4.0")
|
||||
self.max_concurrent = int(Config.get("llm_rate_limit_concurrent") or "1")
|
||||
self._semaphore = threading.BoundedSemaphore(self.max_concurrent)
|
||||
self._last_request_time = 0.0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
async def stream_request(
|
||||
self, completion_args: dict[str, Any]
|
||||
) -> AsyncIterator[ModelResponseStream]:
|
||||
try:
|
||||
while not self._semaphore.acquire(timeout=0.2):
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
with self._lock:
|
||||
now = time.time()
|
||||
time_since_last = now - self._last_request_time
|
||||
sleep_needed = max(0, self.delay_between_requests - time_since_last)
|
||||
self._last_request_time = now + sleep_needed
|
||||
|
||||
if sleep_needed > 0:
|
||||
await asyncio.sleep(sleep_needed)
|
||||
|
||||
async for chunk in self._stream_request(completion_args):
|
||||
yield chunk
|
||||
finally:
|
||||
self._semaphore.release()
|
||||
|
||||
async def _stream_request(
|
||||
self, completion_args: dict[str, Any]
|
||||
) -> AsyncIterator[ModelResponseStream]:
|
||||
response = await acompletion(**completion_args, stream=True)
|
||||
|
||||
async for chunk in response:
|
||||
yield chunk
|
||||
|
||||
|
||||
_global_queue: LLMRequestQueue | None = None
|
||||
|
||||
|
||||
def get_global_queue() -> LLMRequestQueue:
|
||||
global _global_queue # noqa: PLW0603
|
||||
if _global_queue is None:
|
||||
_global_queue = LLMRequestQueue()
|
||||
return _global_queue
|
||||
@@ -18,7 +18,7 @@ def _truncate_to_first_function(content: str) -> str:
|
||||
|
||||
|
||||
def parse_tool_invocations(content: str) -> list[dict[str, Any]] | None:
|
||||
content = _fix_stopword(content)
|
||||
content = fix_incomplete_tool_call(content)
|
||||
|
||||
tool_invocations: list[dict[str, Any]] = []
|
||||
|
||||
@@ -46,16 +46,15 @@ def parse_tool_invocations(content: str) -> list[dict[str, Any]] | None:
|
||||
return tool_invocations if tool_invocations else None
|
||||
|
||||
|
||||
def _fix_stopword(content: str) -> str:
|
||||
def fix_incomplete_tool_call(content: str) -> str:
|
||||
"""Fix incomplete tool calls by adding missing </function> tag."""
|
||||
if (
|
||||
"<function=" in content
|
||||
and content.count("<function=") == 1
|
||||
and "</function>" not in content
|
||||
):
|
||||
if content.endswith("</"):
|
||||
content = content.rstrip() + "function>"
|
||||
else:
|
||||
content = content + "\n</function>"
|
||||
content = content.rstrip()
|
||||
content = content + "function>" if content.endswith("</") else content + "\n</function>"
|
||||
return content
|
||||
|
||||
|
||||
@@ -74,7 +73,7 @@ def clean_content(content: str) -> str:
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
content = _fix_stopword(content)
|
||||
content = fix_incomplete_tool_call(content)
|
||||
|
||||
tool_pattern = r"<function=[^>]+>.*?</function>"
|
||||
cleaned = re.sub(tool_pattern, "", content, flags=re.DOTALL)
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from pathlib import Path
|
||||
|
||||
from jinja2 import Environment
|
||||
|
||||
from strix.utils.resource_paths import get_strix_resource_path
|
||||
|
||||
|
||||
def get_available_skills() -> dict[str, list[str]]:
|
||||
skills_dir = Path(__file__).parent
|
||||
available_skills = {}
|
||||
skills_dir = get_strix_resource_path("skills")
|
||||
available_skills: dict[str, list[str]] = {}
|
||||
|
||||
if not skills_dir.exists():
|
||||
return available_skills
|
||||
|
||||
for category_dir in skills_dir.iterdir():
|
||||
if category_dir.is_dir() and not category_dir.name.startswith("__"):
|
||||
@@ -72,7 +75,7 @@ def load_skills(skill_names: list[str], jinja_env: Environment) -> dict[str, str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
skill_content = {}
|
||||
skills_dir = Path(__file__).parent
|
||||
skills_dir = get_strix_resource_path("skills")
|
||||
|
||||
available_skills = get_available_skills()
|
||||
|
||||
|
||||
@@ -430,10 +430,8 @@ class Tracer:
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"cached_tokens": 0,
|
||||
"cache_creation_tokens": 0,
|
||||
"cost": 0.0,
|
||||
"requests": 0,
|
||||
"failed_requests": 0,
|
||||
}
|
||||
|
||||
for agent_instance in _agent_instances.values():
|
||||
@@ -442,10 +440,8 @@ class Tracer:
|
||||
total_stats["input_tokens"] += agent_stats.input_tokens
|
||||
total_stats["output_tokens"] += agent_stats.output_tokens
|
||||
total_stats["cached_tokens"] += agent_stats.cached_tokens
|
||||
total_stats["cache_creation_tokens"] += agent_stats.cache_creation_tokens
|
||||
total_stats["cost"] += agent_stats.cost
|
||||
total_stats["requests"] += agent_stats.requests
|
||||
total_stats["failed_requests"] += agent_stats.failed_requests
|
||||
|
||||
total_stats["cost"] = round(total_stats["cost"], 4)
|
||||
|
||||
|
||||
@@ -14,12 +14,13 @@ from .argument_parser import convert_arguments
|
||||
from .registry import (
|
||||
get_tool_by_name,
|
||||
get_tool_names,
|
||||
get_tool_param_schema,
|
||||
needs_agent_state,
|
||||
should_execute_in_sandbox,
|
||||
)
|
||||
|
||||
|
||||
SANDBOX_EXECUTION_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout") or "500")
|
||||
SANDBOX_EXECUTION_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout") or "120")
|
||||
SANDBOX_CONNECT_TIMEOUT = float(Config.get("strix_sandbox_connect_timeout") or "10")
|
||||
|
||||
|
||||
@@ -110,14 +111,51 @@ async def _execute_tool_locally(tool_name: str, agent_state: Any | None, **kwarg
|
||||
|
||||
def validate_tool_availability(tool_name: str | None) -> tuple[bool, str]:
|
||||
if tool_name is None:
|
||||
return False, "Tool name is missing"
|
||||
available = ", ".join(sorted(get_tool_names()))
|
||||
return False, f"Tool name is missing. Available tools: {available}"
|
||||
|
||||
if tool_name not in get_tool_names():
|
||||
return False, f"Tool '{tool_name}' is not available"
|
||||
available = ", ".join(sorted(get_tool_names()))
|
||||
return False, f"Tool '{tool_name}' is not available. Available tools: {available}"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def _validate_tool_arguments(tool_name: str, kwargs: dict[str, Any]) -> str | None:
|
||||
param_schema = get_tool_param_schema(tool_name)
|
||||
if not param_schema or not param_schema.get("has_params"):
|
||||
return None
|
||||
|
||||
allowed_params: set[str] = param_schema.get("params", set())
|
||||
required_params: set[str] = param_schema.get("required", set())
|
||||
optional_params = allowed_params - required_params
|
||||
|
||||
schema_hint = _format_schema_hint(tool_name, required_params, optional_params)
|
||||
|
||||
unknown_params = set(kwargs.keys()) - allowed_params
|
||||
if unknown_params:
|
||||
unknown_list = ", ".join(sorted(unknown_params))
|
||||
return f"Tool '{tool_name}' received unknown parameter(s): {unknown_list}\n{schema_hint}"
|
||||
|
||||
missing_required = [
|
||||
param for param in required_params if param not in kwargs or kwargs.get(param) in (None, "")
|
||||
]
|
||||
if missing_required:
|
||||
missing_list = ", ".join(sorted(missing_required))
|
||||
return f"Tool '{tool_name}' missing required parameter(s): {missing_list}\n{schema_hint}"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _format_schema_hint(tool_name: str, required: set[str], optional: set[str]) -> str:
|
||||
parts = [f"Valid parameters for '{tool_name}':"]
|
||||
if required:
|
||||
parts.append(f" Required: {', '.join(sorted(required))}")
|
||||
if optional:
|
||||
parts.append(f" Optional: {', '.join(sorted(optional))}")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
async def execute_tool_with_validation(
|
||||
tool_name: str | None, agent_state: Any | None = None, **kwargs: Any
|
||||
) -> Any:
|
||||
@@ -127,6 +165,10 @@ async def execute_tool_with_validation(
|
||||
|
||||
assert tool_name is not None
|
||||
|
||||
arg_error = _validate_tool_arguments(tool_name, kwargs)
|
||||
if arg_error:
|
||||
return f"Error: {arg_error}"
|
||||
|
||||
try:
|
||||
result = await execute_tool(tool_name, agent_state, **kwargs)
|
||||
except Exception as e: # noqa: BLE001
|
||||
|
||||
@@ -55,6 +55,7 @@
|
||||
- Print statements and stdout are captured
|
||||
- Variables persist between executions in the same session
|
||||
- Imports, function definitions, etc. persist in the session
|
||||
- IMPORTANT (multiline): Put real line breaks in <parameter=code>. Do NOT emit literal "\n" sequences.
|
||||
- IPython magic commands are fully supported (%pip, %time, %whos, %%writefile, etc.)
|
||||
- Line magics (%) and cell magics (%%) work as expected
|
||||
6. CLOSE: Terminates the session completely and frees memory
|
||||
@@ -73,6 +74,14 @@
|
||||
print("Security analysis session started")</parameter>
|
||||
</function>
|
||||
|
||||
<function=python_action>
|
||||
<parameter=action>execute</parameter>
|
||||
<parameter=code>import requests
|
||||
url = "https://example.com"
|
||||
resp = requests.get(url, timeout=10)
|
||||
print(resp.status_code)</parameter>
|
||||
</function>
|
||||
|
||||
# Analyze security data in the default session
|
||||
<function=python_action>
|
||||
<parameter=action>execute</parameter>
|
||||
|
||||
@@ -7,9 +7,14 @@ from inspect import signature
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import defusedxml.ElementTree as DefusedET
|
||||
|
||||
from strix.utils.resource_paths import get_strix_resource_path
|
||||
|
||||
|
||||
tools: list[dict[str, Any]] = []
|
||||
_tools_by_name: dict[str, Callable[..., Any]] = {}
|
||||
_tool_param_schemas: dict[str, dict[str, Any]] = {}
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -82,6 +87,34 @@ def _load_xml_schema(path: Path) -> Any:
|
||||
return tools_dict
|
||||
|
||||
|
||||
def _parse_param_schema(tool_xml: str) -> dict[str, Any]:
|
||||
params: set[str] = set()
|
||||
required: set[str] = set()
|
||||
|
||||
params_start = tool_xml.find("<parameters>")
|
||||
params_end = tool_xml.find("</parameters>")
|
||||
|
||||
if params_start == -1 or params_end == -1:
|
||||
return {"params": set(), "required": set(), "has_params": False}
|
||||
|
||||
params_section = tool_xml[params_start : params_end + len("</parameters>")]
|
||||
|
||||
try:
|
||||
root = DefusedET.fromstring(params_section)
|
||||
except DefusedET.ParseError:
|
||||
return {"params": set(), "required": set(), "has_params": False}
|
||||
|
||||
for param in root.findall(".//parameter"):
|
||||
name = param.attrib.get("name")
|
||||
if not name:
|
||||
continue
|
||||
params.add(name)
|
||||
if param.attrib.get("required", "false").lower() == "true":
|
||||
required.add(name)
|
||||
|
||||
return {"params": params, "required": required, "has_params": bool(params or required)}
|
||||
|
||||
|
||||
def _get_module_name(func: Callable[..., Any]) -> str:
|
||||
module = inspect.getmodule(func)
|
||||
if not module:
|
||||
@@ -95,6 +128,27 @@ def _get_module_name(func: Callable[..., Any]) -> str:
|
||||
return "unknown"
|
||||
|
||||
|
||||
def _get_schema_path(func: Callable[..., Any]) -> Path | None:
|
||||
module = inspect.getmodule(func)
|
||||
if not module or not module.__name__:
|
||||
return None
|
||||
|
||||
module_name = module.__name__
|
||||
|
||||
if ".tools." not in module_name:
|
||||
return None
|
||||
|
||||
parts = module_name.split(".tools.")[-1].split(".")
|
||||
if len(parts) < 2:
|
||||
return None
|
||||
|
||||
folder = parts[0]
|
||||
file_stem = parts[1]
|
||||
schema_file = f"{file_stem}_schema.xml"
|
||||
|
||||
return get_strix_resource_path("tools", folder, schema_file)
|
||||
|
||||
|
||||
def register_tool(
|
||||
func: Callable[..., Any] | None = None, *, sandbox_execution: bool = True
|
||||
) -> Callable[..., Any]:
|
||||
@@ -109,11 +163,8 @@ def register_tool(
|
||||
sandbox_mode = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
|
||||
if not sandbox_mode:
|
||||
try:
|
||||
module_path = Path(inspect.getfile(f))
|
||||
schema_file_name = f"{module_path.stem}_schema.xml"
|
||||
schema_path = module_path.parent / schema_file_name
|
||||
|
||||
xml_tools = _load_xml_schema(schema_path)
|
||||
schema_path = _get_schema_path(f)
|
||||
xml_tools = _load_xml_schema(schema_path) if schema_path else None
|
||||
|
||||
if xml_tools is not None and f.__name__ in xml_tools:
|
||||
func_dict["xml_schema"] = xml_tools[f.__name__]
|
||||
@@ -131,6 +182,11 @@ def register_tool(
|
||||
"</tool>"
|
||||
)
|
||||
|
||||
if not sandbox_mode:
|
||||
xml_schema = func_dict.get("xml_schema")
|
||||
param_schema = _parse_param_schema(xml_schema if isinstance(xml_schema, str) else "")
|
||||
_tool_param_schemas[str(func_dict["name"])] = param_schema
|
||||
|
||||
tools.append(func_dict)
|
||||
_tools_by_name[str(func_dict["name"])] = f
|
||||
|
||||
@@ -153,6 +209,10 @@ def get_tool_names() -> list[str]:
|
||||
return list(_tools_by_name.keys())
|
||||
|
||||
|
||||
def get_tool_param_schema(name: str) -> dict[str, Any] | None:
|
||||
return _tool_param_schemas.get(name)
|
||||
|
||||
|
||||
def needs_agent_state(tool_name: str) -> bool:
|
||||
tool_func = get_tool_by_name(tool_name)
|
||||
if not tool_func:
|
||||
@@ -194,3 +254,4 @@ def get_tools_prompt() -> str:
|
||||
def clear_registry() -> None:
|
||||
tools.clear()
|
||||
_tools_by_name.clear()
|
||||
_tool_param_schemas.clear()
|
||||
|
||||
@@ -95,6 +95,12 @@
|
||||
<parameter=command>ls -la</parameter>
|
||||
</function>
|
||||
|
||||
<function=terminal_execute>
|
||||
<parameter=command>cd /workspace
|
||||
pwd
|
||||
ls -la</parameter>
|
||||
</function>
|
||||
|
||||
# Run a command with custom timeout
|
||||
<function=terminal_execute>
|
||||
<parameter=command>npm install</parameter>
|
||||
|
||||
0
strix/utils/__init__.py
Normal file
0
strix/utils/__init__.py
Normal file
13
strix/utils/resource_paths.py
Normal file
13
strix/utils/resource_paths.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def get_strix_resource_path(*parts: str) -> Path:
|
||||
frozen_base = getattr(sys, "_MEIPASS", None)
|
||||
if frozen_base:
|
||||
base = Path(frozen_base) / "strix"
|
||||
if base.exists():
|
||||
return base.joinpath(*parts)
|
||||
|
||||
base = Path(__file__).resolve().parent.parent
|
||||
return base.joinpath(*parts)
|
||||
Reference in New Issue
Block a user