Compare commits

...

2 Commits

Author SHA1 Message Date
Shoubhit Dash
a34148bc06 Merge branch 'dev' into default-explore-models 2026-03-06 12:40:05 +05:30
Shoubhit Dash
ef288cf93d yeah 2026-03-06 12:38:06 +05:30
3 changed files with 293 additions and 51 deletions

View File

@@ -1231,6 +1231,42 @@ export namespace Provider {
}
}
async function pick(providerID: string, query: string[]) {
const provider = await state().then((state) => state.providers[providerID])
if (!provider) return
const models = Object.keys(provider.models)
for (const item of query) {
if (providerID === "amazon-bedrock") {
const prefixes = ["global.", "us.", "eu."]
const candidates = models.filter((model) => model.toLowerCase().includes(item.toLowerCase()))
// Model selection priority:
// 1. global. prefix (works everywhere)
// 2. User's region prefix (us., eu.)
// 3. Unprefixed model
const best = candidates.find((model) => model.startsWith("global."))
if (best) return getModel(providerID, best)
const region = provider.options?.region
if (region) {
const prefix = region.split("-")[0]
if (prefix === "us" || prefix === "eu") {
const hit = candidates.find((model) => model.startsWith(`${prefix}.`))
if (hit) return getModel(providerID, hit)
}
}
const bare = candidates.find((model) => !prefixes.some((prefix) => model.startsWith(prefix)))
if (bare) return getModel(providerID, bare)
continue
}
const hit = models.find((model) => model.toLowerCase().includes(item.toLowerCase()))
if (hit) return getModel(providerID, hit)
}
}
export async function getSmallModel(providerID: string) {
const cfg = await Config.get()
@@ -1239,54 +1275,25 @@ export namespace Provider {
return getModel(parsed.providerID, parsed.modelID)
}
const provider = await state().then((state) => state.providers[providerID])
if (provider) {
let priority = [
"claude-haiku-4-5",
"claude-haiku-4.5",
"3-5-haiku",
"3.5-haiku",
"gemini-3-flash",
"gemini-2.5-flash",
"gpt-5-nano",
]
if (providerID.startsWith("opencode")) {
priority = ["gpt-5-nano"]
}
if (providerID.startsWith("github-copilot")) {
// prioritize free models for github copilot
priority = ["gpt-5-mini", "claude-haiku-4.5", ...priority]
}
for (const item of priority) {
if (providerID === "amazon-bedrock") {
const crossRegionPrefixes = ["global.", "us.", "eu."]
const candidates = Object.keys(provider.models).filter((m) => m.includes(item))
// Model selection priority:
// 1. global. prefix (works everywhere)
// 2. User's region prefix (us., eu.)
// 3. Unprefixed model
const globalMatch = candidates.find((m) => m.startsWith("global."))
if (globalMatch) return getModel(providerID, globalMatch)
const region = provider.options?.region
if (region) {
const regionPrefix = region.split("-")[0]
if (regionPrefix === "us" || regionPrefix === "eu") {
const regionalMatch = candidates.find((m) => m.startsWith(`${regionPrefix}.`))
if (regionalMatch) return getModel(providerID, regionalMatch)
}
}
const unprefixed = candidates.find((m) => !crossRegionPrefixes.some((p) => m.startsWith(p)))
if (unprefixed) return getModel(providerID, unprefixed)
} else {
for (const model of Object.keys(provider.models)) {
if (model.includes(item)) return getModel(providerID, model)
}
}
}
let query = [
"claude-haiku-4-5",
"claude-haiku-4.5",
"3-5-haiku",
"3.5-haiku",
"gemini-3-flash",
"gemini-2.5-flash",
"gpt-5-nano",
]
if (providerID.startsWith("opencode")) {
query = ["gpt-5-nano"]
}
if (providerID.startsWith("github-copilot")) {
// prioritize free models for github copilot
query = ["gpt-5-mini", "claude-haiku-4.5", ...query]
}
const model = await pick(providerID, query)
if (model) return model
// Check if opencode provider is available before using it
const opencodeProvider = await state().then((state) => state.providers["opencode"])
@@ -1297,6 +1304,22 @@ export namespace Provider {
return undefined
}
export async function getExploreModel(providerID: string) {
const model = await pick(providerID, [
"gpt-5.3-codex-spark",
"claude-haiku-4-5",
"claude-haiku-4.5",
"gemini-3-flash",
"minimax-m2.5",
"minimax-m2-5",
"glm-5",
"kimi-k2.5",
"kimi-k2-5",
])
if (model) return model
return undefined
}
const priority = ["gpt-5", "claude-sonnet-4", "big-pickle", "gemini-3-pro"]
export function sort(models: Model[]) {
return sortBy(

View File

@@ -5,6 +5,7 @@ import { Session } from "../session"
import { MessageV2 } from "../session/message-v2"
import { Identifier } from "../id/id"
import { Agent } from "../agent/agent"
import { Provider } from "../provider/provider"
import { SessionPrompt } from "../session/prompt"
import { iife } from "@/util/iife"
import { defer } from "@/util/defer"
@@ -102,11 +103,30 @@ export const TaskTool = Tool.define("task", async (ctx) => {
})
const msg = await MessageV2.get({ sessionID: ctx.sessionID, messageID: ctx.messageID })
if (msg.info.role !== "assistant") throw new Error("Not an assistant message")
const info = msg.info
const model = agent.model ?? {
modelID: msg.info.modelID,
providerID: msg.info.providerID,
}
const model = await iife(async () => {
if (agent.model) return agent.model
if (agent.name !== "explore") {
return {
modelID: info.modelID,
providerID: info.providerID,
}
}
const pick = await Provider.getExploreModel(info.providerID)
if (pick) {
return {
modelID: pick.id,
providerID: pick.providerID,
}
}
return {
modelID: info.modelID,
providerID: info.providerID,
}
})
ctx.metadata({
title: params.description,

View File

@@ -964,6 +964,205 @@ test("getSmallModel respects config small_model override", async () => {
})
})
test("getExploreModel returns preferred explore model", async () => {
await using tmp = await tmpdir({
config: {
provider: {
"custom-provider": {
name: "Custom Provider",
npm: "@ai-sdk/openai-compatible",
api: "https://api.custom.com/v1",
env: ["CUSTOM_API_KEY"],
models: {
"gpt-5-3-codex-spark": {
name: "GPT-5.3 Codex Spark",
tool_call: true,
limit: {
context: 128000,
output: 4096,
},
},
"claude-haiku-4.5": {
name: "Claude Haiku 4.5",
tool_call: true,
limit: {
context: 128000,
output: 4096,
},
},
"gemini-3-flash-preview": {
name: "Gemini 3 Flash",
tool_call: true,
limit: {
context: 128000,
output: 4096,
},
},
"MiniMax-M2-5": {
name: "MiniMax M2.5",
tool_call: true,
limit: {
context: 128000,
output: 4096,
},
},
"GLM-5": {
name: "GLM-5",
tool_call: true,
limit: {
context: 128000,
output: 4096,
},
},
"Kimi-K2-5": {
name: "Kimi K2.5",
tool_call: true,
limit: {
context: 128000,
output: 4096,
},
},
},
options: {
apiKey: "custom-key",
},
},
},
},
})
await Instance.provide({
directory: tmp.path,
fn: async () => {
const model = await Provider.getExploreModel("custom-provider")
expect(model).toBeDefined()
expect(model?.id).toBe("gpt-5-3-codex-spark")
},
})
})
test("getExploreModel matches fallback models case-insensitively", async () => {
await using tmp = await tmpdir({
config: {
provider: {
"custom-provider": {
name: "Custom Provider",
npm: "@ai-sdk/openai-compatible",
api: "https://api.custom.com/v1",
env: ["CUSTOM_API_KEY"],
models: {
"MiniMax-M2-5": {
name: "MiniMax M2.5",
tool_call: true,
limit: {
context: 128000,
output: 4096,
},
},
"GLM-5": {
name: "GLM-5",
tool_call: true,
limit: {
context: 128000,
output: 4096,
},
},
"Kimi-K2-5": {
name: "Kimi K2.5",
tool_call: true,
limit: {
context: 128000,
output: 4096,
},
},
},
options: {
apiKey: "custom-key",
},
},
},
},
})
await Instance.provide({
directory: tmp.path,
fn: async () => {
const model = await Provider.getExploreModel("custom-provider")
expect(model).toBeDefined()
expect(model?.id).toBe("MiniMax-M2-5")
},
})
})
test("getExploreModel matches kimi separator variant", async () => {
await using tmp = await tmpdir({
config: {
provider: {
"custom-provider": {
name: "Custom Provider",
npm: "@ai-sdk/openai-compatible",
api: "https://api.custom.com/v1",
env: ["CUSTOM_API_KEY"],
models: {
"Kimi-K2-5": {
name: "Kimi K2.5",
tool_call: true,
limit: {
context: 128000,
output: 4096,
},
},
},
options: {
apiKey: "custom-key",
},
},
},
},
})
await Instance.provide({
directory: tmp.path,
fn: async () => {
const model = await Provider.getExploreModel("custom-provider")
expect(model).toBeDefined()
expect(model?.id).toBe("Kimi-K2-5")
},
})
})
test("getExploreModel returns undefined when no explore model matches", async () => {
await using tmp = await tmpdir({
config: {
provider: {
"custom-provider": {
name: "Custom Provider",
npm: "@ai-sdk/openai-compatible",
api: "https://api.custom.com/v1",
env: ["CUSTOM_API_KEY"],
models: {
"custom-model": {
name: "Custom Model",
tool_call: true,
limit: {
context: 128000,
output: 4096,
},
},
},
options: {
apiKey: "custom-key",
},
},
},
},
})
await Instance.provide({
directory: tmp.path,
fn: async () => {
const model = await Provider.getExploreModel("custom-provider")
expect(model).toBeUndefined()
},
})
})
test("provider.sort prioritizes preferred models", () => {
const models = [
{ id: "random-model", name: "Random" },