Track tool part versions for reliable updates
This commit is contained in:
@@ -4,8 +4,6 @@ import { Markdown } from "./markdown"
|
|||||||
import { useTheme } from "../lib/theme"
|
import { useTheme } from "../lib/theme"
|
||||||
import type { TextPart } from "../types/message"
|
import type { TextPart } from "../types/message"
|
||||||
|
|
||||||
// Module-level cache for stable TextPart objects per tool call
|
|
||||||
const markdownPartCache = new Map<string, TextPart>()
|
|
||||||
const toolScrollState = new Map<string, { scrollTop: number; atBottom: boolean }>()
|
const toolScrollState = new Map<string, { scrollTop: number; atBottom: boolean }>()
|
||||||
|
|
||||||
function updateScrollState(id: string, element: HTMLElement) {
|
function updateScrollState(id: string, element: HTMLElement) {
|
||||||
@@ -37,27 +35,6 @@ function restoreScrollState(id: string, element: HTMLElement) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
function getCachedMarkdownPart(id: string, text: string): TextPart {
|
|
||||||
if (!id) {
|
|
||||||
// No caching case - return fresh object
|
|
||||||
return { type: "text", text }
|
|
||||||
}
|
|
||||||
|
|
||||||
const part = markdownPartCache.get(id)
|
|
||||||
if (!part) {
|
|
||||||
const freshPart: TextPart = { type: "text", text }
|
|
||||||
markdownPartCache.set(id, freshPart)
|
|
||||||
return freshPart
|
|
||||||
}
|
|
||||||
|
|
||||||
if (part.text !== text) {
|
|
||||||
const freshPart: TextPart = { type: "text", text }
|
|
||||||
markdownPartCache.set(id, freshPart)
|
|
||||||
return freshPart
|
|
||||||
}
|
|
||||||
|
|
||||||
return part
|
|
||||||
}
|
|
||||||
|
|
||||||
interface ToolCallProps {
|
interface ToolCallProps {
|
||||||
toolCall: any
|
toolCall: any
|
||||||
@@ -184,7 +161,6 @@ export default function ToolCall(props: ToolCallProps) {
|
|||||||
if (!id) return
|
if (!id) return
|
||||||
|
|
||||||
onCleanup(() => {
|
onCleanup(() => {
|
||||||
markdownPartCache.delete(id)
|
|
||||||
toolScrollState.delete(id)
|
toolScrollState.delete(id)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@@ -367,7 +343,7 @@ export default function ToolCall(props: ToolCallProps) {
|
|||||||
const messageClass = `message-text tool-call-markdown${isLarge ? " tool-call-markdown-large" : ""}`
|
const messageClass = `message-text tool-call-markdown${isLarge ? " tool-call-markdown-large" : ""}`
|
||||||
const disableHighlight = state?.status === "running"
|
const disableHighlight = state?.status === "running"
|
||||||
|
|
||||||
const cachedPart = getCachedMarkdownPart(toolCallId(), content)
|
const markdownPart: TextPart = { type: "text", text: content }
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
@@ -390,7 +366,7 @@ export default function ToolCall(props: ToolCallProps) {
|
|||||||
onScroll={(event) => updateScrollState(toolCallId(), event.currentTarget)}
|
onScroll={(event) => updateScrollState(toolCallId(), event.currentTarget)}
|
||||||
>
|
>
|
||||||
<Markdown
|
<Markdown
|
||||||
part={cachedPart}
|
part={markdownPart}
|
||||||
isDark={isDark()}
|
isDark={isDark()}
|
||||||
disableHighlight={disableHighlight}
|
disableHighlight={disableHighlight}
|
||||||
onRendered={handleMarkdownRendered}
|
onRendered={handleMarkdownRendered}
|
||||||
|
|||||||
@@ -185,33 +185,17 @@ export function computeDisplayParts(message: Message, showThinking: boolean): Me
|
|||||||
return { text, tool, reasoning, combined, showThinking, version }
|
return { text, tool, reasoning, combined, showThinking, version }
|
||||||
}
|
}
|
||||||
|
|
||||||
function ensurePartVersionsMap(message: Message) {
|
function initializePartVersion(part: any, version = 0) {
|
||||||
if (!message.partVersions) {
|
if (!part || typeof part !== "object") return
|
||||||
message.partVersions = new Map()
|
const partAny = part as any
|
||||||
}
|
partAny.version = typeof partAny.version === "number" ? partAny.version : version
|
||||||
return message.partVersions
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function initializePartVersion(message: Message, part: any) {
|
function bumpPartVersion(previousPart: any, nextPart: any): number {
|
||||||
const partId = typeof part?.id === "string" ? part.id : null
|
const prevVersion = typeof previousPart?.version === "number" ? previousPart.version : -1
|
||||||
if (!partId) return
|
const nextVersion = prevVersion + 1
|
||||||
const versions = ensurePartVersionsMap(message)
|
initializePartVersion(nextPart, nextVersion)
|
||||||
if (!versions.has(partId)) {
|
return nextVersion
|
||||||
versions.set(partId, 0)
|
|
||||||
}
|
|
||||||
const partAny = part as any
|
|
||||||
partAny.__version = versions.get(partId)
|
|
||||||
}
|
|
||||||
|
|
||||||
function bumpPartVersion(message: Message, part: any): number | undefined {
|
|
||||||
const partId = typeof part?.id === "string" ? part.id : null
|
|
||||||
if (!partId) return undefined
|
|
||||||
const versions = ensurePartVersionsMap(message)
|
|
||||||
const next = (versions.get(partId) ?? 0) + 1
|
|
||||||
versions.set(partId, next)
|
|
||||||
const partAny = part as any
|
|
||||||
partAny.__version = next
|
|
||||||
return next
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function withSession(instanceId: string, sessionId: string, updater: (session: Session) => void) {
|
function withSession(instanceId: string, sessionId: string, updater: (session: Session) => void) {
|
||||||
@@ -836,10 +820,9 @@ async function loadMessages(instanceId: string, sessionId: string, force = false
|
|||||||
timestamp: info.time?.created || Date.now(),
|
timestamp: info.time?.created || Date.now(),
|
||||||
status: "complete" as const,
|
status: "complete" as const,
|
||||||
version: 0,
|
version: 0,
|
||||||
partVersions: new Map(),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
parts.forEach((part: any) => initializePartVersion(message, part))
|
parts.forEach((part: any) => initializePartVersion(part))
|
||||||
|
|
||||||
message.displayParts = computeDisplayParts(message, preferences().showThinkingBlocks)
|
message.displayParts = computeDisplayParts(message, preferences().showThinkingBlocks)
|
||||||
|
|
||||||
@@ -956,10 +939,9 @@ function handleMessageUpdate(instanceId: string, event: any): void {
|
|||||||
timestamp: Date.now(),
|
timestamp: Date.now(),
|
||||||
status: "streaming" as const,
|
status: "streaming" as const,
|
||||||
version: 0,
|
version: 0,
|
||||||
partVersions: new Map(),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
initializePartVersion(newMessage, part)
|
initializePartVersion(part)
|
||||||
newMessage.displayParts = computeDisplayParts(newMessage, preferences().showThinkingBlocks)
|
newMessage.displayParts = computeDisplayParts(newMessage, preferences().showThinkingBlocks)
|
||||||
|
|
||||||
let insertIndex = session.messages.length
|
let insertIndex = session.messages.length
|
||||||
@@ -1017,11 +999,11 @@ function handleMessageUpdate(instanceId: string, event: any): void {
|
|||||||
const partIndex = partMap.get(part.id)
|
const partIndex = partMap.get(part.id)
|
||||||
|
|
||||||
if (partIndex === undefined) {
|
if (partIndex === undefined) {
|
||||||
|
initializePartVersion(part)
|
||||||
baseParts.push(part)
|
baseParts.push(part)
|
||||||
if (part.id && typeof part.id === "string") {
|
if (part.id && typeof part.id === "string") {
|
||||||
partMap.set(part.id, baseParts.length - 1)
|
partMap.set(part.id, baseParts.length - 1)
|
||||||
}
|
}
|
||||||
initializePartVersion(message, part)
|
|
||||||
shouldIncrementVersion = true
|
shouldIncrementVersion = true
|
||||||
// Clear render cache for new text parts
|
// Clear render cache for new text parts
|
||||||
if (part.type === "text") {
|
if (part.type === "text") {
|
||||||
@@ -1040,8 +1022,8 @@ function handleMessageUpdate(instanceId: string, event: any): void {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bumpPartVersion(previousPart, part)
|
||||||
baseParts[partIndex] = part
|
baseParts[partIndex] = part
|
||||||
bumpPartVersion(message, part)
|
|
||||||
if (part.type !== "text" || !previousPart || previousPart.text !== part.text) {
|
if (part.type !== "text" || !previousPart || previousPart.text !== part.text) {
|
||||||
shouldIncrementVersion = true
|
shouldIncrementVersion = true
|
||||||
// Clear render cache when text changes
|
// Clear render cache when text changes
|
||||||
@@ -1347,10 +1329,9 @@ async function sendMessage(
|
|||||||
timestamp: Date.now(),
|
timestamp: Date.now(),
|
||||||
status: "sending",
|
status: "sending",
|
||||||
version: 0,
|
version: 0,
|
||||||
partVersions: new Map(),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
optimisticParts.forEach((part: any) => initializePartVersion(optimisticMessage, part))
|
optimisticParts.forEach((part: any) => initializePartVersion(part))
|
||||||
|
|
||||||
optimisticMessage.displayParts = computeDisplayParts(optimisticMessage, preferences().showThinkingBlocks)
|
optimisticMessage.displayParts = computeDisplayParts(optimisticMessage, preferences().showThinkingBlocks)
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ export interface Message {
|
|||||||
timestamp: number
|
timestamp: number
|
||||||
status: "sending" | "sent" | "streaming" | "complete" | "error"
|
status: "sending" | "sent" | "streaming" | "complete" | "error"
|
||||||
version: number
|
version: number
|
||||||
partVersions?: Map<string, number>
|
|
||||||
displayParts?: MessageDisplayParts
|
displayParts?: MessageDisplayParts
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -29,6 +28,7 @@ export interface TextPart {
|
|||||||
id?: string
|
id?: string
|
||||||
type: "text"
|
type: "text"
|
||||||
text: string
|
text: string
|
||||||
|
version?: number
|
||||||
synthetic?: boolean
|
synthetic?: boolean
|
||||||
renderCache?: RenderCache
|
renderCache?: RenderCache
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user