From 226678f3f2e7b067836b692cf87e12011a51d41b Mon Sep 17 00:00:00 2001 From: Daniel Sangorrin Date: Tue, 16 Dec 2025 15:11:24 +0100 Subject: [PATCH] fix: add thinking blocks --- strix/agents/base_agent.py | 5 +++-- strix/agents/state.py | 7 +++++-- strix/llm/llm.py | 15 +++++++++++++++ 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/strix/agents/base_agent.py b/strix/agents/base_agent.py index bb0bd84..3763541 100644 --- a/strix/agents/base_agent.py +++ b/strix/agents/base_agent.py @@ -251,7 +251,7 @@ class BaseAgent(metaclass=AgentMeta): if self.state.has_waiting_timeout(): self.state.resume_from_waiting() - self.state.add_message("assistant", "Waiting timeout reached. Resuming execution.") + self.state.add_message("user", "Waiting timeout reached. Resuming execution.") from strix.telemetry.tracer import get_global_tracer @@ -364,7 +364,8 @@ class BaseAgent(metaclass=AgentMeta): self.state.add_message("user", corrective_message) return False - self.state.add_message("assistant", final_response.content) + thinking_blocks = getattr(final_response, "thinking_blocks", None) + self.state.add_message("assistant", final_response.content, thinking_blocks=thinking_blocks) if tracer: tracer.clear_streaming_content(self.state.agent_id) tracer.log_chat_message( diff --git a/strix/agents/state.py b/strix/agents/state.py index 81ac657..1f32d61 100644 --- a/strix/agents/state.py +++ b/strix/agents/state.py @@ -43,8 +43,11 @@ class AgentState(BaseModel): self.iteration += 1 self.last_updated = datetime.now(UTC).isoformat() - def add_message(self, role: str, content: Any) -> None: - self.messages.append({"role": role, "content": content}) + def add_message(self, role: str, content: Any, thinking_blocks: list[dict[str, Any]] | None = None) -> None: + message = {"role": role, "content": content} + if thinking_blocks: + message["thinking_blocks"] = thinking_blocks + self.messages.append(message) self.last_updated = datetime.now(UTC).isoformat() def add_action(self, action: dict[str, Any]) -> None: diff --git a/strix/llm/llm.py b/strix/llm/llm.py index 6bf4347..fd63c30 100644 --- a/strix/llm/llm.py +++ b/strix/llm/llm.py @@ -75,6 +75,7 @@ class LLMResponse: scan_id: str | None = None step_number: int = 1 role: StepRole = StepRole.AGENT + thinking_blocks: list[dict[str, Any]] | None = None # For reasoning models. @dataclass @@ -291,12 +292,26 @@ class LLM: tool_invocations = parse_tool_invocations(accumulated_content) + # Extract thinking blocks from the complete response if available + thinking_blocks = None + if chunks and self._should_include_reasoning_effort(): + complete_response = stream_chunk_builder(chunks) + if ( + hasattr(complete_response, "choices") + and complete_response.choices + and hasattr(complete_response.choices[0], "message") + ): + message = complete_response.choices[0].message + if hasattr(message, "thinking_blocks") and message.thinking_blocks: + thinking_blocks = message.thinking_blocks + yield LLMResponse( scan_id=scan_id, step_number=step_number, role=StepRole.AGENT, content=accumulated_content, tool_invocations=tool_invocations if tool_invocations else None, + thinking_blocks=thinking_blocks, ) def _raise_llm_error(self, e: Exception) -> None: