diff --git a/src/components/message-stream.tsx b/src/components/message-stream.tsx index 579c1918..cf9cc770 100644 --- a/src/components/message-stream.tsx +++ b/src/components/message-stream.tsx @@ -5,7 +5,7 @@ import ToolCall from "./tool-call" import { sseManager } from "../lib/sse-manager" import Kbd from "./kbd" import { preferences } from "../stores/preferences" -import { providers, sessionInfoByInstance } from "../stores/sessions" +import { providers, getSessionInfo } from "../stores/sessions" // Calculate session tokens and cost from messagesInfo (matches TUI logic) function calculateSessionInfo(messagesInfo?: Map, instanceId?: string) { @@ -135,7 +135,7 @@ export default function MessageStream(props: MessageStreamProps) { const sessionInfo = createMemo(() => { return ( - sessionInfoByInstance().get(props.instanceId) || { + getSessionInfo(props.instanceId, props.sessionId) || { tokens: 0, cost: 0, contextWindow: 0, @@ -145,7 +145,7 @@ export default function MessageStream(props: MessageStreamProps) { }) const formattedSessionInfo = createMemo(() => { - const sessionInfo = sessionInfoByInstance().get(props.instanceId) || { + const sessionInfo = getSessionInfo(props.instanceId, props.sessionId) || { tokens: 0, cost: 0, contextWindow: 0, diff --git a/src/stores/sessions.ts b/src/stores/sessions.ts index 9e591163..580ef99a 100644 --- a/src/stores/sessions.ts +++ b/src/stores/sessions.ts @@ -25,7 +25,7 @@ const [loading, setLoading] = createSignal({ }) const [messagesLoaded, setMessagesLoaded] = createSignal>>(new Map()) -const [sessionInfoByInstance, setSessionInfoByInstance] = createSignal>(new Map()) +const [sessionInfoByInstance, setSessionInfoByInstance] = createSignal>>(new Map()) async function fetchSessions(instanceId: string): Promise { const instance = instances().get(instanceId) @@ -133,21 +133,26 @@ async function getDefaultModel( return { providerId: "", modelId: "" } } -function updateSessionInfo(instanceId: string) { +function getSessionInfo(instanceId: string, sessionId: string): SessionInfo | undefined { + return sessionInfoByInstance().get(instanceId)?.get(sessionId) +} + +function updateSessionInfo(instanceId: string, sessionId: string) { const instanceSessions = sessions().get(instanceId) if (!instanceSessions) return - let totalTokens = 0 - let totalCost = 0 + const session = instanceSessions.get(sessionId) + if (!session) return + + let tokens = 0 + let cost = 0 let contextWindow = 0 let isSubscriptionModel = false let modelID = "" let providerID = "" - // Calculate from last assistant message in each session (like original calculateSessionInfo) - for (const session of instanceSessions.values()) { - if (session.messagesInfo.size === 0) continue - + // Calculate from last assistant message in this session only + if (session.messagesInfo.size > 0) { // Go backwards through messagesInfo to find the last relevant assistant message (like TUI) const messageArray = Array.from(session.messagesInfo.values()).reverse() @@ -158,23 +163,23 @@ function updateSessionInfo(instanceId: string) { if (usage.output > 0) { if (info.summary) { // If summary message, only count output tokens and stop (like TUI) - totalTokens = usage.output || 0 - totalCost = info.cost || 0 + tokens = usage.output || 0 + cost = info.cost || 0 } else { // Regular message - count all token types (like TUI) - totalTokens = + tokens = (usage.input || 0) + (usage.cache?.read || 0) + (usage.cache?.write || 0) + (usage.output || 0) + (usage.reasoning || 0) - totalCost = info.cost || 0 + cost = info.cost || 0 } // Get model info for context window and subscription check modelID = info.modelID || "" providerID = info.providerID || "" - isSubscriptionModel = totalCost === 0 + isSubscriptionModel = cost === 0 break // Break after finding the last assistant message } @@ -200,12 +205,14 @@ function updateSessionInfo(instanceId: string) { setSessionInfoByInstance((prev) => { const next = new Map(prev) - next.set(instanceId, { - tokens: totalTokens, - cost: totalCost, + const instanceInfo = new Map(prev.get(instanceId)) + instanceInfo.set(sessionId, { + tokens, + cost, contextWindow, isSubscriptionModel, }) + next.set(instanceId, instanceInfo) return next }) } @@ -266,7 +273,20 @@ async function createSession(instanceId: string, agent?: string): Promise { + const next = new Map(prev) + const instanceInfo = new Map(prev.get(instanceId)) + instanceInfo.set(session.id, { + tokens: 0, + cost: 0, + contextWindow: 0, + isSubscriptionModel: false, + }) + next.set(instanceId, instanceInfo) + return next + }) + return session } catch (error) { console.error("Failed to create session:", error) @@ -306,6 +326,22 @@ async function deleteSession(instanceId: string, sessionId: string): Promise { + const next = new Map(prev) + const instanceInfo = next.get(instanceId) + if (instanceInfo) { + const updatedInstanceInfo = new Map(instanceInfo) + updatedInstanceInfo.delete(sessionId) + if (updatedInstanceInfo.size === 0) { + next.delete(instanceId) + } else { + next.set(instanceId, updatedInstanceInfo) + } + } + return next + }) + if (activeSessionId().get(instanceId) === sessionId) { setActiveSessionId((prev) => { const next = new Map(prev) @@ -594,7 +630,7 @@ async function loadMessages(instanceId: string, sessionId: string, force = false }) } - updateSessionInfo(instanceId) + updateSessionInfo(instanceId, sessionId) } function handleMessageUpdate(instanceId: string, event: any): void { @@ -643,7 +679,7 @@ function handleMessageUpdate(instanceId: string, event: any): void { return next }) - updateSessionInfo(instanceId) + updateSessionInfo(instanceId, part.sessionID) } else if (event.type === "message.updated") { const info = event.properties?.info if (!info) return @@ -697,7 +733,7 @@ function handleMessageUpdate(instanceId: string, event: any): void { return next }) - updateSessionInfo(instanceId) + updateSessionInfo(instanceId, info.sessionID) } } @@ -1000,6 +1036,7 @@ export { providers, loading, sessionInfoByInstance, + getSessionInfo, fetchSessions, createSession, deleteSession,