Improve message stream caching and virtualization for large sessions
This commit is contained in:
@@ -20,7 +20,7 @@ import { preferences } from "./preferences"
|
||||
import { setSessionPendingPermission } from "./session-state"
|
||||
import { setHasInstances } from "./ui"
|
||||
import { messageStoreBus } from "./message-v2/bus"
|
||||
import { clearScrollCacheForInstance } from "../lib/scroll-cache"
|
||||
import { clearCacheForInstance } from "../lib/global-cache"
|
||||
import type { MessageRecord } from "./message-v2/types"
|
||||
|
||||
|
||||
@@ -296,7 +296,7 @@ function removeInstance(id: string) {
|
||||
}
|
||||
|
||||
// Clean up session indexes and drafts for removed instance
|
||||
clearScrollCacheForInstance(id)
|
||||
clearCacheForInstance(id)
|
||||
messageStoreBus.unregisterInstance(id)
|
||||
clearInstanceDraftPrompts(id)
|
||||
}
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import { createInstanceMessageStore } from "./instance-store"
|
||||
import type { InstanceMessageStore } from "./instance-store"
|
||||
import { clearCacheForInstance } from "../../lib/global-cache"
|
||||
|
||||
class MessageStoreBus {
|
||||
private stores = new Map<string, InstanceMessageStore>()
|
||||
private teardownHandlers = new Set<(instanceId: string) => void>()
|
||||
|
||||
registerInstance(instanceId: string, store?: InstanceMessageStore): InstanceMessageStore {
|
||||
if (this.stores.has(instanceId)) {
|
||||
@@ -22,20 +24,41 @@ class MessageStoreBus {
|
||||
return this.registerInstance(instanceId)
|
||||
}
|
||||
|
||||
onInstanceDestroyed(handler: (instanceId: string) => void): () => void {
|
||||
this.teardownHandlers.add(handler)
|
||||
return () => {
|
||||
this.teardownHandlers.delete(handler)
|
||||
}
|
||||
}
|
||||
|
||||
unregisterInstance(instanceId: string) {
|
||||
const store = this.stores.get(instanceId)
|
||||
if (store) {
|
||||
store.clearInstance()
|
||||
}
|
||||
clearCacheForInstance(instanceId)
|
||||
this.notifyInstanceDestroyed(instanceId)
|
||||
this.stores.delete(instanceId)
|
||||
}
|
||||
|
||||
clearAll() {
|
||||
for (const [instanceId, store] of this.stores.entries()) {
|
||||
store.clearInstance()
|
||||
clearCacheForInstance(instanceId)
|
||||
this.notifyInstanceDestroyed(instanceId)
|
||||
this.stores.delete(instanceId)
|
||||
}
|
||||
}
|
||||
|
||||
private notifyInstanceDestroyed(instanceId: string) {
|
||||
for (const handler of this.teardownHandlers) {
|
||||
try {
|
||||
handler(instanceId)
|
||||
} catch (error) {
|
||||
console.error("Failed to run message store teardown handler", error)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const messageStoreBus = new MessageStoreBus()
|
||||
|
||||
@@ -24,6 +24,7 @@ function createInitialState(instanceId: string): InstanceMessageState {
|
||||
messages: {},
|
||||
messageInfoVersion: {},
|
||||
pendingParts: {},
|
||||
sessionRevisions: {},
|
||||
permissions: {
|
||||
queue: [],
|
||||
active: null,
|
||||
@@ -41,8 +42,52 @@ function ensurePartId(messageId: string, part: ClientPart, index: number): strin
|
||||
return `${messageId}-part-${index}`
|
||||
}
|
||||
|
||||
const PENDING_PART_MAX_AGE_MS = 30_000
|
||||
|
||||
function clonePart(part: ClientPart): ClientPart {
|
||||
return JSON.parse(JSON.stringify(part)) as ClientPart
|
||||
if (!part || typeof part !== "object") {
|
||||
return part
|
||||
}
|
||||
const cloned: Record<string, any> = { ...part }
|
||||
if ("renderCache" in cloned) {
|
||||
cloned.renderCache = undefined
|
||||
}
|
||||
if ("text" in cloned) {
|
||||
cloned.text = cloneStructuredValue(cloned.text)
|
||||
}
|
||||
if ("thinking" in cloned && typeof cloned.thinking === "object") {
|
||||
cloned.thinking = cloneStructuredValue(cloned.thinking)
|
||||
}
|
||||
if ("content" in cloned && Array.isArray(cloned.content)) {
|
||||
cloned.content = cloneStructuredValue(cloned.content)
|
||||
}
|
||||
return cloned as ClientPart
|
||||
}
|
||||
|
||||
function cloneStructuredValue<T>(value: T): T {
|
||||
if (Array.isArray(value)) {
|
||||
return value.map((item) => cloneStructuredValue(item)) as T
|
||||
}
|
||||
if (value && typeof value === "object") {
|
||||
const next: Record<string, any> = {}
|
||||
Object.entries(value as Record<string, any>).forEach(([key, nested]) => {
|
||||
next[key] = cloneStructuredValue(nested)
|
||||
})
|
||||
return next as T
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
function areMessageIdListsEqual(a: string[], b: string[]): boolean {
|
||||
if (a.length !== b.length) {
|
||||
return false
|
||||
}
|
||||
for (let index = 0; index < a.length; index++) {
|
||||
if (a[index] !== b[index]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
function createEmptyUsageState(): SessionUsageState {
|
||||
@@ -158,6 +203,7 @@ export interface InstanceMessageStore {
|
||||
getSessionUsage: (sessionId: string) => SessionUsageState | undefined
|
||||
setScrollSnapshot: (sessionId: string, scope: string, snapshot: Omit<ScrollSnapshot, "updatedAt">) => void
|
||||
getScrollSnapshot: (sessionId: string, scope: string) => ScrollSnapshot | undefined
|
||||
getSessionRevision: (sessionId: string) => number
|
||||
getSessionMessageIds: (sessionId: string) => string[]
|
||||
getMessage: (messageId: string) => MessageRecord | undefined
|
||||
clearInstance: () => void
|
||||
@@ -167,6 +213,15 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
|
||||
const [state, setState] = createStore<InstanceMessageState>(createInitialState(instanceId))
|
||||
const messageInfoCache = new Map<string, MessageInfo>()
|
||||
|
||||
function bumpSessionRevision(sessionId: string) {
|
||||
if (!sessionId) return
|
||||
setState("sessionRevisions", sessionId, (value = 0) => value + 1)
|
||||
}
|
||||
|
||||
function getSessionRevisionValue(sessionId: string) {
|
||||
return state.sessionRevisions[sessionId] ?? 0
|
||||
}
|
||||
|
||||
function withUsageState(sessionId: string, updater: (draft: SessionUsageState) => void) {
|
||||
setState("usage", sessionId, (current) => {
|
||||
const draft = current
|
||||
@@ -223,6 +278,7 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
|
||||
|
||||
function addOrUpdateSession(input: SessionUpsertInput) {
|
||||
const session = ensureSessionEntry(input.id)
|
||||
const previousIds = [...session.messageIds]
|
||||
const nextMessageIds = Array.isArray(input.messageIds) ? input.messageIds : session.messageIds
|
||||
|
||||
setState("sessions", input.id, {
|
||||
@@ -233,6 +289,10 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
|
||||
messageIds: nextMessageIds,
|
||||
revert: input.revert ?? session.revert ?? null,
|
||||
})
|
||||
|
||||
if (Array.isArray(input.messageIds) && !areMessageIdListsEqual(previousIds, nextMessageIds)) {
|
||||
bumpSessionRevision(input.id)
|
||||
}
|
||||
}
|
||||
|
||||
function hydrateMessages(sessionId: string, inputs: MessageUpsertInput[], infos?: Iterable<MessageInfo>) {
|
||||
@@ -303,7 +363,7 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
|
||||
|
||||
setState("messages", (prev) => ({ ...prev, ...nextMessages }))
|
||||
setState("messageInfoVersion", (prev) => ({ ...prev, ...nextMessageInfoVersion }))
|
||||
setState("pendingParts", (prev) => ({ ...prev, ...nextPendingParts }))
|
||||
setState("pendingParts", () => nextPendingParts)
|
||||
setState("permissions", "byMessage", (prev) => ({ ...prev, ...nextPermissionsByMessage }))
|
||||
|
||||
if (usageState) {
|
||||
@@ -315,6 +375,8 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
|
||||
messageIds: incomingIds,
|
||||
updatedAt: Date.now(),
|
||||
}))
|
||||
|
||||
bumpSessionRevision(sessionId)
|
||||
}
|
||||
|
||||
function insertMessageIntoSession(sessionId: string, messageId: string) {
|
||||
@@ -374,12 +436,24 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
|
||||
|
||||
insertMessageIntoSession(input.sessionId, input.id)
|
||||
flushPendingParts(input.id)
|
||||
bumpSessionRevision(input.sessionId)
|
||||
}
|
||||
|
||||
function bufferPendingPart(entry: PendingPartEntry) {
|
||||
setState("pendingParts", entry.messageId, (list = []) => [...list, entry])
|
||||
}
|
||||
|
||||
function clearPendingPartsForMessage(messageId: string) {
|
||||
setState("pendingParts", (prev) => {
|
||||
if (!prev[messageId]) {
|
||||
return prev
|
||||
}
|
||||
const next = { ...prev }
|
||||
delete next[messageId]
|
||||
return next
|
||||
})
|
||||
}
|
||||
|
||||
function applyPartUpdate(input: PartUpdateInput) {
|
||||
const message = state.messages[input.messageId]
|
||||
if (!message) {
|
||||
@@ -417,12 +491,14 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
|
||||
if (!pending || pending.length === 0) {
|
||||
return
|
||||
}
|
||||
pending.forEach((entry) => applyPartUpdate({ messageId, part: entry.part }))
|
||||
setState("pendingParts", (prev) => {
|
||||
const next = { ...prev }
|
||||
delete next[messageId]
|
||||
return next
|
||||
})
|
||||
const now = Date.now()
|
||||
const validEntries = pending.filter((entry) => now - entry.receivedAt <= PENDING_PART_MAX_AGE_MS)
|
||||
if (validEntries.length === 0) {
|
||||
clearPendingPartsForMessage(messageId)
|
||||
return
|
||||
}
|
||||
validEntries.forEach((entry) => applyPartUpdate({ messageId, part: entry.part }))
|
||||
clearPendingPartsForMessage(messageId)
|
||||
}
|
||||
|
||||
function replaceMessageId(options: ReplaceMessageIdOptions) {
|
||||
@@ -444,6 +520,8 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
|
||||
return next
|
||||
})
|
||||
|
||||
const affectedSessions = new Set<string>()
|
||||
|
||||
Object.values(state.sessions).forEach((session) => {
|
||||
const index = session.messageIds.indexOf(options.oldId)
|
||||
if (index === -1) return
|
||||
@@ -452,8 +530,11 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
|
||||
next[index] = options.newId
|
||||
return next
|
||||
})
|
||||
affectedSessions.add(session.id)
|
||||
})
|
||||
|
||||
affectedSessions.forEach((sessionId) => bumpSessionRevision(sessionId))
|
||||
|
||||
const infoEntry = messageInfoCache.get(options.oldId)
|
||||
if (infoEntry) {
|
||||
messageInfoCache.set(options.newId, infoEntry)
|
||||
@@ -482,12 +563,8 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
|
||||
const pending = state.pendingParts[options.oldId]
|
||||
if (pending) {
|
||||
setState("pendingParts", options.newId, pending)
|
||||
setState("pendingParts", (prev) => {
|
||||
const next = { ...prev }
|
||||
delete next[options.oldId]
|
||||
return next
|
||||
})
|
||||
}
|
||||
clearPendingPartsForMessage(options.oldId)
|
||||
}
|
||||
|
||||
function setMessageInfo(messageId: string, info: MessageInfo) {
|
||||
@@ -608,6 +685,7 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
|
||||
getSessionUsage,
|
||||
setScrollSnapshot,
|
||||
getScrollSnapshot,
|
||||
getSessionRevision: getSessionRevisionValue,
|
||||
getSessionMessageIds: (sessionId: string) => state.sessions[sessionId]?.messageIds ?? [],
|
||||
getMessage: (messageId: string) => state.messages[messageId],
|
||||
clearInstance,
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import { decodeHtmlEntities } from "../../lib/markdown"
|
||||
import { partHasRenderableText } from "../../types/message"
|
||||
import type { MessageDisplayParts, Message } from "../../types/message"
|
||||
|
||||
function decodeTextSegment(segment: any): any {
|
||||
if (typeof segment === "string") {
|
||||
@@ -74,23 +72,3 @@ export function normalizeMessagePart(part: any): any {
|
||||
return normalized
|
||||
}
|
||||
|
||||
export function computeDisplayParts(message: Message, showThinking: boolean): MessageDisplayParts {
|
||||
const text: any[] = []
|
||||
const tool: any[] = []
|
||||
const reasoning: any[] = []
|
||||
|
||||
for (const part of message.parts) {
|
||||
if (part.type === "text" && !part.synthetic && partHasRenderableText(part)) {
|
||||
text.push(part)
|
||||
} else if (part.type === "tool") {
|
||||
tool.push(part)
|
||||
} else if (part.type === "reasoning" && showThinking && partHasRenderableText(part)) {
|
||||
reasoning.push(part)
|
||||
}
|
||||
}
|
||||
|
||||
const combined = reasoning.length > 0 ? [...text, ...reasoning] : [...text]
|
||||
const version = typeof message.version === "number" ? message.version : 0
|
||||
|
||||
return { text, tool, reasoning, combined, showThinking, version }
|
||||
}
|
||||
|
||||
72
packages/ui/src/stores/message-v2/record-display-cache.ts
Normal file
72
packages/ui/src/stores/message-v2/record-display-cache.ts
Normal file
@@ -0,0 +1,72 @@
|
||||
import type { ClientPart } from "../../types/message"
|
||||
import { partHasRenderableText } from "../../types/message"
|
||||
import type { MessageRecord } from "./types"
|
||||
|
||||
export type ToolCallPart = Extract<ClientPart, { type: "tool" }>
|
||||
|
||||
export interface RecordDisplayData {
|
||||
orderedParts: ClientPart[]
|
||||
textAndReasoningParts: ClientPart[]
|
||||
toolParts: ToolCallPart[]
|
||||
}
|
||||
|
||||
interface RecordDisplayCacheEntry {
|
||||
revision: number
|
||||
data: RecordDisplayData
|
||||
}
|
||||
|
||||
const recordDisplayCache = new Map<string, RecordDisplayCacheEntry>()
|
||||
|
||||
function makeCacheKey(instanceId: string, messageId: string, showThinking: boolean) {
|
||||
return `${instanceId}:${messageId}:${showThinking ? 1 : 0}`
|
||||
}
|
||||
|
||||
function isToolPart(part: ClientPart): part is ToolCallPart {
|
||||
return part.type === "tool"
|
||||
}
|
||||
|
||||
export function buildRecordDisplayData(instanceId: string, record: MessageRecord, showThinking: boolean): RecordDisplayData {
|
||||
const cacheKey = makeCacheKey(instanceId, record.id, showThinking)
|
||||
const cached = recordDisplayCache.get(cacheKey)
|
||||
if (cached && cached.revision === record.revision) {
|
||||
return cached.data
|
||||
}
|
||||
|
||||
const orderedParts: ClientPart[] = []
|
||||
const textAndReasoningParts: ClientPart[] = []
|
||||
const toolParts: ToolCallPart[] = []
|
||||
|
||||
for (const partId of record.partIds) {
|
||||
const entry = record.parts[partId]
|
||||
if (!entry?.data) continue
|
||||
const part = entry.data
|
||||
orderedParts.push(part)
|
||||
|
||||
if (isToolPart(part)) {
|
||||
toolParts.push(part)
|
||||
continue
|
||||
}
|
||||
|
||||
if (part.type === "text" && !part.synthetic && partHasRenderableText(part)) {
|
||||
textAndReasoningParts.push(part)
|
||||
continue
|
||||
}
|
||||
|
||||
if (part.type === "reasoning" && showThinking && partHasRenderableText(part)) {
|
||||
textAndReasoningParts.push(part)
|
||||
}
|
||||
}
|
||||
|
||||
const data = { orderedParts, textAndReasoningParts, toolParts }
|
||||
recordDisplayCache.set(cacheKey, { revision: record.revision, data })
|
||||
return data
|
||||
}
|
||||
|
||||
export function clearRecordDisplayCacheForInstance(instanceId: string) {
|
||||
const prefix = `${instanceId}:`
|
||||
for (const key of recordDisplayCache.keys()) {
|
||||
if (key.startsWith(prefix)) {
|
||||
recordDisplayCache.delete(key)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -95,8 +95,7 @@ export interface InstanceMessageState {
|
||||
messages: Record<string, MessageRecord>
|
||||
messageInfoVersion: Record<string, number>
|
||||
pendingParts: Record<string, PendingPartEntry[]>
|
||||
|
||||
|
||||
sessionRevisions: Record<string, number>
|
||||
permissions: InstancePermissionState
|
||||
usage: Record<string, SessionUsageState>
|
||||
scrollState: Record<string, ScrollSnapshot>
|
||||
|
||||
Reference in New Issue
Block a user