Add agent and model selection with automatic defaults and subagent support
- Implement automatic agent/model selection for new sessions (first agent, prioritize Anthropic default) - Extract and restore agent/model from last assistant message when resuming sessions - Add subagent support: child sessions can select subagents, parent sessions cannot - Display subagent badge in agent dropdown for identification - Truncate agent descriptions to 50 characters in dropdown - Improve model search to include full provider/model ID path - Auto-select input text on focus for easy model search - Add getDefaultModel() with priority: agent model → Anthropic → first provider
This commit is contained in:
20
src/App.tsx
20
src/App.tsx
@@ -37,6 +37,8 @@ import {
|
||||
getParentSessions,
|
||||
loadMessages,
|
||||
sendMessage,
|
||||
updateSessionAgent,
|
||||
updateSessionModel,
|
||||
} from "./stores/sessions"
|
||||
import { setupTabKeyboardShortcuts } from "./lib/keyboard"
|
||||
|
||||
@@ -58,6 +60,14 @@ const SessionView: Component<{
|
||||
await sendMessage(props.instanceId, props.sessionId, prompt)
|
||||
}
|
||||
|
||||
async function handleAgentChange(agent: string) {
|
||||
await updateSessionAgent(props.instanceId, props.sessionId, agent)
|
||||
}
|
||||
|
||||
async function handleModelChange(model: { providerId: string; modelId: string }) {
|
||||
await updateSessionModel(props.instanceId, props.sessionId, model)
|
||||
}
|
||||
|
||||
return (
|
||||
<Show
|
||||
when={session()}
|
||||
@@ -75,7 +85,15 @@ const SessionView: Component<{
|
||||
messages={s().messages || []}
|
||||
messagesInfo={s().messagesInfo}
|
||||
/>
|
||||
<PromptInput instanceId={props.instanceId} sessionId={s().id} onSend={handleSendMessage} />
|
||||
<PromptInput
|
||||
instanceId={props.instanceId}
|
||||
sessionId={s().id}
|
||||
onSend={handleSendMessage}
|
||||
agent={s().agent}
|
||||
model={s().model}
|
||||
onAgentChange={handleAgentChange}
|
||||
onModelChange={handleModelChange}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</Show>
|
||||
|
||||
101
src/components/agent-selector.tsx
Normal file
101
src/components/agent-selector.tsx
Normal file
@@ -0,0 +1,101 @@
|
||||
import { Select } from "@kobalte/core/select"
|
||||
import { For, Show, createEffect, createMemo } from "solid-js"
|
||||
import { agents, fetchAgents, sessions } from "../stores/sessions"
|
||||
import { ChevronDown } from "lucide-solid"
|
||||
import type { Agent } from "../types/session"
|
||||
|
||||
interface AgentSelectorProps {
|
||||
instanceId: string
|
||||
sessionId: string
|
||||
currentAgent: string
|
||||
onAgentChange: (agent: string) => Promise<void>
|
||||
}
|
||||
|
||||
export default function AgentSelector(props: AgentSelectorProps) {
|
||||
const instanceAgents = () => agents().get(props.instanceId) || []
|
||||
|
||||
const session = createMemo(() => {
|
||||
const instanceSessions = sessions().get(props.instanceId)
|
||||
return instanceSessions?.get(props.sessionId)
|
||||
})
|
||||
|
||||
const isChildSession = createMemo(() => {
|
||||
return session()?.parentId !== null && session()?.parentId !== undefined
|
||||
})
|
||||
|
||||
const availableAgents = createMemo(() => {
|
||||
const allAgents = instanceAgents()
|
||||
if (isChildSession()) {
|
||||
return allAgents
|
||||
}
|
||||
|
||||
const filtered = allAgents.filter((agent) => agent.mode !== "subagent")
|
||||
|
||||
const currentAgent = allAgents.find((a) => a.name === props.currentAgent)
|
||||
if (currentAgent && !filtered.find((a) => a.name === props.currentAgent)) {
|
||||
return [currentAgent, ...filtered]
|
||||
}
|
||||
|
||||
return filtered
|
||||
})
|
||||
|
||||
createEffect(() => {
|
||||
if (instanceAgents().length === 0) {
|
||||
fetchAgents(props.instanceId).catch(console.error)
|
||||
}
|
||||
})
|
||||
|
||||
const handleChange = async (value: Agent | null) => {
|
||||
if (value && value.name !== props.currentAgent) {
|
||||
await props.onAgentChange(value.name)
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Select
|
||||
value={availableAgents().find((a) => a.name === props.currentAgent)}
|
||||
onChange={handleChange}
|
||||
options={availableAgents()}
|
||||
optionValue="name"
|
||||
optionTextValue="name"
|
||||
placeholder="Select agent..."
|
||||
itemComponent={(itemProps) => (
|
||||
<Select.Item
|
||||
item={itemProps.item}
|
||||
class="px-3 py-2 cursor-pointer hover:bg-gray-100 rounded outline-none focus:bg-gray-100"
|
||||
>
|
||||
<div class="flex flex-col">
|
||||
<Select.ItemLabel class="font-medium text-sm text-gray-900 flex items-center gap-2">
|
||||
<span>{itemProps.item.rawValue.name}</span>
|
||||
<Show when={itemProps.item.rawValue.mode === "subagent"}>
|
||||
<span class="text-xs font-normal text-blue-600 bg-blue-50 px-1.5 py-0.5 rounded">subagent</span>
|
||||
</Show>
|
||||
</Select.ItemLabel>
|
||||
<Show when={itemProps.item.rawValue.description}>
|
||||
<Select.ItemDescription class="text-xs text-gray-600">
|
||||
{itemProps.item.rawValue.description.length > 50
|
||||
? itemProps.item.rawValue.description.slice(0, 50) + "..."
|
||||
: itemProps.item.rawValue.description}
|
||||
</Select.ItemDescription>
|
||||
</Show>
|
||||
</div>
|
||||
</Select.Item>
|
||||
)}
|
||||
>
|
||||
<Select.Trigger class="inline-flex items-center justify-between gap-2 px-2 py-1 bg-white border border-gray-300 rounded hover:bg-gray-50 outline-none focus:ring-2 focus:ring-blue-500 text-xs min-w-[100px]">
|
||||
<Select.Value<Agent>>
|
||||
{(state) => <span class="text-gray-700">Agent: {state.selectedOption()?.name ?? "None"}</span>}
|
||||
</Select.Value>
|
||||
<Select.Icon>
|
||||
<ChevronDown class="w-3 h-3 text-gray-500" />
|
||||
</Select.Icon>
|
||||
</Select.Trigger>
|
||||
|
||||
<Select.Portal>
|
||||
<Select.Content class="bg-white border border-gray-300 rounded-md shadow-lg max-h-80 overflow-auto p-1 z-50">
|
||||
<Select.Listbox />
|
||||
</Select.Content>
|
||||
</Select.Portal>
|
||||
</Select>
|
||||
)
|
||||
}
|
||||
@@ -59,9 +59,7 @@ export default function MessageItem(props: MessageItemProps) {
|
||||
</div>
|
||||
</Show>
|
||||
|
||||
<For each={props.message.parts}>
|
||||
{(part) => <MessagePart part={part} key={part.id || `${part.type}-${Math.random()}`} />}
|
||||
</For>
|
||||
<For each={props.message.parts}>{(part) => <MessagePart part={part} />}</For>
|
||||
</div>
|
||||
|
||||
<Show when={props.message.status === "sending"}>
|
||||
|
||||
108
src/components/model-selector.tsx
Normal file
108
src/components/model-selector.tsx
Normal file
@@ -0,0 +1,108 @@
|
||||
import { Combobox } from "@kobalte/core/combobox"
|
||||
import { For, Show, createEffect, createMemo } from "solid-js"
|
||||
import { providers, fetchProviders } from "../stores/sessions"
|
||||
import { ChevronDown } from "lucide-solid"
|
||||
import type { Provider, Model } from "../types/session"
|
||||
|
||||
interface ModelSelectorProps {
|
||||
instanceId: string
|
||||
sessionId: string
|
||||
currentModel: { providerId: string; modelId: string }
|
||||
onModelChange: (model: { providerId: string; modelId: string }) => Promise<void>
|
||||
}
|
||||
|
||||
interface FlatModel extends Model {
|
||||
providerName: string
|
||||
}
|
||||
|
||||
export default function ModelSelector(props: ModelSelectorProps) {
|
||||
const instanceProviders = () => providers().get(props.instanceId) || []
|
||||
let listboxRef!: HTMLUListElement
|
||||
let inputRef!: HTMLInputElement
|
||||
|
||||
createEffect(() => {
|
||||
if (instanceProviders().length === 0) {
|
||||
fetchProviders(props.instanceId).catch(console.error)
|
||||
}
|
||||
})
|
||||
|
||||
const handleFocus = (e: FocusEvent) => {
|
||||
const input = e.target as HTMLInputElement
|
||||
input.select()
|
||||
}
|
||||
|
||||
const allModels = createMemo<FlatModel[]>(() =>
|
||||
instanceProviders().flatMap((p) =>
|
||||
p.models.map((m) => ({
|
||||
...m,
|
||||
providerName: p.name,
|
||||
})),
|
||||
),
|
||||
)
|
||||
|
||||
const currentModelValue = createMemo(() =>
|
||||
allModels().find((m) => m.providerId === props.currentModel.providerId && m.id === props.currentModel.modelId),
|
||||
)
|
||||
|
||||
const handleChange = async (value: FlatModel | null) => {
|
||||
if (!value) return
|
||||
|
||||
if (value.providerId !== props.currentModel.providerId || value.id !== props.currentModel.modelId) {
|
||||
await props.onModelChange({ providerId: value.providerId, modelId: value.id })
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Combobox
|
||||
value={currentModelValue()}
|
||||
onChange={handleChange}
|
||||
options={allModels()}
|
||||
optionValue={(m) => `${m.providerId}/${m.id}`}
|
||||
optionTextValue={(m) => `${m.name} ${m.providerName} ${m.providerId}/${m.id}`}
|
||||
optionLabel="name"
|
||||
placeholder="Search models..."
|
||||
defaultFilter="contains"
|
||||
triggerMode="focus"
|
||||
allowsEmptyCollection={false}
|
||||
itemComponent={(itemProps) => (
|
||||
<Combobox.Item
|
||||
item={itemProps.item}
|
||||
class="px-3 py-2 cursor-pointer hover:bg-gray-100 rounded outline-none focus:bg-gray-100 flex items-start gap-2"
|
||||
>
|
||||
<div class="flex flex-col flex-1 min-w-0">
|
||||
<Combobox.ItemLabel class="font-medium text-sm text-gray-900">
|
||||
{itemProps.item.rawValue.name}
|
||||
</Combobox.ItemLabel>
|
||||
<Combobox.ItemDescription class="text-xs text-gray-600">
|
||||
{itemProps.item.rawValue.providerName} • {itemProps.item.rawValue.providerId}/{itemProps.item.rawValue.id}
|
||||
</Combobox.ItemDescription>
|
||||
</div>
|
||||
<Combobox.ItemIndicator class="flex-shrink-0 mt-0.5">
|
||||
<svg class="w-4 h-4 text-blue-600" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M5 13l4 4L19 7" />
|
||||
</svg>
|
||||
</Combobox.ItemIndicator>
|
||||
</Combobox.Item>
|
||||
)}
|
||||
>
|
||||
<Combobox.Control class="inline-flex items-center justify-between gap-1 px-2 py-1 bg-white border border-gray-300 rounded hover:bg-gray-50 outline-none focus-within:ring-2 focus-within:ring-blue-500 text-xs min-w-[140px]">
|
||||
<Combobox.Input
|
||||
ref={inputRef}
|
||||
onFocus={handleFocus}
|
||||
class="bg-transparent border-none outline-none text-xs text-gray-700 placeholder:text-gray-500 w-full min-w-0 px-0"
|
||||
/>
|
||||
<Combobox.Trigger class="flex items-center justify-center">
|
||||
<Combobox.Icon>
|
||||
<ChevronDown class="w-3 h-3 text-gray-500" />
|
||||
</Combobox.Icon>
|
||||
</Combobox.Trigger>
|
||||
</Combobox.Control>
|
||||
|
||||
<Combobox.Portal>
|
||||
<Combobox.Content class="bg-white border border-gray-300 rounded-md shadow-lg max-h-80 overflow-hidden p-1 z-50 min-w-[300px]">
|
||||
<Combobox.Listbox ref={listboxRef} scrollRef={() => listboxRef} class="max-h-80 overflow-auto" />
|
||||
</Combobox.Content>
|
||||
</Combobox.Portal>
|
||||
</Combobox>
|
||||
)
|
||||
}
|
||||
@@ -1,10 +1,16 @@
|
||||
import { createSignal, Show } from "solid-js"
|
||||
import AgentSelector from "./agent-selector"
|
||||
import ModelSelector from "./model-selector"
|
||||
|
||||
interface PromptInputProps {
|
||||
instanceId: string
|
||||
sessionId: string
|
||||
onSend: (prompt: string) => Promise<void>
|
||||
disabled?: boolean
|
||||
agent: string
|
||||
model: { providerId: string; modelId: string }
|
||||
onAgentChange: (agent: string) => Promise<void>
|
||||
onModelChange: (model: { providerId: string; modelId: string }) => Promise<void>
|
||||
}
|
||||
|
||||
export default function PromptInput(props: PromptInputProps) {
|
||||
@@ -73,6 +79,20 @@ export default function PromptInput(props: PromptInputProps) {
|
||||
<span class="hint">
|
||||
<kbd>Enter</kbd> to send, <kbd>Shift+Enter</kbd> for new line
|
||||
</span>
|
||||
<div class="flex items-center gap-2">
|
||||
<AgentSelector
|
||||
instanceId={props.instanceId}
|
||||
sessionId={props.sessionId}
|
||||
currentAgent={props.agent}
|
||||
onAgentChange={props.onAgentChange}
|
||||
/>
|
||||
<ModelSelector
|
||||
instanceId={props.instanceId}
|
||||
sessionId={props.sessionId}
|
||||
currentModel={props.model}
|
||||
onModelChange={props.onModelChange}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
|
||||
@@ -74,12 +74,61 @@ async function fetchSessions(instanceId: string): Promise<void> {
|
||||
}
|
||||
}
|
||||
|
||||
async function getDefaultModel(
|
||||
instanceId: string,
|
||||
agentName?: string,
|
||||
): Promise<{ providerId: string; modelId: string }> {
|
||||
const instanceProviders = providers().get(instanceId) || []
|
||||
const instanceAgents = agents().get(instanceId) || []
|
||||
|
||||
if (agentName) {
|
||||
const agent = instanceAgents.find((a) => a.name === agentName)
|
||||
if (agent?.model?.providerId && agent.model.modelId) {
|
||||
return {
|
||||
providerId: agent.model.providerId,
|
||||
modelId: agent.model.modelId,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const anthropicProvider = instanceProviders.find((p) => p.id === "anthropic")
|
||||
if (anthropicProvider) {
|
||||
const defaultModelId = anthropicProvider.defaultModelId || anthropicProvider.models[0]?.id
|
||||
if (defaultModelId) {
|
||||
return {
|
||||
providerId: "anthropic",
|
||||
modelId: defaultModelId,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (instanceProviders.length > 0) {
|
||||
const firstProvider = instanceProviders[0]
|
||||
const defaultModelId = firstProvider.defaultModelId || firstProvider.models[0]?.id
|
||||
|
||||
if (defaultModelId) {
|
||||
return {
|
||||
providerId: firstProvider.id,
|
||||
modelId: defaultModelId,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { providerId: "", modelId: "" }
|
||||
}
|
||||
|
||||
async function createSession(instanceId: string, agent?: string): Promise<Session> {
|
||||
const instance = instances().get(instanceId)
|
||||
if (!instance || !instance.client) {
|
||||
throw new Error("Instance not ready")
|
||||
}
|
||||
|
||||
const instanceAgents = agents().get(instanceId) || []
|
||||
const nonSubagents = instanceAgents.filter((a) => a.mode !== "subagent")
|
||||
const selectedAgent = agent || (nonSubagents.length > 0 ? nonSubagents[0].name : "")
|
||||
|
||||
const defaultModel = await getDefaultModel(instanceId, selectedAgent)
|
||||
|
||||
setLoading((prev) => {
|
||||
const next = { ...prev }
|
||||
next.creatingSession.set(instanceId, true)
|
||||
@@ -98,8 +147,8 @@ async function createSession(instanceId: string, agent?: string): Promise<Sessio
|
||||
instanceId,
|
||||
title: response.data.title || "New Session",
|
||||
parentId: null,
|
||||
agent: agent || "",
|
||||
model: { providerId: "", modelId: "" },
|
||||
agent: selectedAgent,
|
||||
model: defaultModel,
|
||||
time: {
|
||||
created: response.data.time.created,
|
||||
updated: response.data.time.updated,
|
||||
@@ -185,13 +234,17 @@ async function fetchAgents(instanceId: string): Promise<void> {
|
||||
|
||||
try {
|
||||
const response = await instance.client.app.agents()
|
||||
const agentList = (response.data ?? [])
|
||||
.filter((agent) => agent.mode !== "subagent")
|
||||
.map((agent) => ({
|
||||
name: agent.name,
|
||||
description: agent.description || "",
|
||||
mode: agent.mode,
|
||||
}))
|
||||
const agentList = (response.data ?? []).map((agent) => ({
|
||||
name: agent.name,
|
||||
description: agent.description || "",
|
||||
mode: agent.mode,
|
||||
model: agent.model?.modelID
|
||||
? {
|
||||
providerId: agent.model.providerID || "",
|
||||
modelId: agent.model.modelID,
|
||||
}
|
||||
: undefined,
|
||||
}))
|
||||
|
||||
setAgents((prev) => {
|
||||
const next = new Map(prev)
|
||||
@@ -216,6 +269,7 @@ async function fetchProviders(instanceId: string): Promise<void> {
|
||||
const providerList = response.data.providers.map((provider) => ({
|
||||
id: provider.id,
|
||||
name: provider.name,
|
||||
defaultModelId: response.data?.default?.[provider.id],
|
||||
models: Object.entries(provider.models).map(([id, model]) => ({
|
||||
id,
|
||||
name: model.name,
|
||||
@@ -359,14 +413,44 @@ async function loadMessages(instanceId: string, sessionId: string): Promise<void
|
||||
}
|
||||
})
|
||||
|
||||
let agentName = ""
|
||||
let providerID = ""
|
||||
let modelID = ""
|
||||
|
||||
for (let i = response.data.length - 1; i >= 0; i--) {
|
||||
const apiMessage = response.data[i]
|
||||
const info = apiMessage.info || apiMessage
|
||||
|
||||
if (info.role === "assistant") {
|
||||
agentName = (info as any).mode || (info as any).agent || ""
|
||||
providerID = (info as any).providerID || ""
|
||||
modelID = (info as any).modelID || ""
|
||||
if (agentName && providerID && modelID) break
|
||||
}
|
||||
}
|
||||
|
||||
if (!agentName && !providerID && !modelID) {
|
||||
const defaultModel = await getDefaultModel(instanceId, session.agent)
|
||||
agentName = session.agent
|
||||
providerID = defaultModel.providerId
|
||||
modelID = defaultModel.modelId
|
||||
}
|
||||
|
||||
setSessions((prev) => {
|
||||
const next = new Map(prev)
|
||||
const instanceSessions = next.get(instanceId)
|
||||
if (instanceSessions) {
|
||||
const session = instanceSessions.get(sessionId)
|
||||
if (session) {
|
||||
const updatedSession = {
|
||||
...session,
|
||||
messages,
|
||||
messagesInfo,
|
||||
agent: agentName || session.agent,
|
||||
model: providerID && modelID ? { providerId: providerID, modelId: modelID } : session.model,
|
||||
}
|
||||
const updatedInstanceSessions = new Map(instanceSessions)
|
||||
updatedInstanceSessions.set(sessionId, { ...session, messages, messagesInfo })
|
||||
updatedInstanceSessions.set(sessionId, updatedSession)
|
||||
next.set(instanceId, updatedInstanceSessions)
|
||||
}
|
||||
}
|
||||
@@ -595,6 +679,48 @@ async function sendMessage(
|
||||
}
|
||||
}
|
||||
|
||||
async function updateSessionAgent(instanceId: string, sessionId: string, agent: string): Promise<void> {
|
||||
const instanceSessions = sessions().get(instanceId)
|
||||
const session = instanceSessions?.get(sessionId)
|
||||
if (!session) {
|
||||
throw new Error("Session not found")
|
||||
}
|
||||
|
||||
setSessions((prev) => {
|
||||
const next = new Map(prev)
|
||||
const instanceSessions = new Map(prev.get(instanceId))
|
||||
const session = instanceSessions.get(sessionId)
|
||||
if (session) {
|
||||
instanceSessions.set(sessionId, { ...session, agent })
|
||||
next.set(instanceId, instanceSessions)
|
||||
}
|
||||
return next
|
||||
})
|
||||
}
|
||||
|
||||
async function updateSessionModel(
|
||||
instanceId: string,
|
||||
sessionId: string,
|
||||
model: { providerId: string; modelId: string },
|
||||
): Promise<void> {
|
||||
const instanceSessions = sessions().get(instanceId)
|
||||
const session = instanceSessions?.get(sessionId)
|
||||
if (!session) {
|
||||
throw new Error("Session not found")
|
||||
}
|
||||
|
||||
setSessions((prev) => {
|
||||
const next = new Map(prev)
|
||||
const instanceSessions = new Map(prev.get(instanceId))
|
||||
const session = instanceSessions.get(sessionId)
|
||||
if (session) {
|
||||
instanceSessions.set(sessionId, { ...session, model })
|
||||
next.set(instanceId, instanceSessions)
|
||||
}
|
||||
return next
|
||||
})
|
||||
}
|
||||
|
||||
sseManager.onMessageUpdate = handleMessageUpdate
|
||||
sseManager.onSessionUpdate = handleSessionUpdate
|
||||
|
||||
@@ -621,4 +747,7 @@ export {
|
||||
getParentSessions,
|
||||
getChildSessions,
|
||||
getSessionFamily,
|
||||
updateSessionAgent,
|
||||
updateSessionModel,
|
||||
getDefaultModel,
|
||||
}
|
||||
|
||||
@@ -22,12 +22,17 @@ export interface Agent {
|
||||
name: string
|
||||
description: string
|
||||
mode: string
|
||||
model?: {
|
||||
providerId: string
|
||||
modelId: string
|
||||
}
|
||||
}
|
||||
|
||||
export interface Provider {
|
||||
id: string
|
||||
name: string
|
||||
models: Model[]
|
||||
defaultModelId?: string
|
||||
}
|
||||
|
||||
export interface Model {
|
||||
|
||||
Reference in New Issue
Block a user