Track tool part versions for reliable updates

This commit is contained in:
Shantur Rathore
2025-11-07 15:12:21 +00:00
parent 27da588d22
commit 7193103646
3 changed files with 17 additions and 60 deletions

View File

@@ -4,8 +4,6 @@ import { Markdown } from "./markdown"
import { useTheme } from "../lib/theme" import { useTheme } from "../lib/theme"
import type { TextPart } from "../types/message" import type { TextPart } from "../types/message"
// Module-level cache for stable TextPart objects per tool call
const markdownPartCache = new Map<string, TextPart>()
const toolScrollState = new Map<string, { scrollTop: number; atBottom: boolean }>() const toolScrollState = new Map<string, { scrollTop: number; atBottom: boolean }>()
function updateScrollState(id: string, element: HTMLElement) { function updateScrollState(id: string, element: HTMLElement) {
@@ -37,27 +35,6 @@ function restoreScrollState(id: string, element: HTMLElement) {
}) })
} }
function getCachedMarkdownPart(id: string, text: string): TextPart {
if (!id) {
// No caching case - return fresh object
return { type: "text", text }
}
const part = markdownPartCache.get(id)
if (!part) {
const freshPart: TextPart = { type: "text", text }
markdownPartCache.set(id, freshPart)
return freshPart
}
if (part.text !== text) {
const freshPart: TextPart = { type: "text", text }
markdownPartCache.set(id, freshPart)
return freshPart
}
return part
}
interface ToolCallProps { interface ToolCallProps {
toolCall: any toolCall: any
@@ -184,7 +161,6 @@ export default function ToolCall(props: ToolCallProps) {
if (!id) return if (!id) return
onCleanup(() => { onCleanup(() => {
markdownPartCache.delete(id)
toolScrollState.delete(id) toolScrollState.delete(id)
}) })
}) })
@@ -367,7 +343,7 @@ export default function ToolCall(props: ToolCallProps) {
const messageClass = `message-text tool-call-markdown${isLarge ? " tool-call-markdown-large" : ""}` const messageClass = `message-text tool-call-markdown${isLarge ? " tool-call-markdown-large" : ""}`
const disableHighlight = state?.status === "running" const disableHighlight = state?.status === "running"
const cachedPart = getCachedMarkdownPart(toolCallId(), content) const markdownPart: TextPart = { type: "text", text: content }
return ( return (
<div <div
@@ -390,7 +366,7 @@ export default function ToolCall(props: ToolCallProps) {
onScroll={(event) => updateScrollState(toolCallId(), event.currentTarget)} onScroll={(event) => updateScrollState(toolCallId(), event.currentTarget)}
> >
<Markdown <Markdown
part={cachedPart} part={markdownPart}
isDark={isDark()} isDark={isDark()}
disableHighlight={disableHighlight} disableHighlight={disableHighlight}
onRendered={handleMarkdownRendered} onRendered={handleMarkdownRendered}

View File

@@ -185,33 +185,17 @@ export function computeDisplayParts(message: Message, showThinking: boolean): Me
return { text, tool, reasoning, combined, showThinking, version } return { text, tool, reasoning, combined, showThinking, version }
} }
function ensurePartVersionsMap(message: Message) { function initializePartVersion(part: any, version = 0) {
if (!message.partVersions) { if (!part || typeof part !== "object") return
message.partVersions = new Map() const partAny = part as any
} partAny.version = typeof partAny.version === "number" ? partAny.version : version
return message.partVersions
} }
function initializePartVersion(message: Message, part: any) { function bumpPartVersion(previousPart: any, nextPart: any): number {
const partId = typeof part?.id === "string" ? part.id : null const prevVersion = typeof previousPart?.version === "number" ? previousPart.version : -1
if (!partId) return const nextVersion = prevVersion + 1
const versions = ensurePartVersionsMap(message) initializePartVersion(nextPart, nextVersion)
if (!versions.has(partId)) { return nextVersion
versions.set(partId, 0)
}
const partAny = part as any
partAny.__version = versions.get(partId)
}
function bumpPartVersion(message: Message, part: any): number | undefined {
const partId = typeof part?.id === "string" ? part.id : null
if (!partId) return undefined
const versions = ensurePartVersionsMap(message)
const next = (versions.get(partId) ?? 0) + 1
versions.set(partId, next)
const partAny = part as any
partAny.__version = next
return next
} }
function withSession(instanceId: string, sessionId: string, updater: (session: Session) => void) { function withSession(instanceId: string, sessionId: string, updater: (session: Session) => void) {
@@ -836,10 +820,9 @@ async function loadMessages(instanceId: string, sessionId: string, force = false
timestamp: info.time?.created || Date.now(), timestamp: info.time?.created || Date.now(),
status: "complete" as const, status: "complete" as const,
version: 0, version: 0,
partVersions: new Map(),
} }
parts.forEach((part: any) => initializePartVersion(message, part)) parts.forEach((part: any) => initializePartVersion(part))
message.displayParts = computeDisplayParts(message, preferences().showThinkingBlocks) message.displayParts = computeDisplayParts(message, preferences().showThinkingBlocks)
@@ -956,10 +939,9 @@ function handleMessageUpdate(instanceId: string, event: any): void {
timestamp: Date.now(), timestamp: Date.now(),
status: "streaming" as const, status: "streaming" as const,
version: 0, version: 0,
partVersions: new Map(),
} }
initializePartVersion(newMessage, part) initializePartVersion(part)
newMessage.displayParts = computeDisplayParts(newMessage, preferences().showThinkingBlocks) newMessage.displayParts = computeDisplayParts(newMessage, preferences().showThinkingBlocks)
let insertIndex = session.messages.length let insertIndex = session.messages.length
@@ -1017,11 +999,11 @@ function handleMessageUpdate(instanceId: string, event: any): void {
const partIndex = partMap.get(part.id) const partIndex = partMap.get(part.id)
if (partIndex === undefined) { if (partIndex === undefined) {
initializePartVersion(part)
baseParts.push(part) baseParts.push(part)
if (part.id && typeof part.id === "string") { if (part.id && typeof part.id === "string") {
partMap.set(part.id, baseParts.length - 1) partMap.set(part.id, baseParts.length - 1)
} }
initializePartVersion(message, part)
shouldIncrementVersion = true shouldIncrementVersion = true
// Clear render cache for new text parts // Clear render cache for new text parts
if (part.type === "text") { if (part.type === "text") {
@@ -1040,8 +1022,8 @@ function handleMessageUpdate(instanceId: string, event: any): void {
return return
} }
bumpPartVersion(previousPart, part)
baseParts[partIndex] = part baseParts[partIndex] = part
bumpPartVersion(message, part)
if (part.type !== "text" || !previousPart || previousPart.text !== part.text) { if (part.type !== "text" || !previousPart || previousPart.text !== part.text) {
shouldIncrementVersion = true shouldIncrementVersion = true
// Clear render cache when text changes // Clear render cache when text changes
@@ -1347,10 +1329,9 @@ async function sendMessage(
timestamp: Date.now(), timestamp: Date.now(),
status: "sending", status: "sending",
version: 0, version: 0,
partVersions: new Map(),
} }
optimisticParts.forEach((part: any) => initializePartVersion(optimisticMessage, part)) optimisticParts.forEach((part: any) => initializePartVersion(part))
optimisticMessage.displayParts = computeDisplayParts(optimisticMessage, preferences().showThinkingBlocks) optimisticMessage.displayParts = computeDisplayParts(optimisticMessage, preferences().showThinkingBlocks)

View File

@@ -21,7 +21,6 @@ export interface Message {
timestamp: number timestamp: number
status: "sending" | "sent" | "streaming" | "complete" | "error" status: "sending" | "sent" | "streaming" | "complete" | "error"
version: number version: number
partVersions?: Map<string, number>
displayParts?: MessageDisplayParts displayParts?: MessageDisplayParts
} }
@@ -29,6 +28,7 @@ export interface TextPart {
id?: string id?: string
type: "text" type: "text"
text: string text: string
version?: number
synthetic?: boolean synthetic?: boolean
renderCache?: RenderCache renderCache?: RenderCache
} }