This commit is contained in:
bearsyankees
2026-04-12 23:52:52 -04:00
parent 3b018447dc
commit a6dd550aec
5 changed files with 65 additions and 59 deletions

View File

@@ -134,7 +134,8 @@ class BaseAgent(metaclass=AgentMeta):
} }
agents_graph_actions._agent_graph["nodes"][self.state.agent_id] = node agents_graph_actions._agent_graph["nodes"][self.state.agent_id] = node
agents_graph_actions._agent_instances[self.state.agent_id] = self with agents_graph_actions._agent_llm_stats_lock:
agents_graph_actions._agent_instances[self.state.agent_id] = self
agents_graph_actions._agent_states[self.state.agent_id] = self.state agents_graph_actions._agent_states[self.state.agent_id] = self.state
if self.state.parent_id: if self.state.parent_id:

View File

@@ -801,27 +801,23 @@ class Tracer:
def get_total_llm_stats(self) -> dict[str, Any]: def get_total_llm_stats(self) -> dict[str, Any]:
from strix.tools.agents_graph.agents_graph_actions import ( from strix.tools.agents_graph.agents_graph_actions import (
_agent_instances, _agent_instances,
_completed_agent_llm_stats, _completed_agent_llm_totals,
_agent_llm_stats_lock,
) )
with _agent_llm_stats_lock:
completed_totals = dict(_completed_agent_llm_totals)
active_agents = list(_agent_instances.values())
total_stats = { total_stats = {
"input_tokens": 0, "input_tokens": int(completed_totals.get("input_tokens", 0) or 0),
"output_tokens": 0, "output_tokens": int(completed_totals.get("output_tokens", 0) or 0),
"cached_tokens": 0, "cached_tokens": int(completed_totals.get("cached_tokens", 0) or 0),
"cost": 0.0, "cost": float(completed_totals.get("cost", 0.0) or 0.0),
"requests": 0, "requests": int(completed_totals.get("requests", 0) or 0),
} }
for agent_id, completed_stats in _completed_agent_llm_stats.items(): for agent_instance in active_agents:
if agent_id in _agent_instances:
continue
total_stats["input_tokens"] += int(completed_stats.get("input_tokens", 0) or 0)
total_stats["output_tokens"] += int(completed_stats.get("output_tokens", 0) or 0)
total_stats["cached_tokens"] += int(completed_stats.get("cached_tokens", 0) or 0)
total_stats["cost"] += float(completed_stats.get("cost", 0.0) or 0.0)
total_stats["requests"] += int(completed_stats.get("requests", 0) or 0)
for agent_instance in _agent_instances.values():
if hasattr(agent_instance, "llm") and hasattr(agent_instance.llm, "_total_stats"): if hasattr(agent_instance, "llm") and hasattr(agent_instance.llm, "_total_stats"):
agent_stats = agent_instance.llm._total_stats agent_stats = agent_instance.llm._total_stats
total_stats["input_tokens"] += agent_stats.input_tokens total_stats["input_tokens"] += agent_stats.input_tokens

View File

@@ -19,7 +19,20 @@ _running_agents: dict[str, threading.Thread] = {}
_agent_instances: dict[str, Any] = {} _agent_instances: dict[str, Any] = {}
_completed_agent_llm_stats: dict[str, dict[str, int | float]] = {} _agent_llm_stats_lock = threading.Lock()
def _empty_llm_stats_totals() -> dict[str, int | float]:
return {
"input_tokens": 0,
"output_tokens": 0,
"cached_tokens": 0,
"cost": 0.0,
"requests": 0,
}
_completed_agent_llm_totals: dict[str, int | float] = _empty_llm_stats_totals()
_agent_states: dict[str, Any] = {} _agent_states: dict[str, Any] = {}
@@ -38,15 +51,21 @@ def _snapshot_agent_llm_stats(agent: Any) -> dict[str, int | float] | None:
} }
def _persist_completed_agent_llm_stats(agent_id: str, agent: Any) -> None: def _finalize_agent_llm_stats(agent_id: str, agent: Any) -> None:
stats = _snapshot_agent_llm_stats(agent) stats = _snapshot_agent_llm_stats(agent)
if stats is None: with _agent_llm_stats_lock:
return if stats is not None:
_completed_agent_llm_totals["input_tokens"] += int(stats["input_tokens"])
_completed_agent_llm_totals["output_tokens"] += int(stats["output_tokens"])
_completed_agent_llm_totals["cached_tokens"] += int(stats["cached_tokens"])
_completed_agent_llm_totals["cost"] += float(stats["cost"])
_completed_agent_llm_totals["requests"] += int(stats["requests"])
_completed_agent_llm_stats[agent_id] = stats node = _agent_graph["nodes"].get(agent_id)
node = _agent_graph["nodes"].get(agent_id) if node is not None:
if node is not None: node["llm_stats"] = stats
node["llm_stats"] = stats
_agent_instances.pop(agent_id, None)
def _is_whitebox_agent(agent_id: str) -> bool: def _is_whitebox_agent(agent_id: str) -> bool:
@@ -264,8 +283,7 @@ def _run_agent_in_thread(
_agent_graph["nodes"][state.agent_id]["finished_at"] = datetime.now(UTC).isoformat() _agent_graph["nodes"][state.agent_id]["finished_at"] = datetime.now(UTC).isoformat()
_agent_graph["nodes"][state.agent_id]["result"] = {"error": str(e)} _agent_graph["nodes"][state.agent_id]["result"] = {"error": str(e)}
_running_agents.pop(state.agent_id, None) _running_agents.pop(state.agent_id, None)
_persist_completed_agent_llm_stats(state.agent_id, agent) _finalize_agent_llm_stats(state.agent_id, agent)
_agent_instances.pop(state.agent_id, None)
raise raise
else: else:
if state.stop_requested: if state.stop_requested:
@@ -275,8 +293,7 @@ def _run_agent_in_thread(
_agent_graph["nodes"][state.agent_id]["finished_at"] = datetime.now(UTC).isoformat() _agent_graph["nodes"][state.agent_id]["finished_at"] = datetime.now(UTC).isoformat()
_agent_graph["nodes"][state.agent_id]["result"] = result _agent_graph["nodes"][state.agent_id]["result"] = result
_running_agents.pop(state.agent_id, None) _running_agents.pop(state.agent_id, None)
_persist_completed_agent_llm_stats(state.agent_id, agent) _finalize_agent_llm_stats(state.agent_id, agent)
_agent_instances.pop(state.agent_id, None)
return {"result": result} return {"result": result}
@@ -447,7 +464,8 @@ def create_agent(
if inherit_context: if inherit_context:
inherited_messages = agent_state.get_conversation_history() inherited_messages = agent_state.get_conversation_history()
_agent_instances[state.agent_id] = agent with _agent_llm_stats_lock:
_agent_instances[state.agent_id] = agent
thread = threading.Thread( thread = threading.Thread(
target=_run_agent_in_thread, target=_run_agent_in_thread,

View File

@@ -303,15 +303,13 @@ def test_get_total_llm_stats_includes_completed_subagents(monkeypatch, tmp_path)
) )
monkeypatch.setattr( monkeypatch.setattr(
agents_graph_actions, agents_graph_actions,
"_completed_agent_llm_stats", "_completed_agent_llm_totals",
{ {
"sub-agent-1": { "input_tokens": 2_000,
"input_tokens": 2_000, "output_tokens": 500,
"output_tokens": 500, "cached_tokens": 400,
"cached_tokens": 400, "cost": 0.54321,
"cost": 0.54321, "requests": 3,
"requests": 3,
}
}, },
) )

View File

@@ -5,16 +5,24 @@ from strix.llm.config import LLMConfig
from strix.tools.agents_graph import agents_graph_actions from strix.tools.agents_graph import agents_graph_actions
def test_create_agent_inherits_parent_whitebox_flag(monkeypatch) -> None: def _reset_agent_graph_state() -> None:
monkeypatch.setenv("STRIX_LLM", "openai/gpt-5")
agents_graph_actions._agent_graph["nodes"].clear() agents_graph_actions._agent_graph["nodes"].clear()
agents_graph_actions._agent_graph["edges"].clear() agents_graph_actions._agent_graph["edges"].clear()
agents_graph_actions._agent_messages.clear() agents_graph_actions._agent_messages.clear()
agents_graph_actions._running_agents.clear() agents_graph_actions._running_agents.clear()
agents_graph_actions._agent_instances.clear() agents_graph_actions._agent_instances.clear()
agents_graph_actions._completed_agent_llm_totals.clear()
agents_graph_actions._completed_agent_llm_totals.update(
agents_graph_actions._empty_llm_stats_totals()
)
agents_graph_actions._agent_states.clear() agents_graph_actions._agent_states.clear()
def test_create_agent_inherits_parent_whitebox_flag(monkeypatch) -> None:
monkeypatch.setenv("STRIX_LLM", "openai/gpt-5")
_reset_agent_graph_state()
parent_id = "parent-agent" parent_id = "parent-agent"
parent_llm = LLMConfig(timeout=123, scan_mode="standard", is_whitebox=True) parent_llm = LLMConfig(timeout=123, scan_mode="standard", is_whitebox=True)
agents_graph_actions._agent_instances[parent_id] = SimpleNamespace( agents_graph_actions._agent_instances[parent_id] = SimpleNamespace(
@@ -66,12 +74,7 @@ def test_create_agent_inherits_parent_whitebox_flag(monkeypatch) -> None:
def test_delegation_prompt_includes_wiki_memory_instruction_in_whitebox(monkeypatch) -> None: def test_delegation_prompt_includes_wiki_memory_instruction_in_whitebox(monkeypatch) -> None:
monkeypatch.setenv("STRIX_LLM", "openai/gpt-5") monkeypatch.setenv("STRIX_LLM", "openai/gpt-5")
agents_graph_actions._agent_graph["nodes"].clear() _reset_agent_graph_state()
agents_graph_actions._agent_graph["edges"].clear()
agents_graph_actions._agent_messages.clear()
agents_graph_actions._running_agents.clear()
agents_graph_actions._agent_instances.clear()
agents_graph_actions._agent_states.clear()
parent_id = "parent-1" parent_id = "parent-1"
child_id = "child-1" child_id = "child-1"
@@ -116,12 +119,7 @@ def test_delegation_prompt_includes_wiki_memory_instruction_in_whitebox(monkeypa
def test_agent_finish_appends_wiki_update_for_whitebox(monkeypatch) -> None: def test_agent_finish_appends_wiki_update_for_whitebox(monkeypatch) -> None:
monkeypatch.setenv("STRIX_LLM", "openai/gpt-5") monkeypatch.setenv("STRIX_LLM", "openai/gpt-5")
agents_graph_actions._agent_graph["nodes"].clear() _reset_agent_graph_state()
agents_graph_actions._agent_graph["edges"].clear()
agents_graph_actions._agent_messages.clear()
agents_graph_actions._running_agents.clear()
agents_graph_actions._agent_instances.clear()
agents_graph_actions._agent_states.clear()
parent_id = "parent-2" parent_id = "parent-2"
child_id = "child-2" child_id = "child-2"
@@ -192,12 +190,7 @@ def test_agent_finish_appends_wiki_update_for_whitebox(monkeypatch) -> None:
def test_run_agent_in_thread_injects_shared_wiki_context_in_whitebox(monkeypatch) -> None: def test_run_agent_in_thread_injects_shared_wiki_context_in_whitebox(monkeypatch) -> None:
monkeypatch.setenv("STRIX_LLM", "openai/gpt-5") monkeypatch.setenv("STRIX_LLM", "openai/gpt-5")
agents_graph_actions._agent_graph["nodes"].clear() _reset_agent_graph_state()
agents_graph_actions._agent_graph["edges"].clear()
agents_graph_actions._agent_messages.clear()
agents_graph_actions._running_agents.clear()
agents_graph_actions._agent_instances.clear()
agents_graph_actions._agent_states.clear()
parent_id = "parent-3" parent_id = "parent-3"
child_id = "child-3" child_id = "child-3"