Handle revert removals locally and retarget prompt input

This commit is contained in:
Shantur Rathore
2025-12-25 15:12:44 +00:00
parent 94aa469e90
commit 2603b1d260
7 changed files with 158 additions and 22 deletions

View File

@@ -143,6 +143,18 @@ export function removePermissionV2(instanceId: string, permissionId: string): vo
store.removePermission(permissionId)
}
export function removeMessageV2(instanceId: string, messageId: string): void {
if (!messageId) return
const store = messageStoreBus.getOrCreate(instanceId)
store.removeMessage(messageId)
}
export function removeMessagePartV2(instanceId: string, messageId: string, partId: string): void {
if (!messageId || !partId) return
const store = messageStoreBus.getOrCreate(instanceId)
store.removeMessagePart(messageId, partId)
}
export function ensureSessionMetadataV2(instanceId: string, session: Session | null | undefined): void {
if (!session) return
const store = messageStoreBus.getOrCreate(instanceId)

View File

@@ -191,6 +191,8 @@ export interface InstanceMessageStore {
hydrateMessages: (sessionId: string, inputs: MessageUpsertInput[], infos?: Iterable<MessageInfo>) => void
upsertMessage: (input: MessageUpsertInput) => void
applyPartUpdate: (input: PartUpdateInput) => void
removeMessage: (messageId: string) => void
removeMessagePart: (messageId: string, partId: string) => void
bufferPendingPart: (entry: PendingPartEntry) => void
flushPendingParts: (messageId: string) => void
replaceMessageId: (options: ReplaceMessageIdOptions) => void
@@ -508,10 +510,10 @@ export function createInstanceMessageStore(instanceId: string, hooks?: MessageSt
bufferPendingPart({ messageId: input.messageId, part: input.part, receivedAt: Date.now() })
return
}
const partId = ensurePartId(input.messageId, input.part, message.partIds.length)
const cloned = clonePart(input.part)
setState(
"messages",
input.messageId,
@@ -520,7 +522,7 @@ export function createInstanceMessageStore(instanceId: string, hooks?: MessageSt
draft.partIds = [...draft.partIds, partId]
}
const existing = draft.parts[partId]
const nextRevision = existing ? existing.revision + 1 : cloned.version ?? 0
const nextRevision = existing ? existing.revision + 1 : (cloned as any).version ?? 0
draft.parts[partId] = {
id: partId,
data: cloned,
@@ -540,12 +542,106 @@ export function createInstanceMessageStore(instanceId: string, hooks?: MessageSt
timestamp: Date.now(),
})
}
// Any part update can change the rendered height of the message
// list, so we treat it as a session revision for scroll purposes.
bumpSessionRevision(message.sessionId)
}
function removeMessage(messageId: string) {
if (!messageId) return
const record = state.messages[messageId]
const sessionIds = new Set<string>()
if (record?.sessionId) {
sessionIds.add(record.sessionId)
} else {
Object.values(state.sessions).forEach((session) => {
if (session.messageIds.includes(messageId)) {
sessionIds.add(session.id)
}
})
}
clearRecordDisplayCacheForMessages(instanceId, [messageId])
batch(() => {
sessionIds.forEach((sessionId) => {
setState("sessions", sessionId, "messageIds", (ids = []) => ids.filter((id) => id !== messageId))
})
setState("messages", (prev) => {
if (!prev[messageId]) return prev
const next = { ...prev }
delete next[messageId]
return next
})
setState("messageInfoVersion", (prev) => {
if (!(messageId in prev)) return prev
const next = { ...prev }
delete next[messageId]
return next
})
messageInfoCache.delete(messageId)
setState("pendingParts", (prev) => {
if (!prev[messageId]) return prev
const next = { ...prev }
delete next[messageId]
return next
})
setState("permissions", "byMessage", (prev) => {
if (!prev[messageId]) return prev
const next = { ...prev }
delete next[messageId]
return next
})
sessionIds.forEach((sessionId) => {
withUsageState(sessionId, (draft) => removeUsageEntry(draft, messageId))
if (state.latestTodos[sessionId]?.messageId === messageId) {
clearLatestTodoSnapshot(sessionId)
}
bumpSessionRevision(sessionId)
})
})
}
function removeMessagePart(messageId: string, partId: string) {
if (!messageId || !partId) return
const message = state.messages[messageId]
if (!message) return
clearRecordDisplayCacheForMessages(instanceId, [messageId])
batch(() => {
setState(
"messages",
messageId,
produce((draft: MessageRecord) => {
if (!draft.parts[partId] && !draft.partIds.includes(partId)) return
draft.partIds = draft.partIds.filter((id) => id !== partId)
delete draft.parts[partId]
draft.updatedAt = Date.now()
draft.revision += 1
}),
)
setState("permissions", "byMessage", messageId, (prev) => {
if (!prev || !prev[partId]) return prev
const next = { ...prev }
delete next[partId]
return next
})
bumpSessionRevision(message.sessionId)
})
}
function flushPendingParts(messageId: string) {
const pending = state.pendingParts[messageId]
@@ -868,8 +964,10 @@ export function createInstanceMessageStore(instanceId: string, hooks?: MessageSt
addOrUpdateSession,
hydrateMessages,
upsertMessage,
applyPartUpdate,
bufferPendingPart,
applyPartUpdate,
removeMessage,
removeMessagePart,
bufferPendingPart,
flushPendingParts,
replaceMessageId,
setMessageInfo,