diff --git a/src/components/message-stream.tsx b/src/components/message-stream.tsx index e2c81e0a..579c1918 100644 --- a/src/components/message-stream.tsx +++ b/src/components/message-stream.tsx @@ -5,7 +5,7 @@ import ToolCall from "./tool-call" import { sseManager } from "../lib/sse-manager" import Kbd from "./kbd" import { preferences } from "../stores/preferences" -import { providers } from "../stores/sessions" +import { providers, sessionInfoByInstance } from "../stores/sessions" // Calculate session tokens and cost from messagesInfo (matches TUI logic) function calculateSessionInfo(messagesInfo?: Map, instanceId?: string) { @@ -133,6 +133,32 @@ export default function MessageStream(props: MessageStreamProps) { const connectionStatus = () => sseManager.getStatus(props.instanceId) + const sessionInfo = createMemo(() => { + return ( + sessionInfoByInstance().get(props.instanceId) || { + tokens: 0, + cost: 0, + contextWindow: 0, + isSubscriptionModel: false, + } + ) + }) + + const formattedSessionInfo = createMemo(() => { + const sessionInfo = sessionInfoByInstance().get(props.instanceId) || { + tokens: 0, + cost: 0, + contextWindow: 0, + isSubscriptionModel: false, + } + return formatSessionInfo( + sessionInfo.tokens, + sessionInfo.cost, + sessionInfo.contextWindow, + sessionInfo.isSubscriptionModel, + ) + }) + function scrollToBottom() { if (containerRef) { containerRef.scrollTop = containerRef.scrollHeight @@ -212,20 +238,7 @@ export default function MessageStream(props: MessageStreamProps) {
- - {(() => { - const sessionInfo = calculateSessionInfo(props.messagesInfo, props.instanceId) - console.log("[MessageStream] sessionInfo:", sessionInfo) - const result = formatSessionInfo( - sessionInfo.tokens, - sessionInfo.cost, - sessionInfo.contextWindow, - sessionInfo.isSubscriptionModel, - ) - console.log("[MessageStream] formatted result:", result) - return result - })()} - + {formattedSessionInfo()}
diff --git a/src/stores/sessions.ts b/src/stores/sessions.ts index 12aef537..9e591163 100644 --- a/src/stores/sessions.ts +++ b/src/stores/sessions.ts @@ -4,6 +4,13 @@ import type { Message } from "../types/message" import { instances } from "./instances" import { sseManager } from "../lib/sse-manager" +interface SessionInfo { + tokens: number + cost: number + contextWindow: number + isSubscriptionModel: boolean +} + const [sessions, setSessions] = createSignal>>(new Map()) const [activeSessionId, setActiveSessionId] = createSignal>(new Map()) const [activeParentSessionId, setActiveParentSessionId] = createSignal>(new Map()) @@ -18,6 +25,7 @@ const [loading, setLoading] = createSignal({ }) const [messagesLoaded, setMessagesLoaded] = createSignal>>(new Map()) +const [sessionInfoByInstance, setSessionInfoByInstance] = createSignal>(new Map()) async function fetchSessions(instanceId: string): Promise { const instance = instances().get(instanceId) @@ -125,6 +133,83 @@ async function getDefaultModel( return { providerId: "", modelId: "" } } +function updateSessionInfo(instanceId: string) { + const instanceSessions = sessions().get(instanceId) + if (!instanceSessions) return + + let totalTokens = 0 + let totalCost = 0 + let contextWindow = 0 + let isSubscriptionModel = false + let modelID = "" + let providerID = "" + + // Calculate from last assistant message in each session (like original calculateSessionInfo) + for (const session of instanceSessions.values()) { + if (session.messagesInfo.size === 0) continue + + // Go backwards through messagesInfo to find the last relevant assistant message (like TUI) + const messageArray = Array.from(session.messagesInfo.values()).reverse() + + for (const info of messageArray) { + if (info.role === "assistant" && info.tokens) { + const usage = info.tokens + + if (usage.output > 0) { + if (info.summary) { + // If summary message, only count output tokens and stop (like TUI) + totalTokens = usage.output || 0 + totalCost = info.cost || 0 + } else { + // Regular message - count all token types (like TUI) + totalTokens = + (usage.input || 0) + + (usage.cache?.read || 0) + + (usage.cache?.write || 0) + + (usage.output || 0) + + (usage.reasoning || 0) + totalCost = info.cost || 0 + } + + // Get model info for context window and subscription check + modelID = info.modelID || "" + providerID = info.providerID || "" + isSubscriptionModel = totalCost === 0 + + break // Break after finding the last assistant message + } + } + } + } + + // Get context window from providers + if (modelID && providerID) { + const instanceProviders = providers().get(instanceId) || [] + const provider = instanceProviders.find((p) => p.id === providerID) + if (provider) { + const model = provider.models.find((m) => m.id === modelID) + if (model?.limit?.context) { + contextWindow = model.limit.context + } + // Check if it's a subscription model (cost is 0 for both input and output) + if (model?.cost?.input === 0 && model?.cost?.output === 0) { + isSubscriptionModel = true + } + } + } + + setSessionInfoByInstance((prev) => { + const next = new Map(prev) + next.set(instanceId, { + tokens: totalTokens, + cost: totalCost, + contextWindow, + isSubscriptionModel, + }) + return next + }) +} + async function createSession(instanceId: string, agent?: string): Promise { const instance = instances().get(instanceId) if (!instance || !instance.client) { @@ -181,6 +266,7 @@ async function createSession(instanceId: string, agent?: string): Promise