Stabilize message stream rendering caches

This commit is contained in:
Shantur Rathore
2025-10-28 13:48:56 +00:00
parent 79e4931b28
commit 6597783e85
4 changed files with 190 additions and 56 deletions

View File

@@ -6,6 +6,7 @@ interface MessageItemProps {
message: Message message: Message
messageInfo?: any messageInfo?: any
isQueued?: boolean isQueued?: boolean
parts?: any[]
onRevert?: (messageId: string) => void onRevert?: (messageId: string) => void
} }
@@ -16,6 +17,8 @@ export default function MessageItem(props: MessageItemProps) {
return date.toLocaleTimeString([], { hour: "2-digit", minute: "2-digit" }) return date.toLocaleTimeString([], { hour: "2-digit", minute: "2-digit" })
} }
const messageParts = () => props.parts ?? props.message.parts
const errorMessage = () => { const errorMessage = () => {
if (!props.messageInfo?.error) return null if (!props.messageInfo?.error) return null
@@ -36,7 +39,7 @@ export default function MessageItem(props: MessageItemProps) {
} }
const hasContent = () => { const hasContent = () => {
return props.message.parts.length > 0 || errorMessage() !== null return messageParts().length > 0 || errorMessage() !== null
} }
const isGenerating = () => { const isGenerating = () => {
@@ -81,7 +84,7 @@ export default function MessageItem(props: MessageItemProps) {
</div> </div>
</Show> </Show>
<For each={props.message.parts}>{(part) => <MessagePart part={part} />}</For> <For each={messageParts()}>{(part) => <MessagePart part={part} />}</For>
</div> </div>
<Show when={props.message.status === "sending"}> <Show when={props.message.status === "sending"}>

View File

@@ -1,11 +1,11 @@
import { For, Show, createSignal, createEffect, createMemo } from "solid-js" import { For, Show, createSignal, createEffect, createMemo } from "solid-js"
import type { Message } from "../types/message" import type { Message, MessageDisplayParts } from "../types/message"
import MessageItem from "./message-item" import MessageItem from "./message-item"
import ToolCall from "./tool-call" import ToolCall from "./tool-call"
import { sseManager } from "../lib/sse-manager" import { sseManager } from "../lib/sse-manager"
import Kbd from "./kbd" import Kbd from "./kbd"
import { preferences } from "../stores/preferences" import { preferences } from "../stores/preferences"
import { providers, getSessionInfo } from "../stores/sessions" import { providers, getSessionInfo, computeDisplayParts } from "../stores/sessions"
// Calculate session tokens and cost from messagesInfo (matches TUI logic) // Calculate session tokens and cost from messagesInfo (matches TUI logic)
function calculateSessionInfo(messagesInfo?: Map<string, any>, instanceId?: string) { function calculateSessionInfo(messagesInfo?: Map<string, any>, instanceId?: string) {
@@ -120,17 +120,46 @@ interface MessageStreamProps {
onRevert?: (messageId: string) => void onRevert?: (messageId: string) => void
} }
interface DisplayItem { interface MessageDisplayItem {
type: "message" | "tool" type: "message"
data: any message: Message
combinedParts: any[]
isQueued: boolean
messageInfo?: any messageInfo?: any
} }
interface ToolDisplayItem {
type: "tool"
key: string
toolPart: any
messageInfo?: any
}
type DisplayItem = MessageDisplayItem | ToolDisplayItem
interface MessageCacheEntry {
version: number
showThinking: boolean
isQueued: boolean
messageInfo?: any
displayParts: MessageDisplayParts
item: MessageDisplayItem
}
interface ToolCacheEntry {
toolPart: any
messageInfo?: any
item: ToolDisplayItem
}
export default function MessageStream(props: MessageStreamProps) { export default function MessageStream(props: MessageStreamProps) {
let containerRef: HTMLDivElement | undefined let containerRef: HTMLDivElement | undefined
const [autoScroll, setAutoScroll] = createSignal(true) const [autoScroll, setAutoScroll] = createSignal(true)
const [showScrollButton, setShowScrollButton] = createSignal(false) const [showScrollButton, setShowScrollButton] = createSignal(false)
let messageItemCache = new Map<string, MessageCacheEntry>()
let toolItemCache = new Map<string, ToolCacheEntry>()
const connectionStatus = () => sseManager.getStatus(props.instanceId) const connectionStatus = () => sseManager.getStatus(props.instanceId)
const sessionInfo = createMemo(() => { const sessionInfo = createMemo(() => {
@@ -180,9 +209,11 @@ export default function MessageStream(props: MessageStreamProps) {
const displayItems = createMemo(() => { const displayItems = createMemo(() => {
// Ensure memo reacts to preference changes // Ensure memo reacts to preference changes
preferences().showThinkingBlocks const showThinking = preferences().showThinkingBlocks
const items: DisplayItem[] = [] const items: DisplayItem[] = []
const newMessageCache = new Map<string, MessageCacheEntry>()
const newToolCache = new Map<string, ToolCacheEntry>()
let lastAssistantIndex = -1 let lastAssistantIndex = -1
for (let i = props.messages.length - 1; i >= 0; i--) { for (let i = props.messages.length - 1; i >= 0; i--) {
@@ -201,35 +232,82 @@ export default function MessageStream(props: MessageStreamProps) {
break break
} }
// Use precomputed displayParts, fallback to empty arrays if not available const baseDisplayParts = message.displayParts
const displayParts = message.displayParts || { text: [], tool: [], reasoning: [] } const displayParts: MessageDisplayParts =
const textParts = displayParts.text baseDisplayParts && baseDisplayParts.showThinking === showThinking
const toolParts = displayParts.tool ? baseDisplayParts
const reasoningParts = displayParts.reasoning : computeDisplayParts(message, showThinking)
const combinedParts = displayParts.combined
const version = message.version ?? 0
const isQueued = message.type === "user" && (lastAssistantIndex === -1 || index > lastAssistantIndex) const isQueued = message.type === "user" && (lastAssistantIndex === -1 || index > lastAssistantIndex)
if (textParts.length > 0 || reasoningParts.length > 0 || messageInfo?.error) { const cacheEntry = messageItemCache.get(message.id)
items.push({ if (
cacheEntry &&
cacheEntry.version === version &&
cacheEntry.showThinking === showThinking &&
cacheEntry.isQueued === isQueued &&
cacheEntry.messageInfo === messageInfo
) {
cacheEntry.displayParts = displayParts
cacheEntry.version = version
cacheEntry.showThinking = showThinking
cacheEntry.isQueued = isQueued
cacheEntry.messageInfo = messageInfo
cacheEntry.item.message = message
cacheEntry.item.combinedParts = combinedParts
cacheEntry.item.isQueued = isQueued
cacheEntry.item.messageInfo = messageInfo
newMessageCache.set(message.id, cacheEntry)
items.push(cacheEntry.item)
} else {
const messageItem: MessageDisplayItem = {
type: "message", type: "message",
data: { message,
...message, combinedParts,
parts: [...textParts, ...reasoningParts], isQueued,
isQueued,
},
messageInfo, messageInfo,
}
newMessageCache.set(message.id, {
version,
showThinking,
isQueued,
messageInfo,
displayParts,
item: messageItem,
}) })
items.push(messageItem)
} }
for (const toolPart of toolParts) { for (let toolIndex = 0; toolIndex < displayParts.tool.length; toolIndex++) {
items.push({ const toolPart = displayParts.tool[toolIndex]
type: "tool", const toolKey = typeof toolPart?.id === "string" ? toolPart.id : `${message.id}-tool-${toolIndex}`
data: toolPart,
messageInfo, const toolEntry = toolItemCache.get(toolKey)
}) if (toolEntry && toolEntry.toolPart === toolPart && toolEntry.messageInfo === messageInfo) {
toolEntry.item.toolPart = toolPart
toolEntry.item.messageInfo = messageInfo
toolEntry.toolPart = toolPart
toolEntry.messageInfo = messageInfo
newToolCache.set(toolKey, toolEntry)
items.push(toolEntry.item)
} else {
const toolItem: ToolDisplayItem = {
type: "tool",
key: toolKey,
toolPart,
messageInfo,
}
newToolCache.set(toolKey, { toolPart, messageInfo, item: toolItem })
items.push(toolItem)
}
} }
} }
messageItemCache = newMessageCache
toolItemCache = newToolCache
return items return items
}) })
@@ -301,29 +379,30 @@ export default function MessageStream(props: MessageStreamProps) {
</Show> </Show>
<For each={displayItems()} fallback={null}> <For each={displayItems()} fallback={null}>
{(item, index) => { {(item) => {
const key = item.type === "message" ? `msg-${item.data.id}` : `tool-${item.data.id}` if (item.type === "message") {
return ( return (
<Show
when={item.type === "message"}
fallback={
<div class="tool-call-message" data-key={key}>
<div class="tool-call-header-label">
<span class="tool-call-icon">🔧</span>
<span>Tool Call</span>
<span class="tool-name">{item.data?.tool || "unknown"}</span>
</div>
<ToolCall toolCall={item.data} toolCallId={item.data.id} />
</div>
}
>
<MessageItem <MessageItem
message={item.data} message={item.message}
messageInfo={item.messageInfo} messageInfo={item.messageInfo}
isQueued={item.data.isQueued} isQueued={item.isQueued}
parts={item.combinedParts}
onRevert={props.onRevert} onRevert={props.onRevert}
/> />
</Show> )
}
const toolPart = item.toolPart
return (
<div class="tool-call-message" data-key={item.key}>
<div class="tool-call-header-label">
<span class="tool-call-icon">🔧</span>
<span>Tool Call</span>
<span class="tool-name">{toolPart?.tool || "unknown"}</span>
</div>
<ToolCall toolCall={toolPart} toolCallId={toolPart?.id} />
</div>
) )
}} }}
</For> </For>

View File

@@ -82,7 +82,7 @@ function removeSessionIndexes(instanceId: string) {
sessionIndexes.delete(instanceId) sessionIndexes.delete(instanceId)
} }
function computeDisplayParts(message: Message, showThinking: boolean): MessageDisplayParts { export function computeDisplayParts(message: Message, showThinking: boolean): MessageDisplayParts {
const text: any[] = [] const text: any[] = []
const tool: any[] = [] const tool: any[] = []
const reasoning: any[] = [] const reasoning: any[] = []
@@ -97,7 +97,10 @@ function computeDisplayParts(message: Message, showThinking: boolean): MessageDi
} }
} }
return { text, tool, reasoning } const combined = reasoning.length > 0 ? [...text, ...reasoning] : [...text]
const version = typeof message.version === "number" ? message.version : 0
return { text, tool, reasoning, combined, showThinking, version }
} }
function withSession(instanceId: string, sessionId: string, updater: (session: Session) => void) { function withSession(instanceId: string, sessionId: string, updater: (session: Session) => void) {
@@ -710,6 +713,7 @@ async function loadMessages(instanceId: string, sessionId: string, force = false
parts: apiMessage.parts || [], parts: apiMessage.parts || [],
timestamp: info.time?.created || Date.now(), timestamp: info.time?.created || Date.now(),
status: "complete" as const, status: "complete" as const,
version: 0,
} }
message.displayParts = computeDisplayParts(message, preferences().showThinkingBlocks) message.displayParts = computeDisplayParts(message, preferences().showThinkingBlocks)
@@ -824,6 +828,7 @@ function handleMessageUpdate(instanceId: string, event: any): void {
parts: [part], parts: [part],
timestamp: Date.now(), timestamp: Date.now(),
status: "streaming" as const, status: "streaming" as const,
version: 0,
} }
newMessage.displayParts = computeDisplayParts(newMessage, preferences().showThinkingBlocks) newMessage.displayParts = computeDisplayParts(newMessage, preferences().showThinkingBlocks)
@@ -841,6 +846,9 @@ function handleMessageUpdate(instanceId: string, event: any): void {
} else { } else {
// Update existing message // Update existing message
const message = session.messages[messageIndex] const message = session.messages[messageIndex]
if (typeof message.version !== "number") {
message.version = 0
}
// Strip synthetic parts when real data arrives // Strip synthetic parts when real data arrives
let filteredSynthetics = false let filteredSynthetics = false
@@ -864,14 +872,32 @@ function handleMessageUpdate(instanceId: string, event: any): void {
index.partIndex.set(message.id, partMap) index.partIndex.set(message.id, partMap)
} }
let shouldIncrementVersion = filteredSynthetics || replacedTemp
const partIndex = partMap.get(part.id) const partIndex = partMap.get(part.id)
if (partIndex === undefined) { if (partIndex === undefined) {
baseParts.push(part) baseParts.push(part)
if (part.id && typeof part.id === "string") { if (part.id && typeof part.id === "string") {
partMap.set(part.id, baseParts.length - 1) partMap.set(part.id, baseParts.length - 1)
} }
shouldIncrementVersion = true
} else { } else {
const previousPart = baseParts[partIndex]
const textUnchanged =
!filteredSynthetics &&
!replacedTemp &&
part.type === "text" &&
previousPart?.type === "text" &&
previousPart.text === part.text
if (textUnchanged) {
return
}
baseParts[partIndex] = part baseParts[partIndex] = part
if (part.type !== "text" || !previousPart || previousPart.text !== part.text) {
shouldIncrementVersion = true
}
} }
const oldId = message.id const oldId = message.id
@@ -879,7 +905,16 @@ function handleMessageUpdate(instanceId: string, event: any): void {
message.status = message.status === "sending" ? "streaming" : message.status message.status = message.status === "sending" ? "streaming" : message.status
message.parts = baseParts message.parts = baseParts
message.displayParts = computeDisplayParts(message, preferences().showThinkingBlocks) if (shouldIncrementVersion) {
message.version += 1
message.displayParts = computeDisplayParts(message, preferences().showThinkingBlocks)
} else if (
!message.displayParts ||
message.displayParts.showThinking !== preferences().showThinkingBlocks ||
message.displayParts.version !== message.version
) {
message.displayParts = computeDisplayParts(message, preferences().showThinkingBlocks)
}
// Update message index if ID changed // Update message index if ID changed
if (oldId !== message.id) { if (oldId !== message.id) {
@@ -947,11 +982,17 @@ function handleMessageUpdate(instanceId: string, event: any): void {
if (tempMessageIndex > -1) { if (tempMessageIndex > -1) {
// Replace queued message // Replace queued message
const message = session.messages[tempMessageIndex] const message = session.messages[tempMessageIndex]
if (typeof message.version !== "number") {
message.version = 0
}
const oldId = message.id const oldId = message.id
message.id = info.id message.id = info.id
message.type = (info.role === "user" ? "user" : "assistant") as "user" | "assistant" message.type = (info.role === "user" ? "user" : "assistant") as "user" | "assistant"
message.timestamp = info.time?.created || Date.now() message.timestamp = info.time?.created || Date.now()
message.status = "complete" as const message.status = "complete" as const
message.version += 1
message.displayParts = computeDisplayParts(message, preferences().showThinkingBlocks)
if (oldId !== message.id) { if (oldId !== message.id) {
index.messageIndex.delete(oldId) index.messageIndex.delete(oldId)
@@ -971,6 +1012,7 @@ function handleMessageUpdate(instanceId: string, event: any): void {
parts: [], parts: [],
timestamp: info.time?.created || Date.now(), timestamp: info.time?.created || Date.now(),
status: "complete" as const, status: "complete" as const,
version: 0,
} }
newMessage.displayParts = computeDisplayParts(newMessage, preferences().showThinkingBlocks) newMessage.displayParts = computeDisplayParts(newMessage, preferences().showThinkingBlocks)
@@ -989,16 +1031,21 @@ function handleMessageUpdate(instanceId: string, event: any): void {
} else { } else {
// Update existing message status // Update existing message status
const message = session.messages[messageIndex] const message = session.messages[messageIndex]
if (typeof message.version !== "number") {
message.version = 0
}
message.status = "complete" as const message.status = "complete" as const
message.version += 1
message.displayParts = computeDisplayParts(message, preferences().showThinkingBlocks)
session.messagesInfo.set(info.id, info)
withSession(instanceId, info.sessionID, (session) => {
// Session already mutated in place
})
updateSessionInfo(instanceId, info.sessionID)
} }
session.messagesInfo.set(info.id, info)
withSession(instanceId, info.sessionID, (session) => {
// Session already mutated in place
})
updateSessionInfo(instanceId, info.sessionID)
} }
} }
@@ -1110,6 +1157,7 @@ async function sendMessage(
parts: optimisticParts, parts: optimisticParts,
timestamp: Date.now(), timestamp: Date.now(),
status: "sending", status: "sending",
version: 0,
} }
optimisticMessage.displayParts = computeDisplayParts(optimisticMessage, preferences().showThinkingBlocks) optimisticMessage.displayParts = computeDisplayParts(optimisticMessage, preferences().showThinkingBlocks)

View File

@@ -2,6 +2,9 @@ export interface MessageDisplayParts {
text: any[] text: any[]
tool: any[] tool: any[]
reasoning: any[] reasoning: any[]
combined: any[]
showThinking: boolean
version: number
} }
export interface Message { export interface Message {
@@ -11,5 +14,6 @@ export interface Message {
parts: any[] parts: any[]
timestamp: number timestamp: number
status: "sending" | "sent" | "streaming" | "complete" | "error" status: "sending" | "sent" | "streaming" | "complete" | "error"
version: number
displayParts?: MessageDisplayParts displayParts?: MessageDisplayParts
} }