mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-02-01 14:44:29 +00:00
Initial support for reloading extensions in the CLI - mcp servers only (#12239)
This commit is contained in:
116
integration-tests/extensions-reload.test.ts
Normal file
116
integration-tests/extensions-reload.test.ts
Normal file
@@ -0,0 +1,116 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { expect, it, describe } from 'vitest';
|
||||
import { TestRig } from './test-helper.js';
|
||||
import { TestMcpServer } from './test-mcp-server.js';
|
||||
import { writeFileSync } from 'node:fs';
|
||||
import { join } from 'node:path';
|
||||
import { safeJsonStringify } from '@google/gemini-cli-core/src/utils/safeJsonStringify.js';
|
||||
import { env } from 'node:process';
|
||||
import { platform } from 'node:os';
|
||||
|
||||
const itIf = (condition: boolean) => (condition ? it : it.skip);
|
||||
|
||||
describe('extension reloading', () => {
|
||||
const sandboxEnv = env['GEMINI_SANDBOX'];
|
||||
|
||||
// Fails in sandbox mode, can't check for local extension updates.
|
||||
itIf((!sandboxEnv || sandboxEnv === 'false') && platform() !== 'win32')(
|
||||
'installs a local extension, updates it, checks it was reloaded properly',
|
||||
async () => {
|
||||
const serverA = new TestMcpServer();
|
||||
const portA = await serverA.start({
|
||||
hello: () => ({ content: [{ type: 'text', text: 'world' }] }),
|
||||
});
|
||||
const extension = {
|
||||
name: 'test-extension',
|
||||
version: '0.0.1',
|
||||
mcpServers: {
|
||||
'test-server': {
|
||||
httpUrl: `http://localhost:${portA}/mcp`,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const rig = new TestRig();
|
||||
rig.setup('extension reload test', {
|
||||
settings: {
|
||||
experimental: { extensionReloading: true },
|
||||
},
|
||||
});
|
||||
const testServerPath = join(rig.testDir!, 'gemini-extension.json');
|
||||
writeFileSync(testServerPath, safeJsonStringify(extension, 2));
|
||||
// defensive cleanup from previous tests.
|
||||
try {
|
||||
await rig.runCommand(['extensions', 'uninstall', 'test-extension']);
|
||||
} catch {
|
||||
/* empty */
|
||||
}
|
||||
|
||||
const result = await rig.runCommand(
|
||||
['extensions', 'install', `${rig.testDir!}`],
|
||||
{ stdin: 'y\n' },
|
||||
);
|
||||
expect(result).toContain('test-extension');
|
||||
|
||||
// Now create the update, but its not installed yet
|
||||
const serverB = new TestMcpServer();
|
||||
const portB = await serverB.start({
|
||||
goodbye: () => ({ content: [{ type: 'text', text: 'world' }] }),
|
||||
});
|
||||
extension.version = '0.0.2';
|
||||
extension.mcpServers['test-server'].httpUrl =
|
||||
`http://localhost:${portB}/mcp`;
|
||||
writeFileSync(testServerPath, safeJsonStringify(extension, 2));
|
||||
|
||||
// Start the CLI.
|
||||
const run = await rig.runInteractive('--debug');
|
||||
await run.expectText('You have 1 extension with an update available');
|
||||
// See the outdated extension
|
||||
await run.sendText('/extensions list');
|
||||
await run.type('\r');
|
||||
await run.expectText(
|
||||
'test-extension (v0.0.1) - active (update available)',
|
||||
);
|
||||
await run.sendText('/mcp list');
|
||||
await run.type('\r');
|
||||
await run.expectText(
|
||||
'test-server (from test-extension) - Ready (1 tool)',
|
||||
);
|
||||
await run.expectText('- hello');
|
||||
|
||||
// Update the extension, expect the list to update, and mcp servers as well.
|
||||
await run.sendText('/extensions update test-extension');
|
||||
await run.type('\r');
|
||||
await run.expectText(
|
||||
` * test-server (remote): http://localhost:${portB}/mcp`,
|
||||
);
|
||||
await run.type('\r'); // consent
|
||||
await run.expectText(
|
||||
'Extension "test-extension" successfully updated: 0.0.1 → 0.0.2',
|
||||
);
|
||||
await new Promise((resolve) => setTimeout(resolve, 1000));
|
||||
await run.sendText('/extensions list');
|
||||
await run.type('\r');
|
||||
await run.expectText('test-extension (v0.0.2) - active (updated)');
|
||||
await run.sendText('/mcp list');
|
||||
await run.type('\r');
|
||||
await run.expectText(
|
||||
'test-server (from test-extension) - Ready (1 tool)',
|
||||
);
|
||||
await run.expectText('- goodbye');
|
||||
await run.sendText('/quit');
|
||||
await run.sendKeys('\r');
|
||||
|
||||
// Clean things up.
|
||||
await serverA.stop();
|
||||
await serverB.stop();
|
||||
await rig.runCommand(['extensions', 'uninstall', 'test-extension']);
|
||||
await rig.cleanup();
|
||||
},
|
||||
);
|
||||
});
|
||||
@@ -220,6 +220,13 @@ export class InteractiveRun {
|
||||
}
|
||||
}
|
||||
|
||||
// Types an entire string at once, necessary for some things like commands
|
||||
// but may run into paste detection issues for larger strings.
|
||||
async sendText(text: string) {
|
||||
this.ptyProcess.write(text);
|
||||
await new Promise((resolve) => setTimeout(resolve, 5));
|
||||
}
|
||||
|
||||
// Simulates typing a string one character at a time to avoid paste detection.
|
||||
async sendKeys(text: string) {
|
||||
const delay = 5;
|
||||
@@ -311,6 +318,8 @@ export class TestRig {
|
||||
model: DEFAULT_GEMINI_MODEL,
|
||||
sandbox:
|
||||
env['GEMINI_SANDBOX'] !== 'false' ? env['GEMINI_SANDBOX'] : false,
|
||||
// Don't show the IDE connection dialog when running from VsCode
|
||||
ide: { enabled: false, hasSeenNudge: true },
|
||||
...options.settings, // Allow tests to override/add settings
|
||||
};
|
||||
writeFileSync(
|
||||
|
||||
@@ -4,17 +4,21 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
|
||||
import {
|
||||
McpServer,
|
||||
type ToolCallback,
|
||||
} from '@modelcontextprotocol/sdk/server/mcp.js';
|
||||
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
|
||||
import express from 'express';
|
||||
import { type Server as HTTPServer } from 'node:http';
|
||||
|
||||
import { randomUUID } from 'node:crypto';
|
||||
import { type ZodRawShape } from 'zod';
|
||||
|
||||
export class TestMcpServer {
|
||||
private server: HTTPServer | undefined;
|
||||
|
||||
async start(): Promise<number> {
|
||||
async start(
|
||||
tools?: Record<string, ToolCallback<ZodRawShape>>,
|
||||
): Promise<number> {
|
||||
const app = express();
|
||||
app.use(express.json());
|
||||
const mcpServer = new McpServer(
|
||||
@@ -22,18 +26,30 @@ export class TestMcpServer {
|
||||
name: 'test-mcp-server',
|
||||
version: '1.0.0',
|
||||
},
|
||||
{ capabilities: {} },
|
||||
{ capabilities: { tools: {} } },
|
||||
);
|
||||
|
||||
const transport = new StreamableHTTPServerTransport({
|
||||
sessionIdGenerator: () => randomUUID(),
|
||||
});
|
||||
mcpServer.connect(transport);
|
||||
if (tools) {
|
||||
for (const [name, cb] of Object.entries(tools)) {
|
||||
mcpServer.registerTool(name, {}, cb);
|
||||
}
|
||||
}
|
||||
|
||||
app.post('/mcp', async (req, res) => {
|
||||
const transport = new StreamableHTTPServerTransport({
|
||||
sessionIdGenerator: undefined,
|
||||
enableJsonResponse: true,
|
||||
});
|
||||
res.on('close', () => {
|
||||
transport.close();
|
||||
});
|
||||
await mcpServer.connect(transport);
|
||||
await transport.handleRequest(req, res, req.body);
|
||||
});
|
||||
|
||||
app.get('/mcp', async (req, res) => {
|
||||
res.status(405).send('Not supported');
|
||||
});
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
this.server = app.listen(0, () => {
|
||||
const address = this.server!.address();
|
||||
|
||||
@@ -53,6 +53,7 @@ export function createMockConfig(
|
||||
getEnableMessageBusIntegration: vi.fn().mockReturnValue(false),
|
||||
getMessageBus: vi.fn(),
|
||||
getPolicyEngine: vi.fn(),
|
||||
getEnableExtensionReloading: vi.fn().mockReturnValue(false),
|
||||
...overrides,
|
||||
} as unknown as Config;
|
||||
|
||||
|
||||
@@ -30,11 +30,12 @@ const updateOutput = (info: ExtensionUpdateInfo) =>
|
||||
|
||||
export async function handleUpdate(args: UpdateArgs) {
|
||||
const workspaceDir = process.cwd();
|
||||
const settings = loadSettings(workspaceDir).merged;
|
||||
const extensionManager = new ExtensionManager({
|
||||
workspaceDir,
|
||||
requestConsent: requestConsentNonInteractive,
|
||||
requestSetting: promptForSetting,
|
||||
settings: loadSettings(workspaceDir).merged,
|
||||
settings,
|
||||
});
|
||||
|
||||
const extensions = await extensionManager.loadExtensions();
|
||||
@@ -67,6 +68,7 @@ export async function handleUpdate(args: UpdateArgs) {
|
||||
extensionManager,
|
||||
updateState,
|
||||
() => {},
|
||||
settings.experimental?.extensionReloading,
|
||||
))!;
|
||||
if (
|
||||
updatedExtensionInfo.originalVersion !==
|
||||
|
||||
@@ -680,6 +680,7 @@ export async function loadCliConfig(
|
||||
listExtensions: argv.listExtensions || false,
|
||||
enabledExtensions: argv.extensions,
|
||||
extensionLoader: extensionManager,
|
||||
enableExtensionReloading: settings.experimental?.extensionReloading,
|
||||
blockedMcpServers,
|
||||
noBrowser: !!process.env['NO_BROWSER'],
|
||||
summarizeToolOutput: settings.model?.summarizeToolOutput,
|
||||
|
||||
@@ -28,6 +28,7 @@ export async function updateExtension(
|
||||
extensionManager: ExtensionManager,
|
||||
currentState: ExtensionUpdateState,
|
||||
dispatchExtensionStateUpdate: (action: ExtensionUpdateAction) => void,
|
||||
enableExtensionReloading?: boolean,
|
||||
): Promise<ExtensionUpdateInfo | undefined> {
|
||||
if (currentState === ExtensionUpdateState.UPDATING) {
|
||||
return undefined;
|
||||
@@ -81,7 +82,9 @@ export async function updateExtension(
|
||||
type: 'SET_STATE',
|
||||
payload: {
|
||||
name: extension.name,
|
||||
state: ExtensionUpdateState.UPDATED_NEEDS_RESTART,
|
||||
state: enableExtensionReloading
|
||||
? ExtensionUpdateState.UPDATED
|
||||
: ExtensionUpdateState.UPDATED_NEEDS_RESTART,
|
||||
},
|
||||
});
|
||||
return {
|
||||
@@ -109,6 +112,7 @@ export async function updateAllUpdatableExtensions(
|
||||
extensionsState: Map<string, ExtensionUpdateStatus>,
|
||||
extensionManager: ExtensionManager,
|
||||
dispatch: (action: ExtensionUpdateAction) => void,
|
||||
enableExtensionReloading?: boolean,
|
||||
): Promise<ExtensionUpdateInfo[]> {
|
||||
return (
|
||||
await Promise.all(
|
||||
@@ -124,6 +128,7 @@ export async function updateAllUpdatableExtensions(
|
||||
extensionManager,
|
||||
extensionsState.get(extension.name)!.status,
|
||||
dispatch,
|
||||
enableExtensionReloading,
|
||||
),
|
||||
),
|
||||
)
|
||||
@@ -141,34 +146,37 @@ export async function checkForAllExtensionUpdates(
|
||||
dispatch: (action: ExtensionUpdateAction) => void,
|
||||
): Promise<void> {
|
||||
dispatch({ type: 'BATCH_CHECK_START' });
|
||||
const promises: Array<Promise<void>> = [];
|
||||
for (const extension of extensions) {
|
||||
if (!extension.installMetadata) {
|
||||
try {
|
||||
const promises: Array<Promise<void>> = [];
|
||||
for (const extension of extensions) {
|
||||
if (!extension.installMetadata) {
|
||||
dispatch({
|
||||
type: 'SET_STATE',
|
||||
payload: {
|
||||
name: extension.name,
|
||||
state: ExtensionUpdateState.NOT_UPDATABLE,
|
||||
},
|
||||
});
|
||||
continue;
|
||||
}
|
||||
dispatch({
|
||||
type: 'SET_STATE',
|
||||
payload: {
|
||||
name: extension.name,
|
||||
state: ExtensionUpdateState.NOT_UPDATABLE,
|
||||
state: ExtensionUpdateState.CHECKING_FOR_UPDATES,
|
||||
},
|
||||
});
|
||||
continue;
|
||||
promises.push(
|
||||
checkForExtensionUpdate(extension, extensionManager).then((state) =>
|
||||
dispatch({
|
||||
type: 'SET_STATE',
|
||||
payload: { name: extension.name, state },
|
||||
}),
|
||||
),
|
||||
);
|
||||
}
|
||||
dispatch({
|
||||
type: 'SET_STATE',
|
||||
payload: {
|
||||
name: extension.name,
|
||||
state: ExtensionUpdateState.CHECKING_FOR_UPDATES,
|
||||
},
|
||||
});
|
||||
promises.push(
|
||||
checkForExtensionUpdate(extension, extensionManager).then((state) =>
|
||||
dispatch({
|
||||
type: 'SET_STATE',
|
||||
payload: { name: extension.name, state },
|
||||
}),
|
||||
),
|
||||
);
|
||||
await Promise.all(promises);
|
||||
} finally {
|
||||
dispatch({ type: 'BATCH_CHECK_END' });
|
||||
}
|
||||
await Promise.all(promises);
|
||||
dispatch({ type: 'BATCH_CHECK_END' });
|
||||
}
|
||||
|
||||
@@ -1075,6 +1075,16 @@ const SETTINGS_SCHEMA = {
|
||||
description: 'Enable extension management features.',
|
||||
showInDialog: false,
|
||||
},
|
||||
extensionReloading: {
|
||||
type: 'boolean',
|
||||
label: 'Extension Reloading',
|
||||
category: 'Experimental',
|
||||
requiresRestart: true,
|
||||
default: false,
|
||||
description:
|
||||
'Enables extension loading/unloading within the CLI session.',
|
||||
showInDialog: false,
|
||||
},
|
||||
useModelRouter: {
|
||||
type: 'boolean',
|
||||
label: 'Use Model Router',
|
||||
|
||||
@@ -183,7 +183,11 @@ export const AppContainer = (props: AppContainerProps) => {
|
||||
extensionsUpdateState,
|
||||
extensionsUpdateStateInternal,
|
||||
dispatchExtensionStateUpdate,
|
||||
} = useExtensionUpdates(extensionManager, historyManager.addItem);
|
||||
} = useExtensionUpdates(
|
||||
extensionManager,
|
||||
historyManager.addItem,
|
||||
config.getEnableExtensionReloading(),
|
||||
);
|
||||
|
||||
const [isPermissionsDialogOpen, setPermissionsDialogOpen] = useState(false);
|
||||
const openPermissionsDialog = useCallback(
|
||||
|
||||
@@ -97,6 +97,10 @@ describe('<ExtensionsList />', () => {
|
||||
state: ExtensionUpdateState.UPDATED_NEEDS_RESTART,
|
||||
expectedText: '(updated, needs restart)',
|
||||
},
|
||||
{
|
||||
state: ExtensionUpdateState.UPDATED,
|
||||
expectedText: '(updated)',
|
||||
},
|
||||
{
|
||||
state: ExtensionUpdateState.ERROR,
|
||||
expectedText: '(error)',
|
||||
|
||||
@@ -48,6 +48,7 @@ export const ExtensionsList: React.FC<ExtensionsList> = ({ extensions }) => {
|
||||
break;
|
||||
case ExtensionUpdateState.UP_TO_DATE:
|
||||
case ExtensionUpdateState.NOT_UPDATABLE:
|
||||
case ExtensionUpdateState.UPDATED:
|
||||
stateColor = 'green';
|
||||
break;
|
||||
case undefined:
|
||||
|
||||
@@ -84,6 +84,7 @@ describe('handleAtCommand', () => {
|
||||
getReadManyFilesExcludes: () => [],
|
||||
}),
|
||||
getUsageStatisticsEnabled: () => false,
|
||||
getEnableExtensionReloading: () => false,
|
||||
} as unknown as Config;
|
||||
|
||||
const registry = new ToolRegistry(mockConfig);
|
||||
|
||||
@@ -96,7 +96,7 @@ describe('useExtensionUpdates', () => {
|
||||
);
|
||||
|
||||
function TestComponent() {
|
||||
useExtensionUpdates(extensionManager, addItem);
|
||||
useExtensionUpdates(extensionManager, addItem, false);
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -146,7 +146,7 @@ describe('useExtensionUpdates', () => {
|
||||
});
|
||||
|
||||
function TestComponent() {
|
||||
useExtensionUpdates(extensionManager, addItem);
|
||||
useExtensionUpdates(extensionManager, addItem, false);
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -224,7 +224,7 @@ describe('useExtensionUpdates', () => {
|
||||
});
|
||||
|
||||
function TestComponent() {
|
||||
useExtensionUpdates(extensionManager, addItem);
|
||||
useExtensionUpdates(extensionManager, addItem, false);
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -307,7 +307,7 @@ describe('useExtensionUpdates', () => {
|
||||
);
|
||||
|
||||
function TestComponent() {
|
||||
useExtensionUpdates(extensionManager, addItem);
|
||||
useExtensionUpdates(extensionManager, addItem, false);
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
@@ -80,6 +80,7 @@ export const useConfirmUpdateRequests = () => {
|
||||
export const useExtensionUpdates = (
|
||||
extensionManager: ExtensionManager,
|
||||
addItem: UseHistoryManagerReturn['addItem'],
|
||||
enableExtensionReloading: boolean,
|
||||
) => {
|
||||
const [extensionsUpdateState, dispatchExtensionStateUpdate] = useReducer(
|
||||
extensionUpdatesReducer,
|
||||
@@ -163,6 +164,7 @@ export const useExtensionUpdates = (
|
||||
extensionManager,
|
||||
currentState.status,
|
||||
dispatchExtensionStateUpdate,
|
||||
enableExtensionReloading,
|
||||
);
|
||||
updatePromises.push(updatePromise);
|
||||
updatePromise
|
||||
@@ -209,7 +211,13 @@ export const useExtensionUpdates = (
|
||||
});
|
||||
});
|
||||
}
|
||||
}, [extensions, extensionManager, extensionsUpdateState, addItem]);
|
||||
}, [
|
||||
extensions,
|
||||
extensionManager,
|
||||
extensionsUpdateState,
|
||||
addItem,
|
||||
enableExtensionReloading,
|
||||
]);
|
||||
|
||||
const extensionsUpdateStateComputed = useMemo(() => {
|
||||
const result = new Map<string, ExtensionUpdateState>();
|
||||
|
||||
@@ -10,6 +10,7 @@ import { checkExhaustive } from '../../utils/checks.js';
|
||||
export enum ExtensionUpdateState {
|
||||
CHECKING_FOR_UPDATES = 'checking for updates',
|
||||
UPDATED_NEEDS_RESTART = 'updated, needs restart',
|
||||
UPDATED = 'updated',
|
||||
UPDATING = 'updating',
|
||||
UPDATE_AVAILABLE = 'update available',
|
||||
UP_TO_DATE = 'up to date',
|
||||
|
||||
@@ -255,6 +255,7 @@ export interface ConfigParameters {
|
||||
listExtensions?: boolean;
|
||||
extensionLoader?: ExtensionLoader;
|
||||
enabledExtensions?: string[];
|
||||
enableExtensionReloading?: boolean;
|
||||
blockedMcpServers?: Array<{ name: string; extensionName: string }>;
|
||||
noBrowser?: boolean;
|
||||
summarizeToolOutput?: Record<string, SummarizeToolOutputSettings>;
|
||||
@@ -312,7 +313,7 @@ export class Config {
|
||||
private readonly toolDiscoveryCommand: string | undefined;
|
||||
private readonly toolCallCommand: string | undefined;
|
||||
private readonly mcpServerCommand: string | undefined;
|
||||
private readonly mcpServers: Record<string, MCPServerConfig> | undefined;
|
||||
private mcpServers: Record<string, MCPServerConfig> | undefined;
|
||||
private userMemory: string;
|
||||
private geminiMdFileCount: number;
|
||||
private geminiMdFilePaths: string[];
|
||||
@@ -346,6 +347,7 @@ export class Config {
|
||||
private readonly listExtensions: boolean;
|
||||
private readonly _extensionLoader: ExtensionLoader;
|
||||
private readonly _enabledExtensions: string[];
|
||||
private readonly enableExtensionReloading: boolean;
|
||||
private readonly _blockedMcpServers: Array<{
|
||||
name: string;
|
||||
extensionName: string;
|
||||
@@ -501,6 +503,7 @@ export class Config {
|
||||
this.enableShellOutputEfficiency =
|
||||
params.enableShellOutputEfficiency ?? true;
|
||||
this.extensionManagement = params.extensionManagement ?? true;
|
||||
this.enableExtensionReloading = params.enableExtensionReloading ?? false;
|
||||
this.storage = new Storage(this.targetDir);
|
||||
this.fakeResponses = params.fakeResponses;
|
||||
this.recordResponses = params.recordResponses;
|
||||
@@ -749,6 +752,10 @@ export class Config {
|
||||
return this.mcpServers;
|
||||
}
|
||||
|
||||
setMcpServers(mcpServers: Record<string, MCPServerConfig>): void {
|
||||
this.mcpServers = mcpServers;
|
||||
}
|
||||
|
||||
getUserMemory(): string {
|
||||
return this.userMemory;
|
||||
}
|
||||
@@ -924,6 +931,10 @@ export class Config {
|
||||
return this._enabledExtensions;
|
||||
}
|
||||
|
||||
getEnableExtensionReloading(): boolean {
|
||||
return this.enableExtensionReloading;
|
||||
}
|
||||
|
||||
getBlockedMcpServers(): Array<{ name: string; extensionName: string }> {
|
||||
return this._blockedMcpServers;
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import { McpClientManager } from './mcp-client-manager.js';
|
||||
import { McpClient } from './mcp-client.js';
|
||||
import type { ToolRegistry } from './tool-registry.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { SimpleExtensionLoader } from '../utils/extensionLoader.js';
|
||||
|
||||
vi.mock('./mcp-client.js', async () => {
|
||||
const originalModule = await vi.importActual('./mcp-client.js');
|
||||
@@ -36,17 +37,22 @@ describe('McpClientManager', () => {
|
||||
vi.mocked(McpClient).mockReturnValue(
|
||||
mockedMcpClient as unknown as McpClient,
|
||||
);
|
||||
const manager = new McpClientManager({} as ToolRegistry);
|
||||
await manager.discoverAllMcpTools({
|
||||
isTrustedFolder: () => true,
|
||||
getMcpServers: () => ({
|
||||
'test-server': {},
|
||||
}),
|
||||
getMcpServerCommand: () => '',
|
||||
getPromptRegistry: () => {},
|
||||
getDebugMode: () => false,
|
||||
getWorkspaceContext: () => {},
|
||||
} as unknown as Config);
|
||||
const manager = new McpClientManager(
|
||||
{} as ToolRegistry,
|
||||
{
|
||||
isTrustedFolder: () => true,
|
||||
getExtensionLoader: () => new SimpleExtensionLoader([]),
|
||||
getMcpServers: () => ({
|
||||
'test-server': {},
|
||||
}),
|
||||
getMcpServerCommand: () => '',
|
||||
getPromptRegistry: () => {},
|
||||
getDebugMode: () => false,
|
||||
getWorkspaceContext: () => {},
|
||||
getEnableExtensionReloading: () => false,
|
||||
} as unknown as Config,
|
||||
);
|
||||
await manager.discoverAllMcpTools();
|
||||
expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
|
||||
expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
|
||||
});
|
||||
@@ -61,17 +67,22 @@ describe('McpClientManager', () => {
|
||||
vi.mocked(McpClient).mockReturnValue(
|
||||
mockedMcpClient as unknown as McpClient,
|
||||
);
|
||||
const manager = new McpClientManager({} as ToolRegistry);
|
||||
await manager.discoverAllMcpTools({
|
||||
isTrustedFolder: () => false,
|
||||
getMcpServers: () => ({
|
||||
'test-server': {},
|
||||
}),
|
||||
getMcpServerCommand: () => '',
|
||||
getPromptRegistry: () => {},
|
||||
getDebugMode: () => false,
|
||||
getWorkspaceContext: () => {},
|
||||
} as unknown as Config);
|
||||
const manager = new McpClientManager(
|
||||
{} as ToolRegistry,
|
||||
{
|
||||
isTrustedFolder: () => false,
|
||||
getExtensionLoader: () => new SimpleExtensionLoader([]),
|
||||
getMcpServers: () => ({
|
||||
'test-server': {},
|
||||
}),
|
||||
getMcpServerCommand: () => '',
|
||||
getPromptRegistry: () => {},
|
||||
getDebugMode: () => false,
|
||||
getWorkspaceContext: () => {},
|
||||
getEnableExtensionReloading: () => false,
|
||||
} as unknown as Config,
|
||||
);
|
||||
await manager.discoverAllMcpTools();
|
||||
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
|
||||
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
@@ -4,7 +4,11 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Config } from '../config/config.js';
|
||||
import type {
|
||||
Config,
|
||||
GeminiCLIExtension,
|
||||
MCPServerConfig,
|
||||
} from '../config/config.js';
|
||||
import type { ToolRegistry } from './tool-registry.js';
|
||||
import {
|
||||
McpClient,
|
||||
@@ -14,6 +18,7 @@ import {
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import type { EventEmitter } from 'node:events';
|
||||
import { coreEvents } from '../utils/events.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
|
||||
/**
|
||||
* Manages the lifecycle of multiple MCP clients, including local child processes.
|
||||
@@ -23,12 +28,162 @@ import { coreEvents } from '../utils/events.js';
|
||||
export class McpClientManager {
|
||||
private clients: Map<string, McpClient> = new Map();
|
||||
private readonly toolRegistry: ToolRegistry;
|
||||
private readonly cliConfig: Config;
|
||||
// If we have ongoing MCP client discovery, this completes once that is done.
|
||||
private discoveryPromise: Promise<void> | undefined;
|
||||
private discoveryState: MCPDiscoveryState = MCPDiscoveryState.NOT_STARTED;
|
||||
private readonly eventEmitter?: EventEmitter;
|
||||
|
||||
constructor(toolRegistry: ToolRegistry, eventEmitter?: EventEmitter) {
|
||||
constructor(
|
||||
toolRegistry: ToolRegistry,
|
||||
cliConfig: Config,
|
||||
eventEmitter?: EventEmitter,
|
||||
) {
|
||||
this.toolRegistry = toolRegistry;
|
||||
this.cliConfig = cliConfig;
|
||||
this.eventEmitter = eventEmitter;
|
||||
if (this.cliConfig.getEnableExtensionReloading()) {
|
||||
this.cliConfig
|
||||
.getExtensionLoader()
|
||||
.extensionEvents()
|
||||
.on('extensionLoaded', (event) => this.loadExtension(event.extension))
|
||||
.on('extensionEnabled', (event) => this.loadExtension(event.extension))
|
||||
.on('extensionDisabled', (event) =>
|
||||
this.unloadExtension(event.extension),
|
||||
)
|
||||
.on('extensionUnloaded', (event) =>
|
||||
this.unloadExtension(event.extension),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* For all the MCP servers associated with this extension:
|
||||
*
|
||||
* - Removes all its MCP servers from the global configuration object.
|
||||
* - Disconnects all MCP clients from their servers.
|
||||
* - Updates the Gemini chat configuration to load the new tools.
|
||||
*/
|
||||
private async unloadExtension(extension: GeminiCLIExtension) {
|
||||
debugLogger.log(`Unloading extension: ${extension.name}`);
|
||||
await Promise.all(
|
||||
Object.keys(extension.mcpServers ?? {}).map((name) => {
|
||||
const newMcpServers = {
|
||||
...this.cliConfig.getMcpServers(),
|
||||
};
|
||||
delete newMcpServers[name];
|
||||
this.cliConfig.setMcpServers(newMcpServers);
|
||||
return this.disconnectClient(name);
|
||||
}),
|
||||
);
|
||||
// This is required to update the content generator configuration with the
|
||||
// new tool configuration.
|
||||
this.cliConfig.getGeminiClient().setTools();
|
||||
}
|
||||
|
||||
/**
|
||||
* For all the MCP servers associated with this extension:
|
||||
*
|
||||
* - Adds all its MCP servers to the global configuration object.
|
||||
* - Connects MCP clients to each server and discovers their tools.
|
||||
* - Updates the Gemini chat configuration to load the new tools.
|
||||
*/
|
||||
private async loadExtension(extension: GeminiCLIExtension) {
|
||||
debugLogger.log(`Loading extension: ${extension.name}`);
|
||||
await Promise.all(
|
||||
Object.entries(extension.mcpServers ?? {}).map(([name, config]) => {
|
||||
this.cliConfig.setMcpServers({
|
||||
...this.cliConfig.getMcpServers(),
|
||||
[name]: config,
|
||||
});
|
||||
return this.discoverMcpTools(name, config);
|
||||
}),
|
||||
);
|
||||
// This is required to update the content generator configuration with the
|
||||
// new tool configuration.
|
||||
this.cliConfig.getGeminiClient().setTools();
|
||||
}
|
||||
|
||||
private async disconnectClient(name: string) {
|
||||
const existing = this.clients.get(name);
|
||||
if (existing) {
|
||||
try {
|
||||
this.clients.delete(name);
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
await existing.disconnect();
|
||||
} catch (error) {
|
||||
debugLogger.warn(
|
||||
`Error stopping client '${name}': ${getErrorMessage(error)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
discoverMcpTools(
|
||||
name: string,
|
||||
config: MCPServerConfig,
|
||||
): Promise<void> | void {
|
||||
if (!this.cliConfig.isTrustedFolder()) {
|
||||
return;
|
||||
}
|
||||
if (config.extension && !config.extension.isActive) {
|
||||
return;
|
||||
}
|
||||
|
||||
const currentDiscoveryPromise = new Promise<void>((resolve, _reject) => {
|
||||
(async () => {
|
||||
try {
|
||||
await this.disconnectClient(name);
|
||||
|
||||
const client = new McpClient(
|
||||
name,
|
||||
config,
|
||||
this.toolRegistry,
|
||||
this.cliConfig.getPromptRegistry(),
|
||||
this.cliConfig.getWorkspaceContext(),
|
||||
this.cliConfig.getDebugMode(),
|
||||
);
|
||||
this.clients.set(name, client);
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
try {
|
||||
await client.connect();
|
||||
await client.discover(this.cliConfig);
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
} catch (error) {
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
// Log the error but don't let a single failed server stop the others
|
||||
coreEvents.emitFeedback(
|
||||
'error',
|
||||
`Error during discovery for server '${name}': ${getErrorMessage(
|
||||
error,
|
||||
)}`,
|
||||
error,
|
||||
);
|
||||
}
|
||||
} finally {
|
||||
resolve();
|
||||
}
|
||||
})();
|
||||
});
|
||||
|
||||
if (this.discoveryPromise) {
|
||||
this.discoveryPromise = this.discoveryPromise.then(
|
||||
() => currentDiscoveryPromise,
|
||||
);
|
||||
} else {
|
||||
this.discoveryState = MCPDiscoveryState.IN_PROGRESS;
|
||||
this.discoveryPromise = currentDiscoveryPromise;
|
||||
}
|
||||
const currentPromise = this.discoveryPromise;
|
||||
currentPromise.then((_) => {
|
||||
// If we are the last recorded discoveryPromise, then we are done, reset
|
||||
// the world.
|
||||
if (currentPromise === this.discoveryPromise) {
|
||||
this.discoveryPromise = undefined;
|
||||
this.discoveryState = MCPDiscoveryState.COMPLETED;
|
||||
}
|
||||
});
|
||||
return currentPromise;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -36,53 +191,23 @@ export class McpClientManager {
|
||||
* It connects to each server, discovers its available tools, and registers
|
||||
* them with the `ToolRegistry`.
|
||||
*/
|
||||
async discoverAllMcpTools(cliConfig: Config): Promise<void> {
|
||||
if (!cliConfig.isTrustedFolder()) {
|
||||
async discoverAllMcpTools(): Promise<void> {
|
||||
if (!this.cliConfig.isTrustedFolder()) {
|
||||
return;
|
||||
}
|
||||
await this.stop();
|
||||
|
||||
const servers = populateMcpServerCommand(
|
||||
cliConfig.getMcpServers() || {},
|
||||
cliConfig.getMcpServerCommand(),
|
||||
this.cliConfig.getMcpServers() || {},
|
||||
this.cliConfig.getMcpServerCommand(),
|
||||
);
|
||||
|
||||
this.discoveryState = MCPDiscoveryState.IN_PROGRESS;
|
||||
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
const discoveryPromises = Object.entries(servers)
|
||||
.filter(([_, config]) => !config.extension || config.extension.isActive)
|
||||
.map(async ([name, config]) => {
|
||||
const client = new McpClient(
|
||||
name,
|
||||
config,
|
||||
this.toolRegistry,
|
||||
cliConfig.getPromptRegistry(),
|
||||
cliConfig.getWorkspaceContext(),
|
||||
cliConfig.getDebugMode(),
|
||||
);
|
||||
this.clients.set(name, client);
|
||||
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
try {
|
||||
await client.connect();
|
||||
await client.discover(cliConfig);
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
} catch (error) {
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
// Log the error but don't let a single failed server stop the others
|
||||
coreEvents.emitFeedback(
|
||||
'error',
|
||||
`Error during discovery for server '${name}': ${getErrorMessage(
|
||||
error,
|
||||
)}`,
|
||||
error,
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
await Promise.all(discoveryPromises);
|
||||
this.discoveryState = MCPDiscoveryState.COMPLETED;
|
||||
await Promise.all(
|
||||
Object.entries(servers).map(async ([name, config]) =>
|
||||
this.discoverMcpTools(name, config),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -160,6 +160,7 @@ export class McpClient {
|
||||
if (this.status !== MCPServerStatus.CONNECTED) {
|
||||
return;
|
||||
}
|
||||
this.toolRegistry.removeMcpToolsByServer(this.serverName);
|
||||
this.updateStatus(MCPServerStatus.DISCONNECTING);
|
||||
const client = this.client;
|
||||
this.client = undefined;
|
||||
|
||||
@@ -181,7 +181,7 @@ export class ToolRegistry {
|
||||
|
||||
constructor(config: Config, eventEmitter?: EventEmitter) {
|
||||
this.config = config;
|
||||
this.mcpClientManager = new McpClientManager(this, eventEmitter);
|
||||
this.mcpClientManager = new McpClientManager(this, config, eventEmitter);
|
||||
}
|
||||
|
||||
setMessageBus(messageBus: MessageBus): void {
|
||||
@@ -244,7 +244,7 @@ export class ToolRegistry {
|
||||
await this.discoverAndRegisterToolsFromCommand();
|
||||
|
||||
// discover tools using MCP servers, if configured
|
||||
await this.mcpClientManager.discoverAllMcpTools(this.config);
|
||||
await this.mcpClientManager.discoverAllMcpTools();
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -259,7 +259,7 @@ export class ToolRegistry {
|
||||
this.config.getPromptRegistry().clear();
|
||||
|
||||
// discover tools using MCP servers, if configured
|
||||
await this.mcpClientManager.discoverAllMcpTools(this.config);
|
||||
await this.mcpClientManager.discoverAllMcpTools();
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user