Improve message stream caching and virtualization for large sessions

This commit is contained in:
Shantur Rathore
2025-11-26 13:30:20 +00:00
parent c77bfc2ee7
commit fad2809299
18 changed files with 1142 additions and 402 deletions

View File

@@ -1,8 +1,10 @@
import { createInstanceMessageStore } from "./instance-store"
import type { InstanceMessageStore } from "./instance-store"
import { clearCacheForInstance } from "../../lib/global-cache"
class MessageStoreBus {
private stores = new Map<string, InstanceMessageStore>()
private teardownHandlers = new Set<(instanceId: string) => void>()
registerInstance(instanceId: string, store?: InstanceMessageStore): InstanceMessageStore {
if (this.stores.has(instanceId)) {
@@ -22,20 +24,41 @@ class MessageStoreBus {
return this.registerInstance(instanceId)
}
onInstanceDestroyed(handler: (instanceId: string) => void): () => void {
this.teardownHandlers.add(handler)
return () => {
this.teardownHandlers.delete(handler)
}
}
unregisterInstance(instanceId: string) {
const store = this.stores.get(instanceId)
if (store) {
store.clearInstance()
}
clearCacheForInstance(instanceId)
this.notifyInstanceDestroyed(instanceId)
this.stores.delete(instanceId)
}
clearAll() {
for (const [instanceId, store] of this.stores.entries()) {
store.clearInstance()
clearCacheForInstance(instanceId)
this.notifyInstanceDestroyed(instanceId)
this.stores.delete(instanceId)
}
}
private notifyInstanceDestroyed(instanceId: string) {
for (const handler of this.teardownHandlers) {
try {
handler(instanceId)
} catch (error) {
console.error("Failed to run message store teardown handler", error)
}
}
}
}
export const messageStoreBus = new MessageStoreBus()

View File

@@ -24,6 +24,7 @@ function createInitialState(instanceId: string): InstanceMessageState {
messages: {},
messageInfoVersion: {},
pendingParts: {},
sessionRevisions: {},
permissions: {
queue: [],
active: null,
@@ -41,8 +42,52 @@ function ensurePartId(messageId: string, part: ClientPart, index: number): strin
return `${messageId}-part-${index}`
}
const PENDING_PART_MAX_AGE_MS = 30_000
function clonePart(part: ClientPart): ClientPart {
return JSON.parse(JSON.stringify(part)) as ClientPart
if (!part || typeof part !== "object") {
return part
}
const cloned: Record<string, any> = { ...part }
if ("renderCache" in cloned) {
cloned.renderCache = undefined
}
if ("text" in cloned) {
cloned.text = cloneStructuredValue(cloned.text)
}
if ("thinking" in cloned && typeof cloned.thinking === "object") {
cloned.thinking = cloneStructuredValue(cloned.thinking)
}
if ("content" in cloned && Array.isArray(cloned.content)) {
cloned.content = cloneStructuredValue(cloned.content)
}
return cloned as ClientPart
}
function cloneStructuredValue<T>(value: T): T {
if (Array.isArray(value)) {
return value.map((item) => cloneStructuredValue(item)) as T
}
if (value && typeof value === "object") {
const next: Record<string, any> = {}
Object.entries(value as Record<string, any>).forEach(([key, nested]) => {
next[key] = cloneStructuredValue(nested)
})
return next as T
}
return value
}
function areMessageIdListsEqual(a: string[], b: string[]): boolean {
if (a.length !== b.length) {
return false
}
for (let index = 0; index < a.length; index++) {
if (a[index] !== b[index]) {
return false
}
}
return true
}
function createEmptyUsageState(): SessionUsageState {
@@ -158,6 +203,7 @@ export interface InstanceMessageStore {
getSessionUsage: (sessionId: string) => SessionUsageState | undefined
setScrollSnapshot: (sessionId: string, scope: string, snapshot: Omit<ScrollSnapshot, "updatedAt">) => void
getScrollSnapshot: (sessionId: string, scope: string) => ScrollSnapshot | undefined
getSessionRevision: (sessionId: string) => number
getSessionMessageIds: (sessionId: string) => string[]
getMessage: (messageId: string) => MessageRecord | undefined
clearInstance: () => void
@@ -167,6 +213,15 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
const [state, setState] = createStore<InstanceMessageState>(createInitialState(instanceId))
const messageInfoCache = new Map<string, MessageInfo>()
function bumpSessionRevision(sessionId: string) {
if (!sessionId) return
setState("sessionRevisions", sessionId, (value = 0) => value + 1)
}
function getSessionRevisionValue(sessionId: string) {
return state.sessionRevisions[sessionId] ?? 0
}
function withUsageState(sessionId: string, updater: (draft: SessionUsageState) => void) {
setState("usage", sessionId, (current) => {
const draft = current
@@ -223,6 +278,7 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
function addOrUpdateSession(input: SessionUpsertInput) {
const session = ensureSessionEntry(input.id)
const previousIds = [...session.messageIds]
const nextMessageIds = Array.isArray(input.messageIds) ? input.messageIds : session.messageIds
setState("sessions", input.id, {
@@ -233,6 +289,10 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
messageIds: nextMessageIds,
revert: input.revert ?? session.revert ?? null,
})
if (Array.isArray(input.messageIds) && !areMessageIdListsEqual(previousIds, nextMessageIds)) {
bumpSessionRevision(input.id)
}
}
function hydrateMessages(sessionId: string, inputs: MessageUpsertInput[], infos?: Iterable<MessageInfo>) {
@@ -303,7 +363,7 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
setState("messages", (prev) => ({ ...prev, ...nextMessages }))
setState("messageInfoVersion", (prev) => ({ ...prev, ...nextMessageInfoVersion }))
setState("pendingParts", (prev) => ({ ...prev, ...nextPendingParts }))
setState("pendingParts", () => nextPendingParts)
setState("permissions", "byMessage", (prev) => ({ ...prev, ...nextPermissionsByMessage }))
if (usageState) {
@@ -315,6 +375,8 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
messageIds: incomingIds,
updatedAt: Date.now(),
}))
bumpSessionRevision(sessionId)
}
function insertMessageIntoSession(sessionId: string, messageId: string) {
@@ -374,12 +436,24 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
insertMessageIntoSession(input.sessionId, input.id)
flushPendingParts(input.id)
bumpSessionRevision(input.sessionId)
}
function bufferPendingPart(entry: PendingPartEntry) {
setState("pendingParts", entry.messageId, (list = []) => [...list, entry])
}
function clearPendingPartsForMessage(messageId: string) {
setState("pendingParts", (prev) => {
if (!prev[messageId]) {
return prev
}
const next = { ...prev }
delete next[messageId]
return next
})
}
function applyPartUpdate(input: PartUpdateInput) {
const message = state.messages[input.messageId]
if (!message) {
@@ -417,12 +491,14 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
if (!pending || pending.length === 0) {
return
}
pending.forEach((entry) => applyPartUpdate({ messageId, part: entry.part }))
setState("pendingParts", (prev) => {
const next = { ...prev }
delete next[messageId]
return next
})
const now = Date.now()
const validEntries = pending.filter((entry) => now - entry.receivedAt <= PENDING_PART_MAX_AGE_MS)
if (validEntries.length === 0) {
clearPendingPartsForMessage(messageId)
return
}
validEntries.forEach((entry) => applyPartUpdate({ messageId, part: entry.part }))
clearPendingPartsForMessage(messageId)
}
function replaceMessageId(options: ReplaceMessageIdOptions) {
@@ -444,6 +520,8 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
return next
})
const affectedSessions = new Set<string>()
Object.values(state.sessions).forEach((session) => {
const index = session.messageIds.indexOf(options.oldId)
if (index === -1) return
@@ -452,8 +530,11 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
next[index] = options.newId
return next
})
affectedSessions.add(session.id)
})
affectedSessions.forEach((sessionId) => bumpSessionRevision(sessionId))
const infoEntry = messageInfoCache.get(options.oldId)
if (infoEntry) {
messageInfoCache.set(options.newId, infoEntry)
@@ -482,12 +563,8 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
const pending = state.pendingParts[options.oldId]
if (pending) {
setState("pendingParts", options.newId, pending)
setState("pendingParts", (prev) => {
const next = { ...prev }
delete next[options.oldId]
return next
})
}
clearPendingPartsForMessage(options.oldId)
}
function setMessageInfo(messageId: string, info: MessageInfo) {
@@ -608,6 +685,7 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
getSessionUsage,
setScrollSnapshot,
getScrollSnapshot,
getSessionRevision: getSessionRevisionValue,
getSessionMessageIds: (sessionId: string) => state.sessions[sessionId]?.messageIds ?? [],
getMessage: (messageId: string) => state.messages[messageId],
clearInstance,

View File

@@ -1,6 +1,4 @@
import { decodeHtmlEntities } from "../../lib/markdown"
import { partHasRenderableText } from "../../types/message"
import type { MessageDisplayParts, Message } from "../../types/message"
function decodeTextSegment(segment: any): any {
if (typeof segment === "string") {
@@ -74,23 +72,3 @@ export function normalizeMessagePart(part: any): any {
return normalized
}
export function computeDisplayParts(message: Message, showThinking: boolean): MessageDisplayParts {
const text: any[] = []
const tool: any[] = []
const reasoning: any[] = []
for (const part of message.parts) {
if (part.type === "text" && !part.synthetic && partHasRenderableText(part)) {
text.push(part)
} else if (part.type === "tool") {
tool.push(part)
} else if (part.type === "reasoning" && showThinking && partHasRenderableText(part)) {
reasoning.push(part)
}
}
const combined = reasoning.length > 0 ? [...text, ...reasoning] : [...text]
const version = typeof message.version === "number" ? message.version : 0
return { text, tool, reasoning, combined, showThinking, version }
}

View File

@@ -0,0 +1,72 @@
import type { ClientPart } from "../../types/message"
import { partHasRenderableText } from "../../types/message"
import type { MessageRecord } from "./types"
export type ToolCallPart = Extract<ClientPart, { type: "tool" }>
export interface RecordDisplayData {
orderedParts: ClientPart[]
textAndReasoningParts: ClientPart[]
toolParts: ToolCallPart[]
}
interface RecordDisplayCacheEntry {
revision: number
data: RecordDisplayData
}
const recordDisplayCache = new Map<string, RecordDisplayCacheEntry>()
function makeCacheKey(instanceId: string, messageId: string, showThinking: boolean) {
return `${instanceId}:${messageId}:${showThinking ? 1 : 0}`
}
function isToolPart(part: ClientPart): part is ToolCallPart {
return part.type === "tool"
}
export function buildRecordDisplayData(instanceId: string, record: MessageRecord, showThinking: boolean): RecordDisplayData {
const cacheKey = makeCacheKey(instanceId, record.id, showThinking)
const cached = recordDisplayCache.get(cacheKey)
if (cached && cached.revision === record.revision) {
return cached.data
}
const orderedParts: ClientPart[] = []
const textAndReasoningParts: ClientPart[] = []
const toolParts: ToolCallPart[] = []
for (const partId of record.partIds) {
const entry = record.parts[partId]
if (!entry?.data) continue
const part = entry.data
orderedParts.push(part)
if (isToolPart(part)) {
toolParts.push(part)
continue
}
if (part.type === "text" && !part.synthetic && partHasRenderableText(part)) {
textAndReasoningParts.push(part)
continue
}
if (part.type === "reasoning" && showThinking && partHasRenderableText(part)) {
textAndReasoningParts.push(part)
}
}
const data = { orderedParts, textAndReasoningParts, toolParts }
recordDisplayCache.set(cacheKey, { revision: record.revision, data })
return data
}
export function clearRecordDisplayCacheForInstance(instanceId: string) {
const prefix = `${instanceId}:`
for (const key of recordDisplayCache.keys()) {
if (key.startsWith(prefix)) {
recordDisplayCache.delete(key)
}
}
}

View File

@@ -95,8 +95,7 @@ export interface InstanceMessageState {
messages: Record<string, MessageRecord>
messageInfoVersion: Record<string, number>
pendingParts: Record<string, PendingPartEntry[]>
sessionRevisions: Record<string, number>
permissions: InstancePermissionState
usage: Record<string, SessionUsageState>
scrollState: Record<string, ScrollSnapshot>