Initial support for reloading extensions in the CLI - mcp servers only (#12239)

This commit is contained in:
Jacob MacDonald
2025-10-30 11:05:49 -07:00
committed by GitHub
parent d4cad0cdcc
commit cc081337b7
20 changed files with 437 additions and 107 deletions

View File

@@ -0,0 +1,116 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { expect, it, describe } from 'vitest';
import { TestRig } from './test-helper.js';
import { TestMcpServer } from './test-mcp-server.js';
import { writeFileSync } from 'node:fs';
import { join } from 'node:path';
import { safeJsonStringify } from '@google/gemini-cli-core/src/utils/safeJsonStringify.js';
import { env } from 'node:process';
import { platform } from 'node:os';
const itIf = (condition: boolean) => (condition ? it : it.skip);
describe('extension reloading', () => {
const sandboxEnv = env['GEMINI_SANDBOX'];
// Fails in sandbox mode, can't check for local extension updates.
itIf((!sandboxEnv || sandboxEnv === 'false') && platform() !== 'win32')(
'installs a local extension, updates it, checks it was reloaded properly',
async () => {
const serverA = new TestMcpServer();
const portA = await serverA.start({
hello: () => ({ content: [{ type: 'text', text: 'world' }] }),
});
const extension = {
name: 'test-extension',
version: '0.0.1',
mcpServers: {
'test-server': {
httpUrl: `http://localhost:${portA}/mcp`,
},
},
};
const rig = new TestRig();
rig.setup('extension reload test', {
settings: {
experimental: { extensionReloading: true },
},
});
const testServerPath = join(rig.testDir!, 'gemini-extension.json');
writeFileSync(testServerPath, safeJsonStringify(extension, 2));
// defensive cleanup from previous tests.
try {
await rig.runCommand(['extensions', 'uninstall', 'test-extension']);
} catch {
/* empty */
}
const result = await rig.runCommand(
['extensions', 'install', `${rig.testDir!}`],
{ stdin: 'y\n' },
);
expect(result).toContain('test-extension');
// Now create the update, but its not installed yet
const serverB = new TestMcpServer();
const portB = await serverB.start({
goodbye: () => ({ content: [{ type: 'text', text: 'world' }] }),
});
extension.version = '0.0.2';
extension.mcpServers['test-server'].httpUrl =
`http://localhost:${portB}/mcp`;
writeFileSync(testServerPath, safeJsonStringify(extension, 2));
// Start the CLI.
const run = await rig.runInteractive('--debug');
await run.expectText('You have 1 extension with an update available');
// See the outdated extension
await run.sendText('/extensions list');
await run.type('\r');
await run.expectText(
'test-extension (v0.0.1) - active (update available)',
);
await run.sendText('/mcp list');
await run.type('\r');
await run.expectText(
'test-server (from test-extension) - Ready (1 tool)',
);
await run.expectText('- hello');
// Update the extension, expect the list to update, and mcp servers as well.
await run.sendText('/extensions update test-extension');
await run.type('\r');
await run.expectText(
` * test-server (remote): http://localhost:${portB}/mcp`,
);
await run.type('\r'); // consent
await run.expectText(
'Extension "test-extension" successfully updated: 0.0.1 → 0.0.2',
);
await new Promise((resolve) => setTimeout(resolve, 1000));
await run.sendText('/extensions list');
await run.type('\r');
await run.expectText('test-extension (v0.0.2) - active (updated)');
await run.sendText('/mcp list');
await run.type('\r');
await run.expectText(
'test-server (from test-extension) - Ready (1 tool)',
);
await run.expectText('- goodbye');
await run.sendText('/quit');
await run.sendKeys('\r');
// Clean things up.
await serverA.stop();
await serverB.stop();
await rig.runCommand(['extensions', 'uninstall', 'test-extension']);
await rig.cleanup();
},
);
});

View File

@@ -220,6 +220,13 @@ export class InteractiveRun {
}
}
// Types an entire string at once, necessary for some things like commands
// but may run into paste detection issues for larger strings.
async sendText(text: string) {
this.ptyProcess.write(text);
await new Promise((resolve) => setTimeout(resolve, 5));
}
// Simulates typing a string one character at a time to avoid paste detection.
async sendKeys(text: string) {
const delay = 5;
@@ -311,6 +318,8 @@ export class TestRig {
model: DEFAULT_GEMINI_MODEL,
sandbox:
env['GEMINI_SANDBOX'] !== 'false' ? env['GEMINI_SANDBOX'] : false,
// Don't show the IDE connection dialog when running from VsCode
ide: { enabled: false, hasSeenNudge: true },
...options.settings, // Allow tests to override/add settings
};
writeFileSync(

View File

@@ -4,17 +4,21 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
import {
McpServer,
type ToolCallback,
} from '@modelcontextprotocol/sdk/server/mcp.js';
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
import express from 'express';
import { type Server as HTTPServer } from 'node:http';
import { randomUUID } from 'node:crypto';
import { type ZodRawShape } from 'zod';
export class TestMcpServer {
private server: HTTPServer | undefined;
async start(): Promise<number> {
async start(
tools?: Record<string, ToolCallback<ZodRawShape>>,
): Promise<number> {
const app = express();
app.use(express.json());
const mcpServer = new McpServer(
@@ -22,18 +26,30 @@ export class TestMcpServer {
name: 'test-mcp-server',
version: '1.0.0',
},
{ capabilities: {} },
{ capabilities: { tools: {} } },
);
const transport = new StreamableHTTPServerTransport({
sessionIdGenerator: () => randomUUID(),
});
mcpServer.connect(transport);
if (tools) {
for (const [name, cb] of Object.entries(tools)) {
mcpServer.registerTool(name, {}, cb);
}
}
app.post('/mcp', async (req, res) => {
const transport = new StreamableHTTPServerTransport({
sessionIdGenerator: undefined,
enableJsonResponse: true,
});
res.on('close', () => {
transport.close();
});
await mcpServer.connect(transport);
await transport.handleRequest(req, res, req.body);
});
app.get('/mcp', async (req, res) => {
res.status(405).send('Not supported');
});
return new Promise((resolve, reject) => {
this.server = app.listen(0, () => {
const address = this.server!.address();

View File

@@ -53,6 +53,7 @@ export function createMockConfig(
getEnableMessageBusIntegration: vi.fn().mockReturnValue(false),
getMessageBus: vi.fn(),
getPolicyEngine: vi.fn(),
getEnableExtensionReloading: vi.fn().mockReturnValue(false),
...overrides,
} as unknown as Config;

View File

@@ -30,11 +30,12 @@ const updateOutput = (info: ExtensionUpdateInfo) =>
export async function handleUpdate(args: UpdateArgs) {
const workspaceDir = process.cwd();
const settings = loadSettings(workspaceDir).merged;
const extensionManager = new ExtensionManager({
workspaceDir,
requestConsent: requestConsentNonInteractive,
requestSetting: promptForSetting,
settings: loadSettings(workspaceDir).merged,
settings,
});
const extensions = await extensionManager.loadExtensions();
@@ -67,6 +68,7 @@ export async function handleUpdate(args: UpdateArgs) {
extensionManager,
updateState,
() => {},
settings.experimental?.extensionReloading,
))!;
if (
updatedExtensionInfo.originalVersion !==

View File

@@ -680,6 +680,7 @@ export async function loadCliConfig(
listExtensions: argv.listExtensions || false,
enabledExtensions: argv.extensions,
extensionLoader: extensionManager,
enableExtensionReloading: settings.experimental?.extensionReloading,
blockedMcpServers,
noBrowser: !!process.env['NO_BROWSER'],
summarizeToolOutput: settings.model?.summarizeToolOutput,

View File

@@ -28,6 +28,7 @@ export async function updateExtension(
extensionManager: ExtensionManager,
currentState: ExtensionUpdateState,
dispatchExtensionStateUpdate: (action: ExtensionUpdateAction) => void,
enableExtensionReloading?: boolean,
): Promise<ExtensionUpdateInfo | undefined> {
if (currentState === ExtensionUpdateState.UPDATING) {
return undefined;
@@ -81,7 +82,9 @@ export async function updateExtension(
type: 'SET_STATE',
payload: {
name: extension.name,
state: ExtensionUpdateState.UPDATED_NEEDS_RESTART,
state: enableExtensionReloading
? ExtensionUpdateState.UPDATED
: ExtensionUpdateState.UPDATED_NEEDS_RESTART,
},
});
return {
@@ -109,6 +112,7 @@ export async function updateAllUpdatableExtensions(
extensionsState: Map<string, ExtensionUpdateStatus>,
extensionManager: ExtensionManager,
dispatch: (action: ExtensionUpdateAction) => void,
enableExtensionReloading?: boolean,
): Promise<ExtensionUpdateInfo[]> {
return (
await Promise.all(
@@ -124,6 +128,7 @@ export async function updateAllUpdatableExtensions(
extensionManager,
extensionsState.get(extension.name)!.status,
dispatch,
enableExtensionReloading,
),
),
)
@@ -141,34 +146,37 @@ export async function checkForAllExtensionUpdates(
dispatch: (action: ExtensionUpdateAction) => void,
): Promise<void> {
dispatch({ type: 'BATCH_CHECK_START' });
const promises: Array<Promise<void>> = [];
for (const extension of extensions) {
if (!extension.installMetadata) {
try {
const promises: Array<Promise<void>> = [];
for (const extension of extensions) {
if (!extension.installMetadata) {
dispatch({
type: 'SET_STATE',
payload: {
name: extension.name,
state: ExtensionUpdateState.NOT_UPDATABLE,
},
});
continue;
}
dispatch({
type: 'SET_STATE',
payload: {
name: extension.name,
state: ExtensionUpdateState.NOT_UPDATABLE,
state: ExtensionUpdateState.CHECKING_FOR_UPDATES,
},
});
continue;
promises.push(
checkForExtensionUpdate(extension, extensionManager).then((state) =>
dispatch({
type: 'SET_STATE',
payload: { name: extension.name, state },
}),
),
);
}
dispatch({
type: 'SET_STATE',
payload: {
name: extension.name,
state: ExtensionUpdateState.CHECKING_FOR_UPDATES,
},
});
promises.push(
checkForExtensionUpdate(extension, extensionManager).then((state) =>
dispatch({
type: 'SET_STATE',
payload: { name: extension.name, state },
}),
),
);
await Promise.all(promises);
} finally {
dispatch({ type: 'BATCH_CHECK_END' });
}
await Promise.all(promises);
dispatch({ type: 'BATCH_CHECK_END' });
}

View File

@@ -1075,6 +1075,16 @@ const SETTINGS_SCHEMA = {
description: 'Enable extension management features.',
showInDialog: false,
},
extensionReloading: {
type: 'boolean',
label: 'Extension Reloading',
category: 'Experimental',
requiresRestart: true,
default: false,
description:
'Enables extension loading/unloading within the CLI session.',
showInDialog: false,
},
useModelRouter: {
type: 'boolean',
label: 'Use Model Router',

View File

@@ -183,7 +183,11 @@ export const AppContainer = (props: AppContainerProps) => {
extensionsUpdateState,
extensionsUpdateStateInternal,
dispatchExtensionStateUpdate,
} = useExtensionUpdates(extensionManager, historyManager.addItem);
} = useExtensionUpdates(
extensionManager,
historyManager.addItem,
config.getEnableExtensionReloading(),
);
const [isPermissionsDialogOpen, setPermissionsDialogOpen] = useState(false);
const openPermissionsDialog = useCallback(

View File

@@ -97,6 +97,10 @@ describe('<ExtensionsList />', () => {
state: ExtensionUpdateState.UPDATED_NEEDS_RESTART,
expectedText: '(updated, needs restart)',
},
{
state: ExtensionUpdateState.UPDATED,
expectedText: '(updated)',
},
{
state: ExtensionUpdateState.ERROR,
expectedText: '(error)',

View File

@@ -48,6 +48,7 @@ export const ExtensionsList: React.FC<ExtensionsList> = ({ extensions }) => {
break;
case ExtensionUpdateState.UP_TO_DATE:
case ExtensionUpdateState.NOT_UPDATABLE:
case ExtensionUpdateState.UPDATED:
stateColor = 'green';
break;
case undefined:

View File

@@ -84,6 +84,7 @@ describe('handleAtCommand', () => {
getReadManyFilesExcludes: () => [],
}),
getUsageStatisticsEnabled: () => false,
getEnableExtensionReloading: () => false,
} as unknown as Config;
const registry = new ToolRegistry(mockConfig);

View File

@@ -96,7 +96,7 @@ describe('useExtensionUpdates', () => {
);
function TestComponent() {
useExtensionUpdates(extensionManager, addItem);
useExtensionUpdates(extensionManager, addItem, false);
return null;
}
@@ -146,7 +146,7 @@ describe('useExtensionUpdates', () => {
});
function TestComponent() {
useExtensionUpdates(extensionManager, addItem);
useExtensionUpdates(extensionManager, addItem, false);
return null;
}
@@ -224,7 +224,7 @@ describe('useExtensionUpdates', () => {
});
function TestComponent() {
useExtensionUpdates(extensionManager, addItem);
useExtensionUpdates(extensionManager, addItem, false);
return null;
}
@@ -307,7 +307,7 @@ describe('useExtensionUpdates', () => {
);
function TestComponent() {
useExtensionUpdates(extensionManager, addItem);
useExtensionUpdates(extensionManager, addItem, false);
return null;
}

View File

@@ -80,6 +80,7 @@ export const useConfirmUpdateRequests = () => {
export const useExtensionUpdates = (
extensionManager: ExtensionManager,
addItem: UseHistoryManagerReturn['addItem'],
enableExtensionReloading: boolean,
) => {
const [extensionsUpdateState, dispatchExtensionStateUpdate] = useReducer(
extensionUpdatesReducer,
@@ -163,6 +164,7 @@ export const useExtensionUpdates = (
extensionManager,
currentState.status,
dispatchExtensionStateUpdate,
enableExtensionReloading,
);
updatePromises.push(updatePromise);
updatePromise
@@ -209,7 +211,13 @@ export const useExtensionUpdates = (
});
});
}
}, [extensions, extensionManager, extensionsUpdateState, addItem]);
}, [
extensions,
extensionManager,
extensionsUpdateState,
addItem,
enableExtensionReloading,
]);
const extensionsUpdateStateComputed = useMemo(() => {
const result = new Map<string, ExtensionUpdateState>();

View File

@@ -10,6 +10,7 @@ import { checkExhaustive } from '../../utils/checks.js';
export enum ExtensionUpdateState {
CHECKING_FOR_UPDATES = 'checking for updates',
UPDATED_NEEDS_RESTART = 'updated, needs restart',
UPDATED = 'updated',
UPDATING = 'updating',
UPDATE_AVAILABLE = 'update available',
UP_TO_DATE = 'up to date',

View File

@@ -255,6 +255,7 @@ export interface ConfigParameters {
listExtensions?: boolean;
extensionLoader?: ExtensionLoader;
enabledExtensions?: string[];
enableExtensionReloading?: boolean;
blockedMcpServers?: Array<{ name: string; extensionName: string }>;
noBrowser?: boolean;
summarizeToolOutput?: Record<string, SummarizeToolOutputSettings>;
@@ -312,7 +313,7 @@ export class Config {
private readonly toolDiscoveryCommand: string | undefined;
private readonly toolCallCommand: string | undefined;
private readonly mcpServerCommand: string | undefined;
private readonly mcpServers: Record<string, MCPServerConfig> | undefined;
private mcpServers: Record<string, MCPServerConfig> | undefined;
private userMemory: string;
private geminiMdFileCount: number;
private geminiMdFilePaths: string[];
@@ -346,6 +347,7 @@ export class Config {
private readonly listExtensions: boolean;
private readonly _extensionLoader: ExtensionLoader;
private readonly _enabledExtensions: string[];
private readonly enableExtensionReloading: boolean;
private readonly _blockedMcpServers: Array<{
name: string;
extensionName: string;
@@ -501,6 +503,7 @@ export class Config {
this.enableShellOutputEfficiency =
params.enableShellOutputEfficiency ?? true;
this.extensionManagement = params.extensionManagement ?? true;
this.enableExtensionReloading = params.enableExtensionReloading ?? false;
this.storage = new Storage(this.targetDir);
this.fakeResponses = params.fakeResponses;
this.recordResponses = params.recordResponses;
@@ -749,6 +752,10 @@ export class Config {
return this.mcpServers;
}
setMcpServers(mcpServers: Record<string, MCPServerConfig>): void {
this.mcpServers = mcpServers;
}
getUserMemory(): string {
return this.userMemory;
}
@@ -924,6 +931,10 @@ export class Config {
return this._enabledExtensions;
}
getEnableExtensionReloading(): boolean {
return this.enableExtensionReloading;
}
getBlockedMcpServers(): Array<{ name: string; extensionName: string }> {
return this._blockedMcpServers;
}

View File

@@ -9,6 +9,7 @@ import { McpClientManager } from './mcp-client-manager.js';
import { McpClient } from './mcp-client.js';
import type { ToolRegistry } from './tool-registry.js';
import type { Config } from '../config/config.js';
import { SimpleExtensionLoader } from '../utils/extensionLoader.js';
vi.mock('./mcp-client.js', async () => {
const originalModule = await vi.importActual('./mcp-client.js');
@@ -36,17 +37,22 @@ describe('McpClientManager', () => {
vi.mocked(McpClient).mockReturnValue(
mockedMcpClient as unknown as McpClient,
);
const manager = new McpClientManager({} as ToolRegistry);
await manager.discoverAllMcpTools({
isTrustedFolder: () => true,
getMcpServers: () => ({
'test-server': {},
}),
getMcpServerCommand: () => '',
getPromptRegistry: () => {},
getDebugMode: () => false,
getWorkspaceContext: () => {},
} as unknown as Config);
const manager = new McpClientManager(
{} as ToolRegistry,
{
isTrustedFolder: () => true,
getExtensionLoader: () => new SimpleExtensionLoader([]),
getMcpServers: () => ({
'test-server': {},
}),
getMcpServerCommand: () => '',
getPromptRegistry: () => {},
getDebugMode: () => false,
getWorkspaceContext: () => {},
getEnableExtensionReloading: () => false,
} as unknown as Config,
);
await manager.discoverAllMcpTools();
expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
});
@@ -61,17 +67,22 @@ describe('McpClientManager', () => {
vi.mocked(McpClient).mockReturnValue(
mockedMcpClient as unknown as McpClient,
);
const manager = new McpClientManager({} as ToolRegistry);
await manager.discoverAllMcpTools({
isTrustedFolder: () => false,
getMcpServers: () => ({
'test-server': {},
}),
getMcpServerCommand: () => '',
getPromptRegistry: () => {},
getDebugMode: () => false,
getWorkspaceContext: () => {},
} as unknown as Config);
const manager = new McpClientManager(
{} as ToolRegistry,
{
isTrustedFolder: () => false,
getExtensionLoader: () => new SimpleExtensionLoader([]),
getMcpServers: () => ({
'test-server': {},
}),
getMcpServerCommand: () => '',
getPromptRegistry: () => {},
getDebugMode: () => false,
getWorkspaceContext: () => {},
getEnableExtensionReloading: () => false,
} as unknown as Config,
);
await manager.discoverAllMcpTools();
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
});

View File

@@ -4,7 +4,11 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { Config } from '../config/config.js';
import type {
Config,
GeminiCLIExtension,
MCPServerConfig,
} from '../config/config.js';
import type { ToolRegistry } from './tool-registry.js';
import {
McpClient,
@@ -14,6 +18,7 @@ import {
import { getErrorMessage } from '../utils/errors.js';
import type { EventEmitter } from 'node:events';
import { coreEvents } from '../utils/events.js';
import { debugLogger } from '../utils/debugLogger.js';
/**
* Manages the lifecycle of multiple MCP clients, including local child processes.
@@ -23,12 +28,162 @@ import { coreEvents } from '../utils/events.js';
export class McpClientManager {
private clients: Map<string, McpClient> = new Map();
private readonly toolRegistry: ToolRegistry;
private readonly cliConfig: Config;
// If we have ongoing MCP client discovery, this completes once that is done.
private discoveryPromise: Promise<void> | undefined;
private discoveryState: MCPDiscoveryState = MCPDiscoveryState.NOT_STARTED;
private readonly eventEmitter?: EventEmitter;
constructor(toolRegistry: ToolRegistry, eventEmitter?: EventEmitter) {
constructor(
toolRegistry: ToolRegistry,
cliConfig: Config,
eventEmitter?: EventEmitter,
) {
this.toolRegistry = toolRegistry;
this.cliConfig = cliConfig;
this.eventEmitter = eventEmitter;
if (this.cliConfig.getEnableExtensionReloading()) {
this.cliConfig
.getExtensionLoader()
.extensionEvents()
.on('extensionLoaded', (event) => this.loadExtension(event.extension))
.on('extensionEnabled', (event) => this.loadExtension(event.extension))
.on('extensionDisabled', (event) =>
this.unloadExtension(event.extension),
)
.on('extensionUnloaded', (event) =>
this.unloadExtension(event.extension),
);
}
}
/**
* For all the MCP servers associated with this extension:
*
* - Removes all its MCP servers from the global configuration object.
* - Disconnects all MCP clients from their servers.
* - Updates the Gemini chat configuration to load the new tools.
*/
private async unloadExtension(extension: GeminiCLIExtension) {
debugLogger.log(`Unloading extension: ${extension.name}`);
await Promise.all(
Object.keys(extension.mcpServers ?? {}).map((name) => {
const newMcpServers = {
...this.cliConfig.getMcpServers(),
};
delete newMcpServers[name];
this.cliConfig.setMcpServers(newMcpServers);
return this.disconnectClient(name);
}),
);
// This is required to update the content generator configuration with the
// new tool configuration.
this.cliConfig.getGeminiClient().setTools();
}
/**
* For all the MCP servers associated with this extension:
*
* - Adds all its MCP servers to the global configuration object.
* - Connects MCP clients to each server and discovers their tools.
* - Updates the Gemini chat configuration to load the new tools.
*/
private async loadExtension(extension: GeminiCLIExtension) {
debugLogger.log(`Loading extension: ${extension.name}`);
await Promise.all(
Object.entries(extension.mcpServers ?? {}).map(([name, config]) => {
this.cliConfig.setMcpServers({
...this.cliConfig.getMcpServers(),
[name]: config,
});
return this.discoverMcpTools(name, config);
}),
);
// This is required to update the content generator configuration with the
// new tool configuration.
this.cliConfig.getGeminiClient().setTools();
}
private async disconnectClient(name: string) {
const existing = this.clients.get(name);
if (existing) {
try {
this.clients.delete(name);
this.eventEmitter?.emit('mcp-client-update', this.clients);
await existing.disconnect();
} catch (error) {
debugLogger.warn(
`Error stopping client '${name}': ${getErrorMessage(error)}`,
);
}
}
}
discoverMcpTools(
name: string,
config: MCPServerConfig,
): Promise<void> | void {
if (!this.cliConfig.isTrustedFolder()) {
return;
}
if (config.extension && !config.extension.isActive) {
return;
}
const currentDiscoveryPromise = new Promise<void>((resolve, _reject) => {
(async () => {
try {
await this.disconnectClient(name);
const client = new McpClient(
name,
config,
this.toolRegistry,
this.cliConfig.getPromptRegistry(),
this.cliConfig.getWorkspaceContext(),
this.cliConfig.getDebugMode(),
);
this.clients.set(name, client);
this.eventEmitter?.emit('mcp-client-update', this.clients);
try {
await client.connect();
await client.discover(this.cliConfig);
this.eventEmitter?.emit('mcp-client-update', this.clients);
} catch (error) {
this.eventEmitter?.emit('mcp-client-update', this.clients);
// Log the error but don't let a single failed server stop the others
coreEvents.emitFeedback(
'error',
`Error during discovery for server '${name}': ${getErrorMessage(
error,
)}`,
error,
);
}
} finally {
resolve();
}
})();
});
if (this.discoveryPromise) {
this.discoveryPromise = this.discoveryPromise.then(
() => currentDiscoveryPromise,
);
} else {
this.discoveryState = MCPDiscoveryState.IN_PROGRESS;
this.discoveryPromise = currentDiscoveryPromise;
}
const currentPromise = this.discoveryPromise;
currentPromise.then((_) => {
// If we are the last recorded discoveryPromise, then we are done, reset
// the world.
if (currentPromise === this.discoveryPromise) {
this.discoveryPromise = undefined;
this.discoveryState = MCPDiscoveryState.COMPLETED;
}
});
return currentPromise;
}
/**
@@ -36,53 +191,23 @@ export class McpClientManager {
* It connects to each server, discovers its available tools, and registers
* them with the `ToolRegistry`.
*/
async discoverAllMcpTools(cliConfig: Config): Promise<void> {
if (!cliConfig.isTrustedFolder()) {
async discoverAllMcpTools(): Promise<void> {
if (!this.cliConfig.isTrustedFolder()) {
return;
}
await this.stop();
const servers = populateMcpServerCommand(
cliConfig.getMcpServers() || {},
cliConfig.getMcpServerCommand(),
this.cliConfig.getMcpServers() || {},
this.cliConfig.getMcpServerCommand(),
);
this.discoveryState = MCPDiscoveryState.IN_PROGRESS;
this.eventEmitter?.emit('mcp-client-update', this.clients);
const discoveryPromises = Object.entries(servers)
.filter(([_, config]) => !config.extension || config.extension.isActive)
.map(async ([name, config]) => {
const client = new McpClient(
name,
config,
this.toolRegistry,
cliConfig.getPromptRegistry(),
cliConfig.getWorkspaceContext(),
cliConfig.getDebugMode(),
);
this.clients.set(name, client);
this.eventEmitter?.emit('mcp-client-update', this.clients);
try {
await client.connect();
await client.discover(cliConfig);
this.eventEmitter?.emit('mcp-client-update', this.clients);
} catch (error) {
this.eventEmitter?.emit('mcp-client-update', this.clients);
// Log the error but don't let a single failed server stop the others
coreEvents.emitFeedback(
'error',
`Error during discovery for server '${name}': ${getErrorMessage(
error,
)}`,
error,
);
}
});
await Promise.all(discoveryPromises);
this.discoveryState = MCPDiscoveryState.COMPLETED;
await Promise.all(
Object.entries(servers).map(async ([name, config]) =>
this.discoverMcpTools(name, config),
),
);
}
/**

View File

@@ -160,6 +160,7 @@ export class McpClient {
if (this.status !== MCPServerStatus.CONNECTED) {
return;
}
this.toolRegistry.removeMcpToolsByServer(this.serverName);
this.updateStatus(MCPServerStatus.DISCONNECTING);
const client = this.client;
this.client = undefined;

View File

@@ -181,7 +181,7 @@ export class ToolRegistry {
constructor(config: Config, eventEmitter?: EventEmitter) {
this.config = config;
this.mcpClientManager = new McpClientManager(this, eventEmitter);
this.mcpClientManager = new McpClientManager(this, config, eventEmitter);
}
setMessageBus(messageBus: MessageBus): void {
@@ -244,7 +244,7 @@ export class ToolRegistry {
await this.discoverAndRegisterToolsFromCommand();
// discover tools using MCP servers, if configured
await this.mcpClientManager.discoverAllMcpTools(this.config);
await this.mcpClientManager.discoverAllMcpTools();
}
/**
@@ -259,7 +259,7 @@ export class ToolRegistry {
this.config.getPromptRegistry().clear();
// discover tools using MCP servers, if configured
await this.mcpClientManager.discoverAllMcpTools(this.config);
await this.mcpClientManager.discoverAllMcpTools();
}
/**