diff --git a/packages/opencode/src/cli/cmd/run/runtime.ts b/packages/opencode/src/cli/cmd/run/runtime.ts index 41207a912d..a745d7ed33 100644 --- a/packages/opencode/src/cli/cmd/run/runtime.ts +++ b/packages/opencode/src/cli/cmd/run/runtime.ts @@ -1,5 +1,8 @@ +import path from "path" import { createCliRenderer, type CliRenderer, type ScrollbackWriter } from "@opentui/core" import { TuiConfig } from "../../../config/tui" +import { Global } from "../../../global" +import { Filesystem } from "../../../util/filesystem" import { Locale } from "../../../util/locale" import { RunFooter } from "./footer" import { entrySplash, exitSplash, splashMeta } from "./splash" @@ -9,6 +12,7 @@ import type { FooterApi, FooterKeybinds, RunInput } from "./types" const FOOTER_HEIGHT = 6 const HISTORY_LIMIT = 200 +const MODEL_FILE = path.join(Global.Path.state, "model.json") const DEFAULT_KEYBINDS: FooterKeybinds = { leader: "ctrl+x", @@ -65,10 +69,20 @@ type ModelInfo = { limits: Record } +type SessionMessages = Awaited>["data"] + +type ModelState = { + variant?: Record +} + function modelKey(provider: string, model: string): string { return `${provider}/${model}` } +function variantKey(model: NonNullable): string { + return modelKey(model.providerID, model.modelID) +} + async function resolveModelInfo(sdk: RunInput["sdk"], model: RunInput["model"]): Promise { try { const response = await sdk.provider.list() @@ -150,6 +164,120 @@ async function resolvePromptHistory(sdk: RunInput["sdk"], sessionID: string): Pr } } +/** @internal Exported for testing */ +export function pickVariant(model: RunInput["model"], messages: SessionMessages): string | undefined { + if (!model || !messages || messages.length === 0) { + return undefined + } + + for (let index = messages.length - 1; index >= 0; index -= 1) { + const info = messages[index]?.info + if (!info || info.role !== "user") { + continue + } + + if (info.model.providerID !== model.providerID || info.model.modelID !== model.modelID) { + continue + } + + return info.variant + } + + return undefined +} + +function fitVariant(value: string | undefined, variants: string[]): string | undefined { + if (!value) { + return undefined + } + + if (variants.length === 0 || variants.includes(value)) { + return value + } + + return undefined +} + +/** @internal Exported for testing */ +export function resolveVariant( + input: string | undefined, + session: string | undefined, + saved: string | undefined, + variants: string[], +): string | undefined { + if (input !== undefined) { + return input + } + + const fallback = fitVariant(saved, variants) + const current = fitVariant(session, variants) + if (current !== undefined) { + return current + } + + return fallback +} + +async function resolveStoredVariant( + sdk: RunInput["sdk"], + sessionID: string, + model: RunInput["model"], +): Promise { + if (!model) { + return undefined + } + + try { + const response = await sdk.session.messages({ + sessionID, + limit: HISTORY_LIMIT, + }) + + return pickVariant(model, response.data) + } catch { + return undefined + } +} + +async function resolveSavedVariant(model: RunInput["model"]): Promise { + if (!model) { + return undefined + } + + try { + const state = await Filesystem.readJson(MODEL_FILE) + return state.variant?.[variantKey(model)] + } catch { + return undefined + } +} + +function saveVariant(model: RunInput["model"], variant: string | undefined): void { + if (!model) { + return + } + + void (async () => { + const state = await Filesystem.readJson(MODEL_FILE).catch(() => ({}) as ModelState) + const map = { + ...(state.variant ?? {}), + } + const key = variantKey(model) + if (variant) { + map[key] = variant + } + + if (!variant) { + delete map[key] + } + + await Filesystem.writeJson(MODEL_FILE, { + ...state, + variant: map, + }) + })().catch(() => {}) +} + async function resolveFooterKeybinds(): Promise { try { const config = await TuiConfig.get() @@ -370,11 +498,13 @@ export async function runPromptQueue(input: QueueInput): Promise { } export async function runInteractiveMode(input: RunInput): Promise { - const [keybinds, info, first, history] = await Promise.all([ + const [keybinds, info, first, history, storedVariant, savedVariant] = await Promise.all([ resolveFooterKeybinds(), resolveModelInfo(input.sdk, input.model), resolveFirstPrompt(input.sdk, input.sessionID), resolvePromptHistory(input.sdk, input.sessionID), + resolveStoredVariant(input.sdk, input.sessionID, input.model), + resolveSavedVariant(input.model), ]) const meta = splashMeta({ title: input.sessionTitle, @@ -385,7 +515,7 @@ export async function runInteractiveMode(input: RunInput): Promise { exit: false, } const variants = info.variants - let activeVariant = input.variant + let activeVariant = resolveVariant(input.variant, storedVariant, savedVariant, variants) let aborting = false const renderer = await createCliRenderer({ @@ -423,6 +553,7 @@ export async function runInteractiveMode(input: RunInput): Promise { } activeVariant = cycleVariant(activeVariant, variants) + saveVariant(input.model, activeVariant) return { status: activeVariant ? `variant ${activeVariant}` : "variant default", modelLabel: formatModelLabel(input.model, activeVariant), diff --git a/packages/opencode/test/cli/run/direct-runtime.test.ts b/packages/opencode/test/cli/run/direct-runtime.test.ts index 45274ee8e6..9e72bbedfc 100644 --- a/packages/opencode/test/cli/run/direct-runtime.test.ts +++ b/packages/opencode/test/cli/run/direct-runtime.test.ts @@ -1,5 +1,5 @@ import { describe, expect, test } from "bun:test" -import { queueSplash, runPromptQueue } from "../../../src/cli/cmd/run/runtime" +import { pickVariant, queueSplash, resolveVariant, runPromptQueue } from "../../../src/cli/cmd/run/runtime" import type { EntryKind, FooterApi, FooterPatch } from "../../../src/cli/cmd/run/types" function createFooter() { @@ -75,6 +75,100 @@ function createFooter() { } describe("run runtime", () => { + test("restores variant from latest matching user message", () => { + expect( + pickVariant( + { + providerID: "openai", + modelID: "gpt-5", + }, + [ + { + info: { + role: "user", + model: { + providerID: "openai", + modelID: "gpt-5", + }, + variant: "high", + }, + }, + { + info: { + role: "user", + model: { + providerID: "anthropic", + modelID: "claude-3", + }, + variant: "max", + }, + }, + { + info: { + role: "user", + model: { + providerID: "openai", + modelID: "gpt-5", + }, + variant: "minimal", + }, + }, + ] as unknown as Parameters[1], + ), + ).toBe("minimal") + }) + + test("respects default variant from latest matching user message", () => { + expect( + pickVariant( + { + providerID: "openai", + modelID: "gpt-5", + }, + [ + { + info: { + role: "user", + model: { + providerID: "openai", + modelID: "gpt-5", + }, + variant: "high", + }, + }, + { + info: { + role: "assistant", + providerID: "openai", + modelID: "gpt-5", + }, + }, + { + info: { + role: "user", + model: { + providerID: "openai", + modelID: "gpt-5", + }, + }, + }, + ] as unknown as Parameters[1], + ), + ).toBeUndefined() + }) + + test("keeps saved variant when session variant is default", () => { + expect(resolveVariant(undefined, undefined, "high", ["high", "minimal"])).toBe("high") + }) + + test("session variant overrides saved variant", () => { + expect(resolveVariant(undefined, "minimal", "high", ["high", "minimal"])).toBe("minimal") + }) + + test("cli variant overrides session and saved variant", () => { + expect(resolveVariant("custom", "minimal", "high", ["high", "minimal"])).toBe("custom") + }) + test("queues entry and exit splash only once", () => { const writes: unknown[] = [] let renders = 0