diff --git a/packages/app/src/context/local.tsx b/packages/app/src/context/local.tsx index f467e9034f..4465a0261d 100644 --- a/packages/app/src/context/local.tsx +++ b/packages/app/src/context/local.tsx @@ -44,7 +44,7 @@ const migrate = (value: unknown) => { } const clone = (value: State | undefined) => { - if (!value) return undefined + if (!value) return return { ...value, model: value.model ? { ...value.model } : undefined, @@ -104,7 +104,7 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({ const pickAgent = (name: string | undefined) => { const items = list() - if (items.length === 0) return undefined + if (items.length === 0) return return items.find((item) => item.name === name) ?? items[0] } @@ -227,14 +227,14 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({ () => agent.current()?.model, fallback, ) - if (!item) return undefined + if (!item) return return models.find(item) } const configured = () => { const item = agent.current() const model = current() - if (!item || !model) return undefined + if (!item || !model) return return getConfiguredAgentVariant({ agent: { model: item.model, variant: item.variant }, model: { providerID: model.provider.id, modelID: model.id, variants: model.variants }, @@ -314,11 +314,16 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({ configured, selected, current() { - return resolveModelVariant({ + const resolved = resolveModelVariant({ variants: this.list(), selected: this.selected(), configured: this.configured(), }) + if (resolved) return resolved + const model = current() + if (!model) return + const saved = models.variant.get({ providerID: model.provider.id, modelID: model.id }) + if (saved && this.list().includes(saved)) return saved }, list() { const item = current() @@ -335,6 +340,9 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({ variant: value ?? null, }) write({ variant: value ?? null }) + if (model) { + models.variant.set({ providerID: model.provider.id, modelID: model.id }, value ?? undefined) + } }) }, cycle() {