fix: azure issue where azure sdk was being used instead of anthropic one for anthropic models

This commit is contained in:
Aiden Cline
2026-01-18 22:00:01 -06:00
parent 0d8e706fac
commit 84c4fe971a

View File

@@ -66,7 +66,7 @@ export namespace Provider {
"@ai-sdk/github-copilot": createGitHubCopilotOpenAICompatible, "@ai-sdk/github-copilot": createGitHubCopilotOpenAICompatible,
} }
type CustomModelLoader = (sdk: any, modelID: string, options?: Record<string, any>) => Promise<any> type CustomModelLoader = (sdk: any, model: Model, options?: Record<string, any>) => Promise<any>
type CustomLoader = (provider: Info) => Promise<{ type CustomLoader = (provider: Info) => Promise<{
autoload: boolean autoload: boolean
getModel?: CustomModelLoader getModel?: CustomModelLoader
@@ -110,8 +110,8 @@ export namespace Provider {
openai: async () => { openai: async () => {
return { return {
autoload: false, autoload: false,
async getModel(sdk: any, modelID: string, _options?: Record<string, any>) { async getModel(sdk: any, model: Model, _options?: Record<string, any>) {
return sdk.responses(modelID) return sdk.responses(model.api.id)
}, },
options: {}, options: {},
} }
@@ -119,11 +119,11 @@ export namespace Provider {
"github-copilot": async () => { "github-copilot": async () => {
return { return {
autoload: false, autoload: false,
async getModel(sdk: any, modelID: string, _options?: Record<string, any>) { async getModel(sdk: any, model: Model, _options?: Record<string, any>) {
if (modelID.includes("codex")) { if (model.api.id.includes("codex")) {
return sdk.responses(modelID) return sdk.responses(model.api.id)
} }
return sdk.chat(modelID) return sdk.chat(model.api.id)
}, },
options: {}, options: {},
} }
@@ -131,11 +131,11 @@ export namespace Provider {
"github-copilot-enterprise": async () => { "github-copilot-enterprise": async () => {
return { return {
autoload: false, autoload: false,
async getModel(sdk: any, modelID: string, _options?: Record<string, any>) { async getModel(sdk: any, model: Model, _options?: Record<string, any>) {
if (modelID.includes("codex")) { if (model.api.id.includes("codex")) {
return sdk.responses(modelID) return sdk.responses(model.api.id)
} }
return sdk.chat(modelID) return sdk.chat(model.api.id)
}, },
options: {}, options: {},
} }
@@ -143,12 +143,14 @@ export namespace Provider {
azure: async () => { azure: async () => {
return { return {
autoload: false, autoload: false,
async getModel(sdk: any, modelID: string, options?: Record<string, any>) { async getModel(sdk: any, model: Model, options?: Record<string, any>) {
if (options?.["useCompletionUrls"]) { if (model && model.api.npm !== "@ai-sdk/azure") {
return sdk.chat(modelID) return sdk.languageModel(model.api.id)
} else {
return sdk.responses(modelID)
} }
if (options?.["useCompletionUrls"]) {
return sdk.chat(model.api.id)
}
return sdk.responses(model.api.id)
}, },
options: {}, options: {},
} }
@@ -157,12 +159,14 @@ export namespace Provider {
const resourceName = Env.get("AZURE_COGNITIVE_SERVICES_RESOURCE_NAME") const resourceName = Env.get("AZURE_COGNITIVE_SERVICES_RESOURCE_NAME")
return { return {
autoload: false, autoload: false,
async getModel(sdk: any, modelID: string, options?: Record<string, any>) { async getModel(sdk: any, model: Model, options?: Record<string, any>) {
if (options?.["useCompletionUrls"]) { if (model && model.api.npm !== "@ai-sdk/azure") {
return sdk.chat(modelID) return sdk.languageModel(model.api.id)
} else {
return sdk.responses(modelID)
} }
if (options?.["useCompletionUrls"]) {
return sdk.chat(model.api.id)
}
return sdk.responses(model.api.id)
}, },
options: { options: {
baseURL: resourceName ? `https://${resourceName}.cognitiveservices.azure.com/openai` : undefined, baseURL: resourceName ? `https://${resourceName}.cognitiveservices.azure.com/openai` : undefined,
@@ -225,7 +229,8 @@ export namespace Provider {
return { return {
autoload: true, autoload: true,
options: providerOptions, options: providerOptions,
async getModel(sdk: any, modelID: string, options?: Record<string, any>) { async getModel(sdk: any, model: Model, options?: Record<string, any>) {
let modelID = model.api.id
// Skip region prefixing if model already has a cross-region inference profile prefix // Skip region prefixing if model already has a cross-region inference profile prefix
if (modelID.startsWith("global.") || modelID.startsWith("jp.")) { if (modelID.startsWith("global.") || modelID.startsWith("jp.")) {
return sdk.languageModel(modelID) return sdk.languageModel(modelID)
@@ -343,8 +348,8 @@ export namespace Provider {
project, project,
location, location,
}, },
async getModel(sdk: any, modelID: string) { async getModel(sdk: any, model: Model) {
const id = String(modelID).trim() const id = String(model.api.id).trim()
return sdk.languageModel(id) return sdk.languageModel(id)
}, },
} }
@@ -360,8 +365,8 @@ export namespace Provider {
project, project,
location, location,
}, },
async getModel(sdk: any, modelID) { async getModel(sdk: any, model: Model) {
const id = String(modelID).trim() const id = String(model.api.id).trim()
return sdk.languageModel(id) return sdk.languageModel(id)
}, },
} }
@@ -383,8 +388,8 @@ export namespace Provider {
return { return {
autoload: !!envServiceKey, autoload: !!envServiceKey,
options: envServiceKey ? { deploymentId, resourceGroup } : {}, options: envServiceKey ? { deploymentId, resourceGroup } : {},
async getModel(sdk: any, modelID: string) { async getModel(sdk: any, model: Model) {
return sdk(modelID) return sdk(model.api.id)
}, },
} }
}, },
@@ -423,8 +428,8 @@ export namespace Provider {
...(providerConfig?.options?.featureFlags || {}), ...(providerConfig?.options?.featureFlags || {}),
}, },
}, },
async getModel(sdk: ReturnType<typeof createGitLab>, modelID: string) { async getModel(sdk: ReturnType<typeof createGitLab>, model: Model) {
return sdk.agenticChat(modelID, { return sdk.agenticChat(model.api.id, {
featureFlags: { featureFlags: {
duo_agent_platform_agentic_chat: true, duo_agent_platform_agentic_chat: true,
duo_agent_platform: true, duo_agent_platform: true,
@@ -451,8 +456,8 @@ export namespace Provider {
return { return {
autoload: true, autoload: true,
async getModel(sdk: any, modelID: string, _options?: Record<string, any>) { async getModel(sdk: any, model: Model, _options?: Record<string, any>) {
return sdk.languageModel(modelID) return sdk.languageModel(model.api.id)
}, },
options: { options: {
baseURL: `https://gateway.ai.cloudflare.com/v1/${accountId}/${gateway}/compat`, baseURL: `https://gateway.ai.cloudflare.com/v1/${accountId}/${gateway}/compat`,
@@ -1093,9 +1098,8 @@ export namespace Provider {
const sdk = await getSDK(model) const sdk = await getSDK(model)
try { try {
const language = s.modelLoaders[model.providerID] const loader = s.modelLoaders[model.providerID]
? await s.modelLoaders[model.providerID](sdk, model.api.id, provider.options) const language = loader ? await loader(sdk, model, provider.options) : sdk.languageModel(model.api.id)
: sdk.languageModel(model.api.id)
s.models.set(key, language) s.models.set(key, language)
return language return language
} catch (e) { } catch (e) {