diff --git a/strix/telemetry/tracer.py b/strix/telemetry/tracer.py index 2aac97f..dd4f99c 100644 --- a/strix/telemetry/tracer.py +++ b/strix/telemetry/tracer.py @@ -799,7 +799,10 @@ class Tracer: ) def get_total_llm_stats(self) -> dict[str, Any]: - from strix.tools.agents_graph.agents_graph_actions import _agent_instances + from strix.tools.agents_graph.agents_graph_actions import ( + _agent_instances, + _completed_agent_llm_stats, + ) total_stats = { "input_tokens": 0, @@ -809,6 +812,15 @@ class Tracer: "requests": 0, } + for agent_id, completed_stats in _completed_agent_llm_stats.items(): + if agent_id in _agent_instances: + continue + total_stats["input_tokens"] += int(completed_stats.get("input_tokens", 0) or 0) + total_stats["output_tokens"] += int(completed_stats.get("output_tokens", 0) or 0) + total_stats["cached_tokens"] += int(completed_stats.get("cached_tokens", 0) or 0) + total_stats["cost"] += float(completed_stats.get("cost", 0.0) or 0.0) + total_stats["requests"] += int(completed_stats.get("requests", 0) or 0) + for agent_instance in _agent_instances.values(): if hasattr(agent_instance, "llm") and hasattr(agent_instance.llm, "_total_stats"): agent_stats = agent_instance.llm._total_stats diff --git a/strix/tools/agents_graph/agents_graph_actions.py b/strix/tools/agents_graph/agents_graph_actions.py index 76313d7..cee21df 100644 --- a/strix/tools/agents_graph/agents_graph_actions.py +++ b/strix/tools/agents_graph/agents_graph_actions.py @@ -19,9 +19,36 @@ _running_agents: dict[str, threading.Thread] = {} _agent_instances: dict[str, Any] = {} +_completed_agent_llm_stats: dict[str, dict[str, int | float]] = {} + _agent_states: dict[str, Any] = {} +def _snapshot_agent_llm_stats(agent: Any) -> dict[str, int | float] | None: + if not hasattr(agent, "llm") or not hasattr(agent.llm, "_total_stats"): + return None + + stats = agent.llm._total_stats + return { + "input_tokens": stats.input_tokens, + "output_tokens": stats.output_tokens, + "cached_tokens": stats.cached_tokens, + "cost": stats.cost, + "requests": stats.requests, + } + + +def _persist_completed_agent_llm_stats(agent_id: str, agent: Any) -> None: + stats = _snapshot_agent_llm_stats(agent) + if stats is None: + return + + _completed_agent_llm_stats[agent_id] = stats + node = _agent_graph["nodes"].get(agent_id) + if node is not None: + node["llm_stats"] = stats + + def _is_whitebox_agent(agent_id: str) -> bool: agent = _agent_instances.get(agent_id) return bool(getattr(getattr(agent, "llm_config", None), "is_whitebox", False)) @@ -237,6 +264,7 @@ def _run_agent_in_thread( _agent_graph["nodes"][state.agent_id]["finished_at"] = datetime.now(UTC).isoformat() _agent_graph["nodes"][state.agent_id]["result"] = {"error": str(e)} _running_agents.pop(state.agent_id, None) + _persist_completed_agent_llm_stats(state.agent_id, agent) _agent_instances.pop(state.agent_id, None) raise else: @@ -247,6 +275,7 @@ def _run_agent_in_thread( _agent_graph["nodes"][state.agent_id]["finished_at"] = datetime.now(UTC).isoformat() _agent_graph["nodes"][state.agent_id]["result"] = result _running_agents.pop(state.agent_id, None) + _persist_completed_agent_llm_stats(state.agent_id, agent) _agent_instances.pop(state.agent_id, None) return {"result": result} diff --git a/tests/telemetry/test_tracer.py b/tests/telemetry/test_tracer.py index 10f887e..3b4326c 100644 --- a/tests/telemetry/test_tracer.py +++ b/tests/telemetry/test_tracer.py @@ -10,6 +10,7 @@ from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExportResult from strix.telemetry import tracer as tracer_module from strix.telemetry import utils as telemetry_utils from strix.telemetry.tracer import Tracer, set_global_tracer +from strix.tools.agents_graph import agents_graph_actions def _load_events(events_path: Path) -> list[dict[str, Any]]: @@ -255,6 +256,77 @@ def test_events_with_agent_id_include_agent_name(monkeypatch, tmp_path) -> None: assert chat_event["actor"]["agent_name"] == "Root Agent" +def test_get_total_llm_stats_includes_completed_subagents(monkeypatch, tmp_path) -> None: + monkeypatch.chdir(tmp_path) + + class DummyStats: + def __init__( + self, + *, + input_tokens: int, + output_tokens: int, + cached_tokens: int, + cost: float, + requests: int, + ) -> None: + self.input_tokens = input_tokens + self.output_tokens = output_tokens + self.cached_tokens = cached_tokens + self.cost = cost + self.requests = requests + + class DummyLLM: + def __init__(self, stats: DummyStats) -> None: + self._total_stats = stats + + class DummyAgent: + def __init__(self, stats: DummyStats) -> None: + self.llm = DummyLLM(stats) + + tracer = Tracer("cost-rollup") + set_global_tracer(tracer) + + monkeypatch.setattr( + agents_graph_actions, + "_agent_instances", + { + "root-agent": DummyAgent( + DummyStats( + input_tokens=1_000, + output_tokens=250, + cached_tokens=100, + cost=0.12345, + requests=2, + ) + ) + }, + ) + monkeypatch.setattr( + agents_graph_actions, + "_completed_agent_llm_stats", + { + "sub-agent-1": { + "input_tokens": 2_000, + "output_tokens": 500, + "cached_tokens": 400, + "cost": 0.54321, + "requests": 3, + } + }, + ) + + stats = tracer.get_total_llm_stats() + + assert stats["total"] == { + "input_tokens": 3_000, + "output_tokens": 750, + "cached_tokens": 500, + "cost": 0.6667, + "requests": 5, + } + assert stats["total_tokens"] == 3_750 + + def test_run_metadata_is_only_on_run_lifecycle_events(monkeypatch, tmp_path) -> None: monkeypatch.chdir(tmp_path)