fix: add thinking blocks

This commit is contained in:
Daniel Sangorrin
2025-12-16 15:11:24 +01:00
committed by Ahmed Allam
parent 49421f50d5
commit 226678f3f2
3 changed files with 23 additions and 4 deletions

View File

@@ -251,7 +251,7 @@ class BaseAgent(metaclass=AgentMeta):
if self.state.has_waiting_timeout(): if self.state.has_waiting_timeout():
self.state.resume_from_waiting() self.state.resume_from_waiting()
self.state.add_message("assistant", "Waiting timeout reached. Resuming execution.") self.state.add_message("user", "Waiting timeout reached. Resuming execution.")
from strix.telemetry.tracer import get_global_tracer from strix.telemetry.tracer import get_global_tracer
@@ -364,7 +364,8 @@ class BaseAgent(metaclass=AgentMeta):
self.state.add_message("user", corrective_message) self.state.add_message("user", corrective_message)
return False return False
self.state.add_message("assistant", final_response.content) thinking_blocks = getattr(final_response, "thinking_blocks", None)
self.state.add_message("assistant", final_response.content, thinking_blocks=thinking_blocks)
if tracer: if tracer:
tracer.clear_streaming_content(self.state.agent_id) tracer.clear_streaming_content(self.state.agent_id)
tracer.log_chat_message( tracer.log_chat_message(

View File

@@ -43,8 +43,11 @@ class AgentState(BaseModel):
self.iteration += 1 self.iteration += 1
self.last_updated = datetime.now(UTC).isoformat() self.last_updated = datetime.now(UTC).isoformat()
def add_message(self, role: str, content: Any) -> None: def add_message(self, role: str, content: Any, thinking_blocks: list[dict[str, Any]] | None = None) -> None:
self.messages.append({"role": role, "content": content}) message = {"role": role, "content": content}
if thinking_blocks:
message["thinking_blocks"] = thinking_blocks
self.messages.append(message)
self.last_updated = datetime.now(UTC).isoformat() self.last_updated = datetime.now(UTC).isoformat()
def add_action(self, action: dict[str, Any]) -> None: def add_action(self, action: dict[str, Any]) -> None:

View File

@@ -75,6 +75,7 @@ class LLMResponse:
scan_id: str | None = None scan_id: str | None = None
step_number: int = 1 step_number: int = 1
role: StepRole = StepRole.AGENT role: StepRole = StepRole.AGENT
thinking_blocks: list[dict[str, Any]] | None = None # For reasoning models.
@dataclass @dataclass
@@ -291,12 +292,26 @@ class LLM:
tool_invocations = parse_tool_invocations(accumulated_content) tool_invocations = parse_tool_invocations(accumulated_content)
# Extract thinking blocks from the complete response if available
thinking_blocks = None
if chunks and self._should_include_reasoning_effort():
complete_response = stream_chunk_builder(chunks)
if (
hasattr(complete_response, "choices")
and complete_response.choices
and hasattr(complete_response.choices[0], "message")
):
message = complete_response.choices[0].message
if hasattr(message, "thinking_blocks") and message.thinking_blocks:
thinking_blocks = message.thinking_blocks
yield LLMResponse( yield LLMResponse(
scan_id=scan_id, scan_id=scan_id,
step_number=step_number, step_number=step_number,
role=StepRole.AGENT, role=StepRole.AGENT,
content=accumulated_content, content=accumulated_content,
tool_invocations=tool_invocations if tool_invocations else None, tool_invocations=tool_invocations if tool_invocations else None,
thinking_blocks=thinking_blocks,
) )
def _raise_llm_error(self, e: Exception) -> None: def _raise_llm_error(self, e: Exception) -> None: