fix: ensure LLM stats tracking is accurate by including completed subagents
This commit is contained in:
@@ -799,7 +799,10 @@ class Tracer:
|
||||
)
|
||||
|
||||
def get_total_llm_stats(self) -> dict[str, Any]:
|
||||
from strix.tools.agents_graph.agents_graph_actions import _agent_instances
|
||||
from strix.tools.agents_graph.agents_graph_actions import (
|
||||
_agent_instances,
|
||||
_completed_agent_llm_stats,
|
||||
)
|
||||
|
||||
total_stats = {
|
||||
"input_tokens": 0,
|
||||
@@ -809,6 +812,15 @@ class Tracer:
|
||||
"requests": 0,
|
||||
}
|
||||
|
||||
for agent_id, completed_stats in _completed_agent_llm_stats.items():
|
||||
if agent_id in _agent_instances:
|
||||
continue
|
||||
total_stats["input_tokens"] += int(completed_stats.get("input_tokens", 0) or 0)
|
||||
total_stats["output_tokens"] += int(completed_stats.get("output_tokens", 0) or 0)
|
||||
total_stats["cached_tokens"] += int(completed_stats.get("cached_tokens", 0) or 0)
|
||||
total_stats["cost"] += float(completed_stats.get("cost", 0.0) or 0.0)
|
||||
total_stats["requests"] += int(completed_stats.get("requests", 0) or 0)
|
||||
|
||||
for agent_instance in _agent_instances.values():
|
||||
if hasattr(agent_instance, "llm") and hasattr(agent_instance.llm, "_total_stats"):
|
||||
agent_stats = agent_instance.llm._total_stats
|
||||
|
||||
@@ -19,9 +19,36 @@ _running_agents: dict[str, threading.Thread] = {}
|
||||
|
||||
_agent_instances: dict[str, Any] = {}
|
||||
|
||||
_completed_agent_llm_stats: dict[str, dict[str, int | float]] = {}
|
||||
|
||||
_agent_states: dict[str, Any] = {}
|
||||
|
||||
|
||||
def _snapshot_agent_llm_stats(agent: Any) -> dict[str, int | float] | None:
|
||||
if not hasattr(agent, "llm") or not hasattr(agent.llm, "_total_stats"):
|
||||
return None
|
||||
|
||||
stats = agent.llm._total_stats
|
||||
return {
|
||||
"input_tokens": stats.input_tokens,
|
||||
"output_tokens": stats.output_tokens,
|
||||
"cached_tokens": stats.cached_tokens,
|
||||
"cost": stats.cost,
|
||||
"requests": stats.requests,
|
||||
}
|
||||
|
||||
|
||||
def _persist_completed_agent_llm_stats(agent_id: str, agent: Any) -> None:
|
||||
stats = _snapshot_agent_llm_stats(agent)
|
||||
if stats is None:
|
||||
return
|
||||
|
||||
_completed_agent_llm_stats[agent_id] = stats
|
||||
node = _agent_graph["nodes"].get(agent_id)
|
||||
if node is not None:
|
||||
node["llm_stats"] = stats
|
||||
|
||||
|
||||
def _is_whitebox_agent(agent_id: str) -> bool:
|
||||
agent = _agent_instances.get(agent_id)
|
||||
return bool(getattr(getattr(agent, "llm_config", None), "is_whitebox", False))
|
||||
@@ -237,6 +264,7 @@ def _run_agent_in_thread(
|
||||
_agent_graph["nodes"][state.agent_id]["finished_at"] = datetime.now(UTC).isoformat()
|
||||
_agent_graph["nodes"][state.agent_id]["result"] = {"error": str(e)}
|
||||
_running_agents.pop(state.agent_id, None)
|
||||
_persist_completed_agent_llm_stats(state.agent_id, agent)
|
||||
_agent_instances.pop(state.agent_id, None)
|
||||
raise
|
||||
else:
|
||||
@@ -247,6 +275,7 @@ def _run_agent_in_thread(
|
||||
_agent_graph["nodes"][state.agent_id]["finished_at"] = datetime.now(UTC).isoformat()
|
||||
_agent_graph["nodes"][state.agent_id]["result"] = result
|
||||
_running_agents.pop(state.agent_id, None)
|
||||
_persist_completed_agent_llm_stats(state.agent_id, agent)
|
||||
_agent_instances.pop(state.agent_id, None)
|
||||
|
||||
return {"result": result}
|
||||
|
||||
@@ -10,6 +10,7 @@ from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExportResult
|
||||
from strix.telemetry import tracer as tracer_module
|
||||
from strix.telemetry import utils as telemetry_utils
|
||||
from strix.telemetry.tracer import Tracer, set_global_tracer
|
||||
from strix.tools.agents_graph import agents_graph_actions
|
||||
|
||||
|
||||
def _load_events(events_path: Path) -> list[dict[str, Any]]:
|
||||
@@ -255,6 +256,77 @@ def test_events_with_agent_id_include_agent_name(monkeypatch, tmp_path) -> None:
|
||||
assert chat_event["actor"]["agent_name"] == "Root Agent"
|
||||
|
||||
|
||||
def test_get_total_llm_stats_includes_completed_subagents(monkeypatch, tmp_path) -> None:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
class DummyStats:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cached_tokens: int,
|
||||
cost: float,
|
||||
requests: int,
|
||||
) -> None:
|
||||
self.input_tokens = input_tokens
|
||||
self.output_tokens = output_tokens
|
||||
self.cached_tokens = cached_tokens
|
||||
self.cost = cost
|
||||
self.requests = requests
|
||||
|
||||
class DummyLLM:
|
||||
def __init__(self, stats: DummyStats) -> None:
|
||||
self._total_stats = stats
|
||||
|
||||
class DummyAgent:
|
||||
def __init__(self, stats: DummyStats) -> None:
|
||||
self.llm = DummyLLM(stats)
|
||||
|
||||
tracer = Tracer("cost-rollup")
|
||||
set_global_tracer(tracer)
|
||||
|
||||
monkeypatch.setattr(
|
||||
agents_graph_actions,
|
||||
"_agent_instances",
|
||||
{
|
||||
"root-agent": DummyAgent(
|
||||
DummyStats(
|
||||
input_tokens=1_000,
|
||||
output_tokens=250,
|
||||
cached_tokens=100,
|
||||
cost=0.12345,
|
||||
requests=2,
|
||||
)
|
||||
)
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
agents_graph_actions,
|
||||
"_completed_agent_llm_stats",
|
||||
{
|
||||
"sub-agent-1": {
|
||||
"input_tokens": 2_000,
|
||||
"output_tokens": 500,
|
||||
"cached_tokens": 400,
|
||||
"cost": 0.54321,
|
||||
"requests": 3,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
stats = tracer.get_total_llm_stats()
|
||||
|
||||
assert stats["total"] == {
|
||||
"input_tokens": 3_000,
|
||||
"output_tokens": 750,
|
||||
"cached_tokens": 500,
|
||||
"cost": 0.6667,
|
||||
"requests": 5,
|
||||
}
|
||||
assert stats["total_tokens"] == 3_750
|
||||
|
||||
|
||||
def test_run_metadata_is_only_on_run_lifecycle_events(monkeypatch, tmp_path) -> None:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user