feat(core): implement towards policy-driven model fallback mechanism (#13781)

This commit is contained in:
Adam Weidman
2025-11-26 12:36:42 -08:00
committed by GitHub
parent 0f12d6c426
commit 87edeb4e32
8 changed files with 550 additions and 40 deletions

View File

@@ -64,7 +64,7 @@ describe('ModelAvailabilityService', () => {
healthyModel,
]);
expect(first).toEqual({
selected: stickyModel,
selectedModel: stickyModel,
attempts: 1,
skipped: [
{
@@ -81,7 +81,7 @@ describe('ModelAvailabilityService', () => {
healthyModel,
]);
expect(second).toEqual({
selected: healthyModel,
selectedModel: healthyModel,
skipped: [
{
model,
@@ -101,7 +101,7 @@ describe('ModelAvailabilityService', () => {
healthyModel,
]);
expect(third).toEqual({
selected: stickyModel,
selectedModel: stickyModel,
attempts: 1,
skipped: [
{

View File

@@ -30,7 +30,7 @@ export interface ModelAvailabilitySnapshot {
}
export interface ModelSelectionResult {
selected: ModelId | null;
selectedModel: ModelId | null;
attempts?: number;
skipped: Array<{
model: ModelId;
@@ -107,12 +107,12 @@ export class ModelAvailabilityService {
const state = this.health.get(model);
// A sticky model is being attempted, so note that.
const attempts = state?.status === 'sticky_retry' ? 1 : undefined;
return { selected: model, skipped, attempts };
return { selectedModel: model, skipped, attempts };
} else {
skipped.push({ model, reason: snapshot.reason ?? 'unknown' });
}
}
return { selected: null, skipped };
return { selectedModel: null, skipped };
}
resetTurn() {

View File

@@ -0,0 +1,59 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect } from 'vitest';
import {
resolvePolicyChain,
buildFallbackPolicyContext,
} from './policyHelpers.js';
import { createDefaultPolicy } from './policyCatalog.js';
import type { Config } from '../config/config.js';
describe('policyHelpers', () => {
describe('resolvePolicyChain', () => {
it('inserts the active model when missing from the catalog', () => {
const config = {
getPreviewFeatures: () => false,
getUserTier: () => undefined,
getModel: () => 'custom-model',
isInFallbackMode: () => false,
} as unknown as Config;
const chain = resolvePolicyChain(config);
expect(chain[0]?.model).toBe('custom-model');
});
it('leaves catalog order untouched when active model already present', () => {
const config = {
getPreviewFeatures: () => false,
getUserTier: () => undefined,
getModel: () => 'gemini-2.5-pro',
isInFallbackMode: () => false,
} as unknown as Config;
const chain = resolvePolicyChain(config);
expect(chain[0]?.model).toBe('gemini-2.5-pro');
});
});
describe('buildFallbackPolicyContext', () => {
it('returns remaining candidates after the failed model', () => {
const chain = [
createDefaultPolicy('a'),
createDefaultPolicy('b'),
createDefaultPolicy('c'),
];
const context = buildFallbackPolicyContext(chain, 'b');
expect(context.failedPolicy?.model).toBe('b');
expect(context.candidates.map((p) => p.model)).toEqual(['c']);
});
it('returns full chain when model is not in policy list', () => {
const chain = [createDefaultPolicy('a'), createDefaultPolicy('b')];
const context = buildFallbackPolicyContext(chain, 'x');
expect(context.failedPolicy).toBeUndefined();
expect(context.candidates).toEqual(chain);
});
});
});

View File

@@ -0,0 +1,66 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { Config } from '../config/config.js';
import type {
FailureKind,
FallbackAction,
ModelPolicy,
ModelPolicyChain,
} from './modelPolicy.js';
import { createDefaultPolicy, getModelPolicyChain } from './policyCatalog.js';
import { getEffectiveModel } from '../config/models.js';
/**
* Resolves the active policy chain for the given config, ensuring the
* user-selected active model is represented.
*/
export function resolvePolicyChain(config: Config): ModelPolicyChain {
const chain = getModelPolicyChain({
previewEnabled: !!config.getPreviewFeatures(),
userTier: config.getUserTier(),
});
// TODO: This will be replaced when we get rid of Fallback Modes
const activeModel = getEffectiveModel(
config.isInFallbackMode(),
config.getModel(),
config.getPreviewFeatures(),
);
if (chain.some((policy) => policy.model === activeModel)) {
return chain;
}
return [createDefaultPolicy(activeModel), ...chain];
}
/**
* Produces the failed policy (if it exists in the chain) and the list of
* fallback candidates that follow it.
*/
export function buildFallbackPolicyContext(
chain: ModelPolicyChain,
failedModel: string,
): {
failedPolicy?: ModelPolicy;
candidates: ModelPolicy[];
} {
const index = chain.findIndex((policy) => policy.model === failedModel);
if (index === -1) {
return { failedPolicy: undefined, candidates: chain };
}
return {
failedPolicy: chain[index],
candidates: chain.slice(index + 1),
};
}
export function resolvePolicyAction(
failureKind: FailureKind,
policy: ModelPolicy,
): FallbackAction {
return policy.actions?.[failureKind] ?? 'prompt';
}

View File

@@ -59,6 +59,7 @@ import { StandardFileSystemService } from '../services/fileSystemService.js';
import { logRipgrepFallback } from '../telemetry/loggers.js';
import { RipgrepFallbackEvent } from '../telemetry/types.js';
import type { FallbackModelHandler } from '../fallback/types.js';
import { ModelAvailabilityService } from '../availability/modelAvailabilityService.js';
import { ModelRouterService } from '../routing/modelRouterService.js';
import { OutputFormat } from '../output/types.js';
import type { ModelConfigServiceConfig } from '../services/modelConfigService.js';
@@ -347,6 +348,7 @@ export class Config {
private geminiClient!: GeminiClient;
private baseLlmClient!: BaseLlmClient;
private modelRouterService: ModelRouterService;
private readonly modelAvailabilityService: ModelAvailabilityService;
private readonly fileFiltering: {
respectGitIgnore: boolean;
respectGeminiIgnore: boolean;
@@ -483,6 +485,7 @@ export class Config {
this.model = params.model;
this.enableModelAvailabilityService =
params.enableModelAvailabilityService ?? false;
this.modelAvailabilityService = new ModelAvailabilityService();
this.previewFeatures = params.previewFeatures ?? undefined;
this.maxSessionTurns = params.maxSessionTurns ?? -1;
this.experimentalZedIntegration =
@@ -1044,6 +1047,10 @@ export class Config {
return this.modelRouterService;
}
getModelAvailabilityService(): ModelAvailabilityService {
return this.modelAvailabilityService;
}
getEnableRecursiveFileSearch(): boolean {
return this.fileFiltering.enableRecursiveFileSearch;
}

View File

@@ -16,6 +16,7 @@ import {
} from 'vitest';
import { handleFallback } from './handler.js';
import type { Config } from '../config/config.js';
import type { ModelAvailabilityService } from '../availability/modelAvailabilityService.js';
import { AuthType } from '../core/contentGenerator.js';
import {
DEFAULT_GEMINI_FLASH_MODEL,
@@ -25,6 +26,11 @@ import {
import { logFlashFallback } from '../telemetry/index.js';
import type { FallbackModelHandler } from './types.js';
import { ModelNotFoundError } from '../utils/httpErrors.js';
import { openBrowserSecurely } from '../utils/secure-browser-launcher.js';
import { coreEvents } from '../utils/events.js';
import { debugLogger } from '../utils/debugLogger.js';
import * as policyHelpers from '../availability/policyHelpers.js';
import { createDefaultPolicy } from '../availability/policyCatalog.js';
import {
RetryableQuotaError,
TerminalQuotaError,
@@ -35,22 +41,46 @@ vi.mock('../telemetry/index.js', () => ({
logFlashFallback: vi.fn(),
FlashFallbackEvent: class {},
}));
vi.mock('../utils/secure-browser-launcher.js', () => ({
openBrowserSecurely: vi.fn(),
}));
const MOCK_PRO_MODEL = DEFAULT_GEMINI_MODEL;
const FALLBACK_MODEL = DEFAULT_GEMINI_FLASH_MODEL;
const AUTH_OAUTH = AuthType.LOGIN_WITH_GOOGLE;
const AUTH_API_KEY = AuthType.USE_GEMINI;
function createAvailabilityMock(
result: ReturnType<ModelAvailabilityService['selectFirstAvailable']>,
): ModelAvailabilityService {
return {
markTerminal: vi.fn(),
markHealthy: vi.fn(),
markRetryOncePerTurn: vi.fn(),
consumeStickyAttempt: vi.fn(),
snapshot: vi.fn(),
selectFirstAvailable: vi.fn().mockReturnValue(result),
resetTurn: vi.fn(),
} as unknown as ModelAvailabilityService;
}
const createMockConfig = (overrides: Partial<Config> = {}): Config =>
({
isInFallbackMode: vi.fn(() => false),
setFallbackMode: vi.fn(),
isModelAvailabilityServiceEnabled: vi.fn(() => false),
isPreviewModelFallbackMode: vi.fn(() => false),
setPreviewModelFallbackMode: vi.fn(),
isPreviewModelBypassMode: vi.fn(() => false),
setPreviewModelBypassMode: vi.fn(),
fallbackHandler: undefined,
getFallbackModelHandler: vi.fn(),
getModelAvailabilityService: vi.fn(() =>
createAvailabilityMock({ selectedModel: FALLBACK_MODEL, skipped: [] }),
),
getModel: vi.fn(() => MOCK_PRO_MODEL),
getPreviewFeatures: vi.fn(() => false),
getUserTier: vi.fn(() => undefined),
isInteractive: vi.fn(() => false),
...overrides,
}) as unknown as Config;
@@ -59,6 +89,7 @@ describe('handleFallback', () => {
let mockConfig: Config;
let mockHandler: Mock<FallbackModelHandler>;
let consoleErrorSpy: MockInstance;
let fallbackEventSpy: MockInstance;
beforeEach(() => {
vi.clearAllMocks();
@@ -68,10 +99,12 @@ describe('handleFallback', () => {
fallbackModelHandler: mockHandler,
});
consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {});
fallbackEventSpy = vi.spyOn(coreEvents, 'emitFallbackModeChanged');
});
afterEach(() => {
consoleErrorSpy.mockRestore();
fallbackEventSpy.mockRestore();
});
it('should return null immediately if authType is not OAuth', async () => {
@@ -140,6 +173,53 @@ describe('handleFallback', () => {
});
});
it('should return false without toggling fallback when handler returns "retry_later"', async () => {
mockHandler.mockResolvedValue('retry_later');
const result = await handleFallback(mockConfig, MOCK_PRO_MODEL, AUTH_OAUTH);
expect(result).toBe(false);
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
expect(logFlashFallback).not.toHaveBeenCalled();
expect(fallbackEventSpy).not.toHaveBeenCalled();
});
it('should launch upgrade flow and avoid fallback mode when handler returns "upgrade"', async () => {
mockHandler.mockResolvedValue('upgrade');
vi.mocked(openBrowserSecurely).mockResolvedValue(undefined);
const result = await handleFallback(mockConfig, MOCK_PRO_MODEL, AUTH_OAUTH);
expect(result).toBe(false);
expect(openBrowserSecurely).toHaveBeenCalledWith(
'https://goo.gle/set-up-gemini-code-assist',
);
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
expect(logFlashFallback).not.toHaveBeenCalled();
expect(fallbackEventSpy).not.toHaveBeenCalled();
});
it('should log a warning and continue when upgrade flow fails to open a browser', async () => {
mockHandler.mockResolvedValue('upgrade');
const debugWarnSpy = vi.spyOn(debugLogger, 'warn');
const consoleWarnSpy = vi
.spyOn(console, 'warn')
.mockImplementation(() => {});
vi.mocked(openBrowserSecurely).mockRejectedValue(new Error('blocked'));
const result = await handleFallback(mockConfig, MOCK_PRO_MODEL, AUTH_OAUTH);
expect(result).toBe(false);
expect(debugWarnSpy).toHaveBeenCalledWith(
'Failed to open browser automatically:',
'blocked',
);
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
expect(fallbackEventSpy).not.toHaveBeenCalled();
debugWarnSpy.mockRestore();
consoleWarnSpy.mockRestore();
});
describe('when handler returns an unexpected value', () => {
it('should log an error and return null', async () => {
mockHandler.mockResolvedValue(null);
@@ -450,4 +530,142 @@ describe('handleFallback', () => {
expect(result).toBe(true);
expect(mockHandler).toHaveBeenCalled();
});
describe('policy-driven flow', () => {
let policyConfig: Config;
let availability: ModelAvailabilityService;
let policyHandler: Mock<FallbackModelHandler>;
beforeEach(() => {
vi.clearAllMocks();
availability = createAvailabilityMock({
selectedModel: 'gemini-1.5-flash',
skipped: [],
});
policyHandler = vi.fn().mockResolvedValue('retry_once');
policyConfig = createMockConfig();
vi.spyOn(
policyConfig,
'isModelAvailabilityServiceEnabled',
).mockReturnValue(true);
vi.spyOn(policyConfig, 'getModelAvailabilityService').mockReturnValue(
availability,
);
vi.spyOn(policyConfig, 'getFallbackModelHandler').mockReturnValue(
policyHandler,
);
});
it('uses availability selection when enabled', async () => {
await handleFallback(policyConfig, MOCK_PRO_MODEL, AUTH_OAUTH);
expect(availability.selectFirstAvailable).toHaveBeenCalled();
});
it('falls back to last resort when availability returns null', async () => {
availability.selectFirstAvailable = vi
.fn()
.mockReturnValue({ selectedModel: null, skipped: [] });
policyHandler.mockResolvedValue('retry_once');
await handleFallback(policyConfig, MOCK_PRO_MODEL, AUTH_OAUTH);
expect(policyHandler).toHaveBeenCalledWith(
MOCK_PRO_MODEL,
DEFAULT_GEMINI_FLASH_MODEL,
undefined,
);
});
it('executes silent policy action without invoking UI handler', async () => {
const proPolicy = createDefaultPolicy(MOCK_PRO_MODEL);
const flashPolicy = createDefaultPolicy(DEFAULT_GEMINI_FLASH_MODEL);
flashPolicy.actions = {
...flashPolicy.actions,
terminal: 'silent',
unknown: 'silent',
};
flashPolicy.isLastResort = true;
const silentChain = [proPolicy, flashPolicy];
const chainSpy = vi
.spyOn(policyHelpers, 'resolvePolicyChain')
.mockReturnValue(silentChain);
try {
availability.selectFirstAvailable = vi.fn().mockReturnValue({
selectedModel: DEFAULT_GEMINI_FLASH_MODEL,
skipped: [],
});
const result = await handleFallback(
policyConfig,
MOCK_PRO_MODEL,
AUTH_OAUTH,
);
expect(result).toBe(true);
expect(policyConfig.getFallbackModelHandler).not.toHaveBeenCalled();
expect(policyConfig.setFallbackMode).toHaveBeenCalledWith(true);
} finally {
chainSpy.mockRestore();
}
});
it('logs and returns null when handler resolves to null', async () => {
policyHandler.mockResolvedValue(null);
const debugLoggerErrorSpy = vi.spyOn(debugLogger, 'error');
const result = await handleFallback(
policyConfig,
MOCK_PRO_MODEL,
AUTH_OAUTH,
);
expect(result).toBeNull();
expect(debugLoggerErrorSpy).toHaveBeenCalledWith(
'Fallback handler failed:',
new Error(
'Unexpected fallback intent received from fallbackModelHandler: "null"',
),
);
debugLoggerErrorSpy.mockRestore();
});
it('successfully follows expected availability response for Preview Chain', async () => {
availability.selectFirstAvailable = vi
.fn()
.mockReturnValue({ selectedModel: DEFAULT_GEMINI_MODEL, skipped: [] });
policyHandler.mockResolvedValue('retry_once');
vi.spyOn(policyConfig, 'getPreviewFeatures').mockReturnValue(true);
vi.spyOn(policyConfig, 'getModel').mockReturnValue(PREVIEW_GEMINI_MODEL);
const result = await handleFallback(
policyConfig,
PREVIEW_GEMINI_MODEL,
AUTH_OAUTH,
);
expect(result).toBe(true);
expect(availability.selectFirstAvailable).toHaveBeenCalledWith([
DEFAULT_GEMINI_MODEL,
DEFAULT_GEMINI_FLASH_MODEL,
]);
expect(policyHandler).toHaveBeenCalledWith(
PREVIEW_GEMINI_MODEL,
DEFAULT_GEMINI_MODEL,
undefined,
);
});
it('short-circuits when the failed model is already the last-resort policy', async () => {
const result = await handleFallback(
policyConfig,
DEFAULT_GEMINI_FLASH_MODEL,
AUTH_OAUTH,
);
expect(result).toBeNull();
expect(policyConfig.getModelAvailabilityService).not.toHaveBeenCalled();
expect(policyConfig.getFallbackModelHandler).not.toHaveBeenCalled();
});
});
});

View File

@@ -17,7 +17,17 @@ import { openBrowserSecurely } from '../utils/secure-browser-launcher.js';
import { debugLogger } from '../utils/debugLogger.js';
import { getErrorMessage } from '../utils/errors.js';
import { ModelNotFoundError } from '../utils/httpErrors.js';
import { TerminalQuotaError } from '../utils/googleQuotaErrors.js';
import {
RetryableQuotaError,
TerminalQuotaError,
} from '../utils/googleQuotaErrors.js';
import type { FallbackIntent, FallbackRecommendation } from './types.js';
import type { FailureKind } from '../availability/modelPolicy.js';
import {
buildFallbackPolicyContext,
resolvePolicyChain,
resolvePolicyAction,
} from '../availability/policyHelpers.js';
const UPGRADE_URL_PAGE = 'https://goo.gle/set-up-gemini-code-assist';
@@ -27,7 +37,21 @@ export async function handleFallback(
authType?: string,
error?: unknown,
): Promise<string | boolean | null> {
// Applicability Checks
if (config.isModelAvailabilityServiceEnabled()) {
return handlePolicyDrivenFallback(config, failedModel, authType, error);
}
return legacyHandleFallback(config, failedModel, authType, error);
}
/**
* Old fallback logic relying on hard coded strings
*/
async function legacyHandleFallback(
config: Config,
failedModel: string,
authType?: string,
error?: unknown,
): Promise<string | boolean | null> {
if (authType !== AuthType.LOGIN_WITH_GOOGLE) return null;
// Guardrail: If it's a ModelNotFoundError but NOT the preview model, do not handle it.
@@ -70,39 +94,105 @@ export async function handleFallback(
);
// Process Intent and Update State
switch (intent) {
case 'retry_always':
// If the error is non-retryable, e.g. TerminalQuota Error, trigger a regular fallback to flash.
// For all other errors, activate previewModel fallback.
if (shouldActivatePreviewFallback) {
activatePreviewModelFallbackMode(config);
} else {
activateFallbackMode(config, authType);
}
return true; // Signal retryWithBackoff to continue.
case 'retry_once':
// Just retry this time, do NOT set sticky fallback mode.
return true;
case 'stop':
activateFallbackMode(config, authType);
return false;
case 'retry_later':
return false;
case 'upgrade':
await handleUpgrade();
return false;
default:
throw new Error(
`Unexpected fallback intent received from fallbackModelHandler: "${intent}"`,
);
}
return await processIntent(
config,
intent,
failedModel,
fallbackModel,
authType,
error,
);
} catch (handlerError) {
debugLogger.error('Fallback UI handler failed:', handlerError);
console.error('Fallback UI handler failed:', handlerError);
return null;
}
}
/**
* New fallback logic using the ModelAvailabilityService
*/
async function handlePolicyDrivenFallback(
config: Config,
failedModel: string,
authType?: string,
error?: unknown,
): Promise<string | boolean | null> {
if (authType !== AuthType.LOGIN_WITH_GOOGLE) {
return null;
}
const chain = resolvePolicyChain(config);
const { failedPolicy, candidates } = buildFallbackPolicyContext(
chain,
failedModel,
);
if (!candidates.length) {
return null;
}
const availability = config.getModelAvailabilityService();
const selection = availability.selectFirstAvailable(
candidates.map((policy) => policy.model),
);
let lastResortPolicy = candidates.find((policy) => policy.isLastResort);
if (!lastResortPolicy) {
debugLogger.warn(
'No isLastResort policy found in candidates, using last candidate as fallback.',
);
lastResortPolicy = candidates[candidates.length - 1];
}
const fallbackModel = selection.selectedModel ?? lastResortPolicy.model;
const selectedPolicy =
candidates.find((policy) => policy.model === fallbackModel) ??
lastResortPolicy;
if (!fallbackModel || fallbackModel === failedModel) {
return null;
}
const failureKind = classifyFailureKind(error);
const action = resolvePolicyAction(failureKind, selectedPolicy);
if (action === 'silent') {
return processIntent(
config,
'retry_always',
failedModel,
fallbackModel,
authType,
error,
);
}
// This will be used in the future when FallbackRecommendation is passed through UI
const recommendation: FallbackRecommendation = {
...selection,
selectedModel: fallbackModel,
action,
failureKind,
failedPolicy,
selectedPolicy,
};
void recommendation;
const handler = config.getFallbackModelHandler();
if (typeof handler !== 'function') {
return null;
}
try {
const intent = await handler(failedModel, fallbackModel, error);
return await processIntent(
config,
intent,
failedModel,
fallbackModel,
authType,
);
} catch (handlerError) {
debugLogger.error('Fallback handler failed:', handlerError);
return null;
}
}
@@ -118,6 +208,49 @@ async function handleUpgrade() {
}
}
async function processIntent(
config: Config,
intent: FallbackIntent | null,
failedModel: string,
fallbackModel: string,
authType?: string,
error?: unknown,
): Promise<boolean> {
switch (intent) {
case 'retry_always':
// If the error is non-retryable, e.g. TerminalQuota Error, trigger a regular fallback to flash.
// For all other errors, activate previewModel fallback.
if (
failedModel === PREVIEW_GEMINI_MODEL &&
!(error instanceof TerminalQuotaError)
) {
activatePreviewModelFallbackMode(config);
} else {
activateFallbackMode(config, authType);
}
return true;
case 'retry_once':
return true;
case 'stop':
activateFallbackMode(config, authType);
return false;
case 'retry_later':
return false;
case 'upgrade':
await handleUpgrade();
return false;
default:
throw new Error(
`Unexpected fallback intent received from fallbackModelHandler: "${intent}"`,
);
}
}
function activateFallbackMode(config: Config, authType: string | undefined) {
if (!config.isInFallbackMode()) {
config.setFallbackMode(true);
@@ -134,3 +267,16 @@ function activatePreviewModelFallbackMode(config: Config) {
// We might want a specific event for Preview Model fallback, but for now we just set the mode.
}
}
function classifyFailureKind(error?: unknown): FailureKind {
if (error instanceof TerminalQuotaError) {
return 'terminal';
}
if (error instanceof RetryableQuotaError) {
return 'transient';
}
if (error instanceof ModelNotFoundError) {
return 'not_found';
}
return 'unknown';
}

View File

@@ -4,6 +4,13 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { ModelSelectionResult } from '../availability/modelAvailabilityService.js';
import type {
FailureKind,
FallbackAction,
ModelPolicy,
} from '../availability/modelPolicy.js';
/**
* Defines the intent returned by the UI layer during a fallback scenario.
*/
@@ -14,6 +21,13 @@ export type FallbackIntent =
| 'retry_later' // Stop the current request and do not fallback. Intend to try again later with the same model.
| 'upgrade'; // Give user an option to upgrade the tier.
export interface FallbackRecommendation extends ModelSelectionResult {
action: FallbackAction;
failureKind: FailureKind;
failedPolicy?: ModelPolicy;
selectedPolicy: ModelPolicy;
}
/**
* The interface for the handler provided by the UI layer (e.g., the CLI)
* to interact with the user during a fallback scenario.