Improve message stream caching and virtualization for large sessions

This commit is contained in:
Shantur Rathore
2025-11-26 13:30:20 +00:00
parent c77bfc2ee7
commit fad2809299
18 changed files with 1142 additions and 402 deletions

View File

@@ -1,9 +1,10 @@
import { createMemo, Show, onMount, createEffect } from "solid-js"
import { createMemo, Show, createEffect, onCleanup } from "solid-js"
import { DiffView, DiffModeEnum } from "@git-diff-view/solid"
import type { DiffHighlighterLang } from "@git-diff-view/core"
import { getLanguageFromPath } from "../lib/markdown"
import { normalizeDiffText } from "../lib/diff-utils"
import { setToolRenderCache } from "../lib/tool-render-cache"
import { setCacheEntry } from "../lib/global-cache"
import type { CacheEntryParams } from "../lib/global-cache"
import type { DiffViewMode } from "../stores/preferences"
interface ToolCallDiffViewerProps {
@@ -13,7 +14,7 @@ interface ToolCallDiffViewerProps {
mode: DiffViewMode
onRendered?: () => void
cachedHtml?: string
cacheKey?: string
cacheEntryParams?: CacheEntryParams
}
type DiffData = {
@@ -22,6 +23,13 @@ type DiffData = {
hunks: string[]
}
type CaptureContext = {
theme: ToolCallDiffViewerProps["theme"]
mode: DiffViewMode
diffText: string
cacheEntryParams?: CacheEntryParams
}
export function ToolCallDiffViewer(props: ToolCallDiffViewerProps) {
const diffData = createMemo<DiffData | null>(() => {
const normalized = normalizeDiffText(props.diffText)
@@ -46,30 +54,93 @@ export function ToolCallDiffViewer(props: ToolCallDiffViewerProps) {
})
let diffContainerRef: HTMLDivElement | undefined
let pendingCapture: number | undefined
let pendingContext: CaptureContext | undefined
let lastRenderedMarkup: string | undefined
let lastCachedHtml: string | undefined
const captureAndCacheHtml = () => {
if (diffContainerRef && props.cacheKey && !props.cachedHtml) {
// Extract the rendered HTML from DiffView container
const renderedHtml = diffContainerRef.innerHTML
if (renderedHtml) {
setToolRenderCache(props.cacheKey, {
text: props.diffText,
html: renderedHtml,
theme: props.theme,
mode: props.mode,
const clearPendingCapture = () => {
if (pendingCapture !== undefined) {
cancelAnimationFrame(pendingCapture)
pendingCapture = undefined
}
pendingContext = undefined
}
const runCapture = (context: CaptureContext) => {
if (!diffContainerRef) {
props.onRendered?.()
return
}
const markup = diffContainerRef.innerHTML
if (!markup) {
props.onRendered?.()
return
}
const hasChanged = markup !== lastRenderedMarkup
if (hasChanged) {
lastRenderedMarkup = markup
if (context.cacheEntryParams) {
setCacheEntry(context.cacheEntryParams, {
text: context.diffText,
html: markup,
theme: context.theme,
mode: context.mode,
})
}
}
props.onRendered?.()
}
// Also capture HTML when diff data changes
const scheduleCapture = (context: CaptureContext) => {
clearPendingCapture()
pendingContext = context
pendingCapture = requestAnimationFrame(() => {
const activeContext = pendingContext
pendingContext = undefined
pendingCapture = undefined
if (activeContext) {
runCapture(activeContext)
}
})
}
createEffect(() => {
const data = diffData()
if (data && !props.cachedHtml) {
// Delay to allow DiffView to re-render with new data
setTimeout(captureAndCacheHtml, 100)
const cachedHtml = props.cachedHtml
if (cachedHtml) {
clearPendingCapture()
if (cachedHtml !== lastCachedHtml) {
lastCachedHtml = cachedHtml
lastRenderedMarkup = cachedHtml
props.onRendered?.()
}
return
}
lastCachedHtml = undefined
const data = diffData()
const theme = props.theme
const mode = props.mode
if (!data) {
clearPendingCapture()
return
}
scheduleCapture({
theme,
mode,
diffText: props.diffText,
cacheEntryParams: props.cacheEntryParams,
})
})
onCleanup(() => {
clearPendingCapture()
})
return (

View File

@@ -1,26 +1,29 @@
import { For, Show, createMemo } from "solid-js"
import type { Message, SDKPart, MessageInfo, ClientPart } from "../types/message"
import type { MessageInfo, ClientPart } from "../types/message"
import { partHasRenderableText } from "../types/message"
import type { MessageRecord } from "../stores/message-v2/types"
import { formatTokenTotal } from "../lib/formatters"
import { preferences } from "../stores/preferences"
import MessagePart from "./message-part"
interface MessageItemProps {
message: Message
record: MessageRecord
messageInfo?: MessageInfo
instanceId: string
sessionId: string
isQueued?: boolean
parts?: ClientPart[]
combinedParts: ClientPart[]
orderedParts: ClientPart[]
onRevert?: (messageId: string) => void
onFork?: (messageId?: string) => void
}
export default function MessageItem(props: MessageItemProps) {
const isUser = () => props.message.type === "user"
const isUser = () => props.record.role === "user"
const showUsageMetrics = () => preferences().showUsageMetrics ?? true
const timestamp = () => {
const date = new Date(props.message.timestamp)
const createdTime = props.messageInfo?.time?.created ?? props.record.createdAt
const date = new Date(createdTime)
return date.toLocaleTimeString([], { hour: "2-digit", minute: "2-digit" })
}
@@ -30,10 +33,10 @@ export default function MessageItem(props: MessageItemProps) {
filename?: string
}
const displayParts = () => props.parts ?? props.message.parts
const combinedParts = () => props.combinedParts
const fileAttachments = () =>
props.message.parts.filter((part): part is FilePart => part?.type === "file" && typeof (part as FilePart).url === "string")
props.orderedParts.filter((part): part is FilePart => part?.type === "file" && typeof (part as FilePart).url === "string")
const getAttachmentName = (part: FilePart) => {
if (part.filename && part.filename.trim().length > 0) {
@@ -123,7 +126,7 @@ export default function MessageItem(props: MessageItemProps) {
return true
}
return displayParts().some((part) => partHasRenderableText(part))
return combinedParts().some((part) => partHasRenderableText(part))
}
const isGenerating = () => {
@@ -133,7 +136,7 @@ export default function MessageItem(props: MessageItemProps) {
const handleRevert = () => {
if (props.onRevert && isUser()) {
props.onRevert(props.message.id)
props.onRevert(props.record.id)
}
}
@@ -227,7 +230,7 @@ export default function MessageItem(props: MessageItemProps) {
<Show when={isUser() && props.onFork}>
<button
class="bg-transparent border border-[var(--border-base)] text-[var(--text-muted)] cursor-pointer px-3 py-0.5 rounded text-xs font-semibold leading-none transition-all duration-200 flex items-center justify-center h-6 hover:bg-[var(--surface-hover)] hover:border-[var(--accent-primary)] hover:text-[var(--accent-primary)] active:scale-95"
onClick={() => props.onFork?.(props.message.id)}
onClick={() => props.onFork?.(props.record.id)}
title="Fork from this message"
aria-label="Fork from this message"
>
@@ -254,11 +257,11 @@ export default function MessageItem(props: MessageItemProps) {
</div>
</Show>
<For each={displayParts()}>
<For each={combinedParts()}>
{(part) => (
<MessagePart
part={part}
messageType={props.message.type}
messageType={props.record.role}
instanceId={props.instanceId}
sessionId={props.sessionId}
/>
@@ -341,7 +344,7 @@ export default function MessageItem(props: MessageItemProps) {
</div>
)}
</Show>
<Show when={props.message.status === "sending"}>
<Show when={props.record.status === "sending"}>
<div class="message-sending">
@@ -349,7 +352,7 @@ export default function MessageItem(props: MessageItemProps) {
</div>
</Show>
<Show when={props.message.status === "error"}>
<Show when={props.record.status === "error"}>
<div class="message-error"> Message failed to send</div>
</Show>
</div>

View File

@@ -2,21 +2,50 @@ import { For, Show, createMemo, createSignal, createEffect, onCleanup } from "so
import MessageItem from "./message-item"
import ToolCall from "./tool-call"
import Kbd from "./kbd"
import type { Message, MessageInfo, ClientPart, MessageDisplayParts } from "../types/message"
import { partHasRenderableText } from "../types/message"
import type { MessageInfo, ClientPart } from "../types/message"
import { getSessionInfo } from "../stores/sessions"
import { showCommandPalette } from "../stores/command-palette"
import { messageStoreBus } from "../stores/message-v2/bus"
import type { MessageRecord } from "../stores/message-v2/types"
import { buildRecordDisplayData, clearRecordDisplayCacheForInstance, type ToolCallPart } from "../stores/message-v2/record-display-cache"
import { useConfig } from "../stores/preferences"
import { getScrollCache, setScrollCache } from "../lib/scroll-cache"
import { sseManager } from "../lib/sse-manager"
import { formatTokenTotal } from "../lib/formatters"
import { useScrollCache } from "../lib/hooks/use-scroll-cache"
const SCROLL_SCOPE = "session"
const TOOL_ICON = "🔧"
const codeNomadLogo = new URL("../images/CodeNomad-Icon.png", import.meta.url).href
const INITIAL_BATCH_COUNT = 150
const PREPEND_CHUNK_COUNT = 50
const LOAD_MORE_THRESHOLD_PX = 320
const ESTIMATED_MESSAGE_HEIGHT = 120
const messageItemCache = new Map<string, MessageDisplayItem>()
const toolItemCache = new Map<string, ToolDisplayItem>()
function makeInstanceCacheKey(instanceId: string, id: string) {
return `${instanceId}:${id}`
}
function clearInstanceCaches(instanceId: string) {
clearRecordDisplayCacheForInstance(instanceId)
const prefix = `${instanceId}:`
for (const key of messageItemCache.keys()) {
if (key.startsWith(prefix)) {
messageItemCache.delete(key)
}
}
for (const key of toolItemCache.keys()) {
if (key.startsWith(prefix)) {
toolItemCache.delete(key)
}
}
}
messageStoreBus.onInstanceDestroyed(clearInstanceCaches)
function formatTokens(tokens: number): string {
return formatTokenTotal(tokens)
}
@@ -32,8 +61,9 @@ interface MessageStreamV2Props {
interface MessageDisplayItem {
type: "message"
message: Message
record: MessageRecord
combinedParts: ClientPart[]
orderedParts: ClientPart[]
messageInfo?: MessageInfo
isQueued: boolean
}
@@ -48,61 +78,28 @@ interface ToolDisplayItem {
partVersion: number
}
type DisplayItem = MessageDisplayItem | ToolDisplayItem
type ToolCallPart = Extract<ClientPart, { type: "tool" }>
function isToolPart(part: ClientPart): part is ToolCallPart {
return part.type === "tool"
interface MessageDisplayBlock {
record: MessageRecord
messageItem: MessageDisplayItem | null
toolItems: ToolDisplayItem[]
}
function recordToMessage(record: MessageRecord): Message {
const parts = record.partIds
.map((partId) => record.parts[partId]?.data)
.filter((part): part is ClientPart => Boolean(part))
return {
id: record.id,
sessionId: record.sessionId,
type: record.role,
parts,
timestamp: record.createdAt,
status: record.status,
version: record.revision,
}
interface MeasurementEntry {
revision: number
height: number
}
function computeDisplayPartsForMessage(message: Message, showThinking: boolean): MessageDisplayParts {
const text: ClientPart[] = []
const tool: ClientPart[] = []
const reasoning: ClientPart[] = []
for (const part of message.parts) {
if (part.type === "text" && !part.synthetic && partHasRenderableText(part)) {
text.push(part)
} else if (part.type === "tool") {
tool.push(part)
} else if (part.type === "reasoning" && showThinking && partHasRenderableText(part)) {
reasoning.push(part)
}
}
const combined = reasoning.length > 0 ? [...text, ...reasoning] : [...text]
const version = typeof message.version === "number" ? message.version : 0
return { text, tool, reasoning, combined, showThinking, version }
}
function hasRenderableContent(message: Message, combinedParts: ClientPart[], info?: MessageInfo): boolean {
if (message.type !== "assistant" && message.type !== "user") {
function hasRenderableContent(record: MessageRecord, combinedParts: ClientPart[], info?: MessageInfo): boolean {
if (record.role !== "assistant" && record.role !== "user") {
return false
}
if (message.type !== "assistant" || combinedParts.length > 0) {
if (record.role !== "assistant" || combinedParts.length > 0) {
return true
}
if (info && info.role === "assistant" && info.error) {
return true
}
return message.status === "error"
return record.status === "error"
}
export default function MessageStreamV2(props: MessageStreamV2Props) {
@@ -115,6 +112,63 @@ export default function MessageStreamV2(props: MessageStreamV2Props) {
.filter((record): record is MessageRecord => Boolean(record)),
)
const [visibleRange, setVisibleRange] = createSignal({ start: 0, end: 0 })
const [rangeInitialized, setRangeInitialized] = createSignal(false)
const [forceFullHistory, setForceFullHistory] = createSignal(false)
const messageMeasurements = new Map<string, MeasurementEntry>()
const [measurementVersion, setMeasurementVersion] = createSignal(0)
const [virtualPadding, setVirtualPadding] = createSignal(0)
const [reachedAbsoluteTop, setReachedAbsoluteTop] = createSignal(false)
const showLoadOlderButton = createMemo(() => visibleRange().start > 0 && reachedAbsoluteTop())
function updateMeasurementCache(messageId: string, revision: number, height: number) {
const safeHeight = Math.max(0, height)
const existing = messageMeasurements.get(messageId)
if (existing && existing.revision === revision && Math.abs(existing.height - safeHeight) < 1) {
return
}
messageMeasurements.set(messageId, { revision, height: safeHeight })
setMeasurementVersion((value) => value + 1)
}
function getAverageMeasuredHeight() {
if (messageMeasurements.size === 0) {
return ESTIMATED_MESSAGE_HEIGHT
}
let total = 0
for (const entry of messageMeasurements.values()) {
total += entry.height
}
return total / messageMeasurements.size
}
const messageIndexMap = createMemo(() => {
const map = new Map<string, number>()
const records = messageRecords()
records.forEach((record, index) => map.set(record.id, index))
return map
})
const lastAssistantIndex = createMemo(() => {
const records = messageRecords()
for (let index = records.length - 1; index >= 0; index--) {
if (records[index].role === "assistant") {
return index
}
}
return -1
})
const visibleRecords = createMemo(() => {
const records = messageRecords()
const range = visibleRange()
if (range.end === 0) {
return records
}
return records.slice(range.start, range.end)
})
const sessionRevision = createMemo(() => store().getSessionRevision(props.sessionId))
const usageSnapshot = createMemo(() => store().getSessionUsage(props.sessionId))
const sessionInfo = createMemo(() =>
getSessionInfo(props.instanceId, props.sessionId) ?? {
@@ -145,30 +199,50 @@ export default function MessageStreamV2(props: MessageStreamV2Props) {
const messageInfoMap = createMemo(() => {
const map = new Map<string, MessageInfo>()
messageIds().forEach((id) => {
const info = store().getMessageInfo(id)
const records = visibleRecords()
records.forEach((record) => {
const info = store().getMessageInfo(record.id)
if (info) {
map.set(id, info)
map.set(record.id, info)
}
})
return map
})
const revertTarget = createMemo(() => store().getSessionRevert(props.sessionId))
const displayItems = createMemo<DisplayItem[]>(() => {
const scrollCache = useScrollCache({
instanceId: () => props.instanceId,
sessionId: () => props.sessionId,
scope: SCROLL_SCOPE,
})
let previousToken: string | undefined
createEffect(() => {
const sessionId = props.sessionId
store()
messageMeasurements.clear()
setMeasurementVersion((value) => value + 1)
setVirtualPadding(0)
setVisibleRange({ start: 0, end: 0 })
setRangeInitialized(false)
setReachedAbsoluteTop(false)
const snapshot = store().getScrollSnapshot(sessionId, SCROLL_SCOPE)
setForceFullHistory(Boolean(snapshot && !snapshot.atBottom))
previousToken = undefined
})
const displayBlocks = createMemo<MessageDisplayBlock[]>(() => {
const infoMap = messageInfoMap()
const showThinking = preferences().showThinkingBlocks
const revert = revertTarget()
const items: DisplayItem[] = []
const records = messageRecords()
let lastAssistantIndex = -1
for (let i = records.length - 1; i >= 0; i--) {
if (records[i].role === "assistant") {
lastAssistantIndex = i
break
}
}
const instanceId = props.instanceId
const blocks: MessageDisplayBlock[] = []
const usedMessageKeys = new Set<string>()
const usedToolKeys = new Set<string>()
const records = visibleRecords()
const globalAssistantIndex = lastAssistantIndex()
const indexMap = messageIndexMap()
for (let index = 0; index < records.length; index++) {
const record = records[index]
@@ -176,61 +250,197 @@ export default function MessageStreamV2(props: MessageStreamV2Props) {
break
}
const baseMessage = recordToMessage(record)
const displayParts = computeDisplayPartsForMessage(baseMessage, showThinking)
baseMessage.displayParts = displayParts
const combinedParts = displayParts.combined
const { orderedParts, textAndReasoningParts, toolParts } = buildRecordDisplayData(instanceId, record, showThinking)
const messageInfo = infoMap.get(record.id)
const isQueued =
baseMessage.type === "user" && (lastAssistantIndex === -1 || index > lastAssistantIndex)
const recordCacheKey = makeInstanceCacheKey(instanceId, record.id)
const recordIndex = indexMap.get(record.id) ?? 0
const isQueued = record.role === "user" && (globalAssistantIndex === -1 || recordIndex > globalAssistantIndex)
if (hasRenderableContent(baseMessage, combinedParts, messageInfo)) {
items.push({
type: "message",
message: baseMessage,
combinedParts,
messageInfo,
isQueued,
})
let messageItem: MessageDisplayItem | null = null
if (hasRenderableContent(record, textAndReasoningParts, messageInfo)) {
let cached = messageItemCache.get(recordCacheKey)
if (!cached) {
cached = {
type: "message",
record,
combinedParts: textAndReasoningParts,
orderedParts,
messageInfo,
isQueued,
}
messageItemCache.set(recordCacheKey, cached)
} else {
cached.record = record
cached.combinedParts = textAndReasoningParts
cached.orderedParts = orderedParts
cached.messageInfo = messageInfo
cached.isQueued = isQueued
}
messageItem = cached
usedMessageKeys.add(recordCacheKey)
}
const toolParts: ToolCallPart[] = displayParts.tool.filter(isToolPart)
const toolItems: ToolDisplayItem[] = []
toolParts.forEach((toolPart, toolIndex) => {
const partVersion = typeof toolPart.version === "number" ? toolPart.version : 0
const messageVersion = typeof baseMessage.version === "number" ? baseMessage.version : 0
const key = toolPart.id || `${record.id}-tool-${toolIndex}`
items.push({
type: "tool",
key,
toolPart,
messageInfo,
messageId: record.id,
messageVersion,
partVersion,
})
const messageVersion = record.revision
const key = `${record.id}:${toolPart.id ?? toolIndex}`
const toolCacheKey = makeInstanceCacheKey(instanceId, key)
let toolItem = toolItemCache.get(toolCacheKey)
if (!toolItem) {
toolItem = {
type: "tool",
key,
toolPart,
messageInfo,
messageId: record.id,
messageVersion,
partVersion,
}
toolItemCache.set(toolCacheKey, toolItem)
} else {
toolItem.key = key
toolItem.toolPart = toolPart
toolItem.messageInfo = messageInfo
toolItem.messageId = record.id
toolItem.messageVersion = messageVersion
toolItem.partVersion = partVersion
}
toolItems.push(toolItem)
usedToolKeys.add(toolCacheKey)
})
if (!messageItem && toolItems.length === 0) {
continue
}
blocks.push({ record, messageItem, toolItems })
}
return items
for (const key of messageItemCache.keys()) {
if (!usedMessageKeys.has(key)) {
messageItemCache.delete(key)
}
}
for (const key of toolItemCache.keys()) {
if (!usedToolKeys.has(key)) {
toolItemCache.delete(key)
}
}
return blocks
})
createEffect(() => {
const records = messageRecords()
const total = records.length
const requireFullHistory = forceFullHistory()
if (total === 0) {
setVisibleRange({ start: 0, end: 0 })
setRangeInitialized(false)
return
}
setVisibleRange((current) => {
if (!rangeInitialized() || requireFullHistory) {
const start = requireFullHistory ? 0 : Math.max(0, total - INITIAL_BATCH_COUNT)
if (!rangeInitialized()) {
setRangeInitialized(true)
}
if (requireFullHistory) {
setForceFullHistory(false)
}
return { start, end: total }
}
const nextEnd = total
let nextStart = current.start
if (nextStart > nextEnd) {
nextStart = Math.max(0, nextEnd - INITIAL_BATCH_COUNT)
}
return { start: nextStart, end: nextEnd }
})
})
createEffect(() => {
measurementVersion()
const range = visibleRange()
if (range.start <= 0) {
setVirtualPadding(0)
return
}
const records = messageRecords()
const trimmed = records.slice(0, range.start)
if (trimmed.length === 0) {
setVirtualPadding(0)
return
}
const fallback = getAverageMeasuredHeight()
let total = 0
for (const record of trimmed) {
const entry = messageMeasurements.get(record.id)
total += entry?.height ?? fallback
}
setVirtualPadding(total)
})
const changeToken = createMemo(() => {
const entries = displayItems()
return entries
.map((item) => {
if (item.type === "message") {
return `${item.message.id}:${item.message.version}:${item.combinedParts.length}`
}
const status = item.toolPart.state?.status || "unknown"
return `tool:${item.key}:${item.partVersion}:${status}`
})
.join("|")
const revisionValue = sessionRevision()
const range = visibleRange()
const blocks = displayBlocks()
if (blocks.length === 0) {
return `${revisionValue}:${range.start}:${range.end}:empty`
}
const lastBlock = blocks[blocks.length - 1]
const lastTool = lastBlock.toolItems[lastBlock.toolItems.length - 1]
const tailSignature = lastTool
? `tool:${lastTool.key}:${lastTool.partVersion}`
: `msg:${lastBlock.record.id}:${lastBlock.record.revision}`
return `${revisionValue}:${range.start}:${range.end}:${tailSignature}`
})
const [autoScroll, setAutoScroll] = createSignal(true)
const [showScrollButton, setShowScrollButton] = createSignal(false)
let containerRef: HTMLDivElement | undefined
function captureScrollSnapshot() {
if (!containerRef) return { height: 0, top: 0 }
return { height: containerRef.scrollHeight, top: containerRef.scrollTop }
}
function restoreScrollSnapshot(snapshot?: { height: number; top: number }) {
if (!containerRef || !snapshot) return
requestAnimationFrame(() => {
if (!containerRef) return
const delta = containerRef.scrollHeight - snapshot.height
containerRef.scrollTop = snapshot.top + delta
})
}
function prependChunk(amount = PREPEND_CHUNK_COUNT) {
if (visibleRange().start === 0) {
return
}
const snapshot = captureScrollSnapshot()
setVisibleRange((range) => {
if (range.start === 0) {
return range
}
const nextStart = Math.max(0, range.start - amount)
return { start: nextStart, end: range.end }
})
restoreScrollSnapshot(snapshot)
}
function loadAllOlderMessages() {
if (visibleRange().start === 0) {
return
}
const snapshot = captureScrollSnapshot()
setVisibleRange((range) => ({ start: 0, end: range.end }))
restoreScrollSnapshot(snapshot)
}
function isNearBottom(element: HTMLDivElement, offset = 48) {
const { scrollTop, scrollHeight, clientHeight } = element
return scrollHeight - (scrollTop + clientHeight) <= offset
@@ -246,41 +456,45 @@ export default function MessageStreamV2(props: MessageStreamV2Props) {
function persistScrollState() {
if (!containerRef) return
setScrollCache(
{ instanceId: props.instanceId, sessionId: props.sessionId, scope: SCROLL_SCOPE },
{
scrollTop: containerRef.scrollTop,
atBottom: isNearBottom(containerRef),
},
)
scrollCache.persist(containerRef, { atBottomOffset: 48 })
}
function handleScroll(event: Event) {
if (!containerRef) return
const atBottom = isNearBottom(containerRef)
setShowScrollButton(!atBottom)
const atAbsoluteTop = containerRef.scrollTop <= 4
setReachedAbsoluteTop(atAbsoluteTop)
if (event.isTrusted) {
setAutoScroll(atBottom)
if (containerRef.scrollTop <= LOAD_MORE_THRESHOLD_PX && visibleRange().start > 0) {
prependChunk()
}
}
persistScrollState()
}
createEffect(() => {
const scrollSnapshot = getScrollCache({ instanceId: props.instanceId, sessionId: props.sessionId, scope: SCROLL_SCOPE })
requestAnimationFrame(() => {
if (!containerRef) return
if (scrollSnapshot) {
const maxScrollTop = Math.max(containerRef.scrollHeight - containerRef.clientHeight, 0)
containerRef.scrollTop = Math.min(scrollSnapshot.scrollTop, maxScrollTop)
setAutoScroll(scrollSnapshot.atBottom)
setShowScrollButton(!scrollSnapshot.atBottom)
} else {
scrollToBottom(true)
}
const sessionId = props.sessionId
store()
const target = containerRef
if (!target) return
scrollCache.restore(target, {
fallback: () => scrollToBottom(true),
onApplied: (snapshot) => {
if (snapshot) {
setAutoScroll(snapshot.atBottom)
setShowScrollButton(!snapshot.atBottom)
} else {
const atBottom = isNearBottom(target)
setAutoScroll(atBottom)
setShowScrollButton(!atBottom)
}
},
})
void sessionId
})
let previousToken: string | undefined
createEffect(() => {
const token = changeToken()
if (!token || token === previousToken) {
@@ -293,7 +507,7 @@ export default function MessageStreamV2(props: MessageStreamV2Props) {
})
createEffect(() => {
if (displayItems().length === 0) {
if (messageRecords().length === 0) {
setShowScrollButton(false)
setAutoScroll(true)
}
@@ -358,7 +572,7 @@ export default function MessageStreamV2(props: MessageStreamV2Props) {
}}
onScroll={handleScroll}
>
<Show when={!props.loading && displayItems().length === 0}>
<Show when={!props.loading && messageRecords().length === 0}>
<div class="empty-state">
<div class="empty-state-content">
<div class="flex flex-col items-center gap-3 mb-6">
@@ -388,41 +602,84 @@ export default function MessageStreamV2(props: MessageStreamV2Props) {
</div>
</Show>
<For each={displayItems()}>
{(item) => {
if (item.type === "message") {
return (
<MessageItem
message={item.message}
messageInfo={item.messageInfo}
parts={item.combinedParts}
instanceId={props.instanceId}
sessionId={props.sessionId}
isQueued={item.isQueued}
onRevert={props.onRevert}
onFork={props.onFork}
/>
)
<Show when={virtualPadding() > 0}>
<div class="message-stream-virtual-padding" style={{ height: `${virtualPadding()}px` }} aria-hidden="true" />
</Show>
<Show when={showLoadOlderButton()}>
<div class="message-stream-load-older">
<button type="button" class="message-stream-load-older-button" onClick={loadAllOlderMessages}>
Load older messages
</button>
</div>
</Show>
<For each={displayBlocks()}>
{(block) => {
let blockRef: HTMLDivElement | undefined
const scheduleMeasurement = () => {
if (!blockRef) return
requestAnimationFrame(() => {
if (!blockRef) return
updateMeasurementCache(block.record.id, block.record.revision, blockRef.clientHeight)
})
}
createEffect(() => {
void block.record.revision
scheduleMeasurement()
})
return (
<div class="tool-call-message" data-key={item.key}>
<div class="tool-call-header-label">
<div class="tool-call-header-meta">
<span class="tool-call-icon">{TOOL_ICON}</span>
<span>Tool Call</span>
<span class="tool-name">{item.toolPart.tool || "unknown"}</span>
</div>
</div>
<ToolCall
toolCall={item.toolPart}
toolCallId={item.key}
messageId={item.messageId}
messageVersion={item.messageVersion}
partVersion={item.partVersion}
instanceId={props.instanceId}
sessionId={props.sessionId}
/>
<div
class="message-stream-block"
data-message-id={block.record.id}
ref={(element) => {
blockRef = element || undefined
if (element) {
scheduleMeasurement()
}
}}
>
<Show when={block.messageItem} keyed>
{(message) => (
<MessageItem
record={message.record}
messageInfo={message.messageInfo}
combinedParts={message.combinedParts}
orderedParts={message.orderedParts}
instanceId={props.instanceId}
sessionId={props.sessionId}
isQueued={message.isQueued}
onRevert={props.onRevert}
onFork={props.onFork}
/>
)}
</Show>
<For each={block.toolItems}>
{(item) => (
<div class="tool-call-message" data-key={item.key}>
<div class="tool-call-header-label">
<div class="tool-call-header-meta">
<span class="tool-call-icon">{TOOL_ICON}</span>
<span>Tool Call</span>
<span class="tool-name">{item.toolPart.tool || "unknown"}</span>
</div>
</div>
<ToolCall
toolCall={item.toolPart}
toolCallId={item.key}
messageId={item.messageId}
messageVersion={item.messageVersion}
partVersion={item.partVersion}
instanceId={props.instanceId}
sessionId={props.sessionId}
/>
</div>
)}
</For>
</div>
)
}}

View File

@@ -6,11 +6,12 @@ import { ToolCallDiffViewer } from "./diff-viewer"
import { useTheme } from "../lib/theme"
import { getLanguageFromPath } from "../lib/markdown"
import { isRenderableDiffText } from "../lib/diff-utils"
import { getToolRenderCache, setToolRenderCache } from "../lib/tool-render-cache"
import { useGlobalCache } from "../lib/hooks/use-global-cache"
import { useScrollCache } from "../lib/hooks/use-scroll-cache"
import { useConfig } from "../stores/preferences"
import type { DiffViewMode } from "../stores/preferences"
import { sendPermissionResponse } from "../stores/instances"
import type { TextPart, SDKPart, ClientPart } from "../types/message"
import type { TextPart, SDKPart, ClientPart, RenderCache } from "../types/message"
type ToolCallPart = Extract<ClientPart, { type: "tool" }>
@@ -34,46 +35,19 @@ function isToolStateError(state: ToolState): state is ToolStateError {
}
const toolScrollState = new Map<string, { scrollTop: number; atBottom: boolean }>()
const TOOL_CALL_CACHE_SCOPE = "tool-call"
function makeRenderCacheKey(
toolCallId?: string | null,
messageId?: string,
messageVersion?: number,
partVersion?: number,
variant = "default",
) {
const suffix = `${messageVersion ?? 0}:${partVersion ?? 0}`
const keyBase = `${messageId}:${toolCallId}`
return `${keyBase}::${suffix}`
}
function updateScrollState(id: string, element: HTMLElement) {
if (!id) return
const distanceFromBottom = element.scrollHeight - (element.scrollTop + element.clientHeight)
const atBottom = distanceFromBottom <= 2
toolScrollState.set(id, { scrollTop: element.scrollTop, atBottom })
}
function restoreScrollState(id: string, element: HTMLElement) {
if (!id) return
const state = toolScrollState.get(id)
if (!state) {
requestAnimationFrame(() => {
element.scrollTop = element.scrollHeight
updateScrollState(id, element)
})
return
}
requestAnimationFrame(() => {
if (state.atBottom) {
element.scrollTop = element.scrollHeight
} else {
const maxScrollTop = Math.max(element.scrollHeight - element.clientHeight, 0)
element.scrollTop = Math.min(state.scrollTop, maxScrollTop)
}
updateScrollState(id, element)
})
const messageComponent = messageId ?? "unknown-message"
const toolCallComponent = toolCallId ?? "unknown-tool-call"
const versionComponent = `${messageVersion ?? 0}:${partVersion ?? 0}`
return `${messageComponent}:${toolCallComponent}:${versionComponent}:${variant}`
}
@@ -348,6 +322,34 @@ export default function ToolCall(props: ToolCallProps) {
const { isDark } = useTheme()
const toolCallId = () => props.toolCallId || props.toolCall?.id || ""
const store = createMemo(() => messageStoreBus.getOrCreate(props.instanceId))
const cacheContext = createMemo(() => ({
toolCallId: toolCallId(),
messageId: props.messageId,
messageVersion: props.messageVersion ?? 0,
partVersion: props.partVersion ?? 0,
}))
const createVariantCache = (variant: string) =>
useGlobalCache({
instanceId: () => props.instanceId,
sessionId: () => props.sessionId,
scope: TOOL_CALL_CACHE_SCOPE,
key: () => {
const context = cacheContext()
return makeRenderCacheKey(
context.toolCallId || undefined,
context.messageId,
context.messageVersion,
context.partVersion,
variant,
)
},
})
const diffCache = createVariantCache("diff")
const permissionDiffCache = createVariantCache("permission-diff")
const markdownCache = createVariantCache("markdown")
const permissionState = createMemo(() => store().getPermissionState(props.messageId, toolCallId() || props.toolCall?.id))
const pendingPermission = createMemo(() => {
const state = permissionState()
@@ -383,30 +385,49 @@ export default function ToolCall(props: ToolCallProps) {
let scrollContainerRef: HTMLDivElement | undefined
let toolCallRootRef: HTMLDivElement | undefined
const handleScrollRendered = () => {
const id = toolCallId()
if (!id || !scrollContainerRef) return
restoreScrollState(id, scrollContainerRef)
const scrollScopeId = createMemo(() => {
const id = toolCallId()
if (id) return id
const messageKey = props.messageId || "unknown"
const partKey = typeof props.partVersion === "number" ? props.partVersion : 0
return `${messageKey}:${partKey}`
})
const scrollCache = useScrollCache({
instanceId: () => props.instanceId,
sessionId: () => props.sessionId,
scope: () => `${TOOL_CALL_CACHE_SCOPE}:scroll:${scrollScopeId()}`,
})
const persistScrollSnapshot = (element?: HTMLElement | null) => {
if (!element) return
scrollCache.persist(element, { atBottomOffset: 2 })
}
const restoreScrollSnapshot = (element?: HTMLElement | null) => {
if (!element) return
scrollCache.restore(element, {
fallback: () => {
requestAnimationFrame(() => {
if (!element || !element.isConnected) return
element.scrollTop = element.scrollHeight
persistScrollSnapshot(element)
})
},
})
}
const handleScrollRendered = () => {
if (!scrollContainerRef) return
restoreScrollSnapshot(scrollContainerRef)
}
const initializeScrollContainer = (element: HTMLDivElement | null | undefined) => {
const resolvedElement = element || undefined
scrollContainerRef = resolvedElement
const id = toolCallId()
if (!resolvedElement || !id) return
if (!toolScrollState.has(id)) {
requestAnimationFrame(() => {
if (!scrollContainerRef || toolCallId() !== id) return
scrollContainerRef.scrollTop = scrollContainerRef.scrollHeight
updateScrollState(id, scrollContainerRef)
})
} else {
restoreScrollState(id, resolvedElement)
}
if (!resolvedElement) return
restoreScrollSnapshot(resolvedElement)
}
createEffect(() => {
@@ -435,16 +456,6 @@ export default function ToolCall(props: ToolCallProps) {
}
})
// Cleanup cache entry when component unmounts or toolCallId changes
createEffect(() => {
const id = toolCallId()
if (!id) return
onCleanup(() => {
toolScrollState.delete(id)
})
})
createEffect(() => {
if (props.toolCall?.tool !== "task") return
const state = props.toolCall?.state
@@ -734,25 +745,20 @@ export default function ToolCall(props: ToolCallProps) {
return renderMarkdownTool(toolName, state)
}
function renderDiffTool(payload: DiffPayload, options?: { cacheKeySuffix?: string; disableScrollTracking?: boolean; label?: string }) {
function renderDiffTool(payload: DiffPayload, options?: { variant?: string; disableScrollTracking?: boolean; label?: string }) {
const relativePath = payload.filePath ? getRelativePath(payload.filePath) : ""
const toolbarLabel = options?.label || (relativePath ? `Diff · ${relativePath}` : "Diff")
const cacheKeyBase = makeRenderCacheKey(toolCallId(), props.messageId, props.messageVersion, props.partVersion)
const cacheKey = options?.cacheKeySuffix ? `${cacheKeyBase}${options.cacheKeySuffix}` : cacheKeyBase
const selectedVariant = options?.variant === "permission-diff" ? "permission-diff" : "diff"
const cacheHandle = selectedVariant === "permission-diff" ? permissionDiffCache : diffCache
const diffMode = () => (preferences().diffViewMode || "split") as DiffViewMode
const themeKey = isDark() ? "dark" : "light"
// Check if we have valid cache
let cachedHtml: string | undefined
if (cacheKey) {
const cached = getToolRenderCache(cacheKey)
const currentMode = diffMode()
if (cached &&
cached.text === payload.diffText &&
cached.theme === themeKey &&
cached.mode === currentMode) {
cachedHtml = cached.html
}
const cached = cacheHandle.get<RenderCache>()
const currentMode = diffMode()
if (cached && cached.text === payload.diffText && cached.theme === themeKey && cached.mode === currentMode) {
cachedHtml = cached.html
}
const handleModeChange = (mode: DiffViewMode) => {
@@ -760,10 +766,6 @@ export default function ToolCall(props: ToolCallProps) {
}
const handleDiffRendered = () => {
if (cacheKey && !cachedHtml) {
// Cache will be updated by the diff viewer component itself
// We'll capture HTML from the rendered component
}
if (!options?.disableScrollTracking) {
handleScrollRendered()
}
@@ -776,7 +778,7 @@ export default function ToolCall(props: ToolCallProps) {
if (options?.disableScrollTracking) return
initializeScrollContainer(element)
}}
onScroll={options?.disableScrollTracking ? undefined : (event) => updateScrollState(toolCallId(), event.currentTarget)}
onScroll={options?.disableScrollTracking ? undefined : (event) => persistScrollSnapshot(event.currentTarget)}
>
<div class="tool-call-diff-toolbar" role="group" aria-label="Diff view mode">
@@ -806,7 +808,7 @@ export default function ToolCall(props: ToolCallProps) {
theme={themeKey}
mode={diffMode()}
cachedHtml={cachedHtml}
cacheKey={cacheKey}
cacheEntryParams={cacheHandle.params()}
onRendered={handleDiffRendered}
/>
</div>
@@ -822,20 +824,15 @@ export default function ToolCall(props: ToolCallProps) {
const isLarge = toolName === "edit" || toolName === "write" || toolName === "patch"
const messageClass = `message-text tool-call-markdown${isLarge ? " tool-call-markdown-large" : ""}`
const disableHighlight = state?.status === "running"
const cacheKey = makeRenderCacheKey(toolCallId(), props.messageId, props.messageVersion, props.partVersion)
const markdownPart: TextPart = { type: "text", text: content }
if (cacheKey) {
const cached = getToolRenderCache(cacheKey)
if (cached) {
markdownPart.renderCache = cached
}
const cached = markdownCache.get<RenderCache>()
if (cached) {
markdownPart.renderCache = cached
}
const handleMarkdownRendered = () => {
if (cacheKey) {
setToolRenderCache(cacheKey, markdownPart.renderCache)
}
markdownCache.set(markdownPart.renderCache)
handleScrollRendered()
}
@@ -843,7 +840,7 @@ export default function ToolCall(props: ToolCallProps) {
<div
class={messageClass}
ref={(element) => initializeScrollContainer(element)}
onScroll={(event) => updateScrollState(toolCallId(), event.currentTarget)}
onScroll={(event) => persistScrollSnapshot(event.currentTarget)}
>
<Markdown
part={markdownPart}
@@ -1053,7 +1050,7 @@ export default function ToolCall(props: ToolCallProps) {
<div
class="message-text tool-call-markdown tool-call-task-container"
ref={(element) => initializeScrollContainer(element)}
onScroll={(event) => updateScrollState(toolCallId(), event.currentTarget)}
onScroll={(event) => persistScrollSnapshot(event.currentTarget)}
>
<div class="tool-call-task-summary">
<For each={summary}>
@@ -1131,7 +1128,7 @@ export default function ToolCall(props: ToolCallProps) {
{(payload) => (
<div class="tool-call-permission-diff">
{renderDiffTool(payload(), {
cacheKeySuffix: "::permission",
variant: "permission-diff",
disableScrollTracking: true,
label: payload().filePath ? `Requested diff · ${getRelativePath(payload().filePath || "")}` : "Requested diff",
})}