diff --git a/src/App.tsx b/src/App.tsx index 917b248b..12d5bdca 100644 --- a/src/App.tsx +++ b/src/App.tsx @@ -399,12 +399,21 @@ const App: Component = () => { keywords: ["model", "llm", "ai"], shortcut: { key: "M", meta: true, shift: true }, action: () => { - const modelControl = document.querySelector("[data-model-selector]") as HTMLElement - modelControl?.click() - setTimeout(() => { - const modelInput = document.querySelector("[data-model-selector] input") as HTMLInputElement - modelInput?.focus() - }, 100) + const modelInput = document.querySelector("[data-model-selector]") as HTMLInputElement + if (modelInput) { + modelInput.focus() + setTimeout(() => { + const event = new KeyboardEvent("keydown", { + key: "ArrowDown", + code: "ArrowDown", + keyCode: 40, + which: 40, + bubbles: true, + cancelable: true, + }) + modelInput.dispatchEvent(event) + }, 10) + } }, }) @@ -570,8 +579,21 @@ const App: Component = () => { handleCycleAgent, handleCycleAgentReverse, () => { - const modelInput = document.querySelector("[data-model-selector] input") as HTMLInputElement - modelInput?.focus() + const modelInput = document.querySelector("[data-model-selector]") as HTMLInputElement + if (modelInput) { + modelInput.focus() + setTimeout(() => { + const event = new KeyboardEvent("keydown", { + key: "ArrowDown", + code: "ArrowDown", + keyCode: 40, + which: 40, + bubbles: true, + cancelable: true, + }) + modelInput.dispatchEvent(event) + }, 10) + } }, () => { const agentTrigger = document.querySelector("[data-agent-selector]") as HTMLElement @@ -617,9 +639,9 @@ const App: Component = () => { const isInCombobox = target.closest('[role="combobox"]') !== null const isInListbox = target.closest('[role="listbox"]') !== null - const isInSelect = target.closest('[role="button"][data-agent-selector]') !== null + const isInAgentSelect = target.closest('[role="button"][data-agent-selector]') !== null - if (isInCombobox || isInListbox || isInSelect) { + if (isInCombobox || isInListbox || isInAgentSelect) { return } diff --git a/src/components/model-selector.tsx b/src/components/model-selector.tsx index 0ac40ab5..a3ac4418 100644 --- a/src/components/model-selector.tsx +++ b/src/components/model-selector.tsx @@ -1,8 +1,8 @@ import { Combobox } from "@kobalte/core/combobox" -import { For, Show, createEffect, createMemo } from "solid-js" +import { createEffect, createMemo, createSignal } from "solid-js" import { providers, fetchProviders } from "../stores/sessions" import { ChevronDown } from "lucide-solid" -import type { Provider, Model } from "../types/session" +import type { Model } from "../types/session" import Kbd from "./kbd" interface ModelSelectorProps { @@ -14,11 +14,15 @@ interface ModelSelectorProps { interface FlatModel extends Model { providerName: string + key: string + searchText: string } export default function ModelSelector(props: ModelSelectorProps) { const instanceProviders = () => providers().get(props.instanceId) || [] - let inputRef!: HTMLInputElement + const [isOpen, setIsOpen] = createSignal(false) + let triggerRef!: HTMLButtonElement + let searchInputRef!: HTMLInputElement createEffect(() => { if (instanceProviders().length === 0) { @@ -26,16 +30,13 @@ export default function ModelSelector(props: ModelSelectorProps) { } }) - const handleFocus = (e: FocusEvent) => { - const input = e.target as HTMLInputElement - input.select() - } - const allModels = createMemo(() => instanceProviders().flatMap((p) => p.models.map((m) => ({ ...m, providerName: p.name, + key: `${m.providerId}/${m.id}`, + searchText: `${m.name} ${p.name} ${m.providerId} ${m.id} ${m.providerId}/${m.id}`, })), ), ) @@ -46,29 +47,38 @@ export default function ModelSelector(props: ModelSelectorProps) { const handleChange = async (value: FlatModel | null) => { if (!value) return - - if (value.providerId !== props.currentModel.providerId || value.id !== props.currentModel.modelId) { - await props.onModelChange({ providerId: value.providerId, modelId: value.id }) - } + await props.onModelChange({ providerId: value.providerId, modelId: value.id }) } + const customFilter = (option: FlatModel, inputValue: string) => { + return option.searchText.toLowerCase().includes(inputValue.toLowerCase()) + } + + createEffect(() => { + if (isOpen()) { + setTimeout(() => { + searchInputRef?.focus() + }, 100) + } + }) + return (
- value={currentModelValue()} onChange={handleChange} + onOpenChange={setIsOpen} options={allModels()} - optionValue={(m) => `${m.providerId}/${m.id}`} - optionTextValue={(m) => `${m.name} ${m.providerName} ${m.providerId}/${m.id}`} + optionValue="key" + optionTextValue="searchText" optionLabel="name" placeholder="Search models..." - defaultFilter="contains" - triggerMode="focus" - allowsEmptyCollection={false} + defaultFilter={customFilter} + allowsEmptyCollection itemComponent={(itemProps) => (
@@ -87,16 +97,13 @@ export default function ModelSelector(props: ModelSelectorProps) { )} > - - - + + + + Model: {currentModelValue()?.name ?? "None"} @@ -104,8 +111,15 @@ export default function ModelSelector(props: ModelSelectorProps) { - - + +
+ +
+