Improve message stream caching and virtualization for large sessions

This commit is contained in:
Shantur Rathore
2025-11-26 13:30:20 +00:00
parent c77bfc2ee7
commit fad2809299
18 changed files with 1142 additions and 402 deletions

View File

@@ -24,6 +24,7 @@ function createInitialState(instanceId: string): InstanceMessageState {
messages: {},
messageInfoVersion: {},
pendingParts: {},
sessionRevisions: {},
permissions: {
queue: [],
active: null,
@@ -41,8 +42,52 @@ function ensurePartId(messageId: string, part: ClientPart, index: number): strin
return `${messageId}-part-${index}`
}
const PENDING_PART_MAX_AGE_MS = 30_000
function clonePart(part: ClientPart): ClientPart {
return JSON.parse(JSON.stringify(part)) as ClientPart
if (!part || typeof part !== "object") {
return part
}
const cloned: Record<string, any> = { ...part }
if ("renderCache" in cloned) {
cloned.renderCache = undefined
}
if ("text" in cloned) {
cloned.text = cloneStructuredValue(cloned.text)
}
if ("thinking" in cloned && typeof cloned.thinking === "object") {
cloned.thinking = cloneStructuredValue(cloned.thinking)
}
if ("content" in cloned && Array.isArray(cloned.content)) {
cloned.content = cloneStructuredValue(cloned.content)
}
return cloned as ClientPart
}
function cloneStructuredValue<T>(value: T): T {
if (Array.isArray(value)) {
return value.map((item) => cloneStructuredValue(item)) as T
}
if (value && typeof value === "object") {
const next: Record<string, any> = {}
Object.entries(value as Record<string, any>).forEach(([key, nested]) => {
next[key] = cloneStructuredValue(nested)
})
return next as T
}
return value
}
function areMessageIdListsEqual(a: string[], b: string[]): boolean {
if (a.length !== b.length) {
return false
}
for (let index = 0; index < a.length; index++) {
if (a[index] !== b[index]) {
return false
}
}
return true
}
function createEmptyUsageState(): SessionUsageState {
@@ -158,6 +203,7 @@ export interface InstanceMessageStore {
getSessionUsage: (sessionId: string) => SessionUsageState | undefined
setScrollSnapshot: (sessionId: string, scope: string, snapshot: Omit<ScrollSnapshot, "updatedAt">) => void
getScrollSnapshot: (sessionId: string, scope: string) => ScrollSnapshot | undefined
getSessionRevision: (sessionId: string) => number
getSessionMessageIds: (sessionId: string) => string[]
getMessage: (messageId: string) => MessageRecord | undefined
clearInstance: () => void
@@ -167,6 +213,15 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
const [state, setState] = createStore<InstanceMessageState>(createInitialState(instanceId))
const messageInfoCache = new Map<string, MessageInfo>()
function bumpSessionRevision(sessionId: string) {
if (!sessionId) return
setState("sessionRevisions", sessionId, (value = 0) => value + 1)
}
function getSessionRevisionValue(sessionId: string) {
return state.sessionRevisions[sessionId] ?? 0
}
function withUsageState(sessionId: string, updater: (draft: SessionUsageState) => void) {
setState("usage", sessionId, (current) => {
const draft = current
@@ -223,6 +278,7 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
function addOrUpdateSession(input: SessionUpsertInput) {
const session = ensureSessionEntry(input.id)
const previousIds = [...session.messageIds]
const nextMessageIds = Array.isArray(input.messageIds) ? input.messageIds : session.messageIds
setState("sessions", input.id, {
@@ -233,6 +289,10 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
messageIds: nextMessageIds,
revert: input.revert ?? session.revert ?? null,
})
if (Array.isArray(input.messageIds) && !areMessageIdListsEqual(previousIds, nextMessageIds)) {
bumpSessionRevision(input.id)
}
}
function hydrateMessages(sessionId: string, inputs: MessageUpsertInput[], infos?: Iterable<MessageInfo>) {
@@ -303,7 +363,7 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
setState("messages", (prev) => ({ ...prev, ...nextMessages }))
setState("messageInfoVersion", (prev) => ({ ...prev, ...nextMessageInfoVersion }))
setState("pendingParts", (prev) => ({ ...prev, ...nextPendingParts }))
setState("pendingParts", () => nextPendingParts)
setState("permissions", "byMessage", (prev) => ({ ...prev, ...nextPermissionsByMessage }))
if (usageState) {
@@ -315,6 +375,8 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
messageIds: incomingIds,
updatedAt: Date.now(),
}))
bumpSessionRevision(sessionId)
}
function insertMessageIntoSession(sessionId: string, messageId: string) {
@@ -374,12 +436,24 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
insertMessageIntoSession(input.sessionId, input.id)
flushPendingParts(input.id)
bumpSessionRevision(input.sessionId)
}
function bufferPendingPart(entry: PendingPartEntry) {
setState("pendingParts", entry.messageId, (list = []) => [...list, entry])
}
function clearPendingPartsForMessage(messageId: string) {
setState("pendingParts", (prev) => {
if (!prev[messageId]) {
return prev
}
const next = { ...prev }
delete next[messageId]
return next
})
}
function applyPartUpdate(input: PartUpdateInput) {
const message = state.messages[input.messageId]
if (!message) {
@@ -417,12 +491,14 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
if (!pending || pending.length === 0) {
return
}
pending.forEach((entry) => applyPartUpdate({ messageId, part: entry.part }))
setState("pendingParts", (prev) => {
const next = { ...prev }
delete next[messageId]
return next
})
const now = Date.now()
const validEntries = pending.filter((entry) => now - entry.receivedAt <= PENDING_PART_MAX_AGE_MS)
if (validEntries.length === 0) {
clearPendingPartsForMessage(messageId)
return
}
validEntries.forEach((entry) => applyPartUpdate({ messageId, part: entry.part }))
clearPendingPartsForMessage(messageId)
}
function replaceMessageId(options: ReplaceMessageIdOptions) {
@@ -444,6 +520,8 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
return next
})
const affectedSessions = new Set<string>()
Object.values(state.sessions).forEach((session) => {
const index = session.messageIds.indexOf(options.oldId)
if (index === -1) return
@@ -452,8 +530,11 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
next[index] = options.newId
return next
})
affectedSessions.add(session.id)
})
affectedSessions.forEach((sessionId) => bumpSessionRevision(sessionId))
const infoEntry = messageInfoCache.get(options.oldId)
if (infoEntry) {
messageInfoCache.set(options.newId, infoEntry)
@@ -482,12 +563,8 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
const pending = state.pendingParts[options.oldId]
if (pending) {
setState("pendingParts", options.newId, pending)
setState("pendingParts", (prev) => {
const next = { ...prev }
delete next[options.oldId]
return next
})
}
clearPendingPartsForMessage(options.oldId)
}
function setMessageInfo(messageId: string, info: MessageInfo) {
@@ -608,6 +685,7 @@ export function createInstanceMessageStore(instanceId: string): InstanceMessageS
getSessionUsage,
setScrollSnapshot,
getScrollSnapshot,
getSessionRevision: getSessionRevisionValue,
getSessionMessageIds: (sessionId: string) => state.sessions[sessionId]?.messageIds ?? [],
getMessage: (messageId: string) => state.messages[messageId],
clearInstance,