From cc081337b7207df6640318931301101a846539b6 Mon Sep 17 00:00:00 2001 From: Jacob MacDonald Date: Thu, 30 Oct 2025 11:05:49 -0700 Subject: [PATCH] Initial support for reloading extensions in the CLI - mcp servers only (#12239) --- integration-tests/extensions-reload.test.ts | 116 ++++++++++ integration-tests/test-helper.ts | 9 + integration-tests/test-mcp-server.ts | 36 ++- .../a2a-server/src/utils/testing_utils.ts | 1 + .../cli/src/commands/extensions/update.ts | 4 +- packages/cli/src/config/config.ts | 1 + packages/cli/src/config/extensions/update.ts | 54 +++-- packages/cli/src/config/settingsSchema.ts | 10 + packages/cli/src/ui/AppContainer.tsx | 6 +- .../components/views/ExtensionsList.test.tsx | 4 + .../ui/components/views/ExtensionsList.tsx | 1 + .../src/ui/hooks/atCommandProcessor.test.ts | 1 + .../src/ui/hooks/useExtensionUpdates.test.tsx | 8 +- .../cli/src/ui/hooks/useExtensionUpdates.ts | 10 +- packages/cli/src/ui/state/extensions.ts | 1 + packages/core/src/config/config.ts | 13 +- .../core/src/tools/mcp-client-manager.test.ts | 55 +++-- packages/core/src/tools/mcp-client-manager.ts | 207 ++++++++++++++---- packages/core/src/tools/mcp-client.ts | 1 + packages/core/src/tools/tool-registry.ts | 6 +- 20 files changed, 437 insertions(+), 107 deletions(-) create mode 100644 integration-tests/extensions-reload.test.ts diff --git a/integration-tests/extensions-reload.test.ts b/integration-tests/extensions-reload.test.ts new file mode 100644 index 0000000000..d28097f2c0 --- /dev/null +++ b/integration-tests/extensions-reload.test.ts @@ -0,0 +1,116 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { expect, it, describe } from 'vitest'; +import { TestRig } from './test-helper.js'; +import { TestMcpServer } from './test-mcp-server.js'; +import { writeFileSync } from 'node:fs'; +import { join } from 'node:path'; +import { safeJsonStringify } from '@google/gemini-cli-core/src/utils/safeJsonStringify.js'; +import { env } from 'node:process'; +import { platform } from 'node:os'; + +const itIf = (condition: boolean) => (condition ? it : it.skip); + +describe('extension reloading', () => { + const sandboxEnv = env['GEMINI_SANDBOX']; + + // Fails in sandbox mode, can't check for local extension updates. + itIf((!sandboxEnv || sandboxEnv === 'false') && platform() !== 'win32')( + 'installs a local extension, updates it, checks it was reloaded properly', + async () => { + const serverA = new TestMcpServer(); + const portA = await serverA.start({ + hello: () => ({ content: [{ type: 'text', text: 'world' }] }), + }); + const extension = { + name: 'test-extension', + version: '0.0.1', + mcpServers: { + 'test-server': { + httpUrl: `http://localhost:${portA}/mcp`, + }, + }, + }; + + const rig = new TestRig(); + rig.setup('extension reload test', { + settings: { + experimental: { extensionReloading: true }, + }, + }); + const testServerPath = join(rig.testDir!, 'gemini-extension.json'); + writeFileSync(testServerPath, safeJsonStringify(extension, 2)); + // defensive cleanup from previous tests. + try { + await rig.runCommand(['extensions', 'uninstall', 'test-extension']); + } catch { + /* empty */ + } + + const result = await rig.runCommand( + ['extensions', 'install', `${rig.testDir!}`], + { stdin: 'y\n' }, + ); + expect(result).toContain('test-extension'); + + // Now create the update, but its not installed yet + const serverB = new TestMcpServer(); + const portB = await serverB.start({ + goodbye: () => ({ content: [{ type: 'text', text: 'world' }] }), + }); + extension.version = '0.0.2'; + extension.mcpServers['test-server'].httpUrl = + `http://localhost:${portB}/mcp`; + writeFileSync(testServerPath, safeJsonStringify(extension, 2)); + + // Start the CLI. + const run = await rig.runInteractive('--debug'); + await run.expectText('You have 1 extension with an update available'); + // See the outdated extension + await run.sendText('/extensions list'); + await run.type('\r'); + await run.expectText( + 'test-extension (v0.0.1) - active (update available)', + ); + await run.sendText('/mcp list'); + await run.type('\r'); + await run.expectText( + 'test-server (from test-extension) - Ready (1 tool)', + ); + await run.expectText('- hello'); + + // Update the extension, expect the list to update, and mcp servers as well. + await run.sendText('/extensions update test-extension'); + await run.type('\r'); + await run.expectText( + ` * test-server (remote): http://localhost:${portB}/mcp`, + ); + await run.type('\r'); // consent + await run.expectText( + 'Extension "test-extension" successfully updated: 0.0.1 → 0.0.2', + ); + await new Promise((resolve) => setTimeout(resolve, 1000)); + await run.sendText('/extensions list'); + await run.type('\r'); + await run.expectText('test-extension (v0.0.2) - active (updated)'); + await run.sendText('/mcp list'); + await run.type('\r'); + await run.expectText( + 'test-server (from test-extension) - Ready (1 tool)', + ); + await run.expectText('- goodbye'); + await run.sendText('/quit'); + await run.sendKeys('\r'); + + // Clean things up. + await serverA.stop(); + await serverB.stop(); + await rig.runCommand(['extensions', 'uninstall', 'test-extension']); + await rig.cleanup(); + }, + ); +}); diff --git a/integration-tests/test-helper.ts b/integration-tests/test-helper.ts index f937639ac4..8ba6dc5708 100644 --- a/integration-tests/test-helper.ts +++ b/integration-tests/test-helper.ts @@ -220,6 +220,13 @@ export class InteractiveRun { } } + // Types an entire string at once, necessary for some things like commands + // but may run into paste detection issues for larger strings. + async sendText(text: string) { + this.ptyProcess.write(text); + await new Promise((resolve) => setTimeout(resolve, 5)); + } + // Simulates typing a string one character at a time to avoid paste detection. async sendKeys(text: string) { const delay = 5; @@ -311,6 +318,8 @@ export class TestRig { model: DEFAULT_GEMINI_MODEL, sandbox: env['GEMINI_SANDBOX'] !== 'false' ? env['GEMINI_SANDBOX'] : false, + // Don't show the IDE connection dialog when running from VsCode + ide: { enabled: false, hasSeenNudge: true }, ...options.settings, // Allow tests to override/add settings }; writeFileSync( diff --git a/integration-tests/test-mcp-server.ts b/integration-tests/test-mcp-server.ts index 121d6ed0a9..c0b696032b 100644 --- a/integration-tests/test-mcp-server.ts +++ b/integration-tests/test-mcp-server.ts @@ -4,17 +4,21 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; +import { + McpServer, + type ToolCallback, +} from '@modelcontextprotocol/sdk/server/mcp.js'; import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'; import express from 'express'; import { type Server as HTTPServer } from 'node:http'; - -import { randomUUID } from 'node:crypto'; +import { type ZodRawShape } from 'zod'; export class TestMcpServer { private server: HTTPServer | undefined; - async start(): Promise { + async start( + tools?: Record>, + ): Promise { const app = express(); app.use(express.json()); const mcpServer = new McpServer( @@ -22,18 +26,30 @@ export class TestMcpServer { name: 'test-mcp-server', version: '1.0.0', }, - { capabilities: {} }, + { capabilities: { tools: {} } }, ); - - const transport = new StreamableHTTPServerTransport({ - sessionIdGenerator: () => randomUUID(), - }); - mcpServer.connect(transport); + if (tools) { + for (const [name, cb] of Object.entries(tools)) { + mcpServer.registerTool(name, {}, cb); + } + } app.post('/mcp', async (req, res) => { + const transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: undefined, + enableJsonResponse: true, + }); + res.on('close', () => { + transport.close(); + }); + await mcpServer.connect(transport); await transport.handleRequest(req, res, req.body); }); + app.get('/mcp', async (req, res) => { + res.status(405).send('Not supported'); + }); + return new Promise((resolve, reject) => { this.server = app.listen(0, () => { const address = this.server!.address(); diff --git a/packages/a2a-server/src/utils/testing_utils.ts b/packages/a2a-server/src/utils/testing_utils.ts index 865fc9d5ac..10e9bf00a4 100644 --- a/packages/a2a-server/src/utils/testing_utils.ts +++ b/packages/a2a-server/src/utils/testing_utils.ts @@ -53,6 +53,7 @@ export function createMockConfig( getEnableMessageBusIntegration: vi.fn().mockReturnValue(false), getMessageBus: vi.fn(), getPolicyEngine: vi.fn(), + getEnableExtensionReloading: vi.fn().mockReturnValue(false), ...overrides, } as unknown as Config; diff --git a/packages/cli/src/commands/extensions/update.ts b/packages/cli/src/commands/extensions/update.ts index f3e78f2cca..5488bacde7 100644 --- a/packages/cli/src/commands/extensions/update.ts +++ b/packages/cli/src/commands/extensions/update.ts @@ -30,11 +30,12 @@ const updateOutput = (info: ExtensionUpdateInfo) => export async function handleUpdate(args: UpdateArgs) { const workspaceDir = process.cwd(); + const settings = loadSettings(workspaceDir).merged; const extensionManager = new ExtensionManager({ workspaceDir, requestConsent: requestConsentNonInteractive, requestSetting: promptForSetting, - settings: loadSettings(workspaceDir).merged, + settings, }); const extensions = await extensionManager.loadExtensions(); @@ -67,6 +68,7 @@ export async function handleUpdate(args: UpdateArgs) { extensionManager, updateState, () => {}, + settings.experimental?.extensionReloading, ))!; if ( updatedExtensionInfo.originalVersion !== diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts index ffc4d95353..a5a2a1e58d 100755 --- a/packages/cli/src/config/config.ts +++ b/packages/cli/src/config/config.ts @@ -680,6 +680,7 @@ export async function loadCliConfig( listExtensions: argv.listExtensions || false, enabledExtensions: argv.extensions, extensionLoader: extensionManager, + enableExtensionReloading: settings.experimental?.extensionReloading, blockedMcpServers, noBrowser: !!process.env['NO_BROWSER'], summarizeToolOutput: settings.model?.summarizeToolOutput, diff --git a/packages/cli/src/config/extensions/update.ts b/packages/cli/src/config/extensions/update.ts index 7bfa253651..20f1de8a9b 100644 --- a/packages/cli/src/config/extensions/update.ts +++ b/packages/cli/src/config/extensions/update.ts @@ -28,6 +28,7 @@ export async function updateExtension( extensionManager: ExtensionManager, currentState: ExtensionUpdateState, dispatchExtensionStateUpdate: (action: ExtensionUpdateAction) => void, + enableExtensionReloading?: boolean, ): Promise { if (currentState === ExtensionUpdateState.UPDATING) { return undefined; @@ -81,7 +82,9 @@ export async function updateExtension( type: 'SET_STATE', payload: { name: extension.name, - state: ExtensionUpdateState.UPDATED_NEEDS_RESTART, + state: enableExtensionReloading + ? ExtensionUpdateState.UPDATED + : ExtensionUpdateState.UPDATED_NEEDS_RESTART, }, }); return { @@ -109,6 +112,7 @@ export async function updateAllUpdatableExtensions( extensionsState: Map, extensionManager: ExtensionManager, dispatch: (action: ExtensionUpdateAction) => void, + enableExtensionReloading?: boolean, ): Promise { return ( await Promise.all( @@ -124,6 +128,7 @@ export async function updateAllUpdatableExtensions( extensionManager, extensionsState.get(extension.name)!.status, dispatch, + enableExtensionReloading, ), ), ) @@ -141,34 +146,37 @@ export async function checkForAllExtensionUpdates( dispatch: (action: ExtensionUpdateAction) => void, ): Promise { dispatch({ type: 'BATCH_CHECK_START' }); - const promises: Array> = []; - for (const extension of extensions) { - if (!extension.installMetadata) { + try { + const promises: Array> = []; + for (const extension of extensions) { + if (!extension.installMetadata) { + dispatch({ + type: 'SET_STATE', + payload: { + name: extension.name, + state: ExtensionUpdateState.NOT_UPDATABLE, + }, + }); + continue; + } dispatch({ type: 'SET_STATE', payload: { name: extension.name, - state: ExtensionUpdateState.NOT_UPDATABLE, + state: ExtensionUpdateState.CHECKING_FOR_UPDATES, }, }); - continue; + promises.push( + checkForExtensionUpdate(extension, extensionManager).then((state) => + dispatch({ + type: 'SET_STATE', + payload: { name: extension.name, state }, + }), + ), + ); } - dispatch({ - type: 'SET_STATE', - payload: { - name: extension.name, - state: ExtensionUpdateState.CHECKING_FOR_UPDATES, - }, - }); - promises.push( - checkForExtensionUpdate(extension, extensionManager).then((state) => - dispatch({ - type: 'SET_STATE', - payload: { name: extension.name, state }, - }), - ), - ); + await Promise.all(promises); + } finally { + dispatch({ type: 'BATCH_CHECK_END' }); } - await Promise.all(promises); - dispatch({ type: 'BATCH_CHECK_END' }); } diff --git a/packages/cli/src/config/settingsSchema.ts b/packages/cli/src/config/settingsSchema.ts index 2c3fc21ff4..7de0c85e86 100644 --- a/packages/cli/src/config/settingsSchema.ts +++ b/packages/cli/src/config/settingsSchema.ts @@ -1075,6 +1075,16 @@ const SETTINGS_SCHEMA = { description: 'Enable extension management features.', showInDialog: false, }, + extensionReloading: { + type: 'boolean', + label: 'Extension Reloading', + category: 'Experimental', + requiresRestart: true, + default: false, + description: + 'Enables extension loading/unloading within the CLI session.', + showInDialog: false, + }, useModelRouter: { type: 'boolean', label: 'Use Model Router', diff --git a/packages/cli/src/ui/AppContainer.tsx b/packages/cli/src/ui/AppContainer.tsx index 8ba5fb3ed8..5983d2a09d 100644 --- a/packages/cli/src/ui/AppContainer.tsx +++ b/packages/cli/src/ui/AppContainer.tsx @@ -183,7 +183,11 @@ export const AppContainer = (props: AppContainerProps) => { extensionsUpdateState, extensionsUpdateStateInternal, dispatchExtensionStateUpdate, - } = useExtensionUpdates(extensionManager, historyManager.addItem); + } = useExtensionUpdates( + extensionManager, + historyManager.addItem, + config.getEnableExtensionReloading(), + ); const [isPermissionsDialogOpen, setPermissionsDialogOpen] = useState(false); const openPermissionsDialog = useCallback( diff --git a/packages/cli/src/ui/components/views/ExtensionsList.test.tsx b/packages/cli/src/ui/components/views/ExtensionsList.test.tsx index cfb5306d2c..fcb2320dcf 100644 --- a/packages/cli/src/ui/components/views/ExtensionsList.test.tsx +++ b/packages/cli/src/ui/components/views/ExtensionsList.test.tsx @@ -97,6 +97,10 @@ describe('', () => { state: ExtensionUpdateState.UPDATED_NEEDS_RESTART, expectedText: '(updated, needs restart)', }, + { + state: ExtensionUpdateState.UPDATED, + expectedText: '(updated)', + }, { state: ExtensionUpdateState.ERROR, expectedText: '(error)', diff --git a/packages/cli/src/ui/components/views/ExtensionsList.tsx b/packages/cli/src/ui/components/views/ExtensionsList.tsx index b37648d78c..9297d2496a 100644 --- a/packages/cli/src/ui/components/views/ExtensionsList.tsx +++ b/packages/cli/src/ui/components/views/ExtensionsList.tsx @@ -48,6 +48,7 @@ export const ExtensionsList: React.FC = ({ extensions }) => { break; case ExtensionUpdateState.UP_TO_DATE: case ExtensionUpdateState.NOT_UPDATABLE: + case ExtensionUpdateState.UPDATED: stateColor = 'green'; break; case undefined: diff --git a/packages/cli/src/ui/hooks/atCommandProcessor.test.ts b/packages/cli/src/ui/hooks/atCommandProcessor.test.ts index 5758006216..d79714d8f2 100644 --- a/packages/cli/src/ui/hooks/atCommandProcessor.test.ts +++ b/packages/cli/src/ui/hooks/atCommandProcessor.test.ts @@ -84,6 +84,7 @@ describe('handleAtCommand', () => { getReadManyFilesExcludes: () => [], }), getUsageStatisticsEnabled: () => false, + getEnableExtensionReloading: () => false, } as unknown as Config; const registry = new ToolRegistry(mockConfig); diff --git a/packages/cli/src/ui/hooks/useExtensionUpdates.test.tsx b/packages/cli/src/ui/hooks/useExtensionUpdates.test.tsx index 8e36311dc0..2efceef3fa 100644 --- a/packages/cli/src/ui/hooks/useExtensionUpdates.test.tsx +++ b/packages/cli/src/ui/hooks/useExtensionUpdates.test.tsx @@ -96,7 +96,7 @@ describe('useExtensionUpdates', () => { ); function TestComponent() { - useExtensionUpdates(extensionManager, addItem); + useExtensionUpdates(extensionManager, addItem, false); return null; } @@ -146,7 +146,7 @@ describe('useExtensionUpdates', () => { }); function TestComponent() { - useExtensionUpdates(extensionManager, addItem); + useExtensionUpdates(extensionManager, addItem, false); return null; } @@ -224,7 +224,7 @@ describe('useExtensionUpdates', () => { }); function TestComponent() { - useExtensionUpdates(extensionManager, addItem); + useExtensionUpdates(extensionManager, addItem, false); return null; } @@ -307,7 +307,7 @@ describe('useExtensionUpdates', () => { ); function TestComponent() { - useExtensionUpdates(extensionManager, addItem); + useExtensionUpdates(extensionManager, addItem, false); return null; } diff --git a/packages/cli/src/ui/hooks/useExtensionUpdates.ts b/packages/cli/src/ui/hooks/useExtensionUpdates.ts index 43dc5f2e20..6ff5dcb37a 100644 --- a/packages/cli/src/ui/hooks/useExtensionUpdates.ts +++ b/packages/cli/src/ui/hooks/useExtensionUpdates.ts @@ -80,6 +80,7 @@ export const useConfirmUpdateRequests = () => { export const useExtensionUpdates = ( extensionManager: ExtensionManager, addItem: UseHistoryManagerReturn['addItem'], + enableExtensionReloading: boolean, ) => { const [extensionsUpdateState, dispatchExtensionStateUpdate] = useReducer( extensionUpdatesReducer, @@ -163,6 +164,7 @@ export const useExtensionUpdates = ( extensionManager, currentState.status, dispatchExtensionStateUpdate, + enableExtensionReloading, ); updatePromises.push(updatePromise); updatePromise @@ -209,7 +211,13 @@ export const useExtensionUpdates = ( }); }); } - }, [extensions, extensionManager, extensionsUpdateState, addItem]); + }, [ + extensions, + extensionManager, + extensionsUpdateState, + addItem, + enableExtensionReloading, + ]); const extensionsUpdateStateComputed = useMemo(() => { const result = new Map(); diff --git a/packages/cli/src/ui/state/extensions.ts b/packages/cli/src/ui/state/extensions.ts index 49295f5c15..353cf79668 100644 --- a/packages/cli/src/ui/state/extensions.ts +++ b/packages/cli/src/ui/state/extensions.ts @@ -10,6 +10,7 @@ import { checkExhaustive } from '../../utils/checks.js'; export enum ExtensionUpdateState { CHECKING_FOR_UPDATES = 'checking for updates', UPDATED_NEEDS_RESTART = 'updated, needs restart', + UPDATED = 'updated', UPDATING = 'updating', UPDATE_AVAILABLE = 'update available', UP_TO_DATE = 'up to date', diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index c1bcf9e592..9b87868871 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -255,6 +255,7 @@ export interface ConfigParameters { listExtensions?: boolean; extensionLoader?: ExtensionLoader; enabledExtensions?: string[]; + enableExtensionReloading?: boolean; blockedMcpServers?: Array<{ name: string; extensionName: string }>; noBrowser?: boolean; summarizeToolOutput?: Record; @@ -312,7 +313,7 @@ export class Config { private readonly toolDiscoveryCommand: string | undefined; private readonly toolCallCommand: string | undefined; private readonly mcpServerCommand: string | undefined; - private readonly mcpServers: Record | undefined; + private mcpServers: Record | undefined; private userMemory: string; private geminiMdFileCount: number; private geminiMdFilePaths: string[]; @@ -346,6 +347,7 @@ export class Config { private readonly listExtensions: boolean; private readonly _extensionLoader: ExtensionLoader; private readonly _enabledExtensions: string[]; + private readonly enableExtensionReloading: boolean; private readonly _blockedMcpServers: Array<{ name: string; extensionName: string; @@ -501,6 +503,7 @@ export class Config { this.enableShellOutputEfficiency = params.enableShellOutputEfficiency ?? true; this.extensionManagement = params.extensionManagement ?? true; + this.enableExtensionReloading = params.enableExtensionReloading ?? false; this.storage = new Storage(this.targetDir); this.fakeResponses = params.fakeResponses; this.recordResponses = params.recordResponses; @@ -749,6 +752,10 @@ export class Config { return this.mcpServers; } + setMcpServers(mcpServers: Record): void { + this.mcpServers = mcpServers; + } + getUserMemory(): string { return this.userMemory; } @@ -924,6 +931,10 @@ export class Config { return this._enabledExtensions; } + getEnableExtensionReloading(): boolean { + return this.enableExtensionReloading; + } + getBlockedMcpServers(): Array<{ name: string; extensionName: string }> { return this._blockedMcpServers; } diff --git a/packages/core/src/tools/mcp-client-manager.test.ts b/packages/core/src/tools/mcp-client-manager.test.ts index dc0560107f..6f160d1989 100644 --- a/packages/core/src/tools/mcp-client-manager.test.ts +++ b/packages/core/src/tools/mcp-client-manager.test.ts @@ -9,6 +9,7 @@ import { McpClientManager } from './mcp-client-manager.js'; import { McpClient } from './mcp-client.js'; import type { ToolRegistry } from './tool-registry.js'; import type { Config } from '../config/config.js'; +import { SimpleExtensionLoader } from '../utils/extensionLoader.js'; vi.mock('./mcp-client.js', async () => { const originalModule = await vi.importActual('./mcp-client.js'); @@ -36,17 +37,22 @@ describe('McpClientManager', () => { vi.mocked(McpClient).mockReturnValue( mockedMcpClient as unknown as McpClient, ); - const manager = new McpClientManager({} as ToolRegistry); - await manager.discoverAllMcpTools({ - isTrustedFolder: () => true, - getMcpServers: () => ({ - 'test-server': {}, - }), - getMcpServerCommand: () => '', - getPromptRegistry: () => {}, - getDebugMode: () => false, - getWorkspaceContext: () => {}, - } as unknown as Config); + const manager = new McpClientManager( + {} as ToolRegistry, + { + isTrustedFolder: () => true, + getExtensionLoader: () => new SimpleExtensionLoader([]), + getMcpServers: () => ({ + 'test-server': {}, + }), + getMcpServerCommand: () => '', + getPromptRegistry: () => {}, + getDebugMode: () => false, + getWorkspaceContext: () => {}, + getEnableExtensionReloading: () => false, + } as unknown as Config, + ); + await manager.discoverAllMcpTools(); expect(mockedMcpClient.connect).toHaveBeenCalledOnce(); expect(mockedMcpClient.discover).toHaveBeenCalledOnce(); }); @@ -61,17 +67,22 @@ describe('McpClientManager', () => { vi.mocked(McpClient).mockReturnValue( mockedMcpClient as unknown as McpClient, ); - const manager = new McpClientManager({} as ToolRegistry); - await manager.discoverAllMcpTools({ - isTrustedFolder: () => false, - getMcpServers: () => ({ - 'test-server': {}, - }), - getMcpServerCommand: () => '', - getPromptRegistry: () => {}, - getDebugMode: () => false, - getWorkspaceContext: () => {}, - } as unknown as Config); + const manager = new McpClientManager( + {} as ToolRegistry, + { + isTrustedFolder: () => false, + getExtensionLoader: () => new SimpleExtensionLoader([]), + getMcpServers: () => ({ + 'test-server': {}, + }), + getMcpServerCommand: () => '', + getPromptRegistry: () => {}, + getDebugMode: () => false, + getWorkspaceContext: () => {}, + getEnableExtensionReloading: () => false, + } as unknown as Config, + ); + await manager.discoverAllMcpTools(); expect(mockedMcpClient.connect).not.toHaveBeenCalled(); expect(mockedMcpClient.discover).not.toHaveBeenCalled(); }); diff --git a/packages/core/src/tools/mcp-client-manager.ts b/packages/core/src/tools/mcp-client-manager.ts index d482da3722..1116778125 100644 --- a/packages/core/src/tools/mcp-client-manager.ts +++ b/packages/core/src/tools/mcp-client-manager.ts @@ -4,7 +4,11 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { Config } from '../config/config.js'; +import type { + Config, + GeminiCLIExtension, + MCPServerConfig, +} from '../config/config.js'; import type { ToolRegistry } from './tool-registry.js'; import { McpClient, @@ -14,6 +18,7 @@ import { import { getErrorMessage } from '../utils/errors.js'; import type { EventEmitter } from 'node:events'; import { coreEvents } from '../utils/events.js'; +import { debugLogger } from '../utils/debugLogger.js'; /** * Manages the lifecycle of multiple MCP clients, including local child processes. @@ -23,12 +28,162 @@ import { coreEvents } from '../utils/events.js'; export class McpClientManager { private clients: Map = new Map(); private readonly toolRegistry: ToolRegistry; + private readonly cliConfig: Config; + // If we have ongoing MCP client discovery, this completes once that is done. + private discoveryPromise: Promise | undefined; private discoveryState: MCPDiscoveryState = MCPDiscoveryState.NOT_STARTED; private readonly eventEmitter?: EventEmitter; - constructor(toolRegistry: ToolRegistry, eventEmitter?: EventEmitter) { + constructor( + toolRegistry: ToolRegistry, + cliConfig: Config, + eventEmitter?: EventEmitter, + ) { this.toolRegistry = toolRegistry; + this.cliConfig = cliConfig; this.eventEmitter = eventEmitter; + if (this.cliConfig.getEnableExtensionReloading()) { + this.cliConfig + .getExtensionLoader() + .extensionEvents() + .on('extensionLoaded', (event) => this.loadExtension(event.extension)) + .on('extensionEnabled', (event) => this.loadExtension(event.extension)) + .on('extensionDisabled', (event) => + this.unloadExtension(event.extension), + ) + .on('extensionUnloaded', (event) => + this.unloadExtension(event.extension), + ); + } + } + + /** + * For all the MCP servers associated with this extension: + * + * - Removes all its MCP servers from the global configuration object. + * - Disconnects all MCP clients from their servers. + * - Updates the Gemini chat configuration to load the new tools. + */ + private async unloadExtension(extension: GeminiCLIExtension) { + debugLogger.log(`Unloading extension: ${extension.name}`); + await Promise.all( + Object.keys(extension.mcpServers ?? {}).map((name) => { + const newMcpServers = { + ...this.cliConfig.getMcpServers(), + }; + delete newMcpServers[name]; + this.cliConfig.setMcpServers(newMcpServers); + return this.disconnectClient(name); + }), + ); + // This is required to update the content generator configuration with the + // new tool configuration. + this.cliConfig.getGeminiClient().setTools(); + } + + /** + * For all the MCP servers associated with this extension: + * + * - Adds all its MCP servers to the global configuration object. + * - Connects MCP clients to each server and discovers their tools. + * - Updates the Gemini chat configuration to load the new tools. + */ + private async loadExtension(extension: GeminiCLIExtension) { + debugLogger.log(`Loading extension: ${extension.name}`); + await Promise.all( + Object.entries(extension.mcpServers ?? {}).map(([name, config]) => { + this.cliConfig.setMcpServers({ + ...this.cliConfig.getMcpServers(), + [name]: config, + }); + return this.discoverMcpTools(name, config); + }), + ); + // This is required to update the content generator configuration with the + // new tool configuration. + this.cliConfig.getGeminiClient().setTools(); + } + + private async disconnectClient(name: string) { + const existing = this.clients.get(name); + if (existing) { + try { + this.clients.delete(name); + this.eventEmitter?.emit('mcp-client-update', this.clients); + await existing.disconnect(); + } catch (error) { + debugLogger.warn( + `Error stopping client '${name}': ${getErrorMessage(error)}`, + ); + } + } + } + + discoverMcpTools( + name: string, + config: MCPServerConfig, + ): Promise | void { + if (!this.cliConfig.isTrustedFolder()) { + return; + } + if (config.extension && !config.extension.isActive) { + return; + } + + const currentDiscoveryPromise = new Promise((resolve, _reject) => { + (async () => { + try { + await this.disconnectClient(name); + + const client = new McpClient( + name, + config, + this.toolRegistry, + this.cliConfig.getPromptRegistry(), + this.cliConfig.getWorkspaceContext(), + this.cliConfig.getDebugMode(), + ); + this.clients.set(name, client); + this.eventEmitter?.emit('mcp-client-update', this.clients); + try { + await client.connect(); + await client.discover(this.cliConfig); + this.eventEmitter?.emit('mcp-client-update', this.clients); + } catch (error) { + this.eventEmitter?.emit('mcp-client-update', this.clients); + // Log the error but don't let a single failed server stop the others + coreEvents.emitFeedback( + 'error', + `Error during discovery for server '${name}': ${getErrorMessage( + error, + )}`, + error, + ); + } + } finally { + resolve(); + } + })(); + }); + + if (this.discoveryPromise) { + this.discoveryPromise = this.discoveryPromise.then( + () => currentDiscoveryPromise, + ); + } else { + this.discoveryState = MCPDiscoveryState.IN_PROGRESS; + this.discoveryPromise = currentDiscoveryPromise; + } + const currentPromise = this.discoveryPromise; + currentPromise.then((_) => { + // If we are the last recorded discoveryPromise, then we are done, reset + // the world. + if (currentPromise === this.discoveryPromise) { + this.discoveryPromise = undefined; + this.discoveryState = MCPDiscoveryState.COMPLETED; + } + }); + return currentPromise; } /** @@ -36,53 +191,23 @@ export class McpClientManager { * It connects to each server, discovers its available tools, and registers * them with the `ToolRegistry`. */ - async discoverAllMcpTools(cliConfig: Config): Promise { - if (!cliConfig.isTrustedFolder()) { + async discoverAllMcpTools(): Promise { + if (!this.cliConfig.isTrustedFolder()) { return; } await this.stop(); const servers = populateMcpServerCommand( - cliConfig.getMcpServers() || {}, - cliConfig.getMcpServerCommand(), + this.cliConfig.getMcpServers() || {}, + this.cliConfig.getMcpServerCommand(), ); - this.discoveryState = MCPDiscoveryState.IN_PROGRESS; - this.eventEmitter?.emit('mcp-client-update', this.clients); - const discoveryPromises = Object.entries(servers) - .filter(([_, config]) => !config.extension || config.extension.isActive) - .map(async ([name, config]) => { - const client = new McpClient( - name, - config, - this.toolRegistry, - cliConfig.getPromptRegistry(), - cliConfig.getWorkspaceContext(), - cliConfig.getDebugMode(), - ); - this.clients.set(name, client); - - this.eventEmitter?.emit('mcp-client-update', this.clients); - try { - await client.connect(); - await client.discover(cliConfig); - this.eventEmitter?.emit('mcp-client-update', this.clients); - } catch (error) { - this.eventEmitter?.emit('mcp-client-update', this.clients); - // Log the error but don't let a single failed server stop the others - coreEvents.emitFeedback( - 'error', - `Error during discovery for server '${name}': ${getErrorMessage( - error, - )}`, - error, - ); - } - }); - - await Promise.all(discoveryPromises); - this.discoveryState = MCPDiscoveryState.COMPLETED; + await Promise.all( + Object.entries(servers).map(async ([name, config]) => + this.discoverMcpTools(name, config), + ), + ); } /** diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index 6457fc1cb3..da649894a8 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -160,6 +160,7 @@ export class McpClient { if (this.status !== MCPServerStatus.CONNECTED) { return; } + this.toolRegistry.removeMcpToolsByServer(this.serverName); this.updateStatus(MCPServerStatus.DISCONNECTING); const client = this.client; this.client = undefined; diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts index c7d8e35305..c71fba5ab7 100644 --- a/packages/core/src/tools/tool-registry.ts +++ b/packages/core/src/tools/tool-registry.ts @@ -181,7 +181,7 @@ export class ToolRegistry { constructor(config: Config, eventEmitter?: EventEmitter) { this.config = config; - this.mcpClientManager = new McpClientManager(this, eventEmitter); + this.mcpClientManager = new McpClientManager(this, config, eventEmitter); } setMessageBus(messageBus: MessageBus): void { @@ -244,7 +244,7 @@ export class ToolRegistry { await this.discoverAndRegisterToolsFromCommand(); // discover tools using MCP servers, if configured - await this.mcpClientManager.discoverAllMcpTools(this.config); + await this.mcpClientManager.discoverAllMcpTools(); } /** @@ -259,7 +259,7 @@ export class ToolRegistry { this.config.getPromptRegistry().clear(); // discover tools using MCP servers, if configured - await this.mcpClientManager.discoverAllMcpTools(this.config); + await this.mcpClientManager.discoverAllMcpTools(); } /**