fix: ensure LLM stats tracking is accurate by including completed subagents

This commit is contained in:
bearsyankees
2026-04-12 23:22:23 -04:00
parent 62e9af36d2
commit 3b018447dc
3 changed files with 114 additions and 1 deletions

View File

@@ -799,7 +799,10 @@ class Tracer:
)
def get_total_llm_stats(self) -> dict[str, Any]:
from strix.tools.agents_graph.agents_graph_actions import _agent_instances
from strix.tools.agents_graph.agents_graph_actions import (
_agent_instances,
_completed_agent_llm_stats,
)
total_stats = {
"input_tokens": 0,
@@ -809,6 +812,15 @@ class Tracer:
"requests": 0,
}
for agent_id, completed_stats in _completed_agent_llm_stats.items():
if agent_id in _agent_instances:
continue
total_stats["input_tokens"] += int(completed_stats.get("input_tokens", 0) or 0)
total_stats["output_tokens"] += int(completed_stats.get("output_tokens", 0) or 0)
total_stats["cached_tokens"] += int(completed_stats.get("cached_tokens", 0) or 0)
total_stats["cost"] += float(completed_stats.get("cost", 0.0) or 0.0)
total_stats["requests"] += int(completed_stats.get("requests", 0) or 0)
for agent_instance in _agent_instances.values():
if hasattr(agent_instance, "llm") and hasattr(agent_instance.llm, "_total_stats"):
agent_stats = agent_instance.llm._total_stats

View File

@@ -19,9 +19,36 @@ _running_agents: dict[str, threading.Thread] = {}
_agent_instances: dict[str, Any] = {}
_completed_agent_llm_stats: dict[str, dict[str, int | float]] = {}
_agent_states: dict[str, Any] = {}
def _snapshot_agent_llm_stats(agent: Any) -> dict[str, int | float] | None:
if not hasattr(agent, "llm") or not hasattr(agent.llm, "_total_stats"):
return None
stats = agent.llm._total_stats
return {
"input_tokens": stats.input_tokens,
"output_tokens": stats.output_tokens,
"cached_tokens": stats.cached_tokens,
"cost": stats.cost,
"requests": stats.requests,
}
def _persist_completed_agent_llm_stats(agent_id: str, agent: Any) -> None:
stats = _snapshot_agent_llm_stats(agent)
if stats is None:
return
_completed_agent_llm_stats[agent_id] = stats
node = _agent_graph["nodes"].get(agent_id)
if node is not None:
node["llm_stats"] = stats
def _is_whitebox_agent(agent_id: str) -> bool:
agent = _agent_instances.get(agent_id)
return bool(getattr(getattr(agent, "llm_config", None), "is_whitebox", False))
@@ -237,6 +264,7 @@ def _run_agent_in_thread(
_agent_graph["nodes"][state.agent_id]["finished_at"] = datetime.now(UTC).isoformat()
_agent_graph["nodes"][state.agent_id]["result"] = {"error": str(e)}
_running_agents.pop(state.agent_id, None)
_persist_completed_agent_llm_stats(state.agent_id, agent)
_agent_instances.pop(state.agent_id, None)
raise
else:
@@ -247,6 +275,7 @@ def _run_agent_in_thread(
_agent_graph["nodes"][state.agent_id]["finished_at"] = datetime.now(UTC).isoformat()
_agent_graph["nodes"][state.agent_id]["result"] = result
_running_agents.pop(state.agent_id, None)
_persist_completed_agent_llm_stats(state.agent_id, agent)
_agent_instances.pop(state.agent_id, None)
return {"result": result}