variance peristance

This commit is contained in:
Simon Klee
2026-03-30 15:27:14 +02:00
parent 1e00672517
commit 4e45169eec
2 changed files with 228 additions and 3 deletions

View File

@@ -1,5 +1,8 @@
import path from "path"
import { createCliRenderer, type CliRenderer, type ScrollbackWriter } from "@opentui/core"
import { TuiConfig } from "../../../config/tui"
import { Global } from "../../../global"
import { Filesystem } from "../../../util/filesystem"
import { Locale } from "../../../util/locale"
import { RunFooter } from "./footer"
import { entrySplash, exitSplash, splashMeta } from "./splash"
@@ -9,6 +12,7 @@ import type { FooterApi, FooterKeybinds, RunInput } from "./types"
const FOOTER_HEIGHT = 6
const HISTORY_LIMIT = 200
const MODEL_FILE = path.join(Global.Path.state, "model.json")
const DEFAULT_KEYBINDS: FooterKeybinds = {
leader: "ctrl+x",
@@ -65,10 +69,20 @@ type ModelInfo = {
limits: Record<string, number>
}
type SessionMessages = Awaited<ReturnType<RunInput["sdk"]["session"]["messages"]>>["data"]
type ModelState = {
variant?: Record<string, string | undefined>
}
function modelKey(provider: string, model: string): string {
return `${provider}/${model}`
}
function variantKey(model: NonNullable<RunInput["model"]>): string {
return modelKey(model.providerID, model.modelID)
}
async function resolveModelInfo(sdk: RunInput["sdk"], model: RunInput["model"]): Promise<ModelInfo> {
try {
const response = await sdk.provider.list()
@@ -150,6 +164,120 @@ async function resolvePromptHistory(sdk: RunInput["sdk"], sessionID: string): Pr
}
}
/** @internal Exported for testing */
export function pickVariant(model: RunInput["model"], messages: SessionMessages): string | undefined {
if (!model || !messages || messages.length === 0) {
return undefined
}
for (let index = messages.length - 1; index >= 0; index -= 1) {
const info = messages[index]?.info
if (!info || info.role !== "user") {
continue
}
if (info.model.providerID !== model.providerID || info.model.modelID !== model.modelID) {
continue
}
return info.variant
}
return undefined
}
function fitVariant(value: string | undefined, variants: string[]): string | undefined {
if (!value) {
return undefined
}
if (variants.length === 0 || variants.includes(value)) {
return value
}
return undefined
}
/** @internal Exported for testing */
export function resolveVariant(
input: string | undefined,
session: string | undefined,
saved: string | undefined,
variants: string[],
): string | undefined {
if (input !== undefined) {
return input
}
const fallback = fitVariant(saved, variants)
const current = fitVariant(session, variants)
if (current !== undefined) {
return current
}
return fallback
}
async function resolveStoredVariant(
sdk: RunInput["sdk"],
sessionID: string,
model: RunInput["model"],
): Promise<string | undefined> {
if (!model) {
return undefined
}
try {
const response = await sdk.session.messages({
sessionID,
limit: HISTORY_LIMIT,
})
return pickVariant(model, response.data)
} catch {
return undefined
}
}
async function resolveSavedVariant(model: RunInput["model"]): Promise<string | undefined> {
if (!model) {
return undefined
}
try {
const state = await Filesystem.readJson<ModelState>(MODEL_FILE)
return state.variant?.[variantKey(model)]
} catch {
return undefined
}
}
function saveVariant(model: RunInput["model"], variant: string | undefined): void {
if (!model) {
return
}
void (async () => {
const state = await Filesystem.readJson<ModelState>(MODEL_FILE).catch(() => ({}) as ModelState)
const map = {
...(state.variant ?? {}),
}
const key = variantKey(model)
if (variant) {
map[key] = variant
}
if (!variant) {
delete map[key]
}
await Filesystem.writeJson(MODEL_FILE, {
...state,
variant: map,
})
})().catch(() => {})
}
async function resolveFooterKeybinds(): Promise<FooterKeybinds> {
try {
const config = await TuiConfig.get()
@@ -370,11 +498,13 @@ export async function runPromptQueue(input: QueueInput): Promise<void> {
}
export async function runInteractiveMode(input: RunInput): Promise<void> {
const [keybinds, info, first, history] = await Promise.all([
const [keybinds, info, first, history, storedVariant, savedVariant] = await Promise.all([
resolveFooterKeybinds(),
resolveModelInfo(input.sdk, input.model),
resolveFirstPrompt(input.sdk, input.sessionID),
resolvePromptHistory(input.sdk, input.sessionID),
resolveStoredVariant(input.sdk, input.sessionID, input.model),
resolveSavedVariant(input.model),
])
const meta = splashMeta({
title: input.sessionTitle,
@@ -385,7 +515,7 @@ export async function runInteractiveMode(input: RunInput): Promise<void> {
exit: false,
}
const variants = info.variants
let activeVariant = input.variant
let activeVariant = resolveVariant(input.variant, storedVariant, savedVariant, variants)
let aborting = false
const renderer = await createCliRenderer({
@@ -423,6 +553,7 @@ export async function runInteractiveMode(input: RunInput): Promise<void> {
}
activeVariant = cycleVariant(activeVariant, variants)
saveVariant(input.model, activeVariant)
return {
status: activeVariant ? `variant ${activeVariant}` : "variant default",
modelLabel: formatModelLabel(input.model, activeVariant),

View File

@@ -1,5 +1,5 @@
import { describe, expect, test } from "bun:test"
import { queueSplash, runPromptQueue } from "../../../src/cli/cmd/run/runtime"
import { pickVariant, queueSplash, resolveVariant, runPromptQueue } from "../../../src/cli/cmd/run/runtime"
import type { EntryKind, FooterApi, FooterPatch } from "../../../src/cli/cmd/run/types"
function createFooter() {
@@ -75,6 +75,100 @@ function createFooter() {
}
describe("run runtime", () => {
test("restores variant from latest matching user message", () => {
expect(
pickVariant(
{
providerID: "openai",
modelID: "gpt-5",
},
[
{
info: {
role: "user",
model: {
providerID: "openai",
modelID: "gpt-5",
},
variant: "high",
},
},
{
info: {
role: "user",
model: {
providerID: "anthropic",
modelID: "claude-3",
},
variant: "max",
},
},
{
info: {
role: "user",
model: {
providerID: "openai",
modelID: "gpt-5",
},
variant: "minimal",
},
},
] as unknown as Parameters<typeof pickVariant>[1],
),
).toBe("minimal")
})
test("respects default variant from latest matching user message", () => {
expect(
pickVariant(
{
providerID: "openai",
modelID: "gpt-5",
},
[
{
info: {
role: "user",
model: {
providerID: "openai",
modelID: "gpt-5",
},
variant: "high",
},
},
{
info: {
role: "assistant",
providerID: "openai",
modelID: "gpt-5",
},
},
{
info: {
role: "user",
model: {
providerID: "openai",
modelID: "gpt-5",
},
},
},
] as unknown as Parameters<typeof pickVariant>[1],
),
).toBeUndefined()
})
test("keeps saved variant when session variant is default", () => {
expect(resolveVariant(undefined, undefined, "high", ["high", "minimal"])).toBe("high")
})
test("session variant overrides saved variant", () => {
expect(resolveVariant(undefined, "minimal", "high", ["high", "minimal"])).toBe("minimal")
})
test("cli variant overrides session and saved variant", () => {
expect(resolveVariant("custom", "minimal", "high", ["high", "minimal"])).toBe("custom")
})
test("queues entry and exit splash only once", () => {
const writes: unknown[] = []
let renders = 0