From da4fa5ad75ccea4d8e320b1c0d552614e654f806 Mon Sep 17 00:00:00 2001 From: Jacob MacDonald Date: Tue, 4 Nov 2025 07:51:18 -0800 Subject: [PATCH] Extensions MCP refactor (#12413) --- packages/a2a-server/src/agent/task.ts | 2 +- packages/a2a-server/src/config/config.ts | 24 +- packages/cli/src/commands/mcp/list.test.ts | 48 ++-- packages/cli/src/config/config.test.ts | 39 +--- packages/cli/src/config/config.ts | 108 +-------- packages/cli/src/config/extension-manager.ts | 57 +++-- packages/cli/src/config/extension.test.ts | 30 ++- packages/cli/src/gemini.test.tsx | 2 + .../cli/src/services/McpPromptLoader.test.ts | 6 +- packages/cli/src/services/McpPromptLoader.ts | 5 +- packages/cli/src/ui/AppContainer.test.tsx | 1 + .../cli/src/ui/commands/mcpCommand.test.ts | 5 + packages/cli/src/ui/commands/mcpCommand.ts | 29 +-- .../cli/src/ui/components/Composer.test.tsx | 5 +- packages/cli/src/ui/components/Composer.tsx | 6 +- .../src/ui/components/ConfigInitDisplay.tsx | 10 +- packages/cli/src/utils/events.ts | 12 +- packages/core/src/config/config.ts | 64 +++++- packages/core/src/telemetry/loggers.test.ts | 15 +- packages/core/src/telemetry/types.ts | 3 +- .../core/src/tools/mcp-client-manager.test.ts | 217 +++++++++++++----- packages/core/src/tools/mcp-client-manager.ts | 193 +++++++++++----- packages/core/src/tools/mcp-client.test.ts | 65 ++++++ packages/core/src/tools/mcp-client.ts | 5 + packages/core/src/tools/tool-registry.test.ts | 31 --- packages/core/src/tools/tool-registry.ts | 64 +----- .../core/src/utils/extensionLoader.test.ts | 108 +++++++++ packages/core/src/utils/extensionLoader.ts | 201 +++++++++++++--- 28 files changed, 877 insertions(+), 478 deletions(-) create mode 100644 packages/core/src/utils/extensionLoader.test.ts diff --git a/packages/a2a-server/src/agent/task.ts b/packages/a2a-server/src/agent/task.ts index f0061bc6a9..2c58343e6c 100644 --- a/packages/a2a-server/src/agent/task.ts +++ b/packages/a2a-server/src/agent/task.ts @@ -113,7 +113,7 @@ export class Task { // state managed within the @gemini-cli/core module. async getMetadata(): Promise { const toolRegistry = await this.config.getToolRegistry(); - const mcpServers = this.config.getMcpServers() || {}; + const mcpServers = this.config.getMcpClientManager()?.getMcpServers() || {}; const serverStatuses = getAllMCPServerStatuses(); const servers = Object.keys(mcpServers).map((serverName) => ({ name: serverName, diff --git a/packages/a2a-server/src/config/config.ts b/packages/a2a-server/src/config/config.ts index 5492bb9b0a..97a343eb8d 100644 --- a/packages/a2a-server/src/config/config.ts +++ b/packages/a2a-server/src/config/config.ts @@ -20,9 +20,7 @@ import { GEMINI_DIR, DEFAULT_GEMINI_EMBEDDING_MODEL, DEFAULT_GEMINI_MODEL, - type GeminiCLIExtension, type ExtensionLoader, - debugLogger, } from '@google/gemini-cli-core'; import { logger } from '../utils/logger.js'; @@ -34,7 +32,6 @@ export async function loadConfig( extensionLoader: ExtensionLoader, taskId: string, ): Promise { - const mcpServers = mergeMcpServers(settings, extensionLoader.getExtensions()); const workspaceDir = process.cwd(); const adcFilePath = process.env['GOOGLE_APPLICATION_CREDENTIALS']; @@ -54,7 +51,7 @@ export async function loadConfig( process.env['GEMINI_YOLO_MODE'] === 'true' ? ApprovalMode.YOLO : ApprovalMode.DEFAULT, - mcpServers, + mcpServers: settings.mcpServers, cwd: workspaceDir, telemetry: { enabled: settings.telemetry?.enabled, @@ -120,25 +117,6 @@ export async function loadConfig( return config; } -export function mergeMcpServers( - settings: Settings, - extensions: GeminiCLIExtension[], -) { - const mcpServers = { ...(settings.mcpServers || {}) }; - for (const extension of extensions) { - Object.entries(extension.mcpServers || {}).forEach(([key, server]) => { - if (mcpServers[key]) { - debugLogger.warn( - `Skipping extension MCP config for server with key "${key}" as it already exists.`, - ); - return; - } - mcpServers[key] = server; - }); - } - return mcpServers; -} - export function setTargetDir(agentSettings: AgentSettings | undefined): string { const originalCWD = process.cwd(); const targetDir = diff --git a/packages/cli/src/commands/mcp/list.test.ts b/packages/cli/src/commands/mcp/list.test.ts index ee9cf9395c..e8226f0957 100644 --- a/packages/cli/src/commands/mcp/list.test.ts +++ b/packages/cli/src/commands/mcp/list.test.ts @@ -21,27 +21,33 @@ vi.mock('../../config/extensions/storage.js', () => ({ }, })); vi.mock('../../config/extension-manager.js'); -vi.mock('@google/gemini-cli-core', () => ({ - createTransport: vi.fn(), - MCPServerStatus: { - CONNECTED: 'CONNECTED', - CONNECTING: 'CONNECTING', - DISCONNECTED: 'DISCONNECTED', - }, - Storage: vi.fn().mockImplementation((_cwd: string) => ({ - getGlobalSettingsPath: () => '/tmp/gemini/settings.json', - getWorkspaceSettingsPath: () => '/tmp/gemini/workspace-settings.json', - getProjectTempDir: () => '/test/home/.gemini/tmp/mocked_hash', - })), - GEMINI_DIR: '.gemini', - getErrorMessage: (e: unknown) => (e instanceof Error ? e.message : String(e)), - debugLogger: { - log: vi.fn(), - warn: vi.fn(), - error: vi.fn(), - debug: vi.fn(), - }, -})); +vi.mock('@google/gemini-cli-core', async (importOriginal) => { + const original = + await importOriginal(); + return { + ...original, + createTransport: vi.fn(), + MCPServerStatus: { + CONNECTED: 'CONNECTED', + CONNECTING: 'CONNECTING', + DISCONNECTED: 'DISCONNECTED', + }, + Storage: vi.fn().mockImplementation((_cwd: string) => ({ + getGlobalSettingsPath: () => '/tmp/gemini/settings.json', + getWorkspaceSettingsPath: () => '/tmp/gemini/workspace-settings.json', + getProjectTempDir: () => '/test/home/.gemini/tmp/mocked_hash', + })), + GEMINI_DIR: '.gemini', + getErrorMessage: (e: unknown) => + e instanceof Error ? e.message : String(e), + debugLogger: { + log: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }, + }; +}); vi.mock('@modelcontextprotocol/sdk/client/index.js'); const mockedGetUserExtensionsDir = diff --git a/packages/cli/src/config/config.test.ts b/packages/cli/src/config/config.test.ts index 99dd616703..a7d00affdb 100644 --- a/packages/cli/src/config/config.test.ts +++ b/packages/cli/src/config/config.test.ts @@ -1146,9 +1146,7 @@ describe('loadCliConfig with allowed-mcp-server-names', () => { ]; const argv = await parseArguments({} as Settings); const config = await loadCliConfig(baseSettings, 'test-session', argv); - expect(config.getMcpServers()).toEqual({ - server1: { url: 'http://localhost:8080' }, - }); + expect(config.getAllowedMcpServers()).toEqual(['server1']); }); it('should allow multiple specified MCP servers', async () => { @@ -1162,10 +1160,7 @@ describe('loadCliConfig with allowed-mcp-server-names', () => { ]; const argv = await parseArguments({} as Settings); const config = await loadCliConfig(baseSettings, 'test-session', argv); - expect(config.getMcpServers()).toEqual({ - server1: { url: 'http://localhost:8080' }, - server3: { url: 'http://localhost:8082' }, - }); + expect(config.getAllowedMcpServers()).toEqual(['server1', 'server3']); }); it('should handle server names that do not exist', async () => { @@ -1179,16 +1174,14 @@ describe('loadCliConfig with allowed-mcp-server-names', () => { ]; const argv = await parseArguments({} as Settings); const config = await loadCliConfig(baseSettings, 'test-session', argv); - expect(config.getMcpServers()).toEqual({ - server1: { url: 'http://localhost:8080' }, - }); + expect(config.getAllowedMcpServers()).toEqual(['server1', 'server4']); }); it('should allow no MCP servers if the flag is provided but empty', async () => { process.argv = ['node', 'script.js', '--allowed-mcp-server-names', '']; const argv = await parseArguments({} as Settings); const config = await loadCliConfig(baseSettings, 'test-session', argv); - expect(config.getMcpServers()).toEqual({}); + expect(config.getAllowedMcpServers()).toEqual(['']); }); it('should read allowMCPServers from settings', async () => { @@ -1199,10 +1192,7 @@ describe('loadCliConfig with allowed-mcp-server-names', () => { mcp: { allowed: ['server1', 'server2'] }, }; const config = await loadCliConfig(settings, 'test-session', argv); - expect(config.getMcpServers()).toEqual({ - server1: { url: 'http://localhost:8080' }, - server2: { url: 'http://localhost:8081' }, - }); + expect(config.getAllowedMcpServers()).toEqual(['server1', 'server2']); }); it('should read excludeMCPServers from settings', async () => { @@ -1213,9 +1203,7 @@ describe('loadCliConfig with allowed-mcp-server-names', () => { mcp: { excluded: ['server1', 'server2'] }, }; const config = await loadCliConfig(settings, 'test-session', argv); - expect(config.getMcpServers()).toEqual({ - server3: { url: 'http://localhost:8082' }, - }); + expect(config.getBlockedMcpServers()).toEqual(['server1', 'server2']); }); it('should override allowMCPServers with excludeMCPServers if overlapping', async () => { @@ -1229,9 +1217,8 @@ describe('loadCliConfig with allowed-mcp-server-names', () => { }, }; const config = await loadCliConfig(settings, 'test-session', argv); - expect(config.getMcpServers()).toEqual({ - server2: { url: 'http://localhost:8081' }, - }); + expect(config.getAllowedMcpServers()).toEqual(['server1', 'server2']); + expect(config.getBlockedMcpServers()).toEqual(['server1']); }); it('should prioritize mcp server flag if set', async () => { @@ -1250,9 +1237,7 @@ describe('loadCliConfig with allowed-mcp-server-names', () => { }, }; const config = await loadCliConfig(settings, 'test-session', argv); - expect(config.getMcpServers()).toEqual({ - server1: { url: 'http://localhost:8080' }, - }); + expect(config.getAllowedMcpServers()).toEqual(['server1']); }); it('should prioritize CLI flag over both allowed and excluded settings', async () => { @@ -1273,10 +1258,8 @@ describe('loadCliConfig with allowed-mcp-server-names', () => { }, }; const config = await loadCliConfig(settings, 'test-session', argv); - expect(config.getMcpServers()).toEqual({ - server2: { url: 'http://localhost:8081' }, - server3: { url: 'http://localhost:8082' }, - }); + expect(config.getAllowedMcpServers()).toEqual(['server2', 'server3']); + expect(config.getBlockedMcpServers()).toEqual([]); }); }); diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts index d0e6a8a355..553a1ce760 100755 --- a/packages/cli/src/config/config.ts +++ b/packages/cli/src/config/config.ts @@ -13,9 +13,7 @@ import process from 'node:process'; import { mcpCommand } from '../commands/mcp.js'; import type { FileFilteringOptions, - MCPServerConfig, OutputFormat, - GeminiCLIExtension, } from '@google/gemini-cli-core'; import { extensionsCommand } from '../commands/extensions.js'; import { @@ -49,9 +47,13 @@ import { appEvents } from '../utils/events.js'; import { isWorkspaceTrusted } from './trustedFolders.js'; import { createPolicyEngineConfig } from './policy.js'; import { ExtensionManager } from './extension-manager.js'; -import type { ExtensionLoader } from '@google/gemini-cli-core/src/utils/extensionLoader.js'; +import type { + ExtensionEvents, + ExtensionLoader, +} from '@google/gemini-cli-core/src/utils/extensionLoader.js'; import { requestConsentNonInteractive } from './extensions/consent.js'; import { promptForSetting } from './extensions/extensionSettings.js'; +import type { EventEmitter } from 'node:stream'; export interface CliArgs { query: string | undefined; @@ -429,6 +431,7 @@ export async function loadCliConfig( requestSetting: promptForSetting, workspaceDir: cwd, enabledExtensionOverrides: argv.extensions, + eventEmitter: appEvents as EventEmitter, }); await extensionManager.loadExtensions(); @@ -448,7 +451,6 @@ export async function loadCliConfig( memoryFileFiltering, ); - let mcpServers = mergeMcpServers(settings, extensionManager.getExtensions()); const question = argv.promptInteractive || argv.prompt || ''; // Determine approval mode with backward compatibility @@ -565,37 +567,8 @@ export async function loadCliConfig( const excludeTools = mergeExcludeTools( settings, - extensionManager.getExtensions(), extraExcludes.length > 0 ? extraExcludes : undefined, ); - const blockedMcpServers: Array<{ name: string; extensionName: string }> = []; - - if (!argv.allowedMcpServerNames) { - if (settings.mcp?.allowed) { - mcpServers = allowedMcpServers( - mcpServers, - settings.mcp.allowed, - blockedMcpServers, - ); - } - - if (settings.mcp?.excluded) { - const excludedNames = new Set(settings.mcp.excluded.filter(Boolean)); - if (excludedNames.size > 0) { - mcpServers = Object.fromEntries( - Object.entries(mcpServers).filter(([key]) => !excludedNames.has(key)), - ); - } - } - } - - if (argv.allowedMcpServerNames) { - mcpServers = allowedMcpServers( - mcpServers, - argv.allowedMcpServerNames, - blockedMcpServers, - ); - } const useModelRouter = settings.experimental?.useModelRouter ?? true; const defaultModel = useModelRouter @@ -633,7 +606,11 @@ export async function loadCliConfig( toolDiscoveryCommand: settings.tools?.discoveryCommand, toolCallCommand: settings.tools?.callCommand, mcpServerCommand: settings.mcp?.serverCommand, - mcpServers, + mcpServers: settings.mcpServers, + allowedMcpServers: argv.allowedMcpServerNames ?? settings.mcp?.allowed, + blockedMcpServers: argv.allowedMcpServerNames + ? [] // explicitly allowed servers overrides everything + : settings.mcp?.excluded, userMemory: memoryContent, geminiMdFileCount: fileCount, geminiMdFilePaths: filePaths, @@ -663,7 +640,6 @@ export async function loadCliConfig( enabledExtensions: argv.extensions, extensionLoader: extensionManager, enableExtensionReloading: settings.experimental?.extensionReloading, - blockedMcpServers, noBrowser: !!process.env['NO_BROWSER'], summarizeToolOutput: settings.model?.summarizeToolOutput, ideMode, @@ -699,75 +675,13 @@ export async function loadCliConfig( }); } -function allowedMcpServers( - mcpServers: { [x: string]: MCPServerConfig }, - allowMCPServers: string[], - blockedMcpServers: Array<{ name: string; extensionName: string }>, -) { - const allowedNames = new Set(allowMCPServers.filter(Boolean)); - if (allowedNames.size > 0) { - mcpServers = Object.fromEntries( - Object.entries(mcpServers).filter(([key, server]) => { - const isAllowed = allowedNames.has(key); - if (!isAllowed) { - blockedMcpServers.push({ - name: key, - extensionName: server.extension?.name || '', - }); - } - return isAllowed; - }), - ); - } else { - blockedMcpServers.push( - ...Object.entries(mcpServers).map(([key, server]) => ({ - name: key, - extensionName: server.extension?.name || '', - })), - ); - mcpServers = {}; - } - return mcpServers; -} - -function mergeMcpServers(settings: Settings, extensions: GeminiCLIExtension[]) { - const mcpServers = { ...(settings.mcpServers || {}) }; - for (const extension of extensions) { - if (!extension.isActive) { - continue; - } - Object.entries(extension.mcpServers || {}).forEach(([key, server]) => { - if (mcpServers[key]) { - debugLogger.warn( - `Skipping extension MCP config for server with key "${key}" as it already exists.`, - ); - return; - } - mcpServers[key] = { - ...server, - extension, - }; - }); - } - return mcpServers; -} - function mergeExcludeTools( settings: Settings, - extensions: GeminiCLIExtension[], extraExcludes?: string[] | undefined, ): string[] { const allExcludeTools = new Set([ ...(settings.tools?.exclude || []), ...(extraExcludes || []), ]); - for (const extension of extensions) { - if (!extension.isActive) { - continue; - } - for (const tool of extension.excludeTools || []) { - allExcludeTools.add(tool); - } - } return [...allExcludeTools]; } diff --git a/packages/cli/src/config/extension-manager.ts b/packages/cli/src/config/extension-manager.ts index 55e66c811d..dda9b25c6e 100644 --- a/packages/cli/src/config/extension-manager.ts +++ b/packages/cli/src/config/extension-manager.ts @@ -28,6 +28,7 @@ import { ExtensionDisableEvent, ExtensionEnableEvent, ExtensionInstallEvent, + ExtensionLoader, ExtensionUninstallEvent, ExtensionUpdateEvent, getErrorMessage, @@ -36,6 +37,7 @@ import { logExtensionInstallEvent, logExtensionUninstall, logExtensionUpdateEvent, + type ExtensionEvents, type MCPServerConfig, type ExtensionInstallMetadata, type GeminiCLIExtension, @@ -54,11 +56,7 @@ import { maybePromptForSettings, type ExtensionSetting, } from './extensions/extensionSettings.js'; -import type { - ExtensionEvents, - ExtensionLoader, -} from '@google/gemini-cli-core/src/utils/extensionLoader.js'; -import { EventEmitter } from 'node:events'; +import type { EventEmitter } from 'node:stream'; interface ExtensionManagerParams { enabledExtensionOverrides?: string[]; @@ -66,6 +64,7 @@ interface ExtensionManagerParams { requestConsent: (consent: string) => Promise; requestSetting: ((setting: ExtensionSetting) => Promise) | null; workspaceDir: string; + eventEmitter?: EventEmitter; } /** @@ -73,7 +72,7 @@ interface ExtensionManagerParams { * * You must call `loadExtensions` prior to calling other methods on this class. */ -export class ExtensionManager implements ExtensionLoader { +export class ExtensionManager extends ExtensionLoader { private extensionEnablementManager: ExtensionEnablementManager; private settings: Settings; private requestConsent: (consent: string) => Promise; @@ -83,9 +82,9 @@ export class ExtensionManager implements ExtensionLoader { private telemetryConfig: Config; private workspaceDir: string; private loadedExtensions: GeminiCLIExtension[] | undefined; - private eventEmitter: EventEmitter; constructor(options: ExtensionManagerParams) { + super(options.eventEmitter); this.workspaceDir = options.workspaceDir; this.extensionEnablementManager = new ExtensionEnablementManager( options.enabledExtensionOverrides, @@ -102,7 +101,6 @@ export class ExtensionManager implements ExtensionLoader { }); this.requestConsent = options.requestConsent; this.requestSetting = options.requestSetting ?? undefined; - this.eventEmitter = new EventEmitter(); } setRequestConsent( @@ -126,10 +124,6 @@ export class ExtensionManager implements ExtensionLoader { return this.loadedExtensions!; } - extensionEvents(): EventEmitter { - return this.eventEmitter; - } - async installOrUpdateExtension( installMetadata: ExtensionInstallMetadata, previousExtensionConfig?: ExtensionConfig, @@ -303,7 +297,7 @@ export class ExtensionManager implements ExtensionLoader { await fs.promises.writeFile(metadataPath, metadataString); // TODO: Gracefully handle this call failing, we should back up the old - // extension prior to overwriting it and then restore it. + // extension prior to overwriting it and then restore and restart it. extension = await this.loadExtension(destinationPath)!; if (!extension) { throw new Error(`Extension not found`); @@ -320,7 +314,6 @@ export class ExtensionManager implements ExtensionLoader { 'success', ), ); - this.eventEmitter.emit('extensionUpdated', { extension }); } else { logExtensionInstallEvent( this.telemetryConfig, @@ -332,7 +325,6 @@ export class ExtensionManager implements ExtensionLoader { 'success', ), ); - this.eventEmitter.emit('extensionInstalled', { extension }); this.enableExtension(newExtensionConfig.name, SettingScope.User); } } finally { @@ -397,7 +389,7 @@ export class ExtensionManager implements ExtensionLoader { if (!extension) { throw new Error(`Extension not found.`); } - this.unloadExtension(extension); + await this.unloadExtension(extension); const storage = new ExtensionStorage(extension.name); await fs.promises.rm(storage.getExtensionDir(), { @@ -419,9 +411,11 @@ export class ExtensionManager implements ExtensionLoader { 'success', ), ); - this.eventEmitter.emit('extensionUninstalled', { extension }); } + /** + * Loads all installed extensions, should only be called once. + */ async loadExtensions(): Promise { if (this.loadedExtensions) { throw new Error('Extensions already loaded, only load extensions once.'); @@ -433,12 +427,14 @@ export class ExtensionManager implements ExtensionLoader { } for (const subdir of fs.readdirSync(extensionsDir)) { const extensionDir = path.join(extensionsDir, subdir); - await this.loadExtension(extensionDir); } return this.loadedExtensions; } + /** + * Adds `extension` to the list of extensions and starts it if appropriate. + */ private async loadExtension( extensionDir: string, ): Promise { @@ -499,8 +495,9 @@ export class ExtensionManager implements ExtensionLoader { ), id: getExtensionId(config, installMetadata), }; - this.eventEmitter.emit('extensionLoaded', { extension }); - this.getExtensions().push(extension); + this.loadedExtensions = [...this.loadedExtensions, extension]; + + await this.maybeStartExtension(extension); return extension; } catch (e) { debugLogger.error( @@ -512,11 +509,17 @@ export class ExtensionManager implements ExtensionLoader { } } - private unloadExtension(extension: GeminiCLIExtension) { + /** + * Removes `extension` from the list of extensions and stops it if + * appropriate. + */ + private unloadExtension( + extension: GeminiCLIExtension, + ): Promise | undefined { this.loadedExtensions = this.getExtensions().filter( (entry) => extension !== entry, ); - this.eventEmitter.emit('extensionUnloaded', { extension }); + return this.maybeStopExtension(extension); } loadExtensionConfig(extensionDir: string): ExtensionConfig { @@ -616,14 +619,18 @@ export class ExtensionManager implements ExtensionLoader { const scopePath = scope === SettingScope.Workspace ? this.workspaceDir : os.homedir(); this.extensionEnablementManager.disable(name, true, scopePath); + extension.isActive = false; + await this.maybeStopExtension(extension); logExtensionDisable( this.telemetryConfig, new ExtensionDisableEvent(hashValue(name), extension.id, scope), ); - extension.isActive = false; - this.eventEmitter.emit('extensionDisabled', { extension }); } + /** + * Enables an existing extension for a given scope, and starts it if + * appropriate. + */ async enableExtension(name: string, scope: SettingScope) { if ( scope === SettingScope.System || @@ -645,7 +652,7 @@ export class ExtensionManager implements ExtensionLoader { new ExtensionEnableEvent(hashValue(name), extension.id, scope), ); extension.isActive = true; - this.eventEmitter.emit('extensionEnabled', { extension }); + await this.maybeStartExtension(extension); } } diff --git a/packages/cli/src/config/extension.test.ts b/packages/cli/src/config/extension.test.ts index 8a4af9ac31..c762dd3295 100644 --- a/packages/cli/src/config/extension.test.ts +++ b/packages/cli/src/config/extension.test.ts @@ -1666,7 +1666,10 @@ This extension will run the following MCP servers: }); await extensionManager.loadExtensions(); - extensionManager.disableExtension('my-extension', SettingScope.User); + await extensionManager.disableExtension( + 'my-extension', + SettingScope.User, + ); expect( isEnabled({ name: 'my-extension', @@ -1683,7 +1686,10 @@ This extension will run the following MCP servers: }); await extensionManager.loadExtensions(); - extensionManager.disableExtension('my-extension', SettingScope.Workspace); + await extensionManager.disableExtension( + 'my-extension', + SettingScope.Workspace, + ); expect( isEnabled({ name: 'my-extension', @@ -1706,8 +1712,14 @@ This extension will run the following MCP servers: }); await extensionManager.loadExtensions(); - extensionManager.disableExtension('my-extension', SettingScope.User); - extensionManager.disableExtension('my-extension', SettingScope.User); + await extensionManager.disableExtension( + 'my-extension', + SettingScope.User, + ); + await extensionManager.disableExtension( + 'my-extension', + SettingScope.User, + ); expect( isEnabled({ name: 'my-extension', @@ -1738,7 +1750,7 @@ This extension will run the following MCP servers: }); await extensionManager.loadExtensions(); - extensionManager.disableExtension('ext1', SettingScope.Workspace); + await extensionManager.disableExtension('ext1', SettingScope.Workspace); expect(mockLogExtensionDisable).toHaveBeenCalled(); expect(ExtensionDisableEvent).toHaveBeenCalledWith( @@ -1766,7 +1778,7 @@ This extension will run the following MCP servers: version: '1.0.0', }); await extensionManager.loadExtensions(); - extensionManager.disableExtension('ext1', SettingScope.User); + await extensionManager.disableExtension('ext1', SettingScope.User); let activeExtensions = getActiveExtensions(); expect(activeExtensions).toHaveLength(0); @@ -1783,7 +1795,7 @@ This extension will run the following MCP servers: version: '1.0.0', }); await extensionManager.loadExtensions(); - extensionManager.disableExtension('ext1', SettingScope.Workspace); + await extensionManager.disableExtension('ext1', SettingScope.Workspace); let activeExtensions = getActiveExtensions(); expect(activeExtensions).toHaveLength(0); @@ -1804,8 +1816,8 @@ This extension will run the following MCP servers: }, }); await extensionManager.loadExtensions(); - extensionManager.disableExtension('ext1', SettingScope.Workspace); - extensionManager.enableExtension('ext1', SettingScope.Workspace); + await extensionManager.disableExtension('ext1', SettingScope.Workspace); + await extensionManager.enableExtension('ext1', SettingScope.Workspace); expect(mockLogExtensionEnable).toHaveBeenCalled(); expect(ExtensionEnableEvent).toHaveBeenCalledWith( diff --git a/packages/cli/src/gemini.test.tsx b/packages/cli/src/gemini.test.tsx index aac8b3ac60..3b88d6bb31 100644 --- a/packages/cli/src/gemini.test.tsx +++ b/packages/cli/src/gemini.test.tsx @@ -190,6 +190,7 @@ describe('gemini.tsx main function', () => { getDebugMode: () => false, getListExtensions: () => false, getMcpServers: () => ({}), + getMcpClientManager: vi.fn(), initialize: vi.fn(), getIdeMode: () => false, getExperimentalZedIntegration: () => false, @@ -339,6 +340,7 @@ describe('gemini.tsx main function kitty protocol', () => { getDebugMode: () => false, getListExtensions: () => false, getMcpServers: () => ({}), + getMcpClientManager: vi.fn(), initialize: vi.fn(), getIdeMode: () => false, getExperimentalZedIntegration: () => false, diff --git a/packages/cli/src/services/McpPromptLoader.test.ts b/packages/cli/src/services/McpPromptLoader.test.ts index 3ba4c012ad..e90a3720b4 100644 --- a/packages/cli/src/services/McpPromptLoader.test.ts +++ b/packages/cli/src/services/McpPromptLoader.test.ts @@ -159,8 +159,10 @@ describe('McpPromptLoader', () => { describe('loadCommands', () => { const mockConfigWithPrompts = { - getMcpServers: () => ({ - 'test-server': { httpUrl: 'https://test-server.com' }, + getMcpClientManager: () => ({ + getMcpServers: () => ({ + 'test-server': { httpUrl: 'https://test-server.com' }, + }), }), } as unknown as Config; diff --git a/packages/cli/src/services/McpPromptLoader.ts b/packages/cli/src/services/McpPromptLoader.ts index c402fa82e0..35cc80313e 100644 --- a/packages/cli/src/services/McpPromptLoader.ts +++ b/packages/cli/src/services/McpPromptLoader.ts @@ -34,7 +34,7 @@ export class McpPromptLoader implements ICommandLoader { if (!this.config) { return Promise.resolve([]); } - const mcpServers = this.config.getMcpServers() || {}; + const mcpServers = this.config.getMcpClientManager()?.getMcpServers() || {}; for (const serverName in mcpServers) { const prompts = getMCPServerPrompts(this.config, serverName) || []; for (const prompt of prompts) { @@ -101,7 +101,8 @@ export class McpPromptLoader implements ICommandLoader { } try { - const mcpServers = this.config.getMcpServers() || {}; + const mcpServers = + this.config.getMcpClientManager()?.getMcpServers() || {}; const mcpServerConfig = mcpServers[serverName]; if (!mcpServerConfig) { return { diff --git a/packages/cli/src/ui/AppContainer.test.tsx b/packages/cli/src/ui/AppContainer.test.tsx index 98d50e977e..63298829bf 100644 --- a/packages/cli/src/ui/AppContainer.test.tsx +++ b/packages/cli/src/ui/AppContainer.test.tsx @@ -295,6 +295,7 @@ describe('AppContainer State Management', () => { getExtensions: vi.fn().mockReturnValue([]), setRequestConsent: vi.fn(), setRequestSetting: vi.fn(), + start: vi.fn(), } as unknown as ExtensionManager); vi.spyOn(mockConfig, 'getExtensionLoader').mockReturnValue( mockExtensionManager, diff --git a/packages/cli/src/ui/commands/mcpCommand.test.ts b/packages/cli/src/ui/commands/mcpCommand.test.ts index f2b5865ccc..91eea4acd7 100644 --- a/packages/cli/src/ui/commands/mcpCommand.test.ts +++ b/packages/cli/src/ui/commands/mcpCommand.test.ts @@ -62,6 +62,7 @@ describe('mcpCommand', () => { getBlockedMcpServers: ReturnType; getPromptRegistry: ReturnType; getGeminiClient: ReturnType; + getMcpClientManager: ReturnType; }; beforeEach(() => { @@ -88,6 +89,10 @@ describe('mcpCommand', () => { getPromptsByServer: vi.fn().mockReturnValue([]), }), getGeminiClient: vi.fn(), + getMcpClientManager: vi.fn().mockImplementation(() => ({ + getBlockedMcpServers: vi.fn(), + getMcpServers: vi.fn(), + })), }; mockContext = createMockCommandContext({ diff --git a/packages/cli/src/ui/commands/mcpCommand.ts b/packages/cli/src/ui/commands/mcpCommand.ts index e6a071f437..8663965c22 100644 --- a/packages/cli/src/ui/commands/mcpCommand.ts +++ b/packages/cli/src/ui/commands/mcpCommand.ts @@ -43,7 +43,7 @@ const authCommand: SlashCommand = { }; } - const mcpServers = config.getMcpServers() || {}; + const mcpServers = config.getMcpClientManager()?.getMcpServers() ?? {}; if (!serverName) { // List servers that support OAuth @@ -119,20 +119,20 @@ const authCommand: SlashCommand = { ); // Trigger tool re-discovery to pick up authenticated server - const toolRegistry = config.getToolRegistry(); - if (toolRegistry) { + const mcpClientManager = config.getMcpClientManager(); + if (mcpClientManager) { context.ui.addItem( { type: 'info', - text: `Re-discovering tools from '${serverName}'...`, + text: `Restarting MCP server '${serverName}'...`, }, Date.now(), ); - await toolRegistry.discoverToolsForServer(serverName); + await mcpClientManager.restartServer(serverName); } // Update the client with the new tools const geminiClient = config.getGeminiClient(); - if (geminiClient) { + if (geminiClient?.isInitialized()) { await geminiClient.setTools(); } @@ -158,7 +158,7 @@ const authCommand: SlashCommand = { const { config } = context.services; if (!config) return []; - const mcpServers = config.getMcpServers() || {}; + const mcpServers = config.getMcpClientManager()?.getMcpServers() || {}; return Object.keys(mcpServers).filter((name) => name.startsWith(partialArg), ); @@ -188,9 +188,10 @@ const listAction = async ( }; } - const mcpServers = config.getMcpServers() || {}; + const mcpServers = config.getMcpClientManager()?.getMcpServers() || {}; const serverNames = Object.keys(mcpServers); - const blockedMcpServers = config.getBlockedMcpServers() || []; + const blockedMcpServers = + config.getMcpClientManager()?.getBlockedMcpServers() || []; const connectingServers = serverNames.filter( (name) => getMCPServerStatus(name) === MCPServerStatus.CONNECTING, @@ -299,12 +300,12 @@ const refreshCommand: SlashCommand = { }; } - const toolRegistry = config.getToolRegistry(); - if (!toolRegistry) { + const mcpClientManager = config.getMcpClientManager(); + if (!mcpClientManager) { return { type: 'message', messageType: 'error', - content: 'Could not retrieve tool registry.', + content: 'Could not retrieve mcp client manager.', }; } @@ -316,11 +317,11 @@ const refreshCommand: SlashCommand = { Date.now(), ); - await toolRegistry.restartMcpServers(); + await mcpClientManager.restart(); // Update the client with the new tools const geminiClient = config.getGeminiClient(); - if (geminiClient) { + if (geminiClient?.isInitialized()) { await geminiClient.setTools(); } diff --git a/packages/cli/src/ui/components/Composer.test.tsx b/packages/cli/src/ui/components/Composer.test.tsx index 78b2654f2e..181ac31cc6 100644 --- a/packages/cli/src/ui/components/Composer.test.tsx +++ b/packages/cli/src/ui/components/Composer.test.tsx @@ -144,7 +144,10 @@ const createMockConfig = (overrides = {}) => ({ getDebugMode: vi.fn(() => false), getAccessibility: vi.fn(() => ({})), getMcpServers: vi.fn(() => ({})), - getBlockedMcpServers: vi.fn(() => []), + getMcpClientManager: vi.fn().mockImplementation(() => ({ + getBlockedMcpServers: vi.fn(), + getMcpServers: vi.fn(), + })), ...overrides, }); diff --git a/packages/cli/src/ui/components/Composer.tsx b/packages/cli/src/ui/components/Composer.tsx index 327d8ee3e8..eb8c628257 100644 --- a/packages/cli/src/ui/components/Composer.tsx +++ b/packages/cli/src/ui/components/Composer.tsx @@ -101,8 +101,10 @@ export const Composer = () => { ideContext={uiState.ideContextState} geminiMdFileCount={uiState.geminiMdFileCount} contextFileNames={contextFileNames} - mcpServers={config.getMcpServers()} - blockedMcpServers={config.getBlockedMcpServers()} + mcpServers={config.getMcpClientManager()?.getMcpServers() ?? {}} + blockedMcpServers={ + config.getMcpClientManager()?.getBlockedMcpServers() ?? [] + } /> ) )} diff --git a/packages/cli/src/ui/components/ConfigInitDisplay.tsx b/packages/cli/src/ui/components/ConfigInitDisplay.tsx index 6925acb6bb..8180bbe8de 100644 --- a/packages/cli/src/ui/components/ConfigInitDisplay.tsx +++ b/packages/cli/src/ui/components/ConfigInitDisplay.tsx @@ -5,15 +5,13 @@ */ import { useEffect, useState } from 'react'; -import { appEvents } from './../../utils/events.js'; +import { AppEvent, appEvents } from './../../utils/events.js'; import { Box, Text } from 'ink'; -import { useConfig } from '../contexts/ConfigContext.js'; import { type McpClient, MCPServerStatus } from '@google/gemini-cli-core'; import { GeminiSpinner } from './GeminiRespondingSpinner.js'; import { theme } from '../semantic-colors.js'; export const ConfigInitDisplay = () => { - const config = useConfig(); const [message, setMessage] = useState('Initializing...'); useEffect(() => { @@ -31,11 +29,11 @@ export const ConfigInitDisplay = () => { setMessage(`Connecting to MCP servers... (${connected}/${clients.size})`); }; - appEvents.on('mcp-client-update', onChange); + appEvents.on(AppEvent.McpClientUpdate, onChange); return () => { - appEvents.off('mcp-client-update', onChange); + appEvents.off(AppEvent.McpClientUpdate, onChange); }; - }, [config]); + }, []); return ( diff --git a/packages/cli/src/utils/events.ts b/packages/cli/src/utils/events.ts index fabd32828f..ac714fd8e6 100644 --- a/packages/cli/src/utils/events.ts +++ b/packages/cli/src/utils/events.ts @@ -4,6 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +import type { ExtensionEvents, McpClient } from '@google/gemini-cli-core'; import { EventEmitter } from 'node:events'; export enum AppEvent { @@ -11,6 +12,15 @@ export enum AppEvent { LogError = 'log-error', OauthDisplayMessage = 'oauth-display-message', Flicker = 'flicker', + McpClientUpdate = 'mcp-client-update', } -export const appEvents = new EventEmitter(); +export interface AppEvents extends ExtensionEvents { + [AppEvent.OpenDebugConsole]: never[]; + [AppEvent.LogError]: string[]; + [AppEvent.OauthDisplayMessage]: string[]; + [AppEvent.Flicker]: never[]; + [AppEvent.McpClientUpdate]: Array | never>; +} + +export const appEvents = new EventEmitter(); diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index eb1d3418d9..7e505b122e 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -154,6 +154,7 @@ import { type ExtensionLoader, SimpleExtensionLoader, } from '../utils/extensionLoader.js'; +import { McpClientManager } from '../tools/mcp-client-manager.js'; export type { FileFilteringOptions }; export { @@ -251,7 +252,8 @@ export interface ConfigParameters { extensionLoader?: ExtensionLoader; enabledExtensions?: string[]; enableExtensionReloading?: boolean; - blockedMcpServers?: Array<{ name: string; extensionName: string }>; + allowedMcpServers?: string[]; + blockedMcpServers?: string[]; noBrowser?: boolean; summarizeToolOutput?: Record; folderTrust?: boolean; @@ -293,6 +295,9 @@ export interface ConfigParameters { export class Config { private toolRegistry!: ToolRegistry; + private mcpClientManager?: McpClientManager; + private allowedMcpServers: string[]; + private blockedMcpServers: string[]; private promptRegistry!: PromptRegistry; private agentRegistry!: AgentRegistry; private readonly sessionId: string; @@ -347,10 +352,6 @@ export class Config { private readonly _extensionLoader: ExtensionLoader; private readonly _enabledExtensions: string[]; private readonly enableExtensionReloading: boolean; - private readonly _blockedMcpServers: Array<{ - name: string; - extensionName: string; - }>; fallbackModelHandler?: FallbackModelHandler; private quotaErrorOccurred: boolean = false; private readonly summarizeToolOutput: @@ -417,6 +418,8 @@ export class Config { this.toolCallCommand = params.toolCallCommand; this.mcpServerCommand = params.mcpServerCommand; this.mcpServers = params.mcpServers; + this.allowedMcpServers = params.allowedMcpServers ?? []; + this.blockedMcpServers = params.blockedMcpServers ?? []; this.userMemory = params.userMemory ?? ''; this.geminiMdFileCount = params.geminiMdFileCount ?? 0; this.geminiMdFilePaths = params.geminiMdFilePaths ?? []; @@ -458,7 +461,6 @@ export class Config { this._extensionLoader = params.extensionLoader ?? new SimpleExtensionLoader([]); this._enabledExtensions = params.enabledExtensions ?? []; - this._blockedMcpServers = params.blockedMcpServers ?? []; this.noBrowser = params.noBrowser ?? false; this.summarizeToolOutput = params.summarizeToolOutput; this.folderTrust = params.folderTrust ?? false; @@ -572,6 +574,15 @@ export class Config { await this.agentRegistry.initialize(); this.toolRegistry = await this.createToolRegistry(); + this.mcpClientManager = new McpClientManager( + this.toolRegistry, + this, + this.eventEmitter, + ); + await Promise.all([ + await this.mcpClientManager.startConfiguredMcpServers(), + await this.getExtensionLoader().start(this), + ]); await this.geminiClient.initialize(); } @@ -752,8 +763,23 @@ export class Config { return this.allowedTools; } + /** + * All the excluded tools from static configuration, loaded extensions, or + * other sources. + * + * May change over time. + */ getExcludeTools(): string[] | undefined { - return this.excludeTools; + const excludeToolsSet = new Set([...(this.excludeTools ?? [])]); + for (const extension of this.getExtensionLoader().getExtensions()) { + if (!extension.isActive) { + continue; + } + for (const tool of extension.excludeTools || []) { + excludeToolsSet.add(tool); + } + } + return [...excludeToolsSet]; } getToolDiscoveryCommand(): string | undefined { @@ -768,10 +794,27 @@ export class Config { return this.mcpServerCommand; } + /** + * The user configured MCP servers (via gemini settings files). + * + * Does NOT include mcp servers configured by extensions. + */ getMcpServers(): Record | undefined { return this.mcpServers; } + getMcpClientManager(): McpClientManager | undefined { + return this.mcpClientManager; + } + + getAllowedMcpServers(): string[] | undefined { + return this.allowedMcpServers; + } + + getBlockedMcpServers(): string[] | undefined { + return this.blockedMcpServers; + } + setMcpServers(mcpServers: Record): void { this.mcpServers = mcpServers; } @@ -955,10 +998,6 @@ export class Config { return this.enableExtensionReloading; } - getBlockedMcpServers(): Array<{ name: string; extensionName: string }> { - return this._blockedMcpServers; - } - getNoBrowser(): boolean { return this.noBrowser; } @@ -1155,7 +1194,7 @@ export class Config { } async createToolRegistry(): Promise { - const registry = new ToolRegistry(this, this.eventEmitter); + const registry = new ToolRegistry(this); // Set message bus on tool registry before discovery so MCP tools can access it if (this.getEnableMessageBusIntegration()) { @@ -1250,6 +1289,7 @@ export class Config { if (definition) { // We must respect the main allowed/exclude lists for agents too. const excludeTools = this.getExcludeTools() || []; + const allowedTools = this.getAllowedTools(); const isExcluded = excludeTools.includes(definition.name); diff --git a/packages/core/src/telemetry/loggers.test.ts b/packages/core/src/telemetry/loggers.test.ts index 74ebac5fe3..825666055e 100644 --- a/packages/core/src/telemetry/loggers.test.ts +++ b/packages/core/src/telemetry/loggers.test.ts @@ -192,11 +192,9 @@ describe('loggers', () => { getFileFilteringRespectGitIgnore: () => true, getFileFilteringAllowBuildArtifacts: () => false, getDebugMode: () => true, - getMcpServers: () => ({ - 'test-server': { - command: 'test-command', - }, - }), + getMcpServers: () => { + throw new Error('Should not call'); + }, getQuestion: () => 'test-question', getTargetDir: () => 'target-dir', getProxy: () => 'http://test.proxy.com:8080', @@ -206,6 +204,13 @@ describe('loggers', () => { { name: 'ext-one', id: 'id-one' }, { name: 'ext-two', id: 'id-two' }, ] as GeminiCLIExtension[], + getMcpClientManager: () => ({ + getMcpServers: () => ({ + 'test-server': { + command: 'test-command', + }, + }), + }), } as unknown as Config; const startSessionEvent = new StartSessionEvent(mockConfig); diff --git a/packages/core/src/telemetry/types.ts b/packages/core/src/telemetry/types.ts index 0ce8bb8b0e..7461689e2f 100644 --- a/packages/core/src/telemetry/types.ts +++ b/packages/core/src/telemetry/types.ts @@ -74,7 +74,8 @@ export class StartSessionEvent implements BaseTelemetryEvent { constructor(config: Config, toolRegistry?: ToolRegistry) { const generatorConfig = config.getContentGeneratorConfig(); - const mcpServers = config.getMcpServers(); + const mcpServers = + config.getMcpClientManager()?.getMcpServers() ?? config.getMcpServers(); let useGemini = false; let useVertex = false; diff --git a/packages/core/src/tools/mcp-client-manager.test.ts b/packages/core/src/tools/mcp-client-manager.test.ts index 6f160d1989..0cfddb61ff 100644 --- a/packages/core/src/tools/mcp-client-manager.test.ts +++ b/packages/core/src/tools/mcp-client-manager.test.ts @@ -4,86 +4,193 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { afterEach, describe, expect, it, vi } from 'vitest'; +import { + afterEach, + beforeEach, + describe, + expect, + it, + vi, + type MockedObject, +} from 'vitest'; import { McpClientManager } from './mcp-client-manager.js'; import { McpClient } from './mcp-client.js'; import type { ToolRegistry } from './tool-registry.js'; import type { Config } from '../config/config.js'; -import { SimpleExtensionLoader } from '../utils/extensionLoader.js'; vi.mock('./mcp-client.js', async () => { const originalModule = await vi.importActual('./mcp-client.js'); return { ...originalModule, McpClient: vi.fn(), - populateMcpServerCommand: vi.fn(() => ({ - 'test-server': {}, - })), }; }); describe('McpClientManager', () => { - afterEach(() => { - vi.restoreAllMocks(); - }); + let mockedMcpClient: MockedObject; + let mockConfig: MockedObject; - it('should discover tools from all servers', async () => { - const mockedMcpClient = { + beforeEach(() => { + mockedMcpClient = vi.mockObject({ connect: vi.fn(), discover: vi.fn(), disconnect: vi.fn(), getStatus: vi.fn(), - }; - vi.mocked(McpClient).mockReturnValue( - mockedMcpClient as unknown as McpClient, - ); - const manager = new McpClientManager( - {} as ToolRegistry, - { - isTrustedFolder: () => true, - getExtensionLoader: () => new SimpleExtensionLoader([]), - getMcpServers: () => ({ - 'test-server': {}, - }), - getMcpServerCommand: () => '', - getPromptRegistry: () => {}, - getDebugMode: () => false, - getWorkspaceContext: () => {}, - getEnableExtensionReloading: () => false, - } as unknown as Config, - ); - await manager.discoverAllMcpTools(); + getServerConfig: vi.fn(), + } as unknown as McpClient); + vi.mocked(McpClient).mockReturnValue(mockedMcpClient); + mockConfig = vi.mockObject({ + isTrustedFolder: vi.fn().mockReturnValue(true), + getMcpServers: vi.fn().mockReturnValue({}), + getPromptRegistry: () => {}, + getDebugMode: () => false, + getWorkspaceContext: () => {}, + getAllowedMcpServers: vi.fn().mockReturnValue([]), + getBlockedMcpServers: vi.fn().mockReturnValue([]), + getMcpServerCommand: vi.fn().mockReturnValue(''), + getGeminiClient: vi.fn().mockReturnValue({ + isInitialized: vi.fn(), + }), + } as unknown as Config); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + it('should discover tools from all configured', async () => { + mockConfig.getMcpServers.mockReturnValue({ + 'test-server': {}, + }); + const manager = new McpClientManager({} as ToolRegistry, mockConfig); + await manager.startConfiguredMcpServers(); expect(mockedMcpClient.connect).toHaveBeenCalledOnce(); expect(mockedMcpClient.discover).toHaveBeenCalledOnce(); }); it('should not discover tools if folder is not trusted', async () => { - const mockedMcpClient = { - connect: vi.fn(), - discover: vi.fn(), - disconnect: vi.fn(), - getStatus: vi.fn(), - }; - vi.mocked(McpClient).mockReturnValue( - mockedMcpClient as unknown as McpClient, - ); - const manager = new McpClientManager( - {} as ToolRegistry, - { - isTrustedFolder: () => false, - getExtensionLoader: () => new SimpleExtensionLoader([]), - getMcpServers: () => ({ - 'test-server': {}, - }), - getMcpServerCommand: () => '', - getPromptRegistry: () => {}, - getDebugMode: () => false, - getWorkspaceContext: () => {}, - getEnableExtensionReloading: () => false, - } as unknown as Config, - ); - await manager.discoverAllMcpTools(); + mockConfig.getMcpServers.mockReturnValue({ + 'test-server': {}, + }); + mockConfig.isTrustedFolder.mockReturnValue(false); + const manager = new McpClientManager({} as ToolRegistry, mockConfig); + await manager.startConfiguredMcpServers(); expect(mockedMcpClient.connect).not.toHaveBeenCalled(); expect(mockedMcpClient.discover).not.toHaveBeenCalled(); }); + + it('should not start blocked servers', async () => { + mockConfig.getMcpServers.mockReturnValue({ + 'test-server': {}, + }); + mockConfig.getBlockedMcpServers.mockReturnValue(['test-server']); + const manager = new McpClientManager({} as ToolRegistry, mockConfig); + await manager.startConfiguredMcpServers(); + expect(mockedMcpClient.connect).not.toHaveBeenCalled(); + expect(mockedMcpClient.discover).not.toHaveBeenCalled(); + }); + + it('should only start allowed servers if allow list is not empty', async () => { + mockConfig.getMcpServers.mockReturnValue({ + 'test-server': {}, + 'another-server': {}, + }); + mockConfig.getAllowedMcpServers.mockReturnValue(['another-server']); + const manager = new McpClientManager({} as ToolRegistry, mockConfig); + await manager.startConfiguredMcpServers(); + expect(mockedMcpClient.connect).toHaveBeenCalledOnce(); + expect(mockedMcpClient.discover).toHaveBeenCalledOnce(); + }); + + it('should start servers from extensions', async () => { + const manager = new McpClientManager({} as ToolRegistry, mockConfig); + await manager.startExtension({ + name: 'test-extension', + mcpServers: { + 'test-server': {}, + }, + isActive: true, + version: '1.0.0', + path: '/some-path', + contextFiles: [], + id: '123', + }); + expect(mockedMcpClient.connect).toHaveBeenCalledOnce(); + expect(mockedMcpClient.discover).toHaveBeenCalledOnce(); + }); + + it('should not start servers from disabled extensions', async () => { + const manager = new McpClientManager({} as ToolRegistry, mockConfig); + await manager.startExtension({ + name: 'test-extension', + mcpServers: { + 'test-server': {}, + }, + isActive: false, + version: '1.0.0', + path: '/some-path', + contextFiles: [], + id: '123', + }); + expect(mockedMcpClient.connect).not.toHaveBeenCalled(); + expect(mockedMcpClient.discover).not.toHaveBeenCalled(); + }); + + it('should add blocked servers to the blockedMcpServers list', async () => { + mockConfig.getMcpServers.mockReturnValue({ + 'test-server': {}, + }); + mockConfig.getBlockedMcpServers.mockReturnValue(['test-server']); + const manager = new McpClientManager({} as ToolRegistry, mockConfig); + await manager.startConfiguredMcpServers(); + expect(manager.getBlockedMcpServers()).toEqual([ + { name: 'test-server', extensionName: '' }, + ]); + }); + + describe('restart', () => { + it('should restart all running servers', async () => { + mockConfig.getMcpServers.mockReturnValue({ + 'test-server': {}, + }); + mockedMcpClient.getServerConfig.mockReturnValue({}); + const manager = new McpClientManager({} as ToolRegistry, mockConfig); + await manager.startConfiguredMcpServers(); + + expect(mockedMcpClient.connect).toHaveBeenCalledTimes(1); + expect(mockedMcpClient.discover).toHaveBeenCalledTimes(1); + await manager.restart(); + + expect(mockedMcpClient.disconnect).toHaveBeenCalledTimes(1); + expect(mockedMcpClient.connect).toHaveBeenCalledTimes(2); + expect(mockedMcpClient.discover).toHaveBeenCalledTimes(2); + }); + }); + + describe('restartServer', () => { + it('should restart the specified server', async () => { + mockConfig.getMcpServers.mockReturnValue({ + 'test-server': {}, + }); + mockedMcpClient.getServerConfig.mockReturnValue({}); + const manager = new McpClientManager({} as ToolRegistry, mockConfig); + await manager.startConfiguredMcpServers(); + + expect(mockedMcpClient.connect).toHaveBeenCalledTimes(1); + expect(mockedMcpClient.discover).toHaveBeenCalledTimes(1); + + await manager.restartServer('test-server'); + + expect(mockedMcpClient.disconnect).toHaveBeenCalledTimes(1); + expect(mockedMcpClient.connect).toHaveBeenCalledTimes(2); + expect(mockedMcpClient.discover).toHaveBeenCalledTimes(2); + }); + + it('should throw an error if the server does not exist', async () => { + const manager = new McpClientManager({} as ToolRegistry, mockConfig); + await expect(manager.restartServer('non-existent')).rejects.toThrow( + 'No MCP server registered with the name "non-existent"', + ); + }); + }); }); diff --git a/packages/core/src/tools/mcp-client-manager.ts b/packages/core/src/tools/mcp-client-manager.ts index 0b8f2180d2..aaf5fedac6 100644 --- a/packages/core/src/tools/mcp-client-manager.ts +++ b/packages/core/src/tools/mcp-client-manager.ts @@ -33,6 +33,10 @@ export class McpClientManager { private discoveryPromise: Promise | undefined; private discoveryState: MCPDiscoveryState = MCPDiscoveryState.NOT_STARTED; private readonly eventEmitter?: EventEmitter; + private readonly blockedMcpServers: Array<{ + name: string; + extensionName: string; + }> = []; constructor( toolRegistry: ToolRegistry, @@ -42,19 +46,10 @@ export class McpClientManager { this.toolRegistry = toolRegistry; this.cliConfig = cliConfig; this.eventEmitter = eventEmitter; - if (this.cliConfig.getEnableExtensionReloading()) { - this.cliConfig - .getExtensionLoader() - .extensionEvents() - .on('extensionLoaded', (event) => this.loadExtension(event.extension)) - .on('extensionEnabled', (event) => this.loadExtension(event.extension)) - .on('extensionDisabled', (event) => - this.unloadExtension(event.extension), - ) - .on('extensionUnloaded', (event) => - this.unloadExtension(event.extension), - ); - } + } + + getBlockedMcpServers() { + return this.blockedMcpServers; } /** @@ -64,21 +59,13 @@ export class McpClientManager { * - Disconnects all MCP clients from their servers. * - Updates the Gemini chat configuration to load the new tools. */ - private async unloadExtension(extension: GeminiCLIExtension) { + async stopExtension(extension: GeminiCLIExtension) { debugLogger.log(`Unloading extension: ${extension.name}`); await Promise.all( - Object.keys(extension.mcpServers ?? {}).map((name) => { - const newMcpServers = { - ...this.cliConfig.getMcpServers(), - }; - delete newMcpServers[name]; - this.cliConfig.setMcpServers(newMcpServers); - return this.disconnectClient(name); - }), + Object.keys(extension.mcpServers ?? {}).map( + this.disconnectClient.bind(this), + ), ); - // This is required to update the content generator configuration with the - // new tool configuration. - this.cliConfig.getGeminiClient().setTools(); } /** @@ -88,20 +75,36 @@ export class McpClientManager { * - Connects MCP clients to each server and discovers their tools. * - Updates the Gemini chat configuration to load the new tools. */ - private async loadExtension(extension: GeminiCLIExtension) { + async startExtension(extension: GeminiCLIExtension) { debugLogger.log(`Loading extension: ${extension.name}`); await Promise.all( - Object.entries(extension.mcpServers ?? {}).map(([name, config]) => { - this.cliConfig.setMcpServers({ - ...this.cliConfig.getMcpServers(), - [name]: config, - }); - return this.discoverMcpTools(name, config); - }), + Object.entries(extension.mcpServers ?? {}).map(([name, config]) => + this.maybeDiscoverMcpServer(name, { + ...config, + extension, + }), + ), ); - // This is required to update the content generator configuration with the - // new tool configuration. - this.cliConfig.getGeminiClient().setTools(); + } + + private isAllowedMcpServer(name: string) { + const allowedNames = this.cliConfig.getAllowedMcpServers(); + if ( + allowedNames && + allowedNames.length > 0 && + allowedNames.indexOf(name) === -1 + ) { + return false; + } + const blockedNames = this.cliConfig.getBlockedMcpServers(); + if ( + blockedNames && + blockedNames.length > 0 && + blockedNames.indexOf(name) !== -1 + ) { + return false; + } + return true; } private async disconnectClient(name: string) { @@ -115,36 +118,68 @@ export class McpClientManager { debugLogger.warn( `Error stopping client '${name}': ${getErrorMessage(error)}`, ); + } finally { + // This is required to update the content generator configuration with the + // new tool configuration. + const geminiClient = this.cliConfig.getGeminiClient(); + if (geminiClient.isInitialized()) { + await geminiClient.setTools(); + } } } } - discoverMcpTools( + maybeDiscoverMcpServer( name: string, config: MCPServerConfig, ): Promise | void { + if (!this.isAllowedMcpServer(name)) { + if (!this.blockedMcpServers.find((s) => s.name === name)) { + this.blockedMcpServers?.push({ + name, + extensionName: config.extension?.name ?? '', + }); + } + return; + } if (!this.cliConfig.isTrustedFolder()) { return; } if (config.extension && !config.extension.isActive) { return; } + const existing = this.clients.get(name); + if (existing && existing.getServerConfig().extension !== config.extension) { + const extensionText = config.extension + ? ` from extension "${config.extension.name}"` + : ''; + debugLogger.warn( + `Skipping MCP config for server with name "${name}"${extensionText} as it already exists.`, + ); + return; + } const currentDiscoveryPromise = new Promise((resolve, _reject) => { (async () => { try { - await this.disconnectClient(name); + if (existing) { + await existing.disconnect(); + } - const client = new McpClient( - name, - config, - this.toolRegistry, - this.cliConfig.getPromptRegistry(), - this.cliConfig.getWorkspaceContext(), - this.cliConfig.getDebugMode(), - ); - this.clients.set(name, client); - this.eventEmitter?.emit('mcp-client-update', this.clients); + const client = + existing ?? + new McpClient( + name, + config, + this.toolRegistry, + this.cliConfig.getPromptRegistry(), + this.cliConfig.getWorkspaceContext(), + this.cliConfig.getDebugMode(), + ); + if (!existing) { + this.clients.set(name, client); + this.eventEmitter?.emit('mcp-client-update', this.clients); + } try { await client.connect(); await client.discover(this.cliConfig); @@ -161,6 +196,12 @@ export class McpClientManager { ); } } finally { + // This is required to update the content generator configuration with the + // new tool configuration. + const geminiClient = this.cliConfig.getGeminiClient(); + if (geminiClient.isInitialized()) { + await geminiClient.setTools(); + } resolve(); } })(); @@ -174,6 +215,7 @@ export class McpClientManager { this.discoveryState = MCPDiscoveryState.IN_PROGRESS; this.discoveryPromise = currentDiscoveryPromise; } + this.eventEmitter?.emit('mcp-client-update', this.clients); const currentPromise = this.discoveryPromise; currentPromise.then((_) => { // If we are the last recorded discoveryPromise, then we are done, reset @@ -187,15 +229,21 @@ export class McpClientManager { } /** - * Initiates the tool discovery process for all configured MCP servers. + * Initiates the tool discovery process for all configured MCP servers (via + * gemini settings or command line arguments). + * * It connects to each server, discovers its available tools, and registers * them with the `ToolRegistry`. + * + * For any server which is already connected, it will first be disconnected. + * + * This does NOT load extension MCP servers - this happens when the + * ExtensionLoader explicitly calls `loadExtension`. */ - async discoverAllMcpTools(): Promise { + async startConfiguredMcpServers(): Promise { if (!this.cliConfig.isTrustedFolder()) { return; } - await this.stop(); const servers = populateMcpServerCommand( this.cliConfig.getMcpServers() || {}, @@ -204,12 +252,40 @@ export class McpClientManager { this.eventEmitter?.emit('mcp-client-update', this.clients); await Promise.all( - Object.entries(servers).map(async ([name, config]) => - this.discoverMcpTools(name, config), + Object.entries(servers).map(([name, config]) => + this.maybeDiscoverMcpServer(name, config), ), ); } + /** + * Restarts all active MCP Clients. + */ + async restart(): Promise { + await Promise.all( + Array.from(this.clients.entries()).map(async ([name, client]) => { + try { + await this.maybeDiscoverMcpServer(name, client.getServerConfig()); + } catch (error) { + debugLogger.error( + `Error restarting client '${name}': ${getErrorMessage(error)}`, + ); + } + }), + ); + } + + /** + * Restart a single MCP server by name. + */ + async restartServer(name: string) { + const client = this.clients.get(name); + if (!client) { + throw new Error(`No MCP server registered with the name "${name}"`); + } + await this.maybeDiscoverMcpServer(name, client.getServerConfig()); + } + /** * Stops all running local MCP servers and closes all client connections. * This is the cleanup method to be called on application exit. @@ -236,4 +312,15 @@ export class McpClientManager { getDiscoveryState(): MCPDiscoveryState { return this.discoveryState; } + + /** + * All of the MCP server configurations currently loaded. + */ + getMcpServers(): Record { + const mcpServers: Record = {}; + for (const [name, client] of this.clients.entries()) { + mcpServers[name] = client.getServerConfig(); + } + return mcpServers; + } } diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index 5ff64eb0a8..7811888eec 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -303,6 +303,71 @@ describe('mcp-client', () => { expect(mockedMcpToTool).toHaveBeenCalledOnce(); expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce(); }); + + it('should remove tools and prompts on disconnect', async () => { + const mockedClient = { + connect: vi.fn(), + close: vi.fn(), + getStatus: vi.fn(), + registerCapabilities: vi.fn(), + setRequestHandler: vi.fn(), + getServerCapabilities: vi + .fn() + .mockReturnValue({ tools: {}, prompts: {} }), + request: vi.fn().mockResolvedValue({ + prompts: [{ id: 'prompt1', text: 'a prompt' }], + }), + }; + vi.mocked(ClientLib.Client).mockReturnValue( + mockedClient as unknown as ClientLib.Client, + ); + vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( + {} as SdkClientStdioLib.StdioClientTransport, + ); + vi.mocked(GenAiLib.mcpToTool).mockReturnValue({ + tool: () => + Promise.resolve({ + functionDeclarations: [ + { + name: 'testTool', + description: 'A test tool', + }, + ], + }), + } as unknown as GenAiLib.CallableTool); + const mockedToolRegistry = { + registerTool: vi.fn(), + unregisterTool: vi.fn(), + getMessageBus: vi.fn().mockReturnValue(undefined), + removeMcpToolsByServer: vi.fn(), + } as unknown as ToolRegistry; + const mockedPromptRegistry = { + registerPrompt: vi.fn(), + unregisterPrompt: vi.fn(), + removePromptsByServer: vi.fn(), + } as unknown as PromptRegistry; + const client = new McpClient( + 'test-server', + { + command: 'test-command', + }, + mockedToolRegistry, + mockedPromptRegistry, + workspaceContext, + false, + ); + await client.connect(); + await client.discover({} as Config); + + expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce(); + expect(mockedPromptRegistry.registerPrompt).toHaveBeenCalledOnce(); + + await client.disconnect(); + + expect(mockedClient.close).toHaveBeenCalledOnce(); + expect(mockedToolRegistry.removeMcpToolsByServer).toHaveBeenCalledOnce(); + expect(mockedPromptRegistry.removePromptsByServer).toHaveBeenCalledOnce(); + }); }); describe('appendMcpServerCommand', () => { it('should do nothing if no MCP servers or command are configured', () => { diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index f789dd3ee1..c5b1dc6caa 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -161,6 +161,7 @@ export class McpClient { return; } this.toolRegistry.removeMcpToolsByServer(this.serverName); + this.promptRegistry.removePromptsByServer(this.serverName); this.updateStatus(MCPServerStatus.DISCONNECTING); const client = this.client; this.client = undefined; @@ -208,6 +209,10 @@ export class McpClient { this.assertConnected(); return discoverPrompts(this.serverName, this.client!, this.promptRegistry); } + + getServerConfig(): MCPServerConfig { + return this.serverConfig; + } } /** diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts index afa0a77179..f002250910 100644 --- a/packages/core/src/tools/tool-registry.test.ts +++ b/packages/core/src/tools/tool-registry.test.ts @@ -19,20 +19,10 @@ import { spawn } from 'node:child_process'; import fs from 'node:fs'; import { MockTool } from '../test-utils/mock-tool.js'; - -import { McpClientManager } from './mcp-client-manager.js'; import { ToolErrorType } from './tool-error.js'; vi.mock('node:fs'); -// Mock ./mcp-client.js to control its behavior within tool-registry tests -vi.mock('./mcp-client.js', async () => { - const originalModule = await vi.importActual('./mcp-client.js'); - return { - ...originalModule, - }; -}); - // Mock node:child_process vi.mock('node:child_process', async () => { const actual = await vi.importActual('node:child_process'); @@ -401,27 +391,6 @@ describe('ToolRegistry', () => { expect(result.llmContent).toContain('Stderr: Something went wrong'); expect(result.llmContent).toContain('Exit Code: 1'); }); - - it('should discover tools using MCP servers defined in getMcpServers', async () => { - const discoverSpy = vi.spyOn( - McpClientManager.prototype, - 'discoverAllMcpTools', - ); - mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined); - vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined); - const mcpServerConfigVal = { - 'my-mcp-server': { - command: 'mcp-server-cmd', - args: ['--port', '1234'], - trust: true, - }, - }; - vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal); - - await toolRegistry.discoverAllTools(); - - expect(discoverSpy).toHaveBeenCalled(); - }); }); describe('DiscoveredToolInvocation', () => { diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts index c71fba5ab7..abb03d5329 100644 --- a/packages/core/src/tools/tool-registry.ts +++ b/packages/core/src/tools/tool-registry.ts @@ -14,13 +14,10 @@ import { Kind, BaseDeclarativeTool, BaseToolInvocation } from './tools.js'; import type { Config } from '../config/config.js'; import { spawn } from 'node:child_process'; import { StringDecoder } from 'node:string_decoder'; -import { connectAndDiscover } from './mcp-client.js'; -import { McpClientManager } from './mcp-client-manager.js'; import { DiscoveredMCPTool } from './mcp-tool.js'; import { parse } from 'shell-quote'; import { ToolErrorType } from './tool-error.js'; import { safeJsonStringify } from '../utils/safeJsonStringify.js'; -import type { EventEmitter } from 'node:events'; import type { MessageBus } from '../confirmation-bus/message-bus.js'; import { debugLogger } from '../utils/debugLogger.js'; import { coreEvents } from '../utils/events.js'; @@ -176,12 +173,10 @@ export class ToolRegistry { // The tools keyed by tool name as seen by the LLM. private tools: Map = new Map(); private config: Config; - private mcpClientManager: McpClientManager; private messageBus?: MessageBus; - constructor(config: Config, eventEmitter?: EventEmitter) { + constructor(config: Config) { this.config = config; - this.mcpClientManager = new McpClientManager(this, config, eventEmitter); } setMessageBus(messageBus: MessageBus): void { @@ -238,64 +233,7 @@ export class ToolRegistry { async discoverAllTools(): Promise { // remove any previously discovered tools this.removeDiscoveredTools(); - - this.config.getPromptRegistry().clear(); - await this.discoverAndRegisterToolsFromCommand(); - - // discover tools using MCP servers, if configured - await this.mcpClientManager.discoverAllMcpTools(); - } - - /** - * Discovers tools from project (if available and configured). - * Can be called multiple times to update discovered tools. - * This will NOT discover tools from the command line, only from MCP servers. - */ - async discoverMcpTools(): Promise { - // remove any previously discovered tools - this.removeDiscoveredTools(); - - this.config.getPromptRegistry().clear(); - - // discover tools using MCP servers, if configured - await this.mcpClientManager.discoverAllMcpTools(); - } - - /** - * Restarts all MCP servers and re-discovers tools. - */ - async restartMcpServers(): Promise { - await this.discoverMcpTools(); - } - - /** - * Discover or re-discover tools for a single MCP server. - * @param serverName - The name of the server to discover tools from. - */ - async discoverToolsForServer(serverName: string): Promise { - // Remove any previously discovered tools from this server - for (const [name, tool] of this.tools.entries()) { - if (tool instanceof DiscoveredMCPTool && tool.serverName === serverName) { - this.tools.delete(name); - } - } - - this.config.getPromptRegistry().removePromptsByServer(serverName); - - const mcpServers = this.config.getMcpServers() ?? {}; - const serverConfig = mcpServers[serverName]; - if (serverConfig) { - await connectAndDiscover( - serverName, - serverConfig, - this, - this.config.getPromptRegistry(), - this.config.getDebugMode(), - this.config.getWorkspaceContext(), - this.config, - ); - } } private async discoverAndRegisterToolsFromCommand(): Promise { diff --git a/packages/core/src/utils/extensionLoader.test.ts b/packages/core/src/utils/extensionLoader.test.ts new file mode 100644 index 0000000000..3329dc36f7 --- /dev/null +++ b/packages/core/src/utils/extensionLoader.test.ts @@ -0,0 +1,108 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, expect, it, vi, beforeEach, afterEach } from 'vitest'; +import { SimpleExtensionLoader } from './extensionLoader.js'; +import type { Config } from '../config/config.js'; +import { type McpClientManager } from '../tools/mcp-client-manager.js'; + +describe('SimpleExtensionLoader', () => { + let mockConfig: Config; + let extensionReloadingEnabled: boolean; + let mockMcpClientManager: McpClientManager; + const activeExtension = { + name: 'test-extension', + isActive: true, + version: '1.0.0', + path: '/path/to/extension', + contextFiles: [], + id: '123', + }; + const inactiveExtension = { + name: 'test-extension', + isActive: false, + version: '1.0.0', + path: '/path/to/extension', + contextFiles: [], + id: '123', + }; + + beforeEach(() => { + mockMcpClientManager = { + startExtension: vi.fn(), + stopExtension: vi.fn(), + } as unknown as McpClientManager; + extensionReloadingEnabled = false; + mockConfig = { + getMcpClientManager: () => mockMcpClientManager, + getEnableExtensionReloading: () => extensionReloadingEnabled, + } as unknown as Config; + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + it('should start active extensions', async () => { + const loader = new SimpleExtensionLoader([activeExtension]); + await loader.start(mockConfig); + expect(mockMcpClientManager.startExtension).toHaveBeenCalledExactlyOnceWith( + activeExtension, + ); + }); + + it('should not start inactive extensions', async () => { + const loader = new SimpleExtensionLoader([inactiveExtension]); + await loader.start(mockConfig); + expect(mockMcpClientManager.startExtension).not.toHaveBeenCalled(); + }); + + describe('interactive extension loading and unloading', () => { + it('should not call `start` or `stop` if the loader is not already started', async () => { + const loader = new SimpleExtensionLoader([]); + await loader.loadExtension(activeExtension); + expect(mockMcpClientManager.startExtension).not.toHaveBeenCalled(); + await loader.unloadExtension(activeExtension); + expect(mockMcpClientManager.stopExtension).not.toHaveBeenCalled(); + }); + + it('should start extensions that were explicitly loaded prior to initializing the loader', async () => { + const loader = new SimpleExtensionLoader([]); + await loader.loadExtension(activeExtension); + expect(mockMcpClientManager.startExtension).not.toHaveBeenCalled(); + await loader.start(mockConfig); + expect( + mockMcpClientManager.startExtension, + ).toHaveBeenCalledExactlyOnceWith(activeExtension); + }); + + it.each([true, false])( + 'should only call `start` and `stop` if extension reloading is enabled ($i)', + async (reloadingEnabled) => { + extensionReloadingEnabled = reloadingEnabled; + const loader = new SimpleExtensionLoader([]); + await loader.start(mockConfig); + expect(mockMcpClientManager.startExtension).not.toHaveBeenCalled(); + await loader.loadExtension(activeExtension); + if (reloadingEnabled) { + expect( + mockMcpClientManager.startExtension, + ).toHaveBeenCalledExactlyOnceWith(activeExtension); + } else { + expect(mockMcpClientManager.startExtension).not.toHaveBeenCalled(); + } + await loader.unloadExtension(activeExtension); + if (reloadingEnabled) { + expect( + mockMcpClientManager.stopExtension, + ).toHaveBeenCalledExactlyOnceWith(activeExtension); + } else { + expect(mockMcpClientManager.stopExtension).not.toHaveBeenCalled(); + } + }, + ); + }); +}); diff --git a/packages/core/src/utils/extensionLoader.ts b/packages/core/src/utils/extensionLoader.ts index d42fcf6084..b65f227143 100644 --- a/packages/core/src/utils/extensionLoader.ts +++ b/packages/core/src/utils/extensionLoader.ts @@ -4,45 +4,194 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { EventEmitter } from 'node:events'; -import type { GeminiCLIExtension } from '../config/config.js'; +import type { EventEmitter } from 'node:events'; +import type { Config, GeminiCLIExtension } from '../config/config.js'; -export interface ExtensionLoader { - getExtensions(): GeminiCLIExtension[]; +export abstract class ExtensionLoader { + // Assigned in `start`. + protected config: Config | undefined; - extensionEvents(): EventEmitter; + // Used to track the count of currently starting and stopping extensions and + // fire appropriate events. + protected startingCount: number = 0; + protected startCompletedCount: number = 0; + protected stoppingCount: number = 0; + protected stopCompletedCount: number = 0; + + constructor(private readonly eventEmitter?: EventEmitter) {} + + /** + * All currently known extensions, both active and inactive. + */ + abstract getExtensions(): GeminiCLIExtension[]; + + /** + * Fully initializes all active extensions. + * + * Called within `Config.initialize`, which must already have an + * McpClientManager, PromptRegistry, and GeminiChat set up. + */ + async start(config: Config): Promise { + if (!this.config) { + this.config = config; + } else { + throw new Error('Already started, you may only call `start` once.'); + } + await Promise.all( + this.getExtensions() + .filter((e) => e.isActive) + .map(this.startExtension.bind(this)), + ); + } + + /** + * Unconditionally starts an `extension` and loads all its MCP servers, + * context, custom commands, etc. Assumes that `start` has already been called + * and we have a Config object. + * + * This should typically only be called from `start`, most other calls should + * go through `maybeStartExtension` which will only start the extension if + * extension reloading is enabled and the `config` object is initialized. + */ + protected async startExtension(extension: GeminiCLIExtension) { + if (!this.config) { + throw new Error('Cannot call `startExtension` prior to calling `start`.'); + } + this.startingCount++; + this.eventEmitter?.emit('extensionsStarting', { + total: this.startingCount, + completed: this.startCompletedCount, + }); + try { + await this.config.getMcpClientManager()!.startExtension(extension); + // TODO: Move all extension features here, including at least: + // - context file loading + // - custom command loading + // - excluded tool configuration + } finally { + this.startCompletedCount++; + this.eventEmitter?.emit('extensionsStarting', { + total: this.startingCount, + completed: this.startCompletedCount, + }); + if (this.startingCount === this.startCompletedCount) { + this.startingCount = 0; + this.startCompletedCount = 0; + } + } + } + + /** + * If extension reloading is enabled and `start` has already been called, + * then calls `startExtension` to include all extension features into the + * program. + */ + protected maybeStartExtension( + extension: GeminiCLIExtension, + ): Promise | undefined { + if (this.config && this.config.getEnableExtensionReloading()) { + return this.startExtension(extension); + } + return; + } + + /** + * Unconditionally stops an `extension` and unloads all its MCP servers, + * context, custom commands, etc. Assumes that `start` has already been called + * and we have a Config object. + * + * Most calls should go through `maybeStopExtension` which will only stop the + * extension if extension reloading is enabled and the `config` object is + * initialized. + */ + protected async stopExtension(extension: GeminiCLIExtension) { + if (!this.config) { + throw new Error('Cannot call `stopExtension` prior to calling `start`.'); + } + this.stoppingCount++; + this.eventEmitter?.emit('extensionsStopping', { + total: this.stoppingCount, + completed: this.stopCompletedCount, + }); + + try { + await this.config.getMcpClientManager()!.stopExtension(extension); + // TODO: Remove all extension features here, including at least: + // - context files + // - custom commands + // - excluded tools + } finally { + this.stopCompletedCount++; + this.eventEmitter?.emit('extensionsStopping', { + total: this.stoppingCount, + completed: this.stopCompletedCount, + }); + if (this.stoppingCount === this.stopCompletedCount) { + this.stoppingCount = 0; + this.stopCompletedCount = 0; + } + } + } + + /** + * If extension reloading is enabled and `start` has already been called, + * then this also performs all necessary steps to remove all extension + * features from the rest of the system. + */ + protected maybeStopExtension( + extension: GeminiCLIExtension, + ): Promise | undefined { + if (this.config && this.config.getEnableExtensionReloading()) { + return this.stopExtension(extension); + } + return; + } } export interface ExtensionEvents { - extensionEnabled: ExtensionEnableEvent[]; - extensionDisabled: ExtensionDisableEvent[]; - extensionLoaded: ExtensionLoadEvent[]; - extensionUnloaded: ExtensionUnloadEvent[]; - extensionInstalled: ExtensionInstallEvent[]; - extensionUninstalled: ExtensionUninstallEvent[]; - extensionUpdated: ExtensionUpdateEvent[]; + extensionsStarting: ExtensionsStartingEvent[]; + extensionsStopping: ExtensionsStoppingEvent[]; } -interface BaseExtensionEvent { - extension: GeminiCLIExtension; +export interface ExtensionsStartingEvent { + total: number; + completed: number; } -export type ExtensionDisableEvent = BaseExtensionEvent; -export type ExtensionEnableEvent = BaseExtensionEvent; -export type ExtensionInstallEvent = BaseExtensionEvent; -export type ExtensionLoadEvent = BaseExtensionEvent; -export type ExtensionUnloadEvent = BaseExtensionEvent; -export type ExtensionUninstallEvent = BaseExtensionEvent; -export type ExtensionUpdateEvent = BaseExtensionEvent; -export class SimpleExtensionLoader implements ExtensionLoader { - private _eventEmitter = new EventEmitter(); - constructor(private readonly extensions: GeminiCLIExtension[]) {} +export interface ExtensionsStoppingEvent { + total: number; + completed: number; +} - extensionEvents(): EventEmitter { - return this._eventEmitter; +export class SimpleExtensionLoader extends ExtensionLoader { + constructor( + protected readonly extensions: GeminiCLIExtension[], + eventEmitter?: EventEmitter, + ) { + super(eventEmitter); } getExtensions(): GeminiCLIExtension[] { return this.extensions; } + + /// Adds `extension` to the list of extensions and calls + /// `maybeStartExtension`. + /// + /// This is intended for dynamic loading of extensions after calling `start`. + async loadExtension(extension: GeminiCLIExtension) { + this.extensions.push(extension); + await this.maybeStartExtension(extension); + } + + /// Removes `extension` from the list of extensions and calls + // `maybeStopExtension` if it was found. + /// + /// This is intended for dynamic unloading of extensions after calling `start`. + async unloadExtension(extension: GeminiCLIExtension) { + const index = this.extensions.indexOf(extension); + if (index === -1) return; + this.extensions.splice(index, 1); + await this.maybeStopExtension(extension); + } }